Parcourir la source

Update API for translator

cndn il y a 1 an
Parent
commit
8814dce9f6
1 fichiers modifiés avec 3 ajouts et 1 suppressions
  1. 3 1
      src/seamless_communication/models/inference/translator.py

+ 3 - 1
src/seamless_communication/models/inference/translator.py

@@ -53,10 +53,12 @@ class Translator(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
-        dtype: DataType,
+        dtype: DataType = torch.float16,
     ):
         super().__init__()
         # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
         self.model: UnitYModel = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )