Browse Source

Enable downloading Gigaspeech (#443)

* Enable downloading Gigaspeech

* Fix GS issues

* Prepare local
Alisamar Husain 1 year ago
parent
commit
1f254b1074
2 changed files with 57 additions and 12 deletions
  1. 1 1
      setup.py
  2. 56 11
      src/seamless_communication/cli/m4t/finetune/dataset.py

+ 1 - 1
setup.py

@@ -21,7 +21,7 @@ setup(
     url="https://github.com/facebookresearch/seamless_communication",
     license="Creative Commons",
     install_requires=[
-        "datasets",
+        "datasets==2.18.0",
         "fairseq2==0.2.*",
         "fire",
         "librosa",

+ 56 - 11
src/seamless_communication/cli/m4t/finetune/dataset.py

@@ -11,9 +11,11 @@ import json
 import logging
 import os
 from pathlib import Path
+from tqdm import tqdm
 
 import torch
 
+from datasets import load_dataset
 from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
     SpeechTokenizer,
@@ -28,6 +30,10 @@ logging.basicConfig(
 logger = logging.getLogger("dataset")
 
 
+SUPPORTED_DATASETS = ['google/fleurs', 'speechcolab/gigaspeech']
+""" List of Huggingface Datasets that we support at the moment
+"""
+
 # Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
 # Full list of M4T langcodes is available
 # in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
@@ -118,12 +124,12 @@ class UnitSpeechTokenizer(SpeechTokenizer):
         )
 
 
-def download_fleurs_dataset(
+def download_fleurs(
     source_lang: str,
     target_lang: str,
     split: str,
     save_directory: str,
-) -> str:
+):
     _check_lang_code_mapping(source_lang)
     _check_lang_code_mapping(target_lang)
     device = (
@@ -148,17 +154,47 @@ def download_fleurs_dataset(
             sample.target.waveform = None  # already extracted units
             fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
     logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
-    return manifest_path
+    logger.info(f"Manifest saved to: {manifest_path}")
+
+
+def download_gigaspeech(subset: str, huggingface_token: str, save_directory: str):
+    ds = load_dataset("speechcolab/gigaspeech", subset, cache_dir=f"gigaspeech/{subset}", token=huggingface_token)
+    for split in ds:
+        manifest_path = os.path.join(save_directory, f"{subset}_{split}_manifest.json")
+        logger.info(f"Preparing {split} split...")
+        with open(manifest_path, "w") as f:
+            for sample in tqdm(ds[split]):
+                f.write(json.dumps({
+                "source": {
+                    "id": sample["segment_id"],
+                    "text": sample["text"],
+                    "lang":"eng",
+                    "audio_local_path": sample["audio"]["path"],
+                    "sampling_rate": sample["audio"]["sampling_rate"],
+                },
+                "target": {
+                    "id": sample["segment_id"],
+                    "text": sample["text"],
+                    "lang": "eng",
+                }
+                }) + "\n")
+        logger.info(f"Manifest for GigaSpeech-{subset}-{split} saved to: {manifest_path}")
 
 
 def init_parser() -> argparse.ArgumentParser:
     parser = argparse.ArgumentParser(
         description=(
-            "Helper script to download training/evaluation dataset (FLEURS),"
+            "Helper script to download training/evaluation dataset (FLEURS or GigaSpeech),"
             "extract units from target audio and save the dataset as a manifest "
             "consumable by `finetune.py`."
         )
     )
+    parser.add_argument(
+        "--name",
+        type=str,
+        required=True,
+        help="HuggingFace name of the dataset to prepare.",
+    )
     parser.add_argument(
         "--source_lang",
         type=str,
@@ -183,18 +219,27 @@ def init_parser() -> argparse.ArgumentParser:
         required=True,
         help="Directory where the datastets will be stored with HuggingFace datasets cache files",
     )
+    parser.add_argument(
+        "--huggingface_token",
+        type=str,
+        required=False,
+        default=None,
+        help="Your HuggingFace token, this is necessary for some datasets like GigaSpeech.",
+    )
     return parser
 
 
 def main() -> None:
     args = init_parser().parse_args()
-    manifest_path = download_fleurs_dataset(
-        source_lang=args.source_lang,
-        target_lang=args.target_lang,
-        split=args.split,
-        save_directory=args.save_dir,
-    )
-    logger.info(f"Manifest saved to: {manifest_path}")
+    assert args.name in SUPPORTED_DATASETS, \
+        f"The only supported datasets are `{SUPPORTED_DATASETS}`. Please use one of these in `--name`."
+
+    if args.name == 'google/fleurs':
+        download_fleurs(args.source_lang, args.target_lang, args.split, args.save_dir)
+    elif args.name == 'speechcolab/gigaspeech':
+        assert args.huggingface_token is not None, \
+            "Your HuggingFace token is necessary for GigaSpeech. Please read the GigaSpeech agreement."
+        download_gigaspeech(args.split, args.huggingface_token, args.save_dir)
 
 
 if __name__ == "__main__":