|
@@ -221,12 +221,8 @@ class Translator(nn.Module):
|
|
|
task_str: str,
|
|
|
tgt_lang: str,
|
|
|
src_lang: Optional[str] = None,
|
|
|
- text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
|
|
|
- beam_size=5, soft_max_seq_len=(1, 200)
|
|
|
- ),
|
|
|
- unit_generation_opts: Optional[
|
|
|
- SequenceGeneratorOptions
|
|
|
- ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
|
|
|
+ text_generation_opts: Optional[SequenceGeneratorOptions] = None,
|
|
|
+ unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
|
|
|
spkr: Optional[int] = -1,
|
|
|
sample_rate: int = 16000,
|
|
|
unit_generation_ngram_filtering: bool = False,
|
|
@@ -313,6 +309,15 @@ class Translator(nn.Module):
|
|
|
|
|
|
seqs, padding_mask = get_seqs_and_padding_mask(src)
|
|
|
|
|
|
+ if text_generation_opts is None:
|
|
|
+ text_generation_opts = SequenceGeneratorOptions(
|
|
|
+ beam_size=5, soft_max_seq_len=(1, 200)
|
|
|
+ )
|
|
|
+ if unit_generation_opts is None:
|
|
|
+ unit_generation_opts = SequenceGeneratorOptions(
|
|
|
+ beam_size=5, soft_max_seq_len=(25, 50)
|
|
|
+ )
|
|
|
+
|
|
|
text_output, unit_output = self.get_prediction(
|
|
|
self.model,
|
|
|
self.text_tokenizer,
|
|
@@ -345,9 +350,9 @@ class Translator(nn.Module):
|
|
|
sample_rate=sample_rate,
|
|
|
unit_generation_ngram_filtering=unit_generation_ngram_filtering,
|
|
|
)
|
|
|
- src_texts = [asr_text]
|
|
|
+ src_texts = [str(asr_text)]
|
|
|
else:
|
|
|
- src_texts = [input]
|
|
|
+ src_texts = [str(input)]
|
|
|
|
|
|
text_output, unit_output = mintox_pipeline(
|
|
|
model=self.model,
|
|
@@ -389,18 +394,27 @@ class Translator(nn.Module):
|
|
|
audio_wavs = []
|
|
|
speech_units = []
|
|
|
for i in range(len(units)):
|
|
|
- padding_mask = (
|
|
|
+ assert self.model.t2u_model is not None
|
|
|
+ unit_padding_mask = (
|
|
|
units[i] != self.model.t2u_model.target_vocab_info.pad_idx
|
|
|
)
|
|
|
- u = units[i][padding_mask]
|
|
|
+ u = units[i][unit_padding_mask]
|
|
|
speech_units.append(u.tolist())
|
|
|
-
|
|
|
+
|
|
|
if self.vocoder is not None:
|
|
|
translated_audio_wav = self.vocoder(
|
|
|
units, tgt_lang, spkr, dur_prediction=duration_prediction
|
|
|
)
|
|
|
for i in range(len(units)):
|
|
|
- padding_removed_audio_wav = translated_audio_wav[i, :, :int(translated_audio_wav.size(-1)*len(speech_units[i])/len(units[i]))].unsqueeze(0)
|
|
|
+ padding_removed_audio_wav = translated_audio_wav[
|
|
|
+ i,
|
|
|
+ :,
|
|
|
+ : int(
|
|
|
+ translated_audio_wav.size(-1)
|
|
|
+ * len(speech_units[i])
|
|
|
+ / len(units[i])
|
|
|
+ ),
|
|
|
+ ].unsqueeze(0)
|
|
|
audio_wavs.append(padding_removed_audio_wav)
|
|
|
return (
|
|
|
text_output.sentences,
|