Forráskód Böngészése

Update API for translator

cndn 1 éve
szülő
commit
8814dce9f6

+ 3 - 1
src/seamless_communication/models/inference/translator.py

@@ -53,10 +53,12 @@ class Translator(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
-        dtype: DataType,
+        dtype: DataType = torch.float16,
     ):
         super().__init__()
         # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
         self.model: UnitYModel = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )