Browse Source

bug fix & formatter

Yilin Yang 1 năm trước cách đây
mục cha
commit
791004d8bb

+ 0 - 1
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -35,7 +35,6 @@ from seamless_communication.cli.m4t.predict import (
 )
 )
 from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
 from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
 from seamless_communication.models.unity import load_unity_unit_tokenizer
 from seamless_communication.models.unity import load_unity_unit_tokenizer
-
 from seamless_communication.store import add_gated_assets
 from seamless_communication.store import add_gated_assets
 
 
 logging.basicConfig(
 logging.basicConfig(

+ 15 - 9
src/seamless_communication/inference/expressive_translator.py

@@ -3,16 +3,16 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import cast
 from copy import deepcopy
 from copy import deepcopy
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
 
 
 import torch
 import torch
 import torchaudio
 import torchaudio
 from fairseq2.assets.card import AssetCard
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data import SequenceData, StringLike
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.nn.padding import apply_padding_mask, get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 from torch.nn import Module
 from torch.nn import Module
 
 
@@ -109,19 +109,25 @@ class ExpressiveTranslator(Module):
             src = cast(SequenceData, input)
             src = cast(SequenceData, input)
             src_gcmvn = deepcopy(src)
             src_gcmvn = deepcopy(src)
 
 
-            fbank = src["seqs"]
+            fbank, padding_mask = get_seqs_and_padding_mask(src)
             gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
             gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
-            std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
-            ucmvn_fbank = fbank.subtract(mean).divide(std)
+            # due to padding, batched std_mean calculation is wrong
+            mean = torch.zeros_like(fbank[:, 0])
+            std = torch.zeros_like(fbank[:, 0])
+            for i, (i_fbank, i_seq_len) in enumerate(zip(fbank, src["seq_lens"])):
+                std[i], mean[i] = torch.std_mean(i_fbank[:i_seq_len], dim=0)
 
 
-            src["seqs"] = ucmvn_fbank
-            src_gcmvn["seqs"] = gcmvn_fbank
+            ucmvn_fbank = fbank.subtract(mean.unsqueeze(1)).divide(std.unsqueeze(1))
+            src["seqs"] = apply_padding_mask(ucmvn_fbank, padding_mask)
+            src_gcmvn["seqs"] = apply_padding_mask(gcmvn_fbank, padding_mask)
 
 
         elif isinstance(input, Path):
         elif isinstance(input, Path):
             # TODO: Replace with fairseq2.data once re-sampling is implemented.
             # TODO: Replace with fairseq2.data once re-sampling is implemented.
-            wav, sample_rate = torchaudio.load(path)
+            wav, sample_rate = torchaudio.load(input)
             wav = torchaudio.functional.resample(
             wav = torchaudio.functional.resample(
-                wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
+                wav,
+                orig_freq=sample_rate,
+                new_freq=AUDIO_SAMPLE_RATE,
             )
             )
             wav = wav.transpose(0, 1)
             wav = wav.transpose(0, 1)
 
 

+ 3 - 10
src/seamless_communication/inference/pretssel_generator.py

@@ -6,17 +6,11 @@
 from typing import List
 from typing import List
 
 
 import torch
 import torch
-from torch.nn import Module
-
-from fairseq2.typing import DataType, Device
-
 from fairseq2.assets import asset_store
 from fairseq2.assets import asset_store
-from fairseq2.data import (
-    Collater,
-    SequenceData,
-    VocabularyInfo,
-)
+from fairseq2.data import Collater, SequenceData, VocabularyInfo
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.nn.padding import get_seqs_and_padding_mask
+from fairseq2.typing import DataType, Device
+from torch.nn import Module
 
 
 from seamless_communication.inference.translator import BatchedSpeechOutput
 from seamless_communication.inference.translator import BatchedSpeechOutput
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
@@ -60,7 +54,6 @@ class PretsselGenerator(Module):
         tgt_lang: str,
         tgt_lang: str,
         prosody_encoder_input: SequenceData,
         prosody_encoder_input: SequenceData,
     ) -> BatchedSpeechOutput:
     ) -> BatchedSpeechOutput:
-
         units_batch, durations = [], []
         units_batch, durations = [], []
         for u in units:
         for u in units:
             unit = torch.tensor(u).to(self.unit_eos_token)
             unit = torch.tensor(u).to(self.unit_eos_token)

+ 9 - 18
src/seamless_communication/toxicity/mintox.py

@@ -7,26 +7,19 @@
 import logging
 import logging
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
 
 
-from torch import Tensor
 import torch
 import torch
-from torch.nn import functional as F
-
-
-from seamless_communication.inference.generator import SequenceGeneratorOptions
-from seamless_communication.toxicity.etox_bad_word_checker import (
-    ETOXBadWordChecker,
-)
-from fairseq2.generation import BannedSequenceProcessor
+from fairseq2.data import SequenceData
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.data.typing import StringLike
-from fairseq2.typing import Device
-from fairseq2.data import SequenceData
+from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.nn.padding import get_seqs_and_padding_mask
-from seamless_communication.models.unity import (
-    UnitTokenizer,
-    UnitYModel,
-)
+from fairseq2.typing import Device
+from torch import Tensor
+from torch.nn import functional as F
 
 
+from seamless_communication.inference.generator import SequenceGeneratorOptions
+from seamless_communication.models.unity import UnitTokenizer, UnitYModel
+from seamless_communication.toxicity.etox_bad_word_checker import ETOXBadWordChecker
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -84,9 +77,7 @@ def _replace_with_new_unit_output_in_batch(
         )
         )
     else:
     else:
         # pad on the new units
         # pad on the new units
-        new_units = F.pad(
-            new_units, pad=nb_pads, mode="constant", value=pad_idx
-        )
+        new_units = F.pad(new_units, pad=nb_pads, mode="constant", value=pad_idx)
     original_units[indices_with_toxicity_tensor] = new_units
     original_units[indices_with_toxicity_tensor] = new_units