|
@@ -11,9 +11,11 @@ import json
|
|
|
import logging
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
import torch
|
|
|
|
|
|
+from datasets import load_dataset
|
|
|
from seamless_communication.datasets.huggingface import (
|
|
|
Speech2SpeechFleursDatasetBuilder,
|
|
|
SpeechTokenizer,
|
|
@@ -28,6 +30,10 @@ logging.basicConfig(
|
|
|
logger = logging.getLogger("dataset")
|
|
|
|
|
|
|
|
|
+SUPPORTED_DATASETS = ['google/fleurs', 'speechcolab/gigaspeech']
|
|
|
+""" List of Huggingface Datasets that we support at the moment
|
|
|
+"""
|
|
|
+
|
|
|
# Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
|
|
|
# Full list of M4T langcodes is available
|
|
|
# in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
|
|
@@ -118,12 +124,12 @@ class UnitSpeechTokenizer(SpeechTokenizer):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def download_fleurs_dataset(
|
|
|
+def download_fleurs(
|
|
|
source_lang: str,
|
|
|
target_lang: str,
|
|
|
split: str,
|
|
|
save_directory: str,
|
|
|
-) -> str:
|
|
|
+):
|
|
|
_check_lang_code_mapping(source_lang)
|
|
|
_check_lang_code_mapping(target_lang)
|
|
|
device = (
|
|
@@ -148,17 +154,47 @@ def download_fleurs_dataset(
|
|
|
sample.target.waveform = None # already extracted units
|
|
|
fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
|
|
|
logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
|
|
|
- return manifest_path
|
|
|
+ logger.info(f"Manifest saved to: {manifest_path}")
|
|
|
+
|
|
|
+
|
|
|
+def download_gigaspeech(subset: str, huggingface_token: str, save_directory: str):
|
|
|
+ ds = load_dataset("speechcolab/gigaspeech", subset, cache_dir=f"gigaspeech/{subset}", token=huggingface_token)
|
|
|
+ for split in ds:
|
|
|
+ manifest_path = os.path.join(save_directory, f"{subset}_{split}_manifest.json")
|
|
|
+ logger.info(f"Preparing {split} split...")
|
|
|
+ with open(manifest_path, "w") as f:
|
|
|
+ for sample in tqdm(ds[split]):
|
|
|
+ f.write(json.dumps({
|
|
|
+ "source": {
|
|
|
+ "id": sample["segment_id"],
|
|
|
+ "text": sample["text"],
|
|
|
+ "lang":"eng",
|
|
|
+ "audio_local_path": sample["audio"]["path"],
|
|
|
+ "sampling_rate": sample["audio"]["sampling_rate"],
|
|
|
+ },
|
|
|
+ "target": {
|
|
|
+ "id": sample["segment_id"],
|
|
|
+ "text": sample["text"],
|
|
|
+ "lang": "eng",
|
|
|
+ }
|
|
|
+ }) + "\n")
|
|
|
+ logger.info(f"Manifest for GigaSpeech-{subset}-{split} saved to: {manifest_path}")
|
|
|
|
|
|
|
|
|
def init_parser() -> argparse.ArgumentParser:
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description=(
|
|
|
- "Helper script to download training/evaluation dataset (FLEURS),"
|
|
|
+ "Helper script to download training/evaluation dataset (FLEURS or GigaSpeech),"
|
|
|
"extract units from target audio and save the dataset as a manifest "
|
|
|
"consumable by `finetune.py`."
|
|
|
)
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--name",
|
|
|
+ type=str,
|
|
|
+ required=True,
|
|
|
+ help="HuggingFace name of the dataset to prepare.",
|
|
|
+ )
|
|
|
parser.add_argument(
|
|
|
"--source_lang",
|
|
|
type=str,
|
|
@@ -183,18 +219,27 @@ def init_parser() -> argparse.ArgumentParser:
|
|
|
required=True,
|
|
|
help="Directory where the datastets will be stored with HuggingFace datasets cache files",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--huggingface_token",
|
|
|
+ type=str,
|
|
|
+ required=False,
|
|
|
+ default=None,
|
|
|
+ help="Your HuggingFace token, this is necessary for some datasets like GigaSpeech.",
|
|
|
+ )
|
|
|
return parser
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
args = init_parser().parse_args()
|
|
|
- manifest_path = download_fleurs_dataset(
|
|
|
- source_lang=args.source_lang,
|
|
|
- target_lang=args.target_lang,
|
|
|
- split=args.split,
|
|
|
- save_directory=args.save_dir,
|
|
|
- )
|
|
|
- logger.info(f"Manifest saved to: {manifest_path}")
|
|
|
+ assert args.name in SUPPORTED_DATASETS, \
|
|
|
+ f"The only supported datasets are `{SUPPORTED_DATASETS}`. Please use one of these in `--name`."
|
|
|
+
|
|
|
+ if args.name == 'google/fleurs':
|
|
|
+ download_fleurs(args.source_lang, args.target_lang, args.split, args.save_dir)
|
|
|
+ elif args.name == 'speechcolab/gigaspeech':
|
|
|
+ assert args.huggingface_token is not None, \
|
|
|
+ "Your HuggingFace token is necessary for GigaSpeech. Please read the GigaSpeech agreement."
|
|
|
+ download_gigaspeech(args.split, args.huggingface_token, args.save_dir)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|