Browse Source

Anoter round of passes over finetuning scripts

Ruslan Mavlyutov 2 years ago
parent
commit
32df8a397f
5 changed files with 165 additions and 40 deletions
  1. 1 1
      README.md
  2. 1 0
      requirements.txt
  3. 95 30
      scripts/m4t/finetune/README.md
  4. 64 5
      scripts/m4t/finetune/dataset.py
  5. 4 4
      scripts/m4t/finetune/finetune.py

+ 1 - 1
README.md

@@ -73,7 +73,7 @@ To reproduce our results, or to evaluate using the same metrics over your own te
 
 ## Finetuning SeamlessM4T models
 
-TODO
+Please check out [README under scripts/m4t/finetune](scripts/m4t/finetune/README.md).
 
 ## On-device models
 Apart from Seamless-M4T large (2.3B) and medium (1.2B) models, we are also releasing a small model (281M) targeted for on-device inference. To learn more about the usage and model details check out [README here](https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/on_device_README.md)

+ 1 - 0
requirements.txt

@@ -1,2 +1,3 @@
 pre-commit
 datasets
+torchaudio

+ 95 - 30
scripts/m4t/finetune/README.md

@@ -2,38 +2,94 @@
 
 This section demonstrates an example of how M4T model can be finetuned for a subset of translation directions or modalities.
 
-Shared implementations of trainer and dataloader are not exhaustive. They were intentionally made simple in order to not obscure the specifics of data representation and optimization criteria during training.
+Shared implementations of trainer and dataloader are not efficient and/or exhaustive. They were intentionally made simple in order to not obscure the specifics of data representation and optimization criteria during training.
 
 ## Data preparation
 
 M4T training data is a multimodal parallel corpus. Each training sample has four parts: audio and text representation of a sample in source language, and corresponding audio and text representation of a sample in target language.
 
-This kind of dataset can be prepared using `dataset.py` script that downloads FLEURS dataset from [HuggingFace datastes hub](https://huggingface.co/datasets/google/fleurs), extracts units from target audio samples and prepares a manifest consumable by `finetune.py`.
+This kind of dataset can be prepared using `dataset.py` script that downloads FLEURS dataset from [HuggingFace datastes hub](https://huggingface.co/datasets/google/fleurs), extracts units from target audio samples and prepares a manifest consumable by `finetune.py`. Manifest is a text file where each line represents information about a single dataset sample, serialized in JSON format.
 
-Example run command that prepares a training dataset for language pair English->Korean:
+List of input arguments for `dataset.py`:
 
 ```bash
-python scripts/m4t/finetune/dataset.py \
- --source_lang eng \
- --target_lang kor \
- --split train \
- --save_dir /tmp
+  --source_lang SOURCE_LANG
+                        M4T langcode of the dataset SOURCE language
+  --target_lang TARGET_LANG
+                        M4T langcode of the dataset TARGET language
+  --split SPLIT         Dataset split/shard to download (`train`, `test`)
+  --save_dir SAVE_DIR   Directory where the datastets will be stored with HuggingFace datasets cache files
 ```
-Path to the output manifest will be logged in the end of the command output:
+
+Language codes should follow the notation adopted by M4T models.
+
+Below is an example bash script that prepares a training and evaluation dataset for language pair English->Korean:
 
 ```bash
-...
-2023-08-19 03:23 INFO dataset - ..loaded 2600 source samples
-2023-08-19 03:23 INFO dataset - Manifest saved to: /tmp/train_manifest.json
+mkdir -p datasets && cd datasets
+export DATASET_DIR=`pwd`
+cd -
+python scripts/m4t/finetune/dataset.py \
+  --source_lang eng \
+  --target_lang kor \
+  --split train \
+  --save_dir $DATASET_DIR
+ python scripts/m4t/finetune/dataset.py \
+  --source_lang eng \
+  --target_lang kor \
+  --split validation \
+  --save_dir $DATASET_DIR
 ```
 
-Manifest is a text file where each line represents information about a single dataset sample, serialized in JSON format.
+
+Output manifests will be stored in `$DATASET_DIR/train_manifest.json` and `$DATASET_DIR/validation_manifest.json`.
+
 
 ## Finetuning
 
-`finetune.py` is an example finetuning script that initializes dataloader, and launches a training loop with periodic evaluations on evaluation dataset. `torchrun` is the recommended way of launching it.
+`finetune.py` is an example finetuning script that initializes dataloaders, and launches training loop with periodic scoring against validation dataset.
+It is recommended to launch it with `torchrun`. Multi-gpu and multi-node training are supported out of the box.
+
+List of input arguments for `finetune.py`:
+
+```bash
+  --train_dataset TRAIN_DATASET
+                        Path to manifest with train samples
+  --eval_dataset EVAL_DATASET
+                        Path to manifest with eval samples
+  --model_name MODEL_NAME
+                        Base model name (e.g, `seamlessM4T_medium`, `seamlessM4T_large`)
+  --save_model_to SAVE_MODEL_TO
+                        Path to save best finetuned model
+  --seed SEED           Randomizer seed value
+  --batch_size BATCH_SIZE
+                        Batch size for training and evaluation
+  --patience PATIENCE   Set early termination after `patience` number of evaluations without eval loss improvements
+  --max_epochs MAX_EPOCHS
+                        Max number of training epochs
+  --learning_rate LEARNING_RATE
+                        Finetuning learning rate
+  --warmup_steps WARMUP_STEPS
+                        Number of steps with linearly increasing learning rate
+  --eval_steps EVAL_STEPS
+                        Get eval loss after each `eval_steps` training steps
+  --log_steps LOG_STEPS
+                        Log inner loss after each `log_steps` training steps
+  --mode {FinetuneMode.SPEECH_TO_SPEECH,FinetuneMode.SPEECH_TO_TEXT,FinetuneMode.TEXT_TO_SPEECH}
+                        * `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model;
+                        * `TEXT_TO_SPEECH` -- finetune only T2U;
+                        * `SPEECH_TO_TEXT` -- finetune only S2T
+```
+
+The scripts supports three modes of finetuning:
+- `SPEECH_TO_SPEECH`: in this case all model weights except the text encoder will be engaged;
+- `TEXT_TO_SPEECH`: only text-to-unit part of the model will be engaged in the finetuning, other weights will be frozen;
+- `SPEECH_TO_TEXT`: only speech-to-text part of the model will be engaged in the finetuning.
+
+The referenced finetuning script does not support finetuning of the text encoder. Though the code expantion should be trivial.
+
 
-Example launch command on a single node with 8 gpus:
+Below is an example bash script that launches finetuning of M4T-large on the dataset prepared earlier, using a single node with eight GPUs:
 
 ```
 torchrun \
@@ -41,26 +97,35 @@ torchrun \
    --rdzv-endpoint=localhost:0 \
    --nnodes=1 \
    --nproc-per-node=8  \
- ./scripts/m4t/finetune/finetune.py \
-   --train_dataset '<PATH TO TRAIN MANIFEST>' \
-   --eval_dataset '<PATH TO EVAL MANIFEST>' \
+  scripts/m4t/finetune/finetune.py \
+   --train_dataset $DATASET_DIR/train_manifest.json  \
+   --eval_dataset $DATASET_DIR/validation_manifest.json \
+   --learning_rate 1e-6 \
+   --warmup_steps 100 \
+   --max_epochs 10 \
+   --patience 3 \
    --model_name seamlessM4T_large \
-   --save_model_to /tmp/checkpoint.pt
+   --save_model_to $WORKDIR/checkpoint_lr_1e-6_full.pt
 ```
 
-Example of a training log:
+Excerpt from an example finetuning log:
 
 ```
 ...
-2023-08-19 02:27:06,009 INFO -- trainer.1871488: Eval after 350 updates: loss=8.7876 best_loss=8.7876 patience_steps_left=3
-2023-08-19 02:27:06,009 INFO -- trainer.1871488: Saving model
-2023-08-19 02:27:31,100 INFO -- trainer.1871488: Epoch 007 / update 00360: train loss=16.3779 last lr=5.27E-08
-2023-08-19 02:27:38,249 INFO -- trainer.1871488: Epoch 007 / update 00370: train loss=16.3482 last lr=5.20E-08
-2023-08-19 02:27:45,164 INFO -- trainer.1871488: Epoch 007 / update 00380: train loss=16.4406 last lr=5.13E-08
-2023-08-19 02:27:52,521 INFO -- trainer.1871488: Epoch 007 / update 00390: train loss=16.3556 last lr=5.06E-08
-2023-08-19 02:27:59,300 INFO -- trainer.1871488: Epoch 007 / update 00400: train loss=16.3055 last lr=5.00E-08
-2023-08-19 02:27:59,919 INFO -- trainer.1871488: Run evaluation
-2023-08-19 02:28:12,761 INFO -- trainer.1871488: Eval after 400 updates: loss=8.7711 best_loss=8.7711 patience_steps_left=3
-2023-08-19 02:28:12,762 INFO -- trainer.1871488: Saving model
+2023-08-21 14:46:16,936 INFO -- trainer.1100368: Eval after 300 updates: loss=8.7755 best_loss=8.7755 patience_steps_left=3
+2023-08-21 14:46:16,936 INFO -- trainer.1100368: Saving model
+2023-08-21 14:46:35,863 INFO -- trainer.1100368: Epoch 006 / update 00310: train loss=16.3768 last lr=5.68E-08
+2023-08-21 14:46:42,610 INFO -- trainer.1100368: Epoch 006 / update 00320: train loss=16.3730 last lr=5.59E-08
+2023-08-21 14:46:48,285 INFO -- trainer.1100368: Epoch 006 / update 00330: train loss=16.4598 last lr=5.50E-08
+2023-08-21 14:46:54,390 INFO -- trainer.1100368: Epoch 006 / update 00340: train loss=16.4218 last lr=5.42E-08
+2023-08-21 14:47:08,461 INFO -- trainer.1100368: Epoch 006 / update 00350: train loss=16.3906 last lr=5.35E-08
+2023-08-21 14:47:09,067 INFO -- trainer.1100368: Run evaluation
+2023-08-21 14:47:19,205 INFO -- trainer.1100368: Eval after 350 updates: loss=8.7462 best_loss=8.7462 patience_steps_left=3
+2023-08-21 14:47:19,205 INFO -- trainer.1100368: Saving model
+2023-08-21 14:47:44,981 INFO -- trainer.1100368: Epoch 007 / update 00360: train loss=16.4267 last lr=5.27E-08
+2023-08-21 14:47:51,383 INFO -- trainer.1100368: Epoch 007 / update 00370: train loss=16.3630 last lr=5.20E-08
+2023-08-21 14:47:58,305 INFO -- trainer.1100368: Epoch 007 / update 00380: train loss=16.3666 last lr=5.13E-08
+2023-08-21 14:48:04,396 INFO -- trainer.1100368: Epoch 007 / update 00390: train loss=16.3605 last lr=5.06E-08
+2023-08-21 14:48:10,630 INFO -- trainer.1100368: Epoch 007 / update 00400: train loss=16.3518 last lr=5.00E-08
 ...
 ```

+ 64 - 5
scripts/m4t/finetune/dataset.py

@@ -28,15 +28,73 @@ logging.basicConfig(
 logger = logging.getLogger("dataset")
 
 
-# List of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
-# List of M4T langcodes is available in yaml: src/seamless_communication/assets/cards/unity_nllb-100.yaml
+# Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
+# Full list of M4T langcodes is available
+# in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
 UNITY_TO_FLEURS_LANG_MAPPING = {
     "eng": "en_us",
     "ita": "it_it",
+    "afr": "af_za",
+    "asm": "as_in",
+    "bel": "be_by",
+    "bul": "bg_bg",
+    "ben": "bn_in",
+    "cat": "ca_es",
+    "ces": "cs_cz",
+    "dan": "da_dk",
+    "deu": "de_de",
+    "ell": "el_gr",
+    "fin": "fi_fi",
+    "fra": "fr_fr",
+    "glg": "gl_es",
+    "heb": "he_il",
+    "hin": "hi_in",
+    "hrv": "hr_hr",
+    "hun": "hu_hu",
+    "ind": "id_id",
+    "ibo": "ig_ng",
+    "isl": "is_is",
+    "ita": "it_it",
+    "jpn": "ja_jp",
+    "jav": "jv_id",
+    "kaz": "kk_kz",
+    "kan": "kn_in",
+    "kir": "ky_kg",
     "kor": "ko_kr",
+    "lit": "lt_lt",
+    "mkd": "mk_mk",
+    "mlt": "mt_mt",
+    "mya": "my_mm",
+    "nld": "nl_nl",
+    "pan": "pa_in",
+    "pol": "pl_pl",
+    "ron": "ro_ro",
+    "rus": "ru_ru",
+    "snd": "sd_in",
+    "slk": "sk_sk",
+    "srp": "sr_rs",
+    "swh": "sw_ke",
+    "tam": "ta_in",
+    "tel": "te_in",
+    "tha": "th_th",
+    "tur": "tr_tr",
+    "ukr": "uk_ua",
+    "urd": "ur_pk",
+    "uzn": "uz_uz",
+    "vie": "vi_vn",
+    "yor": "yo_ng",
+    "zul": "zu_za",
 }
 
 
+def _check_lang_code_mapping(lang: str) -> None:
+    if lang not in UNITY_TO_FLEURS_LANG_MAPPING:
+        raise ValueError(
+            f"No language code mapping for {lang}(M4T)->??(FLEURs). "
+            "Please expand `UNITY_TO_FLEURS_LANG_MAPPING`"
+        )
+
+
 def download_fleurs_dataset(
     source_lang: str,
     target_lang: str,
@@ -44,6 +102,8 @@ def download_fleurs_dataset(
     unit_extractor_config: str,
     save_directory: str,
 ) -> str:
+    _check_lang_code_mapping(source_lang)
+    _check_lang_code_mapping(target_lang)
     tokenizer_conf: SpeechTokenizerConfig = load_config(
         unit_extractor_config, namespace=""
     )
@@ -93,7 +153,7 @@ def init_parser() -> argparse.ArgumentParser:
         "--split",
         type=str,
         required=True,
-        help="Dataset split/shard to download (`train`, `test`)",
+        help="Dataset split/shard to download (`train`, `validation`, `test`)",
     )
     parser.add_argument(
         "--save_dir",
@@ -108,8 +168,7 @@ def main(args: Namespace) -> None:
     manifest_path = download_fleurs_dataset(
         source_lang=args.source_lang,
         target_lang=args.target_lang,
-        # TODO: remove hardcoded path
-        unit_extractor_config="/checkpoint/krs/unit_extraction/xlsr1b/lang41_10k_xlsr_lyr35.yaml",
+        unit_extractor_config="lang41_10k_xlsr_lyr35.yaml",
         split=args.split,
         save_directory=args.save_dir,
     )

+ 4 - 4
scripts/m4t/finetune/finetune.py

@@ -46,7 +46,7 @@ def init_parser() -> argparse.ArgumentParser:
         "--eval_dataset",
         type=Path,
         required=True,
-        help="Path to manifest with train samples",
+        help="Path to manifest with eval samples",
     )
     parser.add_argument(
         "--model_name",
@@ -117,9 +117,9 @@ def init_parser() -> argparse.ArgumentParser:
         choices=list(trainer.FinetuneMode),
         default=trainer.FinetuneMode.TEXT_TO_SPEECH,
         help=(
-            "* SPEECH_TO_SPEECH -- finetune S2T and T2U parts of the model;\n"
-            "* TEXT_TO_SPEECH -- finetune only T2U;\n"
-            "* SPEECH_TO_TEXT -- finetune only S2T"
+            "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
+            "* `TEXT_TO_SPEECH` -- finetune only T2U; "
+            "* `SPEECH_TO_TEXT` -- finetune only S2T"
         ),
     )
     return parser