Переглянути джерело

Vocoder batch inference (#154)

* Enable M4T vocoder inference on the GPU, remove unnecessary device to host syncs in the Translator.

* batch inference of vocoder

* revert unnecessary local changes

* revert unnecessary changes

---------

Co-authored-by: Kaushik Ram Sadagopan <kaushikram2811@gmail.com>
Ning 1 рік тому
батько
коміт
64a98f0bc5

+ 8 - 6
src/seamless_communication/inference/translator.py

@@ -393,13 +393,15 @@ class Translator(nn.Module):
                     units[i] != self.model.t2u_model.target_vocab_info.pad_idx
                 )
                 u = units[i][padding_mask]
-                if self.vocoder is not None:
-                    # TODO: Implement batched inference for vocoder.
-                    translated_audio_wav = self.vocoder(
-                        u, tgt_lang, spkr, dur_prediction=duration_prediction
-                    )
-                    audio_wavs.append(translated_audio_wav)
                 speech_units.append(u.tolist())
+            
+            if self.vocoder is not None:
+                translated_audio_wav = self.vocoder(
+                    units, tgt_lang, spkr, dur_prediction=duration_prediction
+                )
+                for i in range(len(units)):
+                    padding_removed_audio_wav = translated_audio_wav[i, :, :int(translated_audio_wav.size(-1)*len(speech_units[i])/len(units[i]))].unsqueeze(0)
+                    audio_wavs.append(padding_removed_audio_wav)
             return (
                 text_output.sentences,
                 BatchedSpeechOutput(

+ 12 - 7
src/seamless_communication/models/vocoder/codehifigan.py

@@ -77,20 +77,25 @@ class CodeGenerator(Generator):
         x = self.dict(x).transpose(1, 2)
 
         if self.dur_predictor and dur_prediction:
-            assert x.size(0) == 1, "only support single sample"
             log_dur_pred = self.dur_predictor(x.transpose(1, 2), None)
             dur_out = torch.clamp(
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
             )
             # B x C x T
-            x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
-
+            repeat_interleaved_x = []
+            for i in range(x.size(0)):
+                repeat_interleaved_x.append(torch.repeat_interleave(x[i].unsqueeze(0), dur_out[i].view(-1), dim=2))
+            x = torch.cat(repeat_interleaved_x)
+        upsampled_spkr = []
+        upsampled_lang = []
         spkr = self.spkr(sample["spkr"]).transpose(1, 2)
-        spkr = self._upsample(spkr, x.shape[-1])
-        x = torch.cat([x, spkr], dim=1)
-
         lang = self.lang(sample["lang"]).transpose(1, 2)
-        lang = self._upsample(lang, x.shape[-1])
+        for i in range(x.size(0)):
+            upsampled_spkr.append(self._upsample(spkr[i], x.shape[-1]))
+            upsampled_lang.append(self._upsample(lang[i], x.shape[-1]))
+        spkr = torch.cat(upsampled_spkr, dim=1).transpose(0, 1)
+        lang = torch.cat(upsampled_lang, dim=1).transpose(0, 1)
+        x = torch.cat([x, spkr], dim=1)
         x = torch.cat([lang, x], dim=1)
 
         return super().forward(x)

+ 18 - 12
src/seamless_communication/models/vocoder/vocoder.py

@@ -4,8 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, Optional
-
+from typing import Any, Dict, Optional, List, Union
 import torch
 from torch import Tensor
 from torch.nn import Module
@@ -26,18 +25,25 @@ class Vocoder(Module):
     def forward(
         self,
         units: Tensor,
-        lang: str,
-        spkr: Optional[int] = -1,
+        lang_list: Union[List[str], str],
+        spkr_list: Union[Optional[List[int]], int] = None,
         dur_prediction: bool = True,
     ) -> Tensor:
-        lang_idx = self.lang_spkr_idx_map["multilingual"][lang]
-        spkr_list = self.lang_spkr_idx_map["multispkr"][lang]
-        if not spkr:
-            spkr = -1
-        spkr = spkr_list[0] if spkr == -1 else spkr
+        # TODO: Do we need this backward compatibility, or just update all calling sites? 
+        if len(units.shape) == 1:
+            units = units.unsqueeze(0) # add batch dim
+        if isinstance(lang_list, str):
+            lang_list = [lang_list] * units.size(0)
+        if isinstance(spkr_list, int):
+            spkr_list = [spkr_list] * units.size(0)
+        lang_idx_list = [self.lang_spkr_idx_map["multilingual"][l] for l in lang_list]
+        if not spkr_list:
+            spkr_list = [-1 for _ in range(len(lang_list))]
+        spkr_list = [self.lang_spkr_idx_map["multispkr"][lang_list[i]][0] if spkr_list[i] == -1 else spkr_list[i] for i in range(len(spkr_list))]
         x = {
-            "code": units.view(1, -1),
-            "spkr": torch.tensor([[spkr]], device=units.device),
-            "lang": torch.tensor([[lang_idx]], device=units.device),
+            "code": units.view(units.size(0), -1),
+            "spkr": torch.tensor([spkr_list], device=units.device).t(),
+            "lang": torch.tensor([lang_idx_list], device=units.device).t(),
+
         }
         return self.code_generator(x, dur_prediction)  # type: ignore[no-any-return]