소스 검색

Fix issues in Autocast (#400)

* Enable user-specified device

* CUDA should be default device

* Log model name

* Wrap eval in autocast

* Running on CPU

* TQDM in eval
Alisamar Husain 1 년 전
부모
커밋
1971db5dd6
2개의 변경된 파일20개의 추가작업 그리고 10개의 파일을 삭제
  1. 11 6
      src/seamless_communication/cli/m4t/finetune/finetune.py
  2. 9 4
      src/seamless_communication/cli/m4t/finetune/trainer.py

+ 11 - 6
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -119,20 +119,25 @@ def init_parser() -> argparse.ArgumentParser:
             "* `SPEECH_TO_TEXT` -- finetune only S2T"
         ),
     )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help=("Device to fine-tune on. See `torch.device`."),
+    )
     return parser
 
 
 def main() -> None:
     args = init_parser().parse_args()
     dist_utils.init_distributed([logger, trainer.logger])
-    device = torch.device("cuda")
-    float_dtype = torch.float16
     text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
     unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
     finetune_params = trainer.FinetuneParams(
         finetune_mode=args.mode,
         save_model_path=args.save_model_to,
-        device=device,
+        device=torch.device(args.device),
+        float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
         train_batch_size=args.batch_size,
         eval_batch_size=args.batch_size,
         patience=args.patience,
@@ -156,7 +161,7 @@ def main() -> None:
     if model.text_encoder is not None:
         model.text_encoder = None
     model = model.to(finetune_params.device)
-    logger.info(f"Model {model}")
+    logger.info(f"<{args.model_name}> {model}")
 
     train_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
@@ -166,7 +171,7 @@ def main() -> None:
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
             max_audio_length_sec=15.0,
-            float_dtype=float_dtype,
+            float_dtype=finetune_params.float_dtype,
         ),
         dataset_manifest_path=args.train_dataset,
     )
@@ -178,7 +183,7 @@ def main() -> None:
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
             max_audio_length_sec=100.0,
-            float_dtype=float_dtype,
+            float_dtype=finetune_params.float_dtype,
         ),
         dataset_manifest_path=args.eval_dataset,
     )

+ 9 - 4
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -9,6 +9,7 @@ import logging
 from contextlib import contextmanager
 from dataclasses import dataclass
 from enum import Enum
+from tqdm import tqdm
 from pathlib import Path
 from typing import Optional, Tuple
 
@@ -44,6 +45,9 @@ class FinetuneParams:
 
     finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH
     """Allows to freeze S2T or T2U part of the model"""
+    
+    float_dtype: torch.dtype = torch.float16
+    """Float Dtype"""
 
     max_epochs: int = 10
     """ Maximum number of trainign epochs"""
@@ -260,7 +264,7 @@ class UnitYFinetune:
             eps=1e-08,
             maximize=False,
             weight_decay=0.0,
-            fused=True,
+            fused=(self.params.device.type == "cuda"),
         )
         self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.lr_scheduler = MyleLR(
@@ -321,9 +325,10 @@ class UnitYFinetune:
         loss_hist = LossCollector(device=self.params.device)
         self.model.eval()
         with torch.no_grad():
-            for batch in self.eval_data_loader.get_dataloader():
+            for batch in tqdm(self.eval_data_loader.get_dataloader()):
                 assert batch.speech_to_text.src_tokens is not None
-                loss = self.calc_loss(batch, *self.model(batch))
+                with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
+                    loss = self.calc_loss(batch, *self.model(batch))
                 if loss.isnan():
                     logger.warning("Eval loss value is NaN, setting to inf")
                     loss_val = float("Inf")
@@ -350,7 +355,7 @@ class UnitYFinetune:
         """Run one train step"""
         self.model.train()
         self.optimizer.zero_grad()
-        with torch.autocast(device_type=self.params.device):
+        with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
             tokens, units = self.model(batch)
         loss = self.calc_loss(batch, tokens, units)
         if loss.isnan().any().item():