ソースを参照

Fix float precision in Finetune CLI (#399)

* Fix float precision in Finetune CLI

* Load model in f32

* Incorrect line
Alisamar Husain 1 年間 前
コミット
5effa1f2a2

+ 1 - 1
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -144,7 +144,7 @@ def main() -> None:
     )
     logger.info(f"Finetune params: {finetune_params}")
     model: UnitYModel = load_unity_model(
-        args.model_name, device=torch.device("cpu"), dtype=float_dtype
+        args.model_name, device=torch.device("cpu"), dtype=torch.float32
     )
     assert model.target_vocab_info == text_tokenizer.vocab_info
     # (optional) delete unused params to reduce GPU memory consumption

+ 2 - 1
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -350,7 +350,8 @@ class UnitYFinetune:
         """Run one train step"""
         self.model.train()
         self.optimizer.zero_grad()
-        tokens, units = self.model(batch)
+        with torch.autocast(device_type=self.params.device):
+            tokens, units = self.model(batch)
         loss = self.calc_loss(batch, tokens, units)
         if loss.isnan().any().item():
             logger.error(batch.speech_to_text)