|  | @@ -4,7 +4,7 @@
 | 
	
		
			
				|  |  |  # LICENSE file in the root directory of this source tree.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from pathlib import Path
 | 
	
		
			
				|  |  | -from typing import Dict, List, Optional, Tuple, Union
 | 
	
		
			
				|  |  | +from typing import Callable, Dict, List, Optional, Tuple, Union
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import torch
 | 
	
		
			
				|  |  |  import torch.nn as nn
 | 
	
	
		
			
				|  | @@ -15,7 +15,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer
 | 
	
		
			
				|  |  |  from fairseq2.data.typing import StringLike
 | 
	
		
			
				|  |  |  from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 | 
	
		
			
				|  |  |  from fairseq2.memory import MemoryBlock
 | 
	
		
			
				|  |  | -from fairseq2.typing import Device
 | 
	
		
			
				|  |  | +from fairseq2.typing import DataType, Device
 | 
	
		
			
				|  |  |  from torch import Tensor
 | 
	
		
			
				|  |  |  from enum import Enum, auto
 | 
	
		
			
				|  |  |  from seamless_communication.models.inference.ngram_repeat_block_processor import (
 | 
	
	
		
			
				|  | @@ -53,14 +53,14 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |          model_name_or_card: Union[str, AssetCard],
 | 
	
		
			
				|  |  |          vocoder_name_or_card: Union[str, AssetCard],
 | 
	
		
			
				|  |  |          device: Device,
 | 
	
		
			
				|  |  | +        dtype: DataType,
 | 
	
		
			
				|  |  |          sample_rate: int = 16000,
 | 
	
		
			
				|  |  |      ):
 | 
	
		
			
				|  |  |          super().__init__()
 | 
	
		
			
				|  |  |          # Load the model.
 | 
	
		
			
				|  |  | -        self.model: UnitYModel = load_unity_model(
 | 
	
		
			
				|  |  | -            model_name_or_card, device=device, dtype=torch.float16
 | 
	
		
			
				|  |  | +        self.model: UnitYModel = self.load_model_for_inference(
 | 
	
		
			
				|  |  | +            load_unity_model, model_name_or_card, device, dtype
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  | -        self.model.eval()
 | 
	
		
			
				|  |  |          self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
 | 
	
		
			
				|  |  |          self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
 | 
	
		
			
				|  |  |          self.device = device
 | 
	
	
		
			
				|  | @@ -71,15 +71,27 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |              channel_last=True,
 | 
	
		
			
				|  |  |              standardize=True,
 | 
	
		
			
				|  |  |              device=device,
 | 
	
		
			
				|  |  | -            dtype=torch.float16,
 | 
	
		
			
				|  |  | +            dtype=dtype,
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |          self.collate = Collater(
 | 
	
		
			
				|  |  |              pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  |          # Load the vocoder.
 | 
	
		
			
				|  |  | -        self.vocoder: Vocoder = load_vocoder_model(vocoder_name_or_card, device=device)
 | 
	
		
			
				|  |  | -        self.vocoder.eval()
 | 
	
		
			
				|  |  | -        self.sr = sample_rate
 | 
	
		
			
				|  |  | +        self.vocoder: Vocoder = self.load_model_for_inference(
 | 
	
		
			
				|  |  | +            load_vocoder_model, vocoder_name_or_card, device, torch.float32
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +        self.sample_rate = sample_rate
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def load_model_for_inference(
 | 
	
		
			
				|  |  | +        load_model_fn: Callable[..., nn.Module],
 | 
	
		
			
				|  |  | +        model_name_or_card: Union[str, AssetCard],
 | 
	
		
			
				|  |  | +        device: Device,
 | 
	
		
			
				|  |  | +        dtype: DataType,
 | 
	
		
			
				|  |  | +    ) -> nn.Module:
 | 
	
		
			
				|  |  | +        model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
 | 
	
		
			
				|  |  | +        model.eval()
 | 
	
		
			
				|  |  | +        return model
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
		
			
				|  |  |      def get_prediction(
 | 
	
	
		
			
				|  | @@ -136,17 +148,7 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              return Modality.TEXT, Modality.SPEECH
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    @torch.no_grad()
 | 
	
		
			
				|  |  | -    def synthesize_speech(
 | 
	
		
			
				|  |  | -        self,
 | 
	
		
			
				|  |  | -        code: List[int],
 | 
	
		
			
				|  |  | -        lang: str,
 | 
	
		
			
				|  |  | -        speaker: Optional[int] = None,
 | 
	
		
			
				|  |  | -        dur_prediction: Optional[bool] = True,
 | 
	
		
			
				|  |  | -    ) -> Tuple[List[Tensor], int]:
 | 
	
		
			
				|  |  | -        return self.vocoder(code, lang, speaker, dur_prediction), self.sr
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    @torch.no_grad()
 | 
	
		
			
				|  |  | +    @torch.inference_mode()
 | 
	
		
			
				|  |  |      def predict(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          input: Union[str, torch.Tensor],
 | 
	
	
		
			
				|  | @@ -173,8 +175,8 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          :returns:
 | 
	
		
			
				|  |  |              - Translated text.
 | 
	
		
			
				|  |  | -            - Audio waveform.
 | 
	
		
			
				|  |  | -            - Sampling rate of audio waveform.
 | 
	
		
			
				|  |  | +            - Generated output audio waveform corresponding to the translated text.
 | 
	
		
			
				|  |  | +            - Sample rate of output audio waveform.
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              task = Task[task_str.upper()]
 | 
	
	
		
			
				|  | @@ -192,7 +194,7 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |              else:
 | 
	
		
			
				|  |  |                  decoded_audio = {
 | 
	
		
			
				|  |  |                      "waveform": audio,
 | 
	
		
			
				|  |  | -                    "sample_rate": self.sr,
 | 
	
		
			
				|  |  | +                    "sample_rate": self.sample_rate,
 | 
	
		
			
				|  |  |                      "format": -1,
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |              src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
 | 
	
	
		
			
				|  | @@ -223,5 +225,5 @@ class Translator(nn.Module):
 | 
	
		
			
				|  |  |              return text_out.sentences[0], None, None
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
 | 
	
		
			
				|  |  | -            wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
 | 
	
		
			
				|  |  | -            return text_out.sentences[0], wav_out, sr_out
 | 
	
		
			
				|  |  | +            wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
 | 
	
		
			
				|  |  | +            return text_out.sentences[0], wav_out, self.sample_rate
 |