|
@@ -3,16 +3,16 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# MIT_LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from typing import cast
|
|
|
from copy import deepcopy
|
|
|
from pathlib import Path
|
|
|
-from typing import List, Optional, Tuple, Union
|
|
|
+from typing import List, Optional, Tuple, Union, cast
|
|
|
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import SequenceData, StringLike
|
|
|
from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
+from fairseq2.nn.padding import apply_padding_mask, get_seqs_and_padding_mask
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from torch.nn import Module
|
|
|
|
|
@@ -109,19 +109,25 @@ class ExpressiveTranslator(Module):
|
|
|
src = cast(SequenceData, input)
|
|
|
src_gcmvn = deepcopy(src)
|
|
|
|
|
|
- fbank = src["seqs"]
|
|
|
+ fbank, padding_mask = get_seqs_and_padding_mask(src)
|
|
|
gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
- std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
|
|
|
- ucmvn_fbank = fbank.subtract(mean).divide(std)
|
|
|
+ # due to padding, batched std_mean calculation is wrong
|
|
|
+ mean = torch.zeros_like(fbank[:, 0])
|
|
|
+ std = torch.zeros_like(fbank[:, 0])
|
|
|
+ for i, (i_fbank, i_seq_len) in enumerate(zip(fbank, src["seq_lens"])):
|
|
|
+ std[i], mean[i] = torch.std_mean(i_fbank[:i_seq_len], dim=0)
|
|
|
|
|
|
- src["seqs"] = ucmvn_fbank
|
|
|
- src_gcmvn["seqs"] = gcmvn_fbank
|
|
|
+ ucmvn_fbank = fbank.subtract(mean.unsqueeze(1)).divide(std.unsqueeze(1))
|
|
|
+ src["seqs"] = apply_padding_mask(ucmvn_fbank, padding_mask)
|
|
|
+ src_gcmvn["seqs"] = apply_padding_mask(gcmvn_fbank, padding_mask)
|
|
|
|
|
|
elif isinstance(input, Path):
|
|
|
# TODO: Replace with fairseq2.data once re-sampling is implemented.
|
|
|
- wav, sample_rate = torchaudio.load(path)
|
|
|
+ wav, sample_rate = torchaudio.load(input)
|
|
|
wav = torchaudio.functional.resample(
|
|
|
- wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
|
|
|
+ wav,
|
|
|
+ orig_freq=sample_rate,
|
|
|
+ new_freq=AUDIO_SAMPLE_RATE,
|
|
|
)
|
|
|
wav = wav.transpose(0, 1)
|
|
|
|