Browse Source

Enable M4T vocoder inference on the GPU in fp16. (#151)

* Enable M4T vocoder inference on the GPU, remove unnecessary device to host syncs in the Translator.

* Address comment, disable non-deterministic GPU integ test.
Kaushik Ram Sadagopan 1 year ago
parent
commit
fb59ee0a49

+ 12 - 12
src/seamless_communication/inference/translator.py

@@ -151,7 +151,7 @@ class Translator(nn.Module):
             output_modality is None or output_modality == Modality.SPEECH
             output_modality is None or output_modality == Modality.SPEECH
         ):
         ):
             self.vocoder = load_vocoder_model(
             self.vocoder = load_vocoder_model(
-                vocoder_name_or_card, device=device, dtype=torch.float32
+                vocoder_name_or_card, device=device, dtype=dtype
             )
             )
             self.vocoder.eval()
             self.vocoder.eval()
 
 
@@ -232,7 +232,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
         prosody_encoder_input: Optional[SequenceData] = None,
-        src_text: Optional[str] = None 
+        src_text: Optional[str] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         """
         The main method used to perform inference on all tasks.
         The main method used to perform inference on all tasks.
@@ -259,15 +259,17 @@ class Translator(nn.Module):
             Optional source transcript (obtained by ASR for instance). This is used for
             Optional source transcript (obtained by ASR for instance). This is used for
             applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
             applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
             then src_lang must be specified and ASR will be run on the audio source.
             then src_lang must be specified and ASR will be run on the audio source.
-            
+
         :returns:
         :returns:
             - Batched list of Translated text.
             - Batched list of Translated text.
             - Translated BatchedSpeechOutput.
             - Translated BatchedSpeechOutput.
         """
         """
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
 
-        if self.apply_mintox and not (src_lang is not None or src_text is not None) :
-            raise ValueError("`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text.")
+        if self.apply_mintox and not (src_lang is not None or src_text is not None):
+            raise ValueError(
+                "`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text."
+            )
 
 
         if isinstance(input, dict):
         if isinstance(input, dict):
             src = cast(SequenceData, input)
             src = cast(SequenceData, input)
@@ -384,20 +386,18 @@ class Translator(nn.Module):
 
 
             audio_wavs = []
             audio_wavs = []
             speech_units = []
             speech_units = []
-            for i in range(len(unit_output.units)):
-                u = units[i].cpu().numpy().tolist()
-                index_of_first_one = next(
-                    (index for index, value in enumerate(u) if value == 1), len(u)
+            for i in range(len(units)):
+                padding_mask = (
+                    units[i] != self.model.t2u_model.target_vocab_info.pad_idx
                 )
                 )
-                u = u[:index_of_first_one]
-                speech_units.append(u)
+                u = units[i][padding_mask]
                 if self.vocoder is not None:
                 if self.vocoder is not None:
                     # TODO: Implement batched inference for vocoder.
                     # TODO: Implement batched inference for vocoder.
                     translated_audio_wav = self.vocoder(
                     translated_audio_wav = self.vocoder(
                         u, tgt_lang, spkr, dur_prediction=duration_prediction
                         u, tgt_lang, spkr, dur_prediction=duration_prediction
                     )
                     )
                     audio_wavs.append(translated_audio_wav)
                     audio_wavs.append(translated_audio_wav)
-
+                speech_units.append(u.tolist())
             return (
             return (
                 text_output.sentences,
                 text_output.sentences,
                 BatchedSpeechOutput(
                 BatchedSpeechOutput(

+ 5 - 10
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -5,9 +5,8 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 import logging
 import logging
-from itertools import groupby
 from pathlib import Path
 from pathlib import Path
-from typing import List, Union
+from typing import Union
 
 
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -103,15 +102,11 @@ class UnitExtractor(nn.Module):
         units: Tensor,
         units: Tensor,
         src_lang: str,
         src_lang: str,
         device: Device,
         device: Device,
-        vocoder_name: str = "vocoder_36langs",
+        dtype: DataType,
+        vocoder_name: str = "vocoder_v2",
     ) -> Tensor:
     ) -> Tensor:
-        def reduce_list(lst: List[Tensor]) -> List[Tensor]:
-            return [key for key, _ in groupby(lst)]
-
-        reduced_units = reduce_list(units.cpu().tolist())
-
-        vocoder = load_vocoder_model(vocoder_name, device=device, dtype=torch.float32)
+        vocoder = load_vocoder_model(vocoder_name, device=device, dtype=dtype)
         vocoder.eval()
         vocoder.eval()
         assert isinstance(vocoder, Vocoder)
         assert isinstance(vocoder, Vocoder)
-        wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
+        wav = vocoder(units, src_lang, spkr=-1, dur_prediction=True)
         return wav  # type: ignore[no-any-return]
         return wav  # type: ignore[no-any-return]

+ 5 - 6
src/seamless_communication/models/vocoder/builder.py

@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
@@ -32,7 +32,7 @@ class VocoderConfig:
     num_langs: int
     num_langs: int
     spkr_embedding_dim: int
     spkr_embedding_dim: int
     num_spkrs: int
     num_spkrs: int
-    lang_spkr_idx_map: Dict
+    lang_spkr_idx_map: Dict[str, Any]
 
 
 
 
 vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
 vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
@@ -93,7 +93,6 @@ class VocoderBuilder:
             The data type of module parameters and buffers.
             The data type of module parameters and buffers.
         """
         """
         self.config = config
         self.config = config
-
         self.device, self.dtype = device, dtype
         self.device, self.dtype = device, dtype
 
 
     def build_model(self) -> Vocoder:
     def build_model(self) -> Vocoder:
@@ -114,8 +113,8 @@ class VocoderBuilder:
             self.config.spkr_embedding_dim,
             self.config.spkr_embedding_dim,
             self.config.num_spkrs,
             self.config.num_spkrs,
         )
         )
+        code_generator.to(device=self.device, dtype=self.dtype)
         vocoder = Vocoder(code_generator, self.config.lang_spkr_idx_map)
         vocoder = Vocoder(code_generator, self.config.lang_spkr_idx_map)
-        vocoder.to(dtype=self.dtype)
         return vocoder
         return vocoder
 
 
 
 
@@ -163,7 +162,7 @@ def _base_mel_vocoder() -> VocoderConfig:
 
 
 
 
 @mel_vocoder_arch("24khz_mel")
 @mel_vocoder_arch("24khz_mel")
-def _base_mel_vocoder() -> VocoderConfig:
+def _24khz_mel_vocoder() -> VocoderConfig:
     return VocoderConfig(
     return VocoderConfig(
         upsample_rates=[5, 4, 4, 3],
         upsample_rates=[5, 4, 4, 3],
         upsample_kernel_sizes=[10, 8, 8, 6],
         upsample_kernel_sizes=[10, 8, 8, 6],
@@ -206,7 +205,7 @@ class MelVocoderBuilder:
             self.config.resblock_dilation_sizes,
             self.config.resblock_dilation_sizes,
             self.config.model_in_dim,
             self.config.model_in_dim,
         )
         )
-        generator.to(dtype=self.dtype, device=self.device)
+        generator.to(device=self.device, dtype=self.dtype)
         return generator
         return generator
 
 
 
 

+ 3 - 3
src/seamless_communication/models/vocoder/codehifigan.py

@@ -73,7 +73,7 @@ class CodeGenerator(Generator):
         return signal
         return signal
 
 
     def forward(self, sample: Dict[str, Any], dur_prediction: bool) -> Tensor:  # type: ignore
     def forward(self, sample: Dict[str, Any], dur_prediction: bool) -> Tensor:  # type: ignore
-        x = sample["code"].clone().to(device=self.dict.weight.device)
+        x = sample["code"]
         x = self.dict(x).transpose(1, 2)
         x = self.dict(x).transpose(1, 2)
 
 
         if self.dur_predictor and dur_prediction:
         if self.dur_predictor and dur_prediction:
@@ -85,11 +85,11 @@ class CodeGenerator(Generator):
             # B x C x T
             # B x C x T
             x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
             x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
 
 
-        spkr = self.spkr(sample["spkr"].to(self.spkr.weight.device)).transpose(1, 2)
+        spkr = self.spkr(sample["spkr"]).transpose(1, 2)
         spkr = self._upsample(spkr, x.shape[-1])
         spkr = self._upsample(spkr, x.shape[-1])
         x = torch.cat([x, spkr], dim=1)
         x = torch.cat([x, spkr], dim=1)
 
 
-        lang = self.lang(sample["lang"].to(self.lang.weight.device)).transpose(1, 2)
+        lang = self.lang(sample["lang"]).transpose(1, 2)
         lang = self._upsample(lang, x.shape[-1])
         lang = self._upsample(lang, x.shape[-1])
         x = torch.cat([lang, x], dim=1)
         x = torch.cat([lang, x], dim=1)
 
 

+ 18 - 10
src/seamless_communication/models/vocoder/hifigan.py

@@ -6,15 +6,23 @@
 
 
 from typing import List, Optional
 from typing import List, Optional
 
 
+import logging
 import torch
 import torch
-import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
+
 from torch import Tensor
 from torch import Tensor
-from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn import Conv1d, ConvTranspose1d, Module, ModuleList
 from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
 from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
 
 
 LRELU_SLOPE = 0.1
 LRELU_SLOPE = 0.1
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 
 def init_weights(m, mean: float = 0.0, std: float = 0.01) -> None:  # type: ignore
 def init_weights(m, mean: float = 0.0, std: float = 0.01) -> None:  # type: ignore
     classname = m.__class__.__name__
     classname = m.__class__.__name__
@@ -26,12 +34,12 @@ def get_padding(kernel_size: int, dilation: int = 1) -> int:
     return (kernel_size * dilation - dilation) // 2
     return (kernel_size * dilation - dilation) // 2
 
 
 
 
-class ResBlock(torch.nn.Module):
+class ResBlock(Module):
     def __init__(
     def __init__(
         self, channels: int, kernel_size: int = 3, dilation: List[int] = [1, 3, 5]
         self, channels: int, kernel_size: int = 3, dilation: List[int] = [1, 3, 5]
     ):
     ):
         super(ResBlock, self).__init__()
         super(ResBlock, self).__init__()
-        self.convs1 = nn.ModuleList(
+        self.convs1 = ModuleList(
             [
             [
                 weight_norm(
                 weight_norm(
                     Conv1d(
                     Conv1d(
@@ -67,7 +75,7 @@ class ResBlock(torch.nn.Module):
         )
         )
         self.convs1.apply(init_weights)
         self.convs1.apply(init_weights)
 
 
-        self.convs2 = nn.ModuleList(
+        self.convs2 = ModuleList(
             [
             [
                 weight_norm(
                 weight_norm(
                     Conv1d(
                     Conv1d(
@@ -119,7 +127,7 @@ class ResBlock(torch.nn.Module):
             remove_weight_norm(layer)
             remove_weight_norm(layer)
 
 
 
 
-class Generator(torch.nn.Module):
+class Generator(Module):
     def __init__(
     def __init__(
         self,
         self,
         upsample_rates: List[int],
         upsample_rates: List[int],
@@ -130,7 +138,7 @@ class Generator(torch.nn.Module):
         model_in_dim: Optional[int],
         model_in_dim: Optional[int],
         add_ups_out_pad: bool = False,
         add_ups_out_pad: bool = False,
     ):
     ):
-        super(Generator, self).__init__()
+        super().__init__()
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
         self.num_upsamples = len(upsample_rates)
         self.conv_pre = weight_norm(
         self.conv_pre = weight_norm(
@@ -143,7 +151,7 @@ class Generator(torch.nn.Module):
             )
             )
         )
         )
 
 
-        self.ups = nn.ModuleList()
+        self.ups = ModuleList()
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
             out_pad = u % 2 if add_ups_out_pad else 0
             out_pad = u % 2 if add_ups_out_pad else 0
             self.ups.append(
             self.ups.append(
@@ -159,7 +167,7 @@ class Generator(torch.nn.Module):
                 )
                 )
             )
             )
 
 
-        self.resblocks = nn.ModuleList()
+        self.resblocks = ModuleList()
         for i in range(len(self.ups)):
         for i in range(len(self.ups)):
             ch = upsample_initial_channel // (2 ** (i + 1))
             ch = upsample_initial_channel // (2 ** (i + 1))
             for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
             for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
@@ -188,7 +196,7 @@ class Generator(torch.nn.Module):
         return x
         return x
 
 
     def remove_weight_norm(self) -> None:
     def remove_weight_norm(self) -> None:
-        print("Removing weight norm...")
+        logger.info("Removing weight norm in Generator.")
         for layer in self.ups:
         for layer in self.ups:
             remove_weight_norm(layer)
             remove_weight_norm(layer)
         for layer in self.resblocks:
         for layer in self.resblocks:

+ 1 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -34,7 +34,7 @@ def convert_vocoder_checkpoint(
     for key in old_state_dict:
     for key in old_state_dict:
         new_key = f"code_generator.{key}"
         new_key = f"code_generator.{key}"
         new_state_dict[new_key] = old_state_dict[key]
         new_state_dict[new_key] = old_state_dict[key]
-    checkpoint["model"] = new_state_dict
+    checkpoint["model"] = new_state_dict  # type: ignore
     del checkpoint["generator"]  # type: ignore
     del checkpoint["generator"]  # type: ignore
     return checkpoint
     return checkpoint
 
 

+ 16 - 12
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,36 +4,40 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-from typing import Dict, List, Optional
+from typing import Any, Dict, Optional
 
 
 import torch
 import torch
-import torch.nn as nn
 from torch import Tensor
 from torch import Tensor
+from torch.nn import Module
 
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
 
 
 
-class Vocoder(nn.Module):
-    def __init__(self, code_generator: CodeGenerator, lang_spkr_idx_map: Dict):
-        super(Vocoder, self).__init__()
+class Vocoder(Module):
+    def __init__(
+        self,
+        code_generator: CodeGenerator,
+        lang_spkr_idx_map: Dict[str, Any],
+    ):
+        super().__init__()
         self.code_generator = code_generator
         self.code_generator = code_generator
         self.lang_spkr_idx_map = lang_spkr_idx_map
         self.lang_spkr_idx_map = lang_spkr_idx_map
 
 
     def forward(
     def forward(
         self,
         self,
-        code: List[int],
+        units: Tensor,
         lang: str,
         lang: str,
         spkr: Optional[int] = -1,
         spkr: Optional[int] = -1,
         dur_prediction: bool = True,
         dur_prediction: bool = True,
     ) -> Tensor:
     ) -> Tensor:
-        x = {
-            "code": torch.LongTensor(code).view(1, -1),
-        }
         lang_idx = self.lang_spkr_idx_map["multilingual"][lang]
         lang_idx = self.lang_spkr_idx_map["multilingual"][lang]
         spkr_list = self.lang_spkr_idx_map["multispkr"][lang]
         spkr_list = self.lang_spkr_idx_map["multispkr"][lang]
         if not spkr:
         if not spkr:
             spkr = -1
             spkr = -1
         spkr = spkr_list[0] if spkr == -1 else spkr
         spkr = spkr_list[0] if spkr == -1 else spkr
-        x["spkr"] = torch.tensor([[spkr]])
-        x["lang"] = torch.tensor([[lang_idx]])
-        return self.code_generator(x, dur_prediction)
+        x = {
+            "code": units.view(1, -1),
+            "spkr": torch.tensor([[spkr]], device=units.device),
+            "lang": torch.tensor([[lang_idx]], device=units.device),
+        }
+        return self.code_generator(x, dur_prediction)  # type: ignore[no-any-return]

+ 4 - 5
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -37,15 +37,14 @@ class VocoderAgent(TextToSpeechAgent):  # type: ignore
                 return ReadAction()
                 return ReadAction()
 
 
         tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
         tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
-        u = units[0][0].tolist()
-        wav_samples = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)[
-            0
-        ][0].tolist()
+        u = units[0][0]
+
+        wav = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)
         states.source = []
         states.source = []
 
 
         return WriteAction(
         return WriteAction(
             SpeechSegment(
             SpeechSegment(
-                content=wav_samples,
+                content=wav[0][0].tolist(),
                 finished=states.source_finished,
                 finished=states.source_finished,
                 sample_rate=self.sample_rate,
                 sample_rate=self.sample_rate,
                 tgt_lang=tgt_lang,
                 tgt_lang=tgt_lang,

+ 0 - 3
tests/integration/inference/test_translator.py

@@ -6,9 +6,6 @@
 
 
 from typing import Final
 from typing import Final
 
 
-import torch
-from fairseq2.typing import Device
-
 from seamless_communication.inference import Translator
 from seamless_communication.inference import Translator
 from tests.common import device, get_default_dtype
 from tests.common import device, get_default_dtype
 
 

+ 1 - 1
tests/integration/models/test_pretssel_vocoder.py

@@ -27,7 +27,7 @@ def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
 
 
     feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
     feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
 
 
-    vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=torch.float32)
+    vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=dtype)
     vocoder.eval()
     vocoder.eval()
 
 
     with torch.inference_mode():
     with torch.inference_mode():

+ 5 - 3
tests/integration/models/test_unit_extractor.py

@@ -7,12 +7,12 @@
 from typing import Final
 from typing import Final
 
 
 import torch
 import torch
-from fairseq2.typing import Device
 from torch import tensor
 from torch import tensor
 
 
+from fairseq2.typing import Device
 from seamless_communication.inference import Translator
 from seamless_communication.inference import Translator
 from seamless_communication.models.unit_extractor import UnitExtractor
 from seamless_communication.models.unit_extractor import UnitExtractor
-from tests.common import assert_equal, device, get_default_dtype
+from tests.common import assert_equal
 
 
 # fmt: off
 # fmt: off
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
@@ -23,7 +23,9 @@ def test_unit_extractor() -> None:
     model_name = "seamlessM4T_v2_large"
     model_name = "seamlessM4T_v2_large"
     english_text = "Hello! I hope you're all doing well."
     english_text = "Hello! I hope you're all doing well."
 
 
-    dtype = get_default_dtype()
+    # We can't test on the GPU since the output is non-deterministic.
+    device = Device("cpu")
+    dtype = torch.float32
 
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)