|
@@ -232,6 +232,7 @@ class Translator(nn.Module):
|
|
|
unit_generation_ngram_filtering: bool = False,
|
|
|
duration_factor: float = 1.0,
|
|
|
prosody_encoder_input: Optional[SequenceData] = None,
|
|
|
+ src_text: Optional[str] = None
|
|
|
) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
|
|
|
"""
|
|
|
The main method used to perform inference on all tasks.
|
|
@@ -254,17 +255,19 @@ class Translator(nn.Module):
|
|
|
:param unit_generation_ngram_filtering:
|
|
|
If True, removes consecutive repeated ngrams
|
|
|
from the decoded unit output.
|
|
|
-
|
|
|
+ :param src_text:
|
|
|
+ Optional source transcript (obtained by ASR for instance). This is used for
|
|
|
+ applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
|
|
|
+ then src_lang must be specified and ASR will be run on the audio source.
|
|
|
+
|
|
|
:returns:
|
|
|
- Batched list of Translated text.
|
|
|
- Translated BatchedSpeechOutput.
|
|
|
"""
|
|
|
input_modality, output_modality = self.get_modalities_from_task_str(task_str)
|
|
|
|
|
|
- if self.apply_mintox and src_lang is None:
|
|
|
- raise ValueError(
|
|
|
- "`src_lang` must be specified when `apply_mintox` is `True`."
|
|
|
- )
|
|
|
+ if self.apply_mintox and not (src_lang is not None or src_text is not None) :
|
|
|
+ raise ValueError("`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text.")
|
|
|
|
|
|
if isinstance(input, dict):
|
|
|
src = cast(SequenceData, input)
|
|
@@ -326,18 +329,21 @@ class Translator(nn.Module):
|
|
|
|
|
|
if self.apply_mintox and task_str != Task.ASR.name:
|
|
|
if input_modality == Modality.SPEECH:
|
|
|
- asr_text, _, = self.predict(
|
|
|
- input=input,
|
|
|
- task_str=Task.ASR.name,
|
|
|
- tgt_lang=tgt_lang,
|
|
|
- src_lang=src_lang,
|
|
|
- text_generation_opts=text_generation_opts,
|
|
|
- unit_generation_opts=unit_generation_opts,
|
|
|
- spkr=spkr,
|
|
|
- sample_rate=sample_rate,
|
|
|
- unit_generation_ngram_filtering=unit_generation_ngram_filtering,
|
|
|
- )
|
|
|
- src_texts = [asr_text]
|
|
|
+ if src_text is not None:
|
|
|
+ src_texts = [src_text]
|
|
|
+ else:
|
|
|
+ asr_text, _, = self.predict(
|
|
|
+ input=input,
|
|
|
+ task_str=Task.ASR.name,
|
|
|
+ tgt_lang=tgt_lang,
|
|
|
+ src_lang=src_lang,
|
|
|
+ text_generation_opts=text_generation_opts,
|
|
|
+ unit_generation_opts=unit_generation_opts,
|
|
|
+ spkr=spkr,
|
|
|
+ sample_rate=sample_rate,
|
|
|
+ unit_generation_ngram_filtering=unit_generation_ngram_filtering,
|
|
|
+ )
|
|
|
+ src_texts = [asr_text]
|
|
|
else:
|
|
|
src_texts = [input]
|
|
|
|