浏览代码

Fix imports in finetuning scripts

Ruslan Mavlyutov 2 年之前
父节点
当前提交
16bc4a5d88
共有 3 个文件被更改,包括 4 次插入6 次删除
  1. 1 3
      scripts/m4t/finetune/finetune.py
  2. 3 3
      scripts/m4t/finetune/trainer.py
  3. 0 0
      scripts/m4t/predict/__init__.py

+ 1 - 3
scripts/m4t/finetune/finetune.py

@@ -9,11 +9,9 @@ import logging
 import os
 from pathlib import Path
 
-import dataloader
-import dist_utils
 import torch
-import trainer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from m4t_scripts.finetune import dataloader, dist_utils, trainer
 
 from seamless_communication.models.unity import (
     UnitTokenizer,

+ 3 - 3
scripts/m4t/finetune/trainer.py

@@ -12,17 +12,17 @@ from enum import Enum
 from pathlib import Path
 from typing import Optional, Tuple
 
-import dataloader
-import dist_utils
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from fairseq2.models.sequence import SequenceModelOutput
-from fairseq2.models.unity import UnitYModel
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
+from m4t_scripts.finetune import dataloader, dist_utils
 from torch.optim import Adam
 
+from seamless_communication.models.unity import UnitYModel
+
 logger = logging.getLogger(__name__)
 
 

+ 0 - 0
scripts/m4t/predict/__init__.py