Bladeren bron

mintox - Add option to consume pretranscribed text + log mintox for cloudwatch (#131)

Pierre Andrews 1 jaar geleden
bovenliggende
commit
87e10d101b
2 gewijzigde bestanden met toevoegingen van 33 en 17 verwijderingen
  1. 23 17
      src/seamless_communication/inference/translator.py
  2. 10 0
      src/seamless_communication/toxicity/mintox.py

+ 23 - 17
src/seamless_communication/inference/translator.py

@@ -232,6 +232,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
+        src_text: Optional[str] = None 
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -254,17 +255,19 @@ class Translator(nn.Module):
         :param unit_generation_ngram_filtering:
             If True, removes consecutive repeated ngrams
             from the decoded unit output.
-
+        :param src_text:
+            Optional source transcript (obtained by ASR for instance). This is used for
+            applying mintox toxicity mitigation. If this is not specify and apply_mintox=True
+            then src_lang must be specified and ASR will be run on the audio source.
+            
         :returns:
             - Batched list of Translated text.
             - Translated BatchedSpeechOutput.
         """
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
-        if self.apply_mintox and src_lang is None:
-            raise ValueError(
-                "`src_lang` must be specified when `apply_mintox` is `True`."
-            )
+        if self.apply_mintox and not (src_lang is not None or src_text is not None) :
+            raise ValueError("`src_lang` must be specified when `apply_mintox` is `True` or you need to specify src_text.")
 
         if isinstance(input, dict):
             src = cast(SequenceData, input)
@@ -326,18 +329,21 @@ class Translator(nn.Module):
 
         if self.apply_mintox and task_str != Task.ASR.name:
             if input_modality == Modality.SPEECH:
-                asr_text, _, = self.predict(
-                    input=input,
-                    task_str=Task.ASR.name,
-                    tgt_lang=tgt_lang,
-                    src_lang=src_lang,
-                    text_generation_opts=text_generation_opts,
-                    unit_generation_opts=unit_generation_opts,
-                    spkr=spkr,
-                    sample_rate=sample_rate,
-                    unit_generation_ngram_filtering=unit_generation_ngram_filtering,
-                )
-                src_texts = [asr_text]
+                if src_text is not None:
+                    src_texts = [src_text]
+                else:
+                    asr_text, _, = self.predict(
+                        input=input,
+                        task_str=Task.ASR.name,
+                        tgt_lang=tgt_lang,
+                        src_lang=src_lang,
+                        text_generation_opts=text_generation_opts,
+                        unit_generation_opts=unit_generation_opts,
+                        spkr=spkr,
+                        sample_rate=sample_rate,
+                        unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+                    )
+                    src_texts = [asr_text]
             else:
                 src_texts = [input]
 

+ 10 - 0
src/seamless_communication/toxicity/mintox.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import logging
 from typing import List, Optional, Tuple
 
 from torch import Tensor
@@ -30,6 +31,9 @@ from seamless_communication.models.unity import (
 )
 
 
+logger = logging.getLogger(__name__)
+
+
 def _extract_bad_words_with_batch_indices(
     source_texts: List[StringLike],
     target_texts: List[StringLike],
@@ -177,6 +181,12 @@ def mintox_pipeline(
         else:
             return original_text_out, original_unit_out
     else:
+        logger.info(
+            "TOX src_lang=%s tgt_lang=%s added_tox=%d",
+            src_lang,
+            tgt_lang,
+            len(indices_with_toxicity),
+        )
         # otherwise, redo the prediction with a list of bad words to ban
         banned_sequence_processor = _get_banned_sequence_processor(
             banned_sequences=list(set(bad_words)),