|
@@ -0,0 +1,245 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Copyright (c) Meta Platforms, Inc. and affiliates\n",
|
|
|
+ "# All rights reserved.\n",
|
|
|
+ "#\n",
|
|
|
+ "# This source code is licensed under the license found in the\n",
|
|
|
+ "# MIT_LICENSE file in the root directory of this source tree."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# MUTOX toxicity classification\n",
|
|
|
+ "\n",
|
|
|
+ "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import torch\n",
|
|
|
+ "from pathlib import Path\n",
|
|
|
+ "\n",
|
|
|
+ "if torch.cuda.is_available():\n",
|
|
|
+ " device = torch.device(\"cuda:0\")\n",
|
|
|
+ " dtype = torch.float16\n",
|
|
|
+ "else:\n",
|
|
|
+ " device = torch.device(\"cpu\")\n",
|
|
|
+ " dtype = torch.float32"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Speech Scoring"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "1. download some demo audio segments\n",
|
|
|
+ "2. create a tsv file to feed to the speech scoring pipeline\n",
|
|
|
+ "3. load the model and build the pipeline\n",
|
|
|
+ "4. go through the batches in the pipeline"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# get demo file\n",
|
|
|
+ "import urllib.request\n",
|
|
|
+ "import tempfile\n",
|
|
|
+ "\n",
|
|
|
+ "files = [\n",
|
|
|
+ " (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n",
|
|
|
+ " (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n",
|
|
|
+ "]\n",
|
|
|
+ "\n",
|
|
|
+ "tmpdir = Path(tempfile.mkdtemp())\n",
|
|
|
+ "tsv_file = (tmpdir / 'data.tsv')\n",
|
|
|
+ "with tsv_file.open('w') as tsv_file_p:\n",
|
|
|
+ " print('path', file=tsv_file_p)\n",
|
|
|
+ " for (uri, name) in files:\n",
|
|
|
+ " dl = tmpdir / name\n",
|
|
|
+ " urllib.request.urlretrieve(uri, dl)\n",
|
|
|
+ " print(dl, file=tsv_file_p)\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
|
|
|
+ "from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n",
|
|
|
+ "\n",
|
|
|
+ "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
|
|
|
+ " mutox_classifier_name =\"mutox\",\n",
|
|
|
+ " encoder_name=f\"sonar_speech_encoder_eng\",\n",
|
|
|
+ " device=device,\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n",
|
|
|
+ " data_file=tsv_file,\n",
|
|
|
+ " audio_root_dir=None,\n",
|
|
|
+ " audio_path_index=0,\n",
|
|
|
+ " target_lang=\"eng\",\n",
|
|
|
+ " batch_size=4,\n",
|
|
|
+ " pad_idx=0,\n",
|
|
|
+ " device=device,\n",
|
|
|
+ " fbank_dtype=torch.float32,\n",
|
|
|
+ " n_parallel=4\n",
|
|
|
+ "))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n",
|
|
|
+ "/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "for batch in pipeline:\n",
|
|
|
+ " ex = batch['audio']\n",
|
|
|
+ " for idx, path in enumerate(ex['path']):\n",
|
|
|
+ " print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# cleanup tmp dir\n",
|
|
|
+ "import shutil\n",
|
|
|
+ "shutil.rmtree(tmpdir)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Text Scoring\n",
|
|
|
+ "\n",
|
|
|
+ "1. load the sonar text encoder\n",
|
|
|
+ "2. load the mutox classifier model\n",
|
|
|
+ "3. compute embedding for a sentence\n",
|
|
|
+ "4. score this embedding"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "from seamless_communication.toxicity.mutox.loader import load_mutox_model\n",
|
|
|
+ "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n",
|
|
|
+ "\n",
|
|
|
+ "t2vec_model = TextToEmbeddingModelPipeline(\n",
|
|
|
+ " encoder=\"text_sonar_basic_encoder\",\n",
|
|
|
+ " tokenizer=\"text_sonar_basic_encoder\",\n",
|
|
|
+ ")\n",
|
|
|
+ "text_column='lang_txt'\n",
|
|
|
+ "classifier = load_mutox_model(\n",
|
|
|
+ " \"mutox\",\n",
|
|
|
+ " device=device,\n",
|
|
|
+ " dtype=dtype,\n",
|
|
|
+ ").eval()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 10,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "with torch.inference_mode():\n",
|
|
|
+ " emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n",
|
|
|
+ " x = classifier(emb.to(device).half())\n",
|
|
|
+ "\n",
|
|
|
+ "x"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "sc_fr2",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.10.13"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 2
|
|
|
+}
|