Jelajahi Sumber

ASR-ETOX script (#157)

* rebranding

* etox pipeline

* make it work

* README+whisper working

* Add the integration tests for the mintox (#155)

* fix tests

* ununsed import

* add ETOX script

---------

Co-authored-by: hitchhicker <yubokai8@gmail.com>
Pierre Andrews 1 tahun lalu
induk
melakukan
d6fcd32b6b

+ 78 - 0
src/seamless_communication/cli/toxicity/README.md

@@ -0,0 +1,78 @@
+# Tool to compute toxicity in speech (ASR-ETOX) and text (ETOX)
+
+In this tool, we combine an ASR model (M4T or whisper) + the ETOX toxicity detection tool
+to compute a toxicity score for speech segments.
+
+ETOX was developed as part of the NLLB project and provides a wordlist detection mechanism for 200 languages. By applying ASR on top of the ETOX detection, we can detect the toxicity in speech. You can find a description of the toxicity detection wordlists in the paper cited below.
+
+## ASR-ETOX Usage
+
+The script works by taking a TSV as input. The TSV needs a header with column names, it can have multiple columns. By defaut the script will look at the `audio` for the name of the audio file to load, this can be overriden with `--audio_column`.
+The file path in the TSV can be absolute or relative to a root directory specified by `--audio_root_dir`. They can also be audiozip file formats with the appropriate byteoffset and length, e.g.: `fleurs_en_us_ogg_16khz.zip:89474600:49079`.
+
+You can choose the ASR model to use, by default it will use `seamlessM4T_v2_large`. If you prefer to use [whisper](https://github.com/openai/whisper) you can specify a `--model_name` that starts with `whisper_` and finishes with the whisper model name (e.g. `whisper_large`).
+
+## Outputs
+
+The output of the script is a new TSV file with three columns:
+- `text` the transcription
+- `toxicity` the number of toxic words detected
+- `bad_words` a list of toxic words, separated by `,`
+
+## Sample Command
+
+**ASR-ETOX**
+
+- using M4T:
+```bash
+python -m seamless_communication.cli.toxicity.asr_etox --lang deu --audio_column ref_tgt_audio s2t/en-xx/deu.tsv ~/etox.tsv
+```
+
+- using Whisper:
+```bash
+python -m seamless_communication.cli.toxicity.asr_etox --model_name whisper_large --lang fra --audio_column ref_tgt_audio s2t/en-xx/fra.tsv ~/etox.test.tsv
+```
+
+**ETOX**
+
+If you only care about getting the toxicity of text, you can use the etox.py script, with one text per line, specifying the language as the first argument.
+
+```bash
+cut -f 4 fleurs/s2t/en-xx/deu.tsv | python -m seamless_communication.cli.toxicity.etox deu > deu.toxicity.txt
+```
+
+You can also specify an input and output file:
+```bash
+python -m seamless_communication.cli.toxicity.etox deu deu.txt deu.toxicity.txt
+```
+
+
+# Citation
+If you use ETOX, ASR-ETOX and SeamlessM4T in your work, please cite:
+
+
+```bibtex
+@misc{costajussà2023toxicity,
+      title={Toxicity in Multilingual Machine Translation at Scale},
+      author={Marta R. Costa-jussà and Eric Smith and Christophe Ropers and Daniel Licht and Jean Maillard and Javier Ferrando and Carlos Escolano},
+      year={2023},
+      eprint={2210.03070},
+      archivePrefix={arXiv},
+      primaryClass={cs.CL}
+}
+```
+
+and
+
+```bibtex
+@article{seamlessm4t2023,
+  title={SeamlessM4T—Massively Multilingual \& Multimodal Machine Translation},
+  author={{Seamless Communication}, Lo\"{i}c Barrault, Yu-An Chung, Mariano Cora Meglioli, David Dale, Ning Dong, Paul-Ambroise Duquenne, Hady Elsahar, Hongyu Gong, Kevin Heffernan, John Hoffman, Christopher Klaiber, Pengwei Li, Daniel Licht, Jean Maillard, Alice Rakotoarison, Kaushik Ram Sadagopan, Guillaume Wenzek, Ethan Ye,  Bapi Akula, Peng-Jen Chen, Naji El Hachem, Brian Ellis, Gabriel Mejia Gonzalez, Justin Haaheim, Prangthip Hansanti, Russ Howes, Bernie Huang, Min-Jae Hwang, Hirofumi Inaguma, Somya Jain, Elahe Kalbassi, Amanda Kallet, Ilia Kulikov, Janice Lam, Daniel Li, Xutai Ma, Ruslan Mavlyutov, Benjamin Peloquin, Mohamed Ramadan, Abinesh Ramakrishnan, Anna Sun, Kevin Tran, Tuan Tran, Igor Tufanov, Vish Vogeti, Carleigh Wood, Yilin Yang, Bokai Yu, Pierre Andrews, Can Balioglu, Marta R. Costa-juss\`{a} \footnotemark[3], Onur \,{C}elebi,Maha Elbayad,Cynthia Gao, Francisco Guzm\'an, Justine Kao, Ann Lee, Alexandre Mourachko, Juan Pino, Sravya Popuri, Christophe Ropers, Safiyyah Saleem, Holger Schwenk, Paden Tomasello, Changhan Wang, Jeff Wang, Skyler Wang},
+  journal={ArXiv},
+  year={2023}
+}
+```
+
+# License
+
+seamless_communication is CC-BY-NC 4.0 licensed, as found in LICENSE file

+ 255 - 0
src/seamless_communication/cli/toxicity/asr_etox.py

@@ -0,0 +1,255 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import tempfile
+import typing as tp
+import torchaudio
+from tqdm import tqdm
+from seamless_communication.cli.eval_utils.compute_metrics import init_whisper_model
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
+from seamless_communication.inference.translator import Modality
+import torch
+
+from pathlib import Path
+from seamless_communication.inference import Translator
+from fairseq2.data import Collater, DataPipeline, FileMapper
+from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
+from fairseq2.data.text import StrSplitter, read_text
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.toxicity import load_etox_bad_word_checker
+
+from whisper.model import Whisper
+
+import logging
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="ASR ETOX will compute the toxicity level of speech inputs."
+    )
+    parser.add_argument(
+        "data_file",
+        type=Path,
+        help="Path to the input TSV manifest that list the audio files.",
+    )
+    parser.add_argument(
+        "output_file",
+        type=Path,
+        help="Path to a TSV file where to save the results.",
+    )
+    parser.add_argument(
+        "--lang",
+        type=str,
+        help="Language, language of the speech to transcribe",
+        required=True,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=str,
+        help="Root directory for the audio filenames in the data file.",
+    )
+    parser.add_argument(
+        "--audio_column",
+        type=str,
+        help="Name of the column where the audiofile is listed in the input tsv.",
+        default="audio",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        help=(
+            "Base model name (`seamlessM4T_medium`, "
+            "`seamlessM4T_large`, `seamlessM4T_v2_large`), "
+            " or whisper model, e.g. 'whisper_large'"
+        ),
+        default="seamlessM4T_v2_large",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--n_parallel",
+        type=int,
+        help="Number of data loading in parallel.",
+        default=4,
+    )
+    args, _unknown = parser.parse_known_args()
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    whisper_model = None
+    translator = None
+    is_whisper = False
+
+    if args.model_name.startswith("whisper_"):
+        logger.info("loading whisper model.")
+        _, model_name = args.model_name.split("_", maxsplit=1)
+        whisper_model = init_whisper_model(device, model_name)
+        is_whisper = True
+    else:
+        logger.info(f"loading {args.model_name} model.")
+        translator = Translator(
+            args.model_name,
+            None,
+            device,
+            text_tokenizer=None,
+            dtype=dtype,
+            input_modality=Modality.SPEECH,
+            output_modality=Modality.TEXT,
+            apply_mintox=False,
+        )
+
+    logger.info("loading etox.")
+    bad_word_checker = load_etox_bad_word_checker("mintox")
+
+    pipeline = build_data_pipeline(
+        data_file=args.data_file,
+        audio_root_dir=args.audio_root_dir,
+        batch_size=args.batch_size,
+        is_whisper=is_whisper,
+        device=device,
+        dtype=dtype,
+        n_parallel=args.n_parallel,
+        audio_column=args.audio_column,
+    )
+
+    logger.info("running ASR-ETOX.")
+    with open(args.output_file, "w", encoding="utf-8") as outf:
+        print("text", "toxicity", "bad_words", file=outf, sep="\t")
+        for example in tqdm(pipeline, unit="line"):
+            texts = get_text(
+                lang=args.lang,
+                example=example,
+                whisper_model=whisper_model,
+                translator=translator,
+                audio_column=args.audio_column,
+            )
+            for t in texts:
+                bad_words = bad_word_checker.get_bad_words(
+                    text=str(t),
+                    lang=args.lang,
+                )
+                print(
+                    t,
+                    len(bad_words),
+                    ",".join(bad_words),
+                    file=outf,
+                    sep="\t",
+                )
+
+
+def get_text(
+    lang: str,
+    example: tp.Dict[str, tp.Any],
+    whisper_model: Whisper,
+    translator: Translator,
+    audio_column: str,
+):
+    if whisper_model:
+        with tempfile.NamedTemporaryFile(suffix=".wav") as temp:
+            torchaudio.save(
+                temp.name,
+                example[audio_column]["data"]["waveform"]["seqs"][0]
+                .transpose(0, 1)
+                .cpu(),
+                int(example[audio_column]["data"]["sample_rate"][0]),
+                format="wav",
+            )
+            results = whisper_model.transcribe(
+                temp.name,
+                language=LANG3_LANG2[lang],
+            )
+            return [results["text"]]
+    else:
+        (text_output, _speech_output) = translator.predict(
+            example[audio_column]["data"]["fbank"],
+            "ASR",
+            lang,
+            src_lang=lang,
+        )
+        return text_output
+
+
+def build_data_pipeline(
+    data_file: Path,
+    audio_root_dir: str,
+    batch_size: int,
+    is_whisper: bool,
+    device: Device,
+    dtype: DataType,
+    audio_column: str = "audio",
+    n_parallel: int = 4,
+) -> DataPipeline:
+    with data_file.open("r", encoding="utf-8") as f:
+        header = f.readline().strip("\n").split("\t")
+
+    split_tsv = StrSplitter(names=header)
+
+    pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
+
+    map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10)
+
+    pipeline_builder.map(
+        map_file,
+        selector=audio_column,
+        num_parallel_calls=n_parallel,
+    )
+
+    decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=True,
+        device=device,
+        dtype=dtype,
+    )
+
+    # get tensor in waveform
+    steps = [decode_audio]
+    if not is_whisper:
+        # also get the fbanks
+        steps.append(convert_to_fbank)
+
+    pipeline_builder.map(
+        steps,
+        selector=f"{audio_column}.data",
+        num_parallel_calls=n_parallel,
+    )
+
+    if is_whisper:
+        # no batching for whisper
+        pipeline_builder.bucket(bucket_size=batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+if __name__ == "__main__":
+    main()

+ 43 - 0
src/seamless_communication/cli/toxicity/etox.py

@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import sys
+
+from seamless_communication.toxicity import load_etox_bad_word_checker
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="ETOX will compute the toxicity level of text inputs (STDIN > STDOUT)."
+    )
+    parser.add_argument(
+        "lang",
+        type=str,
+        help="Language, language of the speech to transcribe",
+    )
+    parser.add_argument(
+        "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
+    )
+    parser.add_argument(
+        "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout
+    )
+    args, _unknown = parser.parse_known_args()
+
+    bad_word_checker = load_etox_bad_word_checker("mintox")
+
+    print("text", "toxicity", "bad_words", sep="\t", file=args.output)
+    for line in args.input:
+        l = line.rstrip()
+        bad_words = bad_word_checker.get_bad_words(
+            text=l,
+            lang=args.lang,
+        )
+        print(l, len(bad_words), ",".join(bad_words), sep="\t", file=args.output)
+
+
+if __name__ == "__main__":
+    main()

+ 4 - 4
src/seamless_communication/inference/translator.py

@@ -39,8 +39,8 @@ from seamless_communication.models.unity import (
 )
 from seamless_communication.models.vocoder import load_vocoder_model
 from seamless_communication.toxicity import (
-    BadWordChecker,
-    load_bad_word_checker,
+    ETOXBadWordChecker,
+    load_etox_bad_word_checker,
 )
 from seamless_communication.toxicity.mintox import mintox_pipeline
 
@@ -127,9 +127,9 @@ class Translator(nn.Module):
         if self.model.t2u_model is not None:
             self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
 
-        self.bad_word_checker: Optional[BadWordChecker] = None
+        self.bad_word_checker: Optional[ETOXBadWordChecker] = None
         if apply_mintox:
-            self.bad_word_checker = load_bad_word_checker("mintox")
+            self.bad_word_checker = load_etox_bad_word_checker("mintox")
 
         self.apply_mintox = apply_mintox
 

+ 6 - 2
src/seamless_communication/toxicity/__init__.py

@@ -4,5 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from seamless_communication.toxicity.bad_word_checker import BadWordChecker as BadWordChecker
-from seamless_communication.toxicity.bad_word_checker import load_bad_word_checker as load_bad_word_checker
+from seamless_communication.toxicity.etox_bad_word_checker import (
+    ETOXBadWordChecker as ETOXBadWordChecker,
+)
+from seamless_communication.toxicity.etox_bad_word_checker import (
+    load_etox_bad_word_checker as load_etox_bad_word_checker,
+)

+ 34 - 15
src/seamless_communication/toxicity/bad_word_checker.py → src/seamless_communication/toxicity/etox_bad_word_checker.py

@@ -13,14 +13,14 @@ from fairseq2.assets import (
     AssetCard,
     AssetDownloadManager,
     AssetStore,
-    asset_store,
-    download_manager,
+    asset_store as base_asset_store,
+    download_manager as base_download_manager,
 )
 from fairseq2.data import StringLike
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceModel
 
 
-class BadWordChecker:
+class ETOXBadWordChecker:
     bad_words: Dict[str, List[str]]
     bad_word_variants: Dict[str, Dict[str, List[str]]]
     sp_encoder: SentencePieceEncoder
@@ -45,13 +45,19 @@ class BadWordChecker:
         source_lang: str,
         target_lang: str,
     ) -> List[str]:
-        bad_words_in_target_text = self._get_bad_words(target_text, target_lang)
+        bad_words_in_target_text = self.get_bad_words(
+            target_text,
+            target_lang,
+        )
 
         # If there are no bad words in the target text, do nothing.
         if len(bad_words_in_target_text) == 0:
             return []
 
-        bad_words_in_source_text = self._get_bad_words(source_text, source_lang)
+        bad_words_in_source_text = self.get_bad_words(
+            source_text,
+            source_lang,
+        )
 
         # If there are bad words in the source text, do nothing.
         if len(bad_words_in_source_text) > 0:
@@ -64,11 +70,11 @@ class BadWordChecker:
 
         return bad_words
 
-    def _get_bad_words(self, text: str, lang: str) -> List[str]:
+    def get_bad_words(self, text: str, lang: str) -> List[str]:
         try:
             bad_words = self.bad_words[lang]
-        except KeyError:
-            raise RuntimeError(f"MinTox model does not support {lang}.")
+        except KeyError as e:
+            raise RuntimeError(f"MinTox model does not support {lang}.") from e
 
         text = self._preprocess(text)
 
@@ -122,21 +128,26 @@ class BadWordChecker:
         return False
 
 
-class BadWordCheckerLoader:
+class ETOXBadWordCheckerLoader:
     asset_store: AssetStore
     download_manager: AssetDownloadManager
 
     def __init__(
-        self, asset_store: AssetStore, download_manager: AssetDownloadManager
+        self,
+        asset_store: AssetStore,
+        download_manager: AssetDownloadManager,
     ) -> None:
         self.asset_store = asset_store
         self.download_manager = download_manager
 
-    def __call__(self, model_name_or_card: Union[str, AssetCard]) -> BadWordChecker:
+    def __call__(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+    ) -> ETOXBadWordChecker:
         if isinstance(model_name_or_card, AssetCard):
             card = model_name_or_card
         else:
-            card = asset_store.retrieve_card(model_name_or_card)
+            card = self.asset_store.retrieve_card(model_name_or_card)
 
         bad_words: Dict[str, List[str]] = {}
 
@@ -177,17 +188,25 @@ class BadWordCheckerLoader:
 
         sp_langs = card.field("sp_langs").as_set(str)
 
-        return BadWordChecker(bad_words, bad_word_variants, sp_encoder, sp_langs)
+        return ETOXBadWordChecker(
+            bad_words,
+            bad_word_variants,
+            sp_encoder,
+            sp_langs,
+        )
 
     @staticmethod
     def _load_words(pathname: Path) -> List[str]:
         words: List[str] = []
 
-        with open(pathname) as fp:
+        with open(pathname, "r", encoding="utf-8") as fp:
             for line in fp.readlines():
                 words.append(codecs.encode(line, "rot_13").rstrip("\n"))
 
         return list(set(words))  # Dedup.
 
 
-load_bad_word_checker = BadWordCheckerLoader(asset_store, download_manager)
+load_etox_bad_word_checker = ETOXBadWordCheckerLoader(
+    base_asset_store,
+    base_download_manager,
+)

+ 4 - 4
src/seamless_communication/toxicity/mintox.py

@@ -16,8 +16,8 @@ from seamless_communication.inference.generator import (
     SequenceToUnitOutput,
     SequenceGeneratorOptions,
 )
-from seamless_communication.toxicity.bad_word_checker import (
-    BadWordChecker,
+from seamless_communication.toxicity.etox_bad_word_checker import (
+    ETOXBadWordChecker,
 )
 from fairseq2.generation import SequenceToTextOutput, BannedSequenceProcessor
 from fairseq2.data.text.text_tokenizer import TextTokenizer
@@ -39,7 +39,7 @@ def _extract_bad_words_with_batch_indices(
     target_texts: List[StringLike],
     source_lang: str,
     target_lang: str,
-    bad_word_checker: BadWordChecker,
+    bad_word_checker: ETOXBadWordChecker,
 ) -> Tuple[List[str], List[int]]:
     all_bad_words, batch_indices = [], []
 
@@ -149,7 +149,7 @@ def mintox_pipeline(
     unit_generation_opts: Optional[SequenceGeneratorOptions] = SequenceGeneratorOptions(
         beam_size=5, soft_max_seq_len=(25, 50)
     ),
-    bad_word_checker: BadWordChecker = None,
+    bad_word_checker: ETOXBadWordChecker = None,
     duration_factor: float = 1.0,
     prosody_encoder_input: Optional[SequenceData] = None,
 ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:

+ 11 - 5
tests/integration/inference/test_mintox.py

@@ -6,13 +6,20 @@
 
 from fairseq2.assets import download_manager
 from seamless_communication.inference.translator import Translator
+from seamless_communication.toxicity.etox_bad_word_checker import ETOXBadWordChecker
 from seamless_communication.toxicity.mintox import _extract_bad_words_with_batch_indices
 from tests.common import device, get_default_dtype
-from seamless_communication.toxicity import load_bad_word_checker
+from seamless_communication.toxicity import load_etox_bad_word_checker
 
+import pytest
 
-def test_mintox_s2tt():
-    bad_words_checker = load_bad_word_checker("mintox")
+
+@pytest.fixture
+def bad_words_checker() -> ETOXBadWordChecker:
+    return load_etox_bad_word_checker("mintox")
+
+
+def test_mintox_s2tt(bad_words_checker: ETOXBadWordChecker):
     model_name = "seamlessM4T_v2_large"
     vocoder_name = "vocoder_v2"
     src_text = "The strategy proved effective, cutting off vital military and civilian supplies, although this blockade violated generally accepted international law codified by several international agreements of the past two centuries."
@@ -66,8 +73,7 @@ def test_mintox_s2tt():
     assert batch_indices == []
 
 
-def test_mintox_t2tt():
-    bad_words_checker = load_bad_word_checker("mintox")
+def test_mintox_t2tt(bad_words_checker: ETOXBadWordChecker):
     model_name = "seamlessM4T_v2_large"
     vocoder_name = "vocoder_v2"
     src_text = "I wonder what it'd be like to be a doff parent."