Ver código fonte

bug fix & formatter

Yilin Yang 1 ano atrás
pai
commit
791004d8bb

+ 0 - 1
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -35,7 +35,6 @@ from seamless_communication.cli.m4t.predict import (
 )
 from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
 from seamless_communication.models.unity import load_unity_unit_tokenizer
-
 from seamless_communication.store import add_gated_assets
 
 logging.basicConfig(

+ 15 - 9
src/seamless_communication/inference/expressive_translator.py

@@ -3,16 +3,16 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
-from typing import cast
 from copy import deepcopy
 from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
 
 import torch
 import torchaudio
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.nn.padding import apply_padding_mask, get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch.nn import Module
 
@@ -109,19 +109,25 @@ class ExpressiveTranslator(Module):
             src = cast(SequenceData, input)
             src_gcmvn = deepcopy(src)
 
-            fbank = src["seqs"]
+            fbank, padding_mask = get_seqs_and_padding_mask(src)
             gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
-            std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
-            ucmvn_fbank = fbank.subtract(mean).divide(std)
+            # due to padding, batched std_mean calculation is wrong
+            mean = torch.zeros_like(fbank[:, 0])
+            std = torch.zeros_like(fbank[:, 0])
+            for i, (i_fbank, i_seq_len) in enumerate(zip(fbank, src["seq_lens"])):
+                std[i], mean[i] = torch.std_mean(i_fbank[:i_seq_len], dim=0)
 
-            src["seqs"] = ucmvn_fbank
-            src_gcmvn["seqs"] = gcmvn_fbank
+            ucmvn_fbank = fbank.subtract(mean.unsqueeze(1)).divide(std.unsqueeze(1))
+            src["seqs"] = apply_padding_mask(ucmvn_fbank, padding_mask)
+            src_gcmvn["seqs"] = apply_padding_mask(gcmvn_fbank, padding_mask)
 
         elif isinstance(input, Path):
             # TODO: Replace with fairseq2.data once re-sampling is implemented.
-            wav, sample_rate = torchaudio.load(path)
+            wav, sample_rate = torchaudio.load(input)
             wav = torchaudio.functional.resample(
-                wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
+                wav,
+                orig_freq=sample_rate,
+                new_freq=AUDIO_SAMPLE_RATE,
             )
             wav = wav.transpose(0, 1)
 

+ 3 - 10
src/seamless_communication/inference/pretssel_generator.py

@@ -6,17 +6,11 @@
 from typing import List
 
 import torch
-from torch.nn import Module
-
-from fairseq2.typing import DataType, Device
-
 from fairseq2.assets import asset_store
-from fairseq2.data import (
-    Collater,
-    SequenceData,
-    VocabularyInfo,
-)
+from fairseq2.data import Collater, SequenceData, VocabularyInfo
 from fairseq2.nn.padding import get_seqs_and_padding_mask
+from fairseq2.typing import DataType, Device
+from torch.nn import Module
 
 from seamless_communication.inference.translator import BatchedSpeechOutput
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
@@ -60,7 +54,6 @@ class PretsselGenerator(Module):
         tgt_lang: str,
         prosody_encoder_input: SequenceData,
     ) -> BatchedSpeechOutput:
-
         units_batch, durations = [], []
         for u in units:
             unit = torch.tensor(u).to(self.unit_eos_token)

+ 9 - 18
src/seamless_communication/toxicity/mintox.py

@@ -7,26 +7,19 @@
 import logging
 from typing import List, Optional, Tuple
 
-from torch import Tensor
 import torch
-from torch.nn import functional as F
-
-
-from seamless_communication.inference.generator import SequenceGeneratorOptions
-from seamless_communication.toxicity.etox_bad_word_checker import (
-    ETOXBadWordChecker,
-)
-from fairseq2.generation import BannedSequenceProcessor
+from fairseq2.data import SequenceData
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
-from fairseq2.typing import Device
-from fairseq2.data import SequenceData
+from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.nn.padding import get_seqs_and_padding_mask
-from seamless_communication.models.unity import (
-    UnitTokenizer,
-    UnitYModel,
-)
+from fairseq2.typing import Device
+from torch import Tensor
+from torch.nn import functional as F
 
+from seamless_communication.inference.generator import SequenceGeneratorOptions
+from seamless_communication.models.unity import UnitTokenizer, UnitYModel
+from seamless_communication.toxicity.etox_bad_word_checker import ETOXBadWordChecker
 
 logger = logging.getLogger(__name__)
 
@@ -84,9 +77,7 @@ def _replace_with_new_unit_output_in_batch(
         )
     else:
         # pad on the new units
-        new_units = F.pad(
-            new_units, pad=nb_pads, mode="constant", value=pad_idx
-        )
+        new_units = F.pad(new_units, pad=nb_pads, mode="constant", value=pad_idx)
     original_units[indices_with_toxicity_tensor] = new_units