cndn 1 год назад
Родитель
Сommit
8814dce9f6
1 измененных файлов с 3 добавлено и 1 удалено
  1. 3 1
      src/seamless_communication/models/inference/translator.py

+ 3 - 1
src/seamless_communication/models/inference/translator.py

@@ -53,10 +53,12 @@ class Translator(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
-        dtype: DataType,
+        dtype: DataType = torch.float16,
     ):
         super().__init__()
         # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
         self.model: UnitYModel = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )