Parcourir la source

Merge pull request #12 from facebookresearch/ruslan_finetune_scripts_only_s2t

Finetune scripts only s2t
Ruslan Mavlyutov il y a 2 ans
Parent
commit
2a24781475

+ 2 - 0
requirements.txt

@@ -1,3 +1,5 @@
 pre-commit
 datasets
 torchaudio
+soundfile
+librosa

+ 14 - 13
scripts/m4t/finetune/README.md

@@ -1,14 +1,14 @@
 ## Finetuning scripts for M4T
 
-This section demonstrates an example of how M4T model can be finetuned for a subset of translation directions or modalities.
+This section demonstrates an example of M4T finetuning on a single translation direction: English-to-Korean.
 
-Shared implementations of trainer and dataloader are not efficient and/or exhaustive. They were intentionally made simple in order to not obscure the specifics of data representation and optimization criteria during training.
+The trainer and dataloader were designed mainly for demonstration purposes. Their simplicity should facilitate the code transparency and portability.
 
 ## Data preparation
 
-M4T training data is a multimodal parallel corpus. Each training sample has four parts: audio and text representation of a sample in source language, and corresponding audio and text representation of a sample in target language.
+M4T training dataset is a multimodal parallel corpus. Each training sample has four parts: audio and text representation of the sample in the source language, and its corresponding audio and text representation in the target language.
 
-This kind of dataset can be prepared using `dataset.py` script that downloads FLEURS dataset from [HuggingFace datastes hub](https://huggingface.co/datasets/google/fleurs), extracts units from target audio samples and prepares a manifest consumable by `finetune.py`. Manifest is a text file where each line represents information about a single dataset sample, serialized in JSON format.
+That kind of dataset can be prepared using `dataset.py` script that downloads FLEURS dataset from [HuggingFace datastes hub](https://huggingface.co/datasets/google/fleurs), (optionally) extracts units from the target audio samples, and prepares a manifest consumable by `finetune.py`. Manifest is a text file where each line represents information about a single dataset sample, serialized in JSON format.
 
 List of input arguments for `dataset.py`:
 
@@ -23,12 +23,12 @@ List of input arguments for `dataset.py`:
 
 Language codes should follow the notation adopted by M4T models.
 
-Below is an example bash script that prepares a training and evaluation dataset for language pair English->Korean:
+Below is an example bash script that prepares a training and evaluation dataset for the translation direction English-to-Korean:
 
 ```bash
-mkdir -p datasets && cd datasets
-export DATASET_DIR=`pwd`
-cd -
+export DATASET_DIR=~/m4t_dataset
+mkdir -p $DATASET_DIR
+
 python scripts/m4t/finetune/dataset.py \
   --source_lang eng \
   --target_lang kor \
@@ -42,13 +42,13 @@ python scripts/m4t/finetune/dataset.py \
 ```
 
 
-Output manifests will be stored in `$DATASET_DIR/train_manifest.json` and `$DATASET_DIR/validation_manifest.json`.
+Output manifests will be stored in `${DATASET_DIR}/train_manifest.json` and `${DATASET_DIR}/validation_manifest.json`.
 
 
 ## Finetuning
 
-`finetune.py` is an example finetuning script that initializes dataloaders, and launches training loop with periodic scoring against validation dataset.
-It is recommended to launch it with `torchrun`. Multi-gpu and multi-node training are supported out of the box.
+`finetune.py` is an example finetuning script that initializes dataloaders, and launches training loop with periodic scoring against the validation dataset.
+It is recommended to launch it with [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html). Multi-gpu and multi-node training are supported out of the box.
 
 List of input arguments for `finetune.py`:
 
@@ -86,7 +86,7 @@ The scripts supports three modes of finetuning:
 - `TEXT_TO_SPEECH`: only text-to-unit part of the model will be engaged in the finetuning, other weights will be frozen;
 - `SPEECH_TO_TEXT`: only speech-to-text part of the model will be engaged in the finetuning.
 
-The referenced finetuning script does not support finetuning of the text encoder. Though the code expantion should be trivial.
+The referenced finetuning script does not support finetuning of the text encoder.
 
 
 Below is an example bash script that launches finetuning of M4T-large on the dataset prepared earlier, using a single node with eight GPUs:
@@ -98,6 +98,7 @@ torchrun \
    --nnodes=1 \
    --nproc-per-node=8  \
   scripts/m4t/finetune/finetune.py \
+   --mode SPEECH_TO_TEXT \
    --train_dataset $DATASET_DIR/train_manifest.json  \
    --eval_dataset $DATASET_DIR/validation_manifest.json \
    --learning_rate 1e-6 \
@@ -105,7 +106,7 @@ torchrun \
    --max_epochs 10 \
    --patience 3 \
    --model_name seamlessM4T_large \
-   --save_model_to $WORKDIR/checkpoint_lr_1e-6_full.pt
+   --save_model_to $DATASET_DIR/checkpoint.pt
 ```
 
 Excerpt from an example finetuning log:

+ 25 - 13
scripts/m4t/finetune/dataloader.py

@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
 class SeqsBatch:
     src_tokens: Optional[Tensor]
     src_lengths: Optional[Tensor]
-    target_tokens: Tensor
-    prev_output_tokens: Tensor
-    target_lengths: Tensor
+    target_tokens: Optional[Tensor]
+    prev_output_tokens: Optional[Tensor]
+    target_lengths: Optional[Tensor]
 
     def __del__(self) -> None:
         """Explicitly delete tensors
@@ -136,8 +136,10 @@ class UnitYDataLoader:
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
         return tokens
 
-    def _get_tokenized_units(self, sample: LangPairSample) -> Tensor:
+    def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
         """Expected sequence is [<eos>, <lang_tok> , ..unit tokens.., <eos>]"""
+        if sample.target.units is None:
+            return None
         target_lang = sample.target.lang
         if target_lang not in self.unit_encoders_per_lang:
             self.unit_encoders_per_lang[
@@ -185,15 +187,25 @@ class UnitYDataLoader:
             [tokens.shape[0] - 1 for tokens in text_tokens_list]
         )
         # output units
-        units_list = [self._get_tokenized_units(sample) for sample in samples]
-        units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
-        prev_outputs_units = self._batch_tensors(
-            [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
-        )
-        target_units = self._batch_tensors(
-            [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
-        )
-        units_lengths = torch.LongTensor([tokens.shape[0] - 1 for tokens in units_list])
+        units_list_raw = [self._get_tokenized_units(sample) for sample in samples]
+        if None in units_list_raw:
+            prev_outputs_units = None
+            target_units = None
+            units_lengths = None
+        else:
+            units_list: List[Tensor] = [
+                value for value in units_list_raw if value is not None
+            ]
+            units_pad_idx = self.unit_tokenizer.vocab_info.pad_idx
+            prev_outputs_units = self._batch_tensors(
+                [tokens[:-1] for tokens in units_list], pad_value=units_pad_idx
+            )
+            target_units = self._batch_tensors(
+                [tokens[1:] for tokens in units_list], pad_value=units_pad_idx
+            )
+            units_lengths = torch.LongTensor(
+                [tokens.shape[0] - 1 for tokens in units_list]
+            )
         return MultimodalSeqsBatch(
             speech_to_text=SeqsBatch(
                 src_tokens=src_tokens,

+ 1 - 9
scripts/m4t/finetune/dataset.py

@@ -13,9 +13,6 @@ import os
 from argparse import Namespace
 from pathlib import Path
 
-from stopes.hub import load_config
-from stopes.speech.tokenizers import SpeechTokenizer, SpeechTokenizerConfig
-
 from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
 )
@@ -99,15 +96,11 @@ def download_fleurs_dataset(
     source_lang: str,
     target_lang: str,
     split: str,
-    unit_extractor_config: str,
     save_directory: str,
 ) -> str:
     _check_lang_code_mapping(source_lang)
     _check_lang_code_mapping(target_lang)
-    tokenizer_conf: SpeechTokenizerConfig = load_config(
-        unit_extractor_config, namespace=""
-    )
-    tokenizer: SpeechTokenizer = SpeechTokenizer.build(tokenizer_conf)
+    tokenizer = None
     dataset_iterator = Speech2SpeechFleursDatasetBuilder(
         source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
         target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
@@ -168,7 +161,6 @@ def main(args: Namespace) -> None:
     manifest_path = download_fleurs_dataset(
         source_lang=args.source_lang,
         target_lang=args.target_lang,
-        unit_extractor_config="lang41_10k_xlsr_lyr35.yaml",
         split=args.split,
         save_directory=args.save_dir,
     )

+ 9 - 4
scripts/m4t/finetune/trainer.py

@@ -86,9 +86,8 @@ class UnitYFinetuneWrapper(nn.Module):
 
     def forward(
         self, batch: dataloader.MultimodalSeqsBatch
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
         assert self.model.t2u_model is not None
-
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_s2t else dummy_context:  # type:ignore
             assert batch.speech_to_text.src_tokens is not None
@@ -96,6 +95,7 @@ class UnitYFinetuneWrapper(nn.Module):
                 seqs=batch.speech_to_text.src_tokens.to(self.device),
                 seq_lens=batch.speech_to_text.src_lengths.to(self.device),
             )
+            assert batch.speech_to_text.prev_output_tokens is not None
             text_decoder_out, text_decoder_padding_mask = self.model.decode(
                 seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
                 seq_lens=batch.speech_to_text.target_lengths.to(self.device),
@@ -103,7 +103,8 @@ class UnitYFinetuneWrapper(nn.Module):
                 encoder_padding_mask=speech_encoder_padding_mask,
             )
             text_logits = self.model.final_proj(text_decoder_out)
-
+        if batch.text_to_units.prev_output_tokens is None:
+            return (text_logits, None)
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_t2u else dummy_context:  # type:ignore
             (
@@ -141,8 +142,9 @@ class CalcLoss:
         self,
         batch: dataloader.MultimodalSeqsBatch,
         text_logits: torch.Tensor,
-        unit_logits: torch.Tensor,
+        unit_logits: Optional[torch.Tensor],
     ) -> torch.Tensor:
+        assert batch.speech_to_text.target_lengths is not None
         s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
             text_logits.device
         )
@@ -153,6 +155,9 @@ class CalcLoss:
             ignore_prefix_size=1,
             label_smoothing=self.label_smoothing,
         )
+        if unit_logits is None:
+            return s2t_loss / s2t_numel
+        assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
             logits=unit_logits, pad_idx=self.t2u_pad_idx