Browse Source

Merge pull request #136 from facebookresearch/api_update

Update API for translator
Ning 1 năm trước cách đây
mục cha
commit
21a7fb5a41
1 tập tin đã thay đổi với 3 bổ sung1 xóa
  1. 3 1
      src/seamless_communication/models/inference/translator.py

+ 3 - 1
src/seamless_communication/models/inference/translator.py

@@ -53,10 +53,12 @@ class Translator(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
-        dtype: DataType,
+        dtype: DataType = torch.float16,
     ):
         super().__init__()
         # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
         self.model: UnitYModel = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )