Browse Source

Revise, clean up MinTox implementation. Part 1 (#96)

Can Balioglu 1 year ago
parent
commit
0bdc7b60ac

+ 50 - 0
src/seamless_communication/cards/mintox.yaml

@@ -0,0 +1,50 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+model_name: MinTox
+etox_dataset: https://dl.fbaipublicfiles.com/nllb/NLLB-200_TWL/nllb-200_twl.zip
+etox_lang_variants:
+  - kas_Arab
+  - kas_Deva
+  - knc_Arab
+  - knc_Latn
+  - min_Arab
+  - min_Latn
+  - zho_Hans
+  - zho_Hant
+
+sp_model: https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/tokenizer.model
+
+# For some languages, we use the SentencePiece model.
+sp_langs:
+  - asm
+  - ben
+  - cmn
+  - guj
+  - mya
+  - hin
+  - gom
+  - ibo
+  - jpn
+  - kan
+  - khm
+  - kor
+  - lao
+  - mai
+  - mal
+  - mar
+  - mni
+  - npi
+  - oan
+  - ory
+  - pan
+  - rwr
+  - sat
+  - tam
+  - tel
+  - tha
+  - wuu
+  - yue

+ 51 - 0
src/seamless_communication/inference/translator.py

@@ -38,6 +38,10 @@ from seamless_communication.models.unity import (
     unity_archs,
 )
 from seamless_communication.models.vocoder import load_vocoder_model
+from seamless_communication.toxicity import (
+    load_bad_word_checker,
+)
+from seamless_communication.toxicity.mintox import mintox_pipeline
 
 logging.basicConfig(
     level=logging.INFO,
@@ -79,6 +83,7 @@ class Translator(nn.Module):
         vocoder_name_or_card: Union[str, AssetCard, None],
         device: Device,
         text_tokenizer: Optional[TextTokenizer] = None,
+        apply_mintox: bool = False,
         dtype: DataType = torch.float16,
         input_modality: Optional[Modality] = None,
         output_modality: Optional[Modality] = None,
@@ -121,6 +126,13 @@ class Translator(nn.Module):
         if self.model.t2u_model is not None:
             self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
 
+        if apply_mintox:
+            self.bad_word_checker = load_bad_word_checker("mintox")
+        else:
+            self.bad_word_checker = None
+
+        self.apply_mintox = apply_mintox
+
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.convert_to_fbank = WaveformToFbankConverter(
@@ -259,6 +271,9 @@ class Translator(nn.Module):
         """
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
+        if self.apply_mintox and src_lang is None:
+            raise ValueError("`src_lang` must be specified when `apply_mintox` is `True`.")
+
         if isinstance(input, dict):
             src = cast(SequenceData, input)
         elif input_modality == Modality.SPEECH:
@@ -317,6 +332,42 @@ class Translator(nn.Module):
             prosody_encoder_input=prosody_encoder_input,
         )
 
+        if self.apply_mintox and task_str != Task.ASR.name:
+            if input_modality == Modality.SPEECH:
+                asr_text, _, = self.predict(
+                    input=input,
+                    task_str=Task.ASR.name,
+                    tgt_lang=tgt_lang,
+                    src_lang=src_lang,
+                    text_generation_opts=text_generation_opts,
+                    unit_generation_opts=unit_generation_opts,
+                    spkr=spkr,
+                    sample_rate=sample_rate,
+                    unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+                )
+                src_texts = [asr_text]
+            else:
+                src_texts = [input]
+
+            text_output, unit_output = mintox_pipeline(
+                model=self.model,
+                text_tokenizer=self.text_tokenizer,
+                unit_tokenizer=self.unit_tokenizer,
+                device=self.device,
+                src_lang=src_lang,
+                tgt_lang=tgt_lang,
+                model_input=src,
+                input_modality=input_modality,
+                output_modality=output_modality,
+                src_texts=src_texts,
+                original_text_out=text_output,
+                original_unit_out=unit_output,
+                unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+                text_generation_opts=text_generation_opts,
+                unit_generation_opts=unit_generation_opts,
+                bad_word_checker=self.bad_word_checker,
+            )
+
         if output_modality == Modality.TEXT:
             return text_output.sentences, None
         else:

+ 8 - 0
src/seamless_communication/toxicity/__init__.py

@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.toxicity.bad_word_checker import BadWordChecker as BadWordChecker
+from seamless_communication.toxicity.bad_word_checker import load_bad_word_checker as load_bad_word_checker

+ 193 - 0
src/seamless_communication/toxicity/bad_word_checker.py

@@ -0,0 +1,193 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import codecs
+import re
+from pathlib import Path
+from typing import Dict, List, Set, Union
+
+from fairseq2.assets import (
+    AssetCard,
+    AssetDownloadManager,
+    AssetStore,
+    asset_store,
+    download_manager,
+)
+from fairseq2.data import StringLike
+from fairseq2.data.text import SentencePieceEncoder, SentencePieceModel
+
+
+class BadWordChecker:
+    bad_words: Dict[str, List[str]]
+    bad_word_variants: Dict[str, Dict[str, List[str]]]
+    sp_encoder: SentencePieceEncoder
+    sp_langs: Set[str]
+
+    def __init__(
+        self,
+        bad_words: Dict[str, List[str]],
+        bad_word_variants: Dict[str, Dict[str, List[str]]],
+        sp_encoder: SentencePieceEncoder,
+        sp_langs: Set[str],
+    ):
+        self.bad_words = bad_words
+        self.bad_word_variants = bad_word_variants
+        self.sp_encoder = sp_encoder
+        self.sp_langs = sp_langs
+
+    def extract_bad_words(
+        self,
+        source_text: str,
+        target_text: str,
+        source_lang: str,
+        target_lang: str,
+    ) -> List[str]:
+        bad_words_in_target_text = self._get_bad_words(target_text, target_lang)
+
+        # If there are no bad words in the target text, do nothing.
+        if len(bad_words_in_target_text) == 0:
+            return []
+
+        bad_words_in_source_text = self._get_bad_words(source_text, source_lang)
+
+        # If there are bad words in the source text, do nothing.
+        if len(bad_words_in_source_text) > 0:
+            return []
+
+        bad_words: List[str] = []
+
+        for word in bad_words_in_target_text:
+            bad_words.extend(self.bad_word_variants[target_lang][word])
+
+        return bad_words
+
+    def _get_bad_words(self, text: str, lang: str) -> List[str]:
+        try:
+            bad_words = self.bad_words[lang]
+        except KeyError:
+            raise RuntimeError(f"MinTox model does not support {lang}.")
+
+        text = self._preprocess(text)
+
+        if lang in self.sp_langs:
+            return self._find_bad_words_in_sp(text, bad_words)
+
+        return self._find_bad_words(text, bad_words)
+
+    @staticmethod
+    def _preprocess(text: str) -> str:
+        return re.sub(r"[\W+]", " ", text.lower())
+
+    @staticmethod
+    def _find_bad_words(text: str, bad_words: List[str]) -> List[str]:
+        output: List[str] = []
+
+        text = " " + text.lower() + " "
+
+        bad_words = [" " + word.lower() + " " for word in bad_words]
+
+        for word in bad_words:
+            if word in text:
+                output.append(word)
+
+        return [word.strip(" ") for word in output]
+
+    def _find_bad_words_in_sp(self, text: str, bad_words: List[str]) -> List[str]:
+        text_tokens = self.sp_encoder.encode_as_tokens(text.lower())
+
+        output: List[str] = []
+
+        for word in bad_words:
+            word_tokens = self.sp_encoder.encode_as_tokens(word.lower())
+
+            if self._contains_tokens(text_tokens, word_tokens):
+                output.append(str(word))
+
+        return output
+
+    @staticmethod
+    def _contains_tokens(
+        text_tokens: List[StringLike], word_tokens: List[StringLike]
+    ) -> bool:
+        for i in range(len(text_tokens) - len(word_tokens) + 1):
+            for j in range(len(word_tokens)):
+                if text_tokens[i + j] != word_tokens[j]:
+                    break
+            else:
+                return True
+
+        return False
+
+
+class BadWordCheckerLoader:
+    asset_store: AssetStore
+    download_manager: AssetDownloadManager
+
+    def __init__(
+        self, asset_store: AssetStore, download_manager: AssetDownloadManager
+    ) -> None:
+        self.asset_store = asset_store
+        self.download_manager = download_manager
+
+    def __call__(self, model_name_or_card: Union[str, AssetCard]) -> BadWordChecker:
+        if isinstance(model_name_or_card, AssetCard):
+            card = model_name_or_card
+        else:
+            card = asset_store.retrieve_card(model_name_or_card)
+
+        bad_words: Dict[str, List[str]] = {}
+
+        bad_word_variants: Dict[str, Dict[str, List[str]]] = {}
+
+        etox_lang_variants = card.field("etox_lang_variants").as_set(str)
+
+        etox_ds_uri = card.field("etox_dataset").as_uri()
+
+        etox_ds_path = self.download_manager.download_dataset(etox_ds_uri, "etox")
+
+        for word_file in etox_ds_path.iterdir():
+            lang = word_file.name[:8]
+
+            if lang not in etox_lang_variants:
+                lang = lang[:3]
+
+            words = self._load_words(word_file)
+
+            bad_words[lang] = words
+
+            bad_word_variants[lang] = {}
+
+            for word in words:
+                bad_word_variants[lang][word] = [
+                    word.lower(),
+                    word.upper(),
+                    word.capitalize(),
+                ]
+
+        sp_uri = card.field("sp_model").as_uri()
+
+        sp_pathname = self.download_manager.download_tokenizer(sp_uri, card.name)
+
+        sp_model = SentencePieceModel(sp_pathname)
+
+        sp_encoder = SentencePieceEncoder(sp_model)
+
+        sp_langs = card.field("sp_langs").as_set(str)
+
+        return BadWordChecker(bad_words, bad_word_variants, sp_encoder, sp_langs)
+
+    @staticmethod
+    def _load_words(pathname: Path) -> List[str]:
+        words: List[str] = []
+
+        with open(pathname) as fp:
+            for line in fp.readlines():
+                words.append(codecs.encode(line, "rot_13").rstrip("\n"))
+
+        return list(set(words))  # Dedup.
+
+
+load_bad_word_checker = BadWordCheckerLoader(asset_store, download_manager)

+ 242 - 0
src/seamless_communication/toxicity/mintox.py

@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Union
+
+from torch import Tensor
+import torch
+from torch.nn import functional as F
+
+
+from seamless_communication.inference.generator import SequenceToUnitOutput, SequenceGeneratorOptions
+from seamless_communication.toxicity.bad_word_checker import (
+    BadWordChecker,
+)
+from fairseq2.generation import SequenceToTextOutput, BannedSequenceProcessor
+from fairseq2.data.text.text_tokenizer import TextTokenizer, TextTokenEncoder
+from fairseq2.data.typing import StringLike
+from fairseq2.typing import Device
+from fairseq2.data import SequenceData
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from seamless_communication.models.unity import (
+    UnitTokenizer,
+    UnitYModel,
+)
+
+
+def _extract_bad_words_with_batch_indices(
+    source_texts: List[StringLike],
+    target_texts: List[StringLike],
+    source_lang: str,
+    target_lang: str,
+    bad_word_checker: BadWordChecker
+) -> Tuple[List[str], List[int]]:
+    all_bad_words, batch_indices = [], []
+
+    for idx, (source_text, target_text) in enumerate(zip(source_texts, target_texts)):
+        bad_words = bad_word_checker.extract_bad_words(
+            str(source_text), str(target_text), source_lang, target_lang
+        )
+
+        if bad_words:
+            batch_indices.append(idx)
+
+            all_bad_words.extend(bad_words)
+
+    return all_bad_words, batch_indices
+
+
+def _replace_with_new_text_output_in_batch(
+    original_text_out: SequenceToTextOutput,
+    indices_with_toxicity: List[int],
+    indices_with_toxicity_tensor: Tensor,
+    new_text_output: SequenceToTextOutput,
+    batch_size: int,
+) -> None:
+    original_text_out.encoder_output[
+        indices_with_toxicity_tensor
+    ] = new_text_output.encoder_output
+    if original_text_out.encoder_padding_mask is not None:
+        assert new_text_output.encoder_padding_mask is not None
+
+        original_text_out.encoder_padding_mask.seq_lens[
+            indices_with_toxicity_tensor
+        ] = new_text_output.encoder_padding_mask.seq_lens
+
+    new_i = 0
+    for original_i in range(batch_size):
+        if (
+            original_i in indices_with_toxicity
+        ):  # indices_with_toxicity is a small list, using list should be fast enough
+            original_text_out.sentences[original_i] = new_text_output.sentences[new_i]
+            original_text_out.generator_output.results[
+                original_i
+            ] = new_text_output.generator_output.results[new_i]
+            new_i += 1
+
+
+def _replace_with_new_unit_output_in_batch(
+    unit_tokenizer: UnitTokenizer,
+    original_unit_out: SequenceToUnitOutput,
+    indices_with_toxicity: List[int],
+    indices_with_toxicity_tensor: Tensor,
+    new_unit_output: SequenceToUnitOutput,
+    batch_size: int,
+) -> None:
+    original_units_length = original_unit_out.units.size(1)
+    new_units_length = new_unit_output.units.size(1)
+    length_diff = abs(new_units_length - original_units_length)
+    nb_pads = (0, length_diff)
+    pad_idx = unit_tokenizer.vocab_info.pad_idx or 1
+    if new_units_length > original_units_length:
+        # pad on the original units
+        original_unit_out.units = F.pad(
+            original_unit_out.units,
+            pad=nb_pads,
+            mode="constant",
+            value=pad_idx,
+        )
+    else:
+        # pad on the new units
+        new_unit_output.units = F.pad(
+            new_unit_output.units,
+            pad=nb_pads,
+            mode="constant",
+            value=pad_idx,
+        )
+    original_unit_out.units[indices_with_toxicity_tensor] = new_unit_output.units
+
+    new_i = 0
+    if original_unit_out.generator_output is not None:
+        for original_i in range(batch_size):
+            if (
+                original_i in indices_with_toxicity
+                and new_unit_output.generator_output is not None
+            ):
+                original_unit_out.generator_output.results[
+                    original_i
+                ] = new_unit_output.generator_output.results[new_i]
+                new_i += 1
+
+
+def mintox_pipeline(
+    model: UnitYModel,
+    text_tokenizer: TextTokenizer,
+    unit_tokenizer: UnitTokenizer,
+    device: Device,
+    src_lang: str,
+    tgt_lang: str,
+    model_input: SequenceData,
+    input_modality: "Modality",
+    output_modality: "Modality",
+    src_texts: List[StringLike],
+    original_text_out: SequenceToTextOutput,
+    original_unit_out: Optional[SequenceToUnitOutput] = None,
+    unit_generation_ngram_filtering: bool = False,
+    text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
+        beam_size=5, soft_max_seq_len=(1, 200)
+    ),
+    unit_generation_opts: Optional[
+        SequenceGeneratorOptions
+    ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
+    bad_word_checker: BadWordChecker = None,
+) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
+    """MinTox: Mitigation at INference time of added TOXicity."""
+    from seamless_communication.inference.translator import Modality, Translator
+
+    def _get_banned_sequence_processor(
+        banned_sequences: List[str],
+    ) -> BannedSequenceProcessor:
+        text_encoder = text_tokenizer.create_raw_encoder(device=device)
+
+        banned_seqs = [text_encoder(b) for b in banned_sequences]
+        # A bannded string often appears after some puncatuations or symbols, we want
+        # to include this sequence of token ids as well.
+        # So we can ban not only the string "shit" but also "*shit", ",shit" etc.
+        banned_seqs += [text_encoder(f"★{x}")[1:] for x in banned_sequences]
+        return BannedSequenceProcessor(banned_seqs)
+
+    bad_words, indices_with_toxicity = _extract_bad_words_with_batch_indices(
+        src_texts,
+        original_text_out.sentences,
+        src_lang,
+        tgt_lang,
+        bad_word_checker,
+    )
+
+    if len(indices_with_toxicity) == 0:
+        # if no added toxicity is found, retrun the orignal output
+        if output_modality == Modality.TEXT:
+            return original_text_out, None
+        else:
+            return original_text_out, original_unit_out
+    else:
+        # otherwise, redo the prediction with a list of bad words to ban
+        banned_sequence_processor = _get_banned_sequence_processor(
+            banned_sequences=list(set(bad_words)),
+        )
+        text_generation_opts.step_processor = banned_sequence_processor
+        # select only the sources with toxicity
+        indices_with_toxicity_tensor = torch.tensor(
+            indices_with_toxicity, device=device
+        )
+        if model_input["is_ragged"]:
+            model_input["seqs"] = torch.index_select(
+                input=model_input["seqs"],
+                dim=0,
+                index=indices_with_toxicity_tensor,
+            )
+            model_input["seq_lens"] = torch.index_select(
+                input=model_input["seq_lens"],
+                dim=0,
+                index=indices_with_toxicity_tensor,
+            )
+        seqs, padding_mask = get_seqs_and_padding_mask(model_input)
+        # redo the prediction
+        new_text_output, new_unit_output = Translator.get_prediction(
+            model=model,
+            text_tokenizer=text_tokenizer,
+            unit_tokenizer=unit_tokenizer,
+            seqs=seqs,
+            padding_mask=padding_mask,
+            input_modality=input_modality,
+            output_modality=output_modality,
+            tgt_lang=tgt_lang,
+            unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+            text_generation_opts=text_generation_opts,
+            unit_generation_opts=unit_generation_opts,
+        )
+        batch_size = len(original_text_out.sentences)
+        if batch_size > 1:
+            # reconstruct the text output by updating the original one in place
+            _replace_with_new_text_output_in_batch(
+                original_text_out,
+                indices_with_toxicity,
+                indices_with_toxicity_tensor,
+                new_text_output,
+                batch_size,
+            )
+            final_text_output = original_text_out
+        else:
+            final_text_output = new_text_output
+
+        if output_modality == Modality.TEXT:
+            return final_text_output, None
+        else:
+            if batch_size > 1:
+                # reconstruct the unit output by updating the original one in place
+                _replace_with_new_unit_output_in_batch(
+                    unit_tokenizer,
+                    original_unit_out,
+                    indices_with_toxicity,
+                    indices_with_toxicity_tensor,
+                    new_unit_output,
+                    batch_size,
+                )
+                final_unit_out = original_unit_out
+            else:
+                final_unit_out = new_unit_output
+            return final_text_output, final_unit_out