Procházet zdrojové kódy

Fix dtype issues with CPU inference, minor refactoring to Translator.

Kaushik Ram Sadagopan před 2 roky
rodič
revize
418134aadf

+ 27 - 25
src/seamless_communication/models/inference/translator.py

@@ -4,7 +4,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -15,7 +15,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 from fairseq2.memory import MemoryBlock
-from fairseq2.typing import Device
+from fairseq2.typing import DataType, Device
 from torch import Tensor
 from enum import Enum, auto
 from seamless_communication.models.inference.ngram_repeat_block_processor import (
@@ -56,11 +56,11 @@ class Translator(nn.Module):
         sample_rate: int = 16000,
     ):
         super().__init__()
+        dtype = torch.float16 if "cuda" in device.type else torch.float32
         # Load the model.
-        self.model: UnitYModel = load_unity_model(
-            model_name_or_card, device=device, dtype=torch.float16
+        self.model: UnitYModel = self.load_model_for_inference(
+            load_unity_model, model_name_or_card, device, dtype
         )
-        self.model.eval()
         self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
         self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
         self.device = device
@@ -71,15 +71,27 @@ class Translator(nn.Module):
             channel_last=True,
             standardize=True,
             device=device,
-            dtype=torch.float16,
+            dtype=dtype,
         )
         self.collate = Collater(
             pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         # Load the vocoder.
-        self.vocoder: Vocoder = load_vocoder_model(vocoder_name_or_card, device=device)
-        self.vocoder.eval()
-        self.sr = sample_rate
+        self.vocoder: Vocoder = self.load_model_for_inference(
+            load_vocoder_model, vocoder_name_or_card, device, torch.float32
+        )
+        self.sample_rate = sample_rate
+
+    @staticmethod
+    def load_model_for_inference(
+        load_model_fn: Callable[..., nn.Module],
+        model_name_or_card: Union[str, AssetCard],
+        device: Device,
+        dtype: DataType,
+    ) -> nn.Module:
+        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
+        model.eval()
+        return model
 
     @classmethod
     def get_prediction(
@@ -136,17 +148,7 @@ class Translator(nn.Module):
         else:
             return Modality.TEXT, Modality.SPEECH
 
-    @torch.no_grad()
-    def synthesize_speech(
-        self,
-        code: List[int],
-        lang: str,
-        speaker: Optional[int] = None,
-        dur_prediction: Optional[bool] = True,
-    ) -> Tuple[List[Tensor], int]:
-        return self.vocoder(code, lang, speaker, dur_prediction), self.sr
-
-    @torch.no_grad()
+    @torch.inference_mode()
     def predict(
         self,
         input: Union[str, torch.Tensor],
@@ -173,8 +175,8 @@ class Translator(nn.Module):
 
         :returns:
             - Translated text.
-            - Audio waveform.
-            - Sampling rate of audio waveform.
+            - Generated output audio waveform corresponding to the translated text.
+            - Sample rate of output audio waveform.
         """
         try:
             task = Task[task_str.upper()]
@@ -192,7 +194,7 @@ class Translator(nn.Module):
             else:
                 decoded_audio = {
                     "waveform": audio,
-                    "sample_rate": self.sr,
+                    "sample_rate": self.sample_rate,
                     "format": -1,
                 }
             src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
@@ -223,5 +225,5 @@ class Translator(nn.Module):
             return text_out.sentences[0], None, None
         else:
             units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
-            wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
-            return text_out.sentences[0], wav_out, sr_out
+            wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
+            return text_out.sentences[0], wav_out, self.sample_rate