Pierre Andrews 1 рік тому
батько
коміт
07279891ea

+ 2 - 1
README.md

@@ -303,7 +303,8 @@ The following non-generative components are MIT licensed as found in [MIT_LICENS
 - Code
 - Text only part of the mExpresso dataset found in the [SeamlessExpressive README](docs/expressive/README.md).
 - UnitY2 forced alignment extractor found in the [UnitY2 Aligner README](docs/m4t/unity2_aligner_README.md).
-- Speech toxicity tool with the etox dataset found in the [Toxicity README](src/seamless_communication/cli/toxicity).
+- Speech toxicity tool with the etox dataset found in the [ETOX README](src/seamless_communication/cli/toxicity/etox).
+- MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector [Mutox README](src/seamless_communication/cli/toxicity/mutox)
 
 The following models are CC-BY-NC 4.0 licensed as found in the [LICENSE](LICENSE):
 - SeamlessM4T models (v1 and v2).

+ 1 - 0
setup.py

@@ -27,6 +27,7 @@ setup(
         "librosa",
         "openai-whisper",
         "simuleval~=1.1.3",
+        "sonar-space==0.2.*",
         "soundfile",
         "scipy",
         "torchaudio",

+ 11 - 0
src/seamless_communication/cards/mutox.yaml

@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: mutox
+model_type: mutox_classifier
+model_arch: mutox
+checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt"
+input_size: 1024

+ 0 - 0
src/seamless_communication/cli/toxicity/README.md → src/seamless_communication/cli/toxicity/etox/README.md


+ 0 - 0
src/seamless_communication/cli/toxicity/asr_etox.py → src/seamless_communication/cli/toxicity/etox/asr_etox.py


+ 0 - 0
src/seamless_communication/cli/toxicity/etox.py → src/seamless_communication/cli/toxicity/etox/etox.py


+ 87 - 0
src/seamless_communication/cli/toxicity/mutox/README.md

@@ -0,0 +1,87 @@
+# MuTox: MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector
+
+MuTox, the first highly multilingual audio-based dataset with toxicity labels.
+The dataset consists of 20k audio utterances for English and Spanish, and 4k for
+the other 19 languages. To showcase the quality of this dataset, we train the
+MuTox audio-based toxicity classifier, which allows zero-shot toxicity detection
+across a broad range of languages. This classifier outperforms existing
+text-based trainable classifiers by more than 1% AUC, while increasing the
+language coverage from 8 to 100+ languages. When compared to a wordlist-based
+classifier that covers a similar number of languages, MuTox improves precision
+and recall by ∼2.5 times.
+
+## License
+
+The mutox code and model are licensed under the MIT license (see MIT_LICENSE
+file at the root of seamless_communication). The mutox model depends on SONAR
+encoders, most are under the MIT license but a few are under CC-BY-NC license.
+See the [SONAR repository](https://github.com/facebookresearch/SONAR) for
+details.
+
+## Dataset Languages.
+
+- English,
+- Spanish,
+- Arabic,
+- Bengali,
+- Mandarin Chinese,
+- Dutch,
+- French,
+- German,
+- Hindi,
+- Indonesian,
+- Italian,
+- Japanese,
+- Korean,
+- Portuguese,
+- Russian,
+- Swahili,
+- Tagalog,
+- Thai,
+- Turkish,
+- Urdu,
+- Vietnamese
+
+## Classifier details.
+
+We use multi-modal and multilingual
+[SONAR](https://github.com/facebookresearch/SONAR) encoders from (Duquenne et
+al., 2023). For the classifier, we use variable input sizes for the 3
+feedforward layers (1024, 512, and 128).
+
+## Classifier Quick Start
+
+This introduces the MuTox speech toxicity model, this relies on computing the
+sonar embedding and then classifying it through the MuTox model. The
+`cli/mutox/mutox.py` provides an example of reading a TSV, computing the SONAR
+embedding and running the classifier on the results:
+
+```bash
+python -m seamless_communication.cli.toxicity.mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv
+```
+
+You can also work with text:
+
+```bash
+python -m seamless_communication.cli.toxicity.mutox.mutox_text --lang fra_Latn sentences.txt
+```
+
+You can also check the mutox example notebook in this directory.
+
+## Dataset
+
+The dataset is available in this [file](https://dl.fbaipublicfiles.com/seamless/datasets/mutox.csv). The dataset is licensed under the MIT license (see MIT_LICENSE
+file at the root of seamless_communication).
+
+## Citation
+
+```bitex
+@misc{costajussà2023mutox,
+      title={MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector},
+      author={ Marta R. Costa-jussà, Mariano Coria Meglioli, Pierre Andrews, David Dale, Prangthip Hansanti, Elahe Kalbassi, Alex Mourachko, Christophe Ropers, Carleigh Wood},
+      year={2023},
+      eprint={},
+      archivePrefix={arXiv},
+      primaryClass={cs.CL}
+}
+```

+ 245 - 0
src/seamless_communication/cli/toxicity/mutox/mutox_example.ipynb

@@ -0,0 +1,245 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Copyright (c) Meta Platforms, Inc. and affiliates\n",
+    "# All rights reserved.\n",
+    "#\n",
+    "# This source code is licensed under the license found in the\n",
+    "# MIT_LICENSE file in the root directory of this source tree."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# MUTOX toxicity classification\n",
+    "\n",
+    "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "from pathlib import Path\n",
+    "\n",
+    "if torch.cuda.is_available():\n",
+    "    device = torch.device(\"cuda:0\")\n",
+    "    dtype = torch.float16\n",
+    "else:\n",
+    "    device = torch.device(\"cpu\")\n",
+    "    dtype = torch.float32"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Speech Scoring"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "1. download some demo audio segments\n",
+    "2. create a tsv file to feed to the speech scoring pipeline\n",
+    "3. load the model and build the pipeline\n",
+    "4. go through the batches in the pipeline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# get demo file\n",
+    "import urllib.request\n",
+    "import tempfile\n",
+    "\n",
+    "files = [\n",
+    "    (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n",
+    "    (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n",
+    "]\n",
+    "\n",
+    "tmpdir = Path(tempfile.mkdtemp())\n",
+    "tsv_file = (tmpdir / 'data.tsv')\n",
+    "with tsv_file.open('w') as tsv_file_p:\n",
+    "    print('path', file=tsv_file_p)\n",
+    "    for (uri, name) in files:\n",
+    "        dl = tmpdir / name\n",
+    "        urllib.request.urlretrieve(uri, dl)\n",
+    "        print(dl, file=tsv_file_p)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
+    "from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n",
+    "\n",
+    "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
+    "    mutox_classifier_name =\"mutox\",\n",
+    "    encoder_name=f\"sonar_speech_encoder_eng\",\n",
+    "    device=device,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n",
+    "    data_file=tsv_file,\n",
+    "    audio_root_dir=None,\n",
+    "    audio_path_index=0,\n",
+    "    target_lang=\"eng\",\n",
+    "    batch_size=4,\n",
+    "    pad_idx=0,\n",
+    "    device=device,\n",
+    "    fbank_dtype=torch.float32,\n",
+    "    n_parallel=4\n",
+    "))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n",
+      "/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n"
+     ]
+    }
+   ],
+   "source": [
+    "for batch in pipeline:\n",
+    "    ex = batch['audio']\n",
+    "    for idx, path in enumerate(ex['path']):\n",
+    "        print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# cleanup tmp dir\n",
+    "import shutil\n",
+    "shutil.rmtree(tmpdir)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Text Scoring\n",
+    "\n",
+    "1. load the sonar text encoder\n",
+    "2. load the mutox classifier model\n",
+    "3. compute embedding for a sentence\n",
+    "4. score this embedding"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n"
+     ]
+    }
+   ],
+   "source": [
+    "from seamless_communication.toxicity.mutox.loader import load_mutox_model\n",
+    "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n",
+    "\n",
+    "t2vec_model = TextToEmbeddingModelPipeline(\n",
+    "    encoder=\"text_sonar_basic_encoder\",\n",
+    "    tokenizer=\"text_sonar_basic_encoder\",\n",
+    ")\n",
+    "text_column='lang_txt'\n",
+    "classifier = load_mutox_model(\n",
+    "    \"mutox\",\n",
+    "    device=device,\n",
+    "    dtype=dtype,\n",
+    ").eval()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "with torch.inference_mode():\n",
+    "    emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n",
+    "    x = classifier(emb.to(device).half())\n",
+    "\n",
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "sc_fr2",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 140 - 0
src/seamless_communication/cli/toxicity/mutox/mutox_speech.py

@@ -0,0 +1,140 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from pathlib import Path
+
+from sonar.inference_pipelines.speech import (
+    SpeechInferenceParams,
+)
+from seamless_communication.toxicity.mutox.speech_pipeline import (
+    MutoxSpeechClassifierPipeline,
+)
+
+import logging
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="Mutox speech will compute a toxicity score for each speech segment it is provided."
+    )
+    parser.add_argument(
+        "data_file",
+        type=Path,
+        help="Path to the input TSV manifest that list the audio files.",
+    )
+    parser.add_argument(
+        "output_file",
+        type=Path,
+        help="Path to a TSV file where to save the results.",
+    )
+    parser.add_argument(
+        "--lang",
+        type=str,
+        help="Language, language of the speech being passed as input, three letter code",
+        required=True,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=str,
+        help="Root directory for the audio filenames in the data file.",
+    )
+    parser.add_argument(
+        "--audio_path_index",
+        type=int,
+        help="Index of the column where the audiofile is listed in the input tsv.",
+        default="audio",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--n_parallel",
+        type=int,
+        help="Number of data loading in parallel.",
+        default=4,
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        help="name of the device to use with torch.",
+        required=False,
+    )
+    args, _unknown = parser.parse_known_args()
+
+    if args.device is not None:
+        device = torch.device(args.device)
+        dtype = torch.float32
+        if device.type == "cuda":
+            dtype = torch.float16
+    elif torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+        logger.info("using cuda:0, %s", dtype)
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+        logger.info("no gpu, using cpu")
+
+    logger.info("loading models.")
+
+    pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(
+        mutox_classifier_name="mutox",
+        encoder_name=f"sonar_speech_encoder_{args.lang}",
+        device=device,
+    )
+
+    pipeline = pipeline_builder.build_pipeline(
+        SpeechInferenceParams(
+            data_file=args.data_file,
+            audio_root_dir=args.audio_root_dir,
+            audio_path_index=args.audio_path_index,
+            target_lang=args.lang,
+            batch_size=args.batch_size,
+            pad_idx=0,
+            device=device,
+            fbank_dtype=torch.float32,
+            n_parallel=args.n_parallel,
+        )
+    )
+
+    logger.info("processing.")
+
+    with open(args.output_file, "w", encoding="utf-8") as outf:
+        print(
+            "input_audio_path",
+            "score",
+            sep="\t",
+            file=outf,
+        )
+        for example in tqdm(pipeline):
+            ex = example["audio"]
+            for idx, path in enumerate(ex["path"]):
+                print(
+                    str(path),
+                    ex["data"][idx].item(),
+                    sep="\t",
+                    file=outf,
+                )
+
+    logger.info(f"Done, outputs are in {args.output_file}.")
+
+
+if __name__ == "__main__":
+    main()

+ 98 - 0
src/seamless_communication/cli/toxicity/mutox/mutox_text.py

@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import argparse
+import sys
+
+import torch
+from seamless_communication.toxicity.mutox.loader import load_mutox_model
+from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
+
+import logging
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+CPU_DEVICE = torch.device("cpu")
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="Mutox Text will compute a toxicity score for each sentence it is passed."
+    )
+
+    parser.add_argument(
+        "lang",
+        type=str,
+        help="Language of the input text, nllb format with script.",
+    )
+    parser.add_argument(
+        "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
+    )
+    parser.add_argument(
+        "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        help="name of the device to use with torch.",
+        required=False,
+    )
+    args, _unknown = parser.parse_known_args()
+
+    if args.device is not None:
+        device = torch.device(args.device)
+        dtype = torch.float32
+        if device.type == "cuda":
+            dtype = torch.float16
+    elif torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    t2vec_model = TextToEmbeddingModelPipeline(
+        encoder="text_sonar_basic_encoder",
+        tokenizer="text_sonar_basic_encoder",
+        device=device,
+    )
+
+    classifier = load_mutox_model(
+        "mutox",
+        device=device,
+        dtype=dtype,
+    ).eval()
+
+    def write_result(batch):
+        emb = t2vec_model.predict(batch, source_lang=args.lang)
+        scores = classifier(emb.half())
+        for s, t in zip(scores, batch):
+            print(t, s.item(), sep="\t", file=args.output)
+
+    with torch.inference_mode():
+        print("text", "score", sep="\t", file=args.output)
+        batch = []
+        for line in args.input:
+            batch.append(line.rstrip())
+            if len(batch) >= args.batch_size:
+                write_result(batch)
+                batch = []
+
+        if len(batch):
+            write_result(batch)
+
+
+if __name__ == "__main__":
+    main()

+ 91 - 0
src/seamless_communication/toxicity/mutox/builder.py

@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import typing as tp
+from seamless_communication.toxicity.mutox.classifier import (
+    MutoxClassifier,
+    MutoxConfig,
+)
+import torch
+from torch import nn
+from fairseq2.typing import DataType, Device
+
+
+class MutoxClassifierBuilder:
+    """
+    Builder module for MutoxClassifier model
+    """
+
+    config: MutoxConfig
+    device: tp.Optional[Device]
+    dtype: tp.Optional[DataType]
+
+    def __init__(
+        self,
+        config: MutoxConfig,
+        *,
+        device: tp.Optional[Device] = None,
+        dtype: tp.Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> MutoxClassifier:
+        model_h1 = nn.Sequential(
+            nn.Dropout(0.01),
+            nn.Linear(self.config.input_size, 512),
+        )
+
+        model_h2 = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(512, 128),
+        )
+
+        model_h3 = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(128, 1),
+        )
+
+        model_all = nn.Sequential(
+            model_h1,
+            model_h2,
+            model_h3,
+        )
+
+        return MutoxClassifier(model_all,).to(
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_mutox_model(
+    config: MutoxConfig,
+    device: tp.Optional[Device] = None,
+    dtype: tp.Optional[DataType] = None,
+) -> MutoxClassifier:
+    """Create a Mutox Classifier model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    return MutoxClassifierBuilder(
+        config,
+        device=device,
+        dtype=dtype,
+    ).build_model()

+ 36 - 0
src/seamless_communication/toxicity/mutox/classifier.py

@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+import torch
+from torch import nn
+from fairseq2.typing import DataType, Device
+
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from typing import Optional
+
+
+class MutoxClassifier(nn.Module):
+    def __init__(
+        self,
+        model_all,
+    ):
+        super().__init__()
+        self.model_all = model_all
+
+    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        return self.model_all(inputs)
+
+
+@dataclass
+class MutoxConfig:
+    """Holds the configuration of a Mutox Classifier model."""
+
+    # size of the input embedding supported by this model
+    input_size: int
+
+
+mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier")

+ 46 - 0
src/seamless_communication/toxicity/mutox/loader.py

@@ -0,0 +1,46 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+from seamless_communication.toxicity.mutox.builder import create_mutox_model
+from seamless_communication.toxicity.mutox.classifier import (
+    MutoxClassifier,
+    MutoxConfig,
+    mutox_archs,
+)
+
+import typing as tp
+
+
+@mutox_archs.decorator("mutox")
+def _base_mutox() -> MutoxConfig:
+    return MutoxConfig(
+        input_size=1024,
+    )
+
+
+def convert_mutox_checkpoint(
+    checkpoint: tp.Mapping[str, tp.Any], config: MutoxConfig
+) -> tp.Mapping[str, tp.Any]:
+    new_dict = {}
+    for key in checkpoint:
+        if key.startswith("model_all."):
+            new_dict[key] = checkpoint[key]
+    return {"model": new_dict}
+
+
+load_mutox_config = ConfigLoader[MutoxConfig](asset_store, mutox_archs)
+
+
+load_mutox_model = ModelLoader[MutoxClassifier, MutoxConfig](
+    asset_store,
+    download_manager,
+    load_mutox_config,
+    create_mutox_model,
+    convert_mutox_checkpoint,
+)

+ 61 - 0
src/seamless_communication/toxicity/mutox/speech_pipeline.py

@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import torch
+from seamless_communication.toxicity.mutox.classifier import MutoxClassifier
+from seamless_communication.toxicity.mutox.loader import load_mutox_model
+from sonar.models.sonar_speech.loader import load_sonar_speech_model
+
+from sonar.inference_pipelines.speech import (
+    SpeechToEmbeddingPipeline,
+    SpeechInferenceParams,
+)
+
+from fairseq2.data import (
+    DataPipelineBuilder,
+)
+
+from typing import Union
+
+from seamless_communication.toxicity.mutox.classifier import MutoxClassifier
+from sonar.models.encoder_model import SonarEncoderModel
+from fairseq2.typing import Device
+
+
+CPU_DEVICE = torch.device("cpu")
+
+
+class MutoxSpeechClassifierPipeline(SpeechToEmbeddingPipeline):
+    def __init__(
+        self,
+        mutox_classifier: Union[str, MutoxClassifier],
+        encoder: Union[str, SonarEncoderModel],
+        device: Device = CPU_DEVICE,
+    ) -> None:
+        super().__init__(encoder)
+        self.model.to(device).eval()
+        self.mutox_classifier = mutox_classifier.to(device).eval()
+
+    @classmethod
+    def load_model_from_name(
+        cls,
+        mutox_classifier_name: str,
+        encoder_name: str,
+        device: Device = CPU_DEVICE,
+    ) -> "SpeechToEmbeddingPipeline":
+        encoder = load_sonar_speech_model(encoder_name, device=device, progress=False)
+        mutox_classifier = load_mutox_model(
+            mutox_classifier_name, device=device, progress=False
+        )
+        return cls(mutox_classifier=mutox_classifier, encoder=encoder, device=device)
+
+    def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder:
+        pipeline_builder = super().prebuild_pipeline(context)
+        return pipeline_builder.map(self._run_classifier, selector="audio.data")
+
+    @torch.inference_mode()
+    def _run_classifier(self, data: dict):
+        return self.mutox_classifier(data.sentence_embeddings)