Эх сурвалжийг харах

Enable M4T vocoder inference on the GPU in fp16. (#151)

* Enable M4T vocoder inference on the GPU, remove unnecessary device to host syncs in the Translator.

* Address comment, disable non-deterministic GPU integ test.
Kaushik Ram Sadagopan 1 жил өмнө
parent
commit
fb59ee0a49

+ 12 - 12
src/seamless_communication/inference/translator.py

@@ -151,7 +151,7 @@ class Translator(nn.Module):
             output_modality is None or output_modality == Modality.SPEECH
         ):
             self.vocoder = load_vocoder_model(
-                vocoder_name_or_card, device=device, dtype=torch.float32
+                vocoder_name_or_card, device=device, dtype=dtype
             )
             self.vocoder.eval()
 
@@ -232,7 +232,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-        src_text: Optional[str] = None 
+        src_text: Optional[str] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -259,15 +259,17 @@ class Translator(nn.Module):
             Optional source transcript (obtained by ASR for instance). This is used for
             applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
             then src_lang must be specified and ASR will be run on the audio source.
-            
+
         :returns:
             - Batched list of Translated text.
             - Translated BatchedSpeechOutput.
         """
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
-        if self.apply_mintox and not (src_lang is not None or src_text is not None) :
-            raise ValueError("`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text.")
+        if self.apply_mintox and not (src_lang is not None or src_text is not None):
+            raise ValueError(
+                "`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text."
+            )
 
         if isinstance(input, dict):
             src = cast(SequenceData, input)
@@ -384,20 +386,18 @@ class Translator(nn.Module):
 
             audio_wavs = []
             speech_units = []
-            for i in range(len(unit_output.units)):
-                u = units[i].cpu().numpy().tolist()
-                index_of_first_one = next(
-                    (index for index, value in enumerate(u) if value == 1), len(u)
+            for i in range(len(units)):
+                padding_mask = (
+                    units[i] != self.model.t2u_model.target_vocab_info.pad_idx
                 )
-                u = u[:index_of_first_one]
-                speech_units.append(u)
+                u = units[i][padding_mask]
                 if self.vocoder is not None:
                     # TODO: Implement batched inference for vocoder.
                     translated_audio_wav = self.vocoder(
                         u, tgt_lang, spkr, dur_prediction=duration_prediction
                     )
                     audio_wavs.append(translated_audio_wav)
-
+                speech_units.append(u.tolist())
             return (
                 text_output.sentences,
                 BatchedSpeechOutput(

+ 5 - 10
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -5,9 +5,8 @@
 # LICENSE file in the root directory of this source tree.
 
 import logging
-from itertools import groupby
 from pathlib import Path
-from typing import List, Union
+from typing import Union
 
 import torch
 import torch.nn.functional as F
@@ -103,15 +102,11 @@ class UnitExtractor(nn.Module):
         units: Tensor,
         src_lang: str,
         device: Device,
-        vocoder_name: str = "vocoder_36langs",
+        dtype: DataType,
+        vocoder_name: str = "vocoder_v2",
     ) -> Tensor:
-        def reduce_list(lst: List[Tensor]) -> List[Tensor]:
-            return [key for key, _ in groupby(lst)]
-
-        reduced_units = reduce_list(units.cpu().tolist())
-
-        vocoder = load_vocoder_model(vocoder_name, device=device, dtype=torch.float32)
+        vocoder = load_vocoder_model(vocoder_name, device=device, dtype=dtype)
         vocoder.eval()
         assert isinstance(vocoder, Vocoder)
-        wav = vocoder(reduced_units, src_lang, spkr=-1, dur_prediction=True)
+        wav = vocoder(units, src_lang, spkr=-1, dur_prediction=True)
         return wav  # type: ignore[no-any-return]

+ 5 - 6
src/seamless_communication/models/vocoder/builder.py

@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.typing import DataType, Device
@@ -32,7 +32,7 @@ class VocoderConfig:
     num_langs: int
     spkr_embedding_dim: int
     num_spkrs: int
-    lang_spkr_idx_map: Dict
+    lang_spkr_idx_map: Dict[str, Any]
 
 
 vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_code_hifigan")
@@ -93,7 +93,6 @@ class VocoderBuilder:
             The data type of module parameters and buffers.
         """
         self.config = config
-
         self.device, self.dtype = device, dtype
 
     def build_model(self) -> Vocoder:
@@ -114,8 +113,8 @@ class VocoderBuilder:
             self.config.spkr_embedding_dim,
             self.config.num_spkrs,
         )
+        code_generator.to(device=self.device, dtype=self.dtype)
         vocoder = Vocoder(code_generator, self.config.lang_spkr_idx_map)
-        vocoder.to(dtype=self.dtype)
         return vocoder
 
 
@@ -163,7 +162,7 @@ def _base_mel_vocoder() -> VocoderConfig:
 
 
 @mel_vocoder_arch("24khz_mel")
-def _base_mel_vocoder() -> VocoderConfig:
+def _24khz_mel_vocoder() -> VocoderConfig:
     return VocoderConfig(
         upsample_rates=[5, 4, 4, 3],
         upsample_kernel_sizes=[10, 8, 8, 6],
@@ -206,7 +205,7 @@ class MelVocoderBuilder:
             self.config.resblock_dilation_sizes,
             self.config.model_in_dim,
         )
-        generator.to(dtype=self.dtype, device=self.device)
+        generator.to(device=self.device, dtype=self.dtype)
         return generator
 
 

+ 3 - 3
src/seamless_communication/models/vocoder/codehifigan.py

@@ -73,7 +73,7 @@ class CodeGenerator(Generator):
         return signal
 
     def forward(self, sample: Dict[str, Any], dur_prediction: bool) -> Tensor:  # type: ignore
-        x = sample["code"].clone().to(device=self.dict.weight.device)
+        x = sample["code"]
         x = self.dict(x).transpose(1, 2)
 
         if self.dur_predictor and dur_prediction:
@@ -85,11 +85,11 @@ class CodeGenerator(Generator):
             # B x C x T
             x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
 
-        spkr = self.spkr(sample["spkr"].to(self.spkr.weight.device)).transpose(1, 2)
+        spkr = self.spkr(sample["spkr"]).transpose(1, 2)
         spkr = self._upsample(spkr, x.shape[-1])
         x = torch.cat([x, spkr], dim=1)
 
-        lang = self.lang(sample["lang"].to(self.lang.weight.device)).transpose(1, 2)
+        lang = self.lang(sample["lang"]).transpose(1, 2)
         lang = self._upsample(lang, x.shape[-1])
         x = torch.cat([lang, x], dim=1)
 

+ 18 - 10
src/seamless_communication/models/vocoder/hifigan.py

@@ -6,15 +6,23 @@
 
 from typing import List, Optional
 
+import logging
 import torch
-import torch.nn as nn
 import torch.nn.functional as F
+
 from torch import Tensor
-from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn import Conv1d, ConvTranspose1d, Module, ModuleList
 from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
 
 LRELU_SLOPE = 0.1
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 def init_weights(m, mean: float = 0.0, std: float = 0.01) -> None:  # type: ignore
     classname = m.__class__.__name__
@@ -26,12 +34,12 @@ def get_padding(kernel_size: int, dilation: int = 1) -> int:
     return (kernel_size * dilation - dilation) // 2
 
 
-class ResBlock(torch.nn.Module):
+class ResBlock(Module):
     def __init__(
         self, channels: int, kernel_size: int = 3, dilation: List[int] = [1, 3, 5]
     ):
         super(ResBlock, self).__init__()
-        self.convs1 = nn.ModuleList(
+        self.convs1 = ModuleList(
             [
                 weight_norm(
                     Conv1d(
@@ -67,7 +75,7 @@ class ResBlock(torch.nn.Module):
         )
         self.convs1.apply(init_weights)
 
-        self.convs2 = nn.ModuleList(
+        self.convs2 = ModuleList(
             [
                 weight_norm(
                     Conv1d(
@@ -119,7 +127,7 @@ class ResBlock(torch.nn.Module):
             remove_weight_norm(layer)
 
 
-class Generator(torch.nn.Module):
+class Generator(Module):
     def __init__(
         self,
         upsample_rates: List[int],
@@ -130,7 +138,7 @@ class Generator(torch.nn.Module):
         model_in_dim: Optional[int],
         add_ups_out_pad: bool = False,
     ):
-        super(Generator, self).__init__()
+        super().__init__()
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
         self.conv_pre = weight_norm(
@@ -143,7 +151,7 @@ class Generator(torch.nn.Module):
             )
         )
 
-        self.ups = nn.ModuleList()
+        self.ups = ModuleList()
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
             out_pad = u % 2 if add_ups_out_pad else 0
             self.ups.append(
@@ -159,7 +167,7 @@ class Generator(torch.nn.Module):
                 )
             )
 
-        self.resblocks = nn.ModuleList()
+        self.resblocks = ModuleList()
         for i in range(len(self.ups)):
             ch = upsample_initial_channel // (2 ** (i + 1))
             for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
@@ -188,7 +196,7 @@ class Generator(torch.nn.Module):
         return x
 
     def remove_weight_norm(self) -> None:
-        print("Removing weight norm...")
+        logger.info("Removing weight norm in Generator.")
         for layer in self.ups:
             remove_weight_norm(layer)
         for layer in self.resblocks:

+ 1 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -34,7 +34,7 @@ def convert_vocoder_checkpoint(
     for key in old_state_dict:
         new_key = f"code_generator.{key}"
         new_state_dict[new_key] = old_state_dict[key]
-    checkpoint["model"] = new_state_dict
+    checkpoint["model"] = new_state_dict  # type: ignore
     del checkpoint["generator"]  # type: ignore
     return checkpoint
 

+ 16 - 12
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,36 +4,40 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Dict, List, Optional
+from typing import Any, Dict, Optional
 
 import torch
-import torch.nn as nn
 from torch import Tensor
+from torch.nn import Module
 
 from seamless_communication.models.vocoder.codehifigan import CodeGenerator
 
 
-class Vocoder(nn.Module):
-    def __init__(self, code_generator: CodeGenerator, lang_spkr_idx_map: Dict):
-        super(Vocoder, self).__init__()
+class Vocoder(Module):
+    def __init__(
+        self,
+        code_generator: CodeGenerator,
+        lang_spkr_idx_map: Dict[str, Any],
+    ):
+        super().__init__()
         self.code_generator = code_generator
         self.lang_spkr_idx_map = lang_spkr_idx_map
 
     def forward(
         self,
-        code: List[int],
+        units: Tensor,
         lang: str,
         spkr: Optional[int] = -1,
         dur_prediction: bool = True,
     ) -> Tensor:
-        x = {
-            "code": torch.LongTensor(code).view(1, -1),
-        }
         lang_idx = self.lang_spkr_idx_map["multilingual"][lang]
         spkr_list = self.lang_spkr_idx_map["multispkr"][lang]
         if not spkr:
             spkr = -1
         spkr = spkr_list[0] if spkr == -1 else spkr
-        x["spkr"] = torch.tensor([[spkr]])
-        x["lang"] = torch.tensor([[lang_idx]])
-        return self.code_generator(x, dur_prediction)
+        x = {
+            "code": units.view(1, -1),
+            "spkr": torch.tensor([[spkr]], device=units.device),
+            "lang": torch.tensor([[lang_idx]], device=units.device),
+        }
+        return self.code_generator(x, dur_prediction)  # type: ignore[no-any-return]

+ 4 - 5
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -37,15 +37,14 @@ class VocoderAgent(TextToSpeechAgent):  # type: ignore
                 return ReadAction()
 
         tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
-        u = units[0][0].tolist()
-        wav_samples = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)[
-            0
-        ][0].tolist()
+        u = units[0][0]
+
+        wav = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)
         states.source = []
 
         return WriteAction(
             SpeechSegment(
-                content=wav_samples,
+                content=wav[0][0].tolist(),
                 finished=states.source_finished,
                 sample_rate=self.sample_rate,
                 tgt_lang=tgt_lang,

+ 0 - 3
tests/integration/inference/test_translator.py

@@ -6,9 +6,6 @@
 
 from typing import Final
 
-import torch
-from fairseq2.typing import Device
-
 from seamless_communication.inference import Translator
 from tests.common import device, get_default_dtype
 

+ 1 - 1
tests/integration/models/test_pretssel_vocoder.py

@@ -27,7 +27,7 @@ def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
 
     feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
 
-    vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=torch.float32)
+    vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=dtype)
     vocoder.eval()
 
     with torch.inference_mode():

+ 5 - 3
tests/integration/models/test_unit_extractor.py

@@ -7,12 +7,12 @@
 from typing import Final
 
 import torch
-from fairseq2.typing import Device
 from torch import tensor
 
+from fairseq2.typing import Device
 from seamless_communication.inference import Translator
 from seamless_communication.models.unit_extractor import UnitExtractor
-from tests.common import assert_equal, device, get_default_dtype
+from tests.common import assert_equal
 
 # fmt: off
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
@@ -23,7 +23,9 @@ def test_unit_extractor() -> None:
     model_name = "seamlessM4T_v2_large"
     english_text = "Hello! I hope you're all doing well."
 
-    dtype = get_default_dtype()
+    # We can't test on the GPU since the output is non-deterministic.
+    device = Device("cpu")
+    dtype = torch.float32
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)