Browse Source

Adding dtype to Translator's __init__().

Kaushik Ram Sadagopan 2 years ago
parent
commit
6c02a4fb08

+ 1 - 1
scripts/m4t/predict/README.md

@@ -63,7 +63,7 @@ from seamless_communication.models.inference import Translator
 
 
 # Initialize a Translator object with a multitask model, vocoder on the GPU.
-translator = Translator("seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0"))
+translator = Translator("seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0"), torch.float16)
 ```
 
 Now `predict()` can be used to run inference as many times on any of the supported tasks.

+ 5 - 3
scripts/m4t/predict/predict.py

@@ -61,12 +61,14 @@ def main():
 
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
-        logger.info("Running inference on the GPU.")
+        dtype = torch.float16
+        logger.info(f"Running inference on the GPU in {dtype}.")
     else:
         device = torch.device("cpu")
-        logger.info("Running inference on the CPU.")
+        dtype = torch.float32
+        logger.info(f"Running inference on the CPU in {dtype}.")
 
-    translator = Translator(args.model_name, args.vocoder_name, device)
+    translator = Translator(args.model_name, args.vocoder_name, device, dtype)
     translated_text, wav, sr = translator.predict(
         args.input,
         args.task,

+ 1 - 1
src/seamless_communication/models/inference/translator.py

@@ -53,10 +53,10 @@ class Translator(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
+        dtype: DataType,
         sample_rate: int = 16000,
     ):
         super().__init__()
-        dtype = torch.float16 if "cuda" in device.type else torch.float32
         # Load the model.
         self.model: UnitYModel = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype