소스 검색

Fix imports in finetuning scripts

Ruslan Mavlyutov 2 년 전
부모
커밋
16bc4a5d88
3개의 변경된 파일4개의 추가작업 그리고 6개의 파일을 삭제
  1. 1 3
      scripts/m4t/finetune/finetune.py
  2. 3 3
      scripts/m4t/finetune/trainer.py
  3. 0 0
      scripts/m4t/predict/__init__.py

+ 1 - 3
scripts/m4t/finetune/finetune.py

@@ -9,11 +9,9 @@ import logging
 import os
 import os
 from pathlib import Path
 from pathlib import Path
 
 
-import dataloader
-import dist_utils
 import torch
 import torch
-import trainer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from m4t_scripts.finetune import dataloader, dist_utils, trainer
 
 
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,

+ 3 - 3
scripts/m4t/finetune/trainer.py

@@ -12,17 +12,17 @@ from enum import Enum
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
-import dataloader
-import dist_utils
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 import torch.nn as nn
 import torch.nn as nn
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.sequence import SequenceModelOutput
-from fairseq2.models.unity import UnitYModel
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
 from fairseq2.typing import Device
+from m4t_scripts.finetune import dataloader, dist_utils
 from torch.optim import Adam
 from torch.optim import Adam
 
 
+from seamless_communication.models.unity import UnitYModel
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 

+ 0 - 0
scripts/m4t/predict/__init__.py