| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates
 
- # All rights reserved.
 
- #
 
- # This source code is licensed under the license found in the
 
- # LICENSE file in the root directory of this source tree.
 
- import argparse
 
- import dataclasses
 
- import json
 
- import logging
 
- import os
 
- from pathlib import Path
 
- import torch
 
- from seamless_communication.datasets.huggingface import (
 
-     Speech2SpeechFleursDatasetBuilder,
 
-     SpeechTokenizer,
 
- )
 
- from seamless_communication.models.unit_extraction import UnitExtractor
 
- logging.basicConfig(
 
-     level=logging.INFO,
 
-     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
 
- )
 
- logger = logging.getLogger("dataset")
 
- # Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
 
- # Full list of M4T langcodes is available
 
- # in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
 
- UNITY_TO_FLEURS_LANG_MAPPING = {
 
-     "eng": "en_us",
 
-     "ita": "it_it",
 
-     "afr": "af_za",
 
-     "asm": "as_in",
 
-     "bel": "be_by",
 
-     "bul": "bg_bg",
 
-     "ben": "bn_in",
 
-     "cat": "ca_es",
 
-     "ces": "cs_cz",
 
-     "dan": "da_dk",
 
-     "deu": "de_de",
 
-     "ell": "el_gr",
 
-     "fin": "fi_fi",
 
-     "fra": "fr_fr",
 
-     "glg": "gl_es",
 
-     "heb": "he_il",
 
-     "hin": "hi_in",
 
-     "hrv": "hr_hr",
 
-     "hun": "hu_hu",
 
-     "ind": "id_id",
 
-     "ibo": "ig_ng",
 
-     "isl": "is_is",
 
-     "ita": "it_it",
 
-     "jpn": "ja_jp",
 
-     "jav": "jv_id",
 
-     "kaz": "kk_kz",
 
-     "kan": "kn_in",
 
-     "kir": "ky_kg",
 
-     "kor": "ko_kr",
 
-     "lit": "lt_lt",
 
-     "mkd": "mk_mk",
 
-     "mlt": "mt_mt",
 
-     "mya": "my_mm",
 
-     "nld": "nl_nl",
 
-     "pan": "pa_in",
 
-     "pol": "pl_pl",
 
-     "ron": "ro_ro",
 
-     "rus": "ru_ru",
 
-     "snd": "sd_in",
 
-     "slk": "sk_sk",
 
-     "srp": "sr_rs",
 
-     "swh": "sw_ke",
 
-     "tam": "ta_in",
 
-     "tel": "te_in",
 
-     "tha": "th_th",
 
-     "tur": "tr_tr",
 
-     "ukr": "uk_ua",
 
-     "urd": "ur_pk",
 
-     "uzn": "uz_uz",
 
-     "vie": "vi_vn",
 
-     "yor": "yo_ng",
 
-     "zul": "zu_za",
 
- }
 
- def _check_lang_code_mapping(lang: str) -> None:
 
-     if lang not in UNITY_TO_FLEURS_LANG_MAPPING:
 
-         raise ValueError(
 
-             f"No language code mapping for {lang}(M4T)->??(FLEURs). "
 
-             "Please expand `UNITY_TO_FLEURS_LANG_MAPPING`"
 
-         )
 
- class UnitSpeechTokenizer(SpeechTokenizer):
 
-     MODEL_NAME = "xlsr2_1b_v2"
 
-     KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
 
-     OUTPUT_LAYER_IDX = 34
 
-     def __init__(self, device: torch.device):
 
-         super().__init__()
 
-         self.device = device
 
-         self.unit_extractor = UnitExtractor(
 
-             model_name_or_card=self.MODEL_NAME,
 
-             kmeans_uri=self.KMEANS_MODEL_URI,
 
-             device=self.device,
 
-         )
 
-     def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
 
-         return self.unit_extractor.predict(
 
-             wav.to(self.device),
 
-             out_layer_idx=self.OUTPUT_LAYER_IDX,
 
-             sample_rate=sample_rate,
 
-         )
 
- def download_fleurs_dataset(
 
-     source_lang: str,
 
-     target_lang: str,
 
-     split: str,
 
-     save_directory: str,
 
- ) -> str:
 
-     _check_lang_code_mapping(source_lang)
 
-     _check_lang_code_mapping(target_lang)
 
-     device = (
 
-         torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
 
-     )
 
-     tokenizer = UnitSpeechTokenizer(device=device)
 
-     dataset_iterator = Speech2SpeechFleursDatasetBuilder(
 
-         source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
 
-         target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
 
-         dataset_cache_dir=save_directory,
 
-         speech_tokenizer=tokenizer,
 
-         skip_source_audio=True,  # don't extract units from source audio
 
-         skip_target_audio=False,
 
-         split=split,
 
-     )
 
-     manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
 
-     with open(manifest_path, "w") as fp_out:
 
-         for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
 
-             # correction as FleursDatasetBuilder return fleurs lang codes
 
-             sample.source.lang = source_lang
 
-             sample.target.lang = target_lang
 
-             sample.target.waveform = None  # already extracted units
 
-             fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
 
-     logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
 
-     return manifest_path
 
- def init_parser() -> argparse.ArgumentParser:
 
-     parser = argparse.ArgumentParser(
 
-         description=(
 
-             "Helper script to download training/evaluation dataset (FLEURS),"
 
-             "extract units from target audio and save the dataset as a manifest "
 
-             "consumable by `finetune.py`."
 
-         )
 
-     )
 
-     parser.add_argument(
 
-         "--source_lang",
 
-         type=str,
 
-         required=True,
 
-         help="M4T langcode of the dataset SOURCE language",
 
-     )
 
-     parser.add_argument(
 
-         "--target_lang",
 
-         type=str,
 
-         required=True,
 
-         help="M4T langcode of the dataset TARGET language",
 
-     )
 
-     parser.add_argument(
 
-         "--split",
 
-         type=str,
 
-         required=True,
 
-         help="Dataset split/shard to download (`train`, `validation`, `test`)",
 
-     )
 
-     parser.add_argument(
 
-         "--save_dir",
 
-         type=Path,
 
-         required=True,
 
-         help="Directory where the datastets will be stored with HuggingFace datasets cache files",
 
-     )
 
-     return parser
 
- def main() -> None:
 
-     args = init_parser().parse_args()
 
-     manifest_path = download_fleurs_dataset(
 
-         source_lang=args.source_lang,
 
-         target_lang=args.target_lang,
 
-         split=args.split,
 
-         save_directory=args.save_dir,
 
-     )
 
-     logger.info(f"Manifest saved to: {manifest_path}")
 
- if __name__ == "__main__":
 
-     main()
 
 
  |