123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from fairseq2.assets import download_manager
- from seamless_communication.inference.translator import Translator
- from seamless_communication.toxicity.mintox import _extract_bad_words_with_batch_indices
- from tests.common import device, get_default_dtype
- from seamless_communication.toxicity import load_bad_word_checker
- def test_mintox_s2tt():
- bad_words_checker = load_bad_word_checker("mintox")
- model_name = "seamlessM4T_v2_large"
- vocoder_name = "vocoder_v2"
- src_text = "The strategy proved effective, cutting off vital military and civilian supplies, although this blockade violated generally accepted international law codified by several international agreements of the past two centuries."
- src_lang = "eng"
- tgt_lang = "fra"
- task = "s2tt"
- sample_rate = 16_000
- test_wav_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/inference/mintox/mintox_s2t_test_file.wav"
- input_wav = str(download_manager.download_checkpoint(test_wav_uri, test_wav_uri))
- dtype = get_default_dtype()
- translator_without_mintox = Translator(
- model_name, vocoder_name, device, dtype=dtype
- )
- translated_texts, _ = translator_without_mintox.predict(
- input=input_wav,
- task_str=task,
- tgt_lang=tgt_lang,
- src_lang=src_lang,
- sample_rate=sample_rate,
- )
- all_bad_words, batch_indices = _extract_bad_words_with_batch_indices(
- [src_text],
- [str(t) for t in translated_texts],
- src_lang,
- tgt_lang,
- bad_words_checker,
- )
- assert all_bad_words == ["violé", "VIOLÉ", "Violé"]
- assert batch_indices == [0]
- del translator_without_mintox
- translator_with_mintox = Translator(
- model_name, vocoder_name, device, dtype=dtype, apply_mintox=True
- )
- translated_texts, _ = translator_with_mintox.predict(
- input=input_wav,
- task_str=task,
- tgt_lang=tgt_lang,
- src_lang=src_lang,
- sample_rate=sample_rate,
- )
- all_bad_words, batch_indices = _extract_bad_words_with_batch_indices(
- [src_text],
- [str(t) for t in translated_texts],
- src_lang,
- tgt_lang,
- bad_words_checker,
- )
- assert all_bad_words == []
- assert batch_indices == []
- def test_mintox_t2tt():
- bad_words_checker = load_bad_word_checker("mintox")
- model_name = "seamlessM4T_v2_large"
- vocoder_name = "vocoder_v2"
- src_text = "I wonder what it'd be like to be a doff parent."
- src_lang = "eng"
- tgt_lang = "fra"
- task = "t2tt"
- dtype = get_default_dtype()
- translator_without_mintox = Translator(
- model_name, vocoder_name, device, dtype=dtype
- )
- translated_texts, _ = translator_without_mintox.predict(
- input=src_text,
- task_str=task,
- tgt_lang=tgt_lang,
- src_lang=src_lang,
- )
- all_bad_words, batch_indices = _extract_bad_words_with_batch_indices(
- [src_text],
- [str(t) for t in translated_texts],
- src_lang,
- tgt_lang,
- bad_words_checker,
- )
- assert (
- str(translated_texts[0])
- == "Je me demande à quoi ça ressemblerait d'être un parent débile."
- )
- assert all_bad_words == ["débile", "DÉBILE", "Débile"]
- assert batch_indices == [0]
- del translator_without_mintox
- translator_with_mintox = Translator(
- model_name, vocoder_name, device, dtype=dtype, apply_mintox=True
- )
- translated_texts, _ = translator_with_mintox.predict(
- input=src_text,
- task_str=task,
- tgt_lang=tgt_lang,
- src_lang=src_lang,
- )
- all_bad_words, batch_indices = _extract_bad_words_with_batch_indices(
- [src_text],
- [str(t) for t in translated_texts],
- src_lang,
- tgt_lang,
- bad_words_checker,
- )
- assert (
- str(translated_texts[0])
- == "Je me demande à quoi ça ressemblerait d'être un parent doff."
- )
- assert all_bad_words == []
- assert batch_indices == []
|