Forráskód Böngészése

Adjust M4T example finetuning script to align with seamless_communication/fairseq2 changes in Nov'23 (#326)

Ruslan Mavlyutov 1 éve
szülő
commit
9f883a6b4f

+ 26 - 17
src/seamless_communication/cli/m4t/finetune/README.md

@@ -105,8 +105,8 @@ torchrun \
    --learning_rate 1e-6 \
    --warmup_steps 100 \
    --max_epochs 10 \
-   --patience 3 \
-   --model_name seamlessM4T_large \
+   --patience 5 \
+   --model_name seamlessM4T_v2_large \
    --save_model_to $DATASET_DIR/checkpoint.pt
 ```
 
@@ -114,20 +114,29 @@ Excerpt from an example finetuning log:
 
 ```
 ...
-2023-08-21 14:46:16,936 INFO -- trainer.1100368: Eval after 300 updates: loss=8.7755 best_loss=8.7755 patience_steps_left=3
-2023-08-21 14:46:16,936 INFO -- trainer.1100368: Saving model
-2023-08-21 14:46:35,863 INFO -- trainer.1100368: Epoch 006 / update 00310: train loss=16.3768 last lr=5.68E-08
-2023-08-21 14:46:42,610 INFO -- trainer.1100368: Epoch 006 / update 00320: train loss=16.3730 last lr=5.59E-08
-2023-08-21 14:46:48,285 INFO -- trainer.1100368: Epoch 006 / update 00330: train loss=16.4598 last lr=5.50E-08
-2023-08-21 14:46:54,390 INFO -- trainer.1100368: Epoch 006 / update 00340: train loss=16.4218 last lr=5.42E-08
-2023-08-21 14:47:08,461 INFO -- trainer.1100368: Epoch 006 / update 00350: train loss=16.3906 last lr=5.35E-08
-2023-08-21 14:47:09,067 INFO -- trainer.1100368: Run evaluation
-2023-08-21 14:47:19,205 INFO -- trainer.1100368: Eval after 350 updates: loss=8.7462 best_loss=8.7462 patience_steps_left=3
-2023-08-21 14:47:19,205 INFO -- trainer.1100368: Saving model
-2023-08-21 14:47:44,981 INFO -- trainer.1100368: Epoch 007 / update 00360: train loss=16.4267 last lr=5.27E-08
-2023-08-21 14:47:51,383 INFO -- trainer.1100368: Epoch 007 / update 00370: train loss=16.3630 last lr=5.20E-08
-2023-08-21 14:47:58,305 INFO -- trainer.1100368: Epoch 007 / update 00380: train loss=16.3666 last lr=5.13E-08
-2023-08-21 14:48:04,396 INFO -- trainer.1100368: Epoch 007 / update 00390: train loss=16.3605 last lr=5.06E-08
-2023-08-21 14:48:10,630 INFO -- trainer.1100368: Epoch 007 / update 00400: train loss=16.3518 last lr=5.00E-08
+2024-01-17 03:13:12,608 INFO -- trainer: Eval after 200 updates: loss=4.5721 best_loss=4.4743 patience_steps_left=7
+2024-01-17 03:13:19,859 INFO -- trainer: Epoch 004 / update 00210: train loss=4.4922 last lr=6.90E-07
+2024-01-17 03:13:27,946 INFO -- trainer: Epoch 004 / update 00220: train loss=4.4694 last lr=6.74E-07
+2024-01-17 03:13:36,320 INFO -- trainer: Epoch 004 / update 00230: train loss=4.4760 last lr=6.59E-07
+2024-01-17 03:14:08,554 INFO -- trainer: Epoch 005 / update 00240: train loss=4.3438 last lr=6.45E-07
+2024-01-17 03:14:16,529 INFO -- trainer: Epoch 005 / update 00250: train loss=4.2979 last lr=6.32E-07
+2024-01-17 03:14:17,382 INFO -- trainer: Run evaluation
+2024-01-17 03:14:31,172 INFO -- trainer: Eval after 250 updates: loss=4.4967 best_loss=4.4743 patience_steps_left=6
+2024-01-17 03:14:38,497 INFO -- trainer: Epoch 005 / update 00260: train loss=4.2690 last lr=6.20E-07
+2024-01-17 03:14:46,505 INFO -- trainer: Epoch 005 / update 00270: train loss=4.2489 last lr=6.09E-07
+2024-01-17 03:14:54,796 INFO -- trainer: Epoch 005 / update 00280: train loss=4.2422 last lr=5.98E-07
+2024-01-17 03:15:02,976 INFO -- trainer: Epoch 005 / update 00290: train loss=4.1874 last lr=5.87E-07
+2024-01-17 03:15:34,510 INFO -- trainer: Epoch 006 / update 00300: train loss=4.1768 last lr=5.77E-07
+2024-01-17 03:15:35,329 INFO -- trainer: Run evaluation
+2024-01-17 03:15:49,634 INFO -- trainer: Eval after 300 updates: loss=4.4688 best_loss=4.4688 patience_steps_left=10
+2024-01-17 03:15:49,634 INFO -- trainer: Saving model
+2024-01-17 03:16:08,825 INFO -- trainer: Epoch 006 / update 00310: train loss=4.1509 last lr=5.68E-07
+2024-01-17 03:16:16,979 INFO -- trainer: Epoch 006 / update 00320: train loss=4.0949 last lr=5.59E-07
+2024-01-17 03:16:25,142 INFO -- trainer: Epoch 006 / update 00330: train loss=4.1053 last lr=5.50E-07
+2024-01-17 03:16:32,966 INFO -- trainer: Epoch 006 / update 00340: train loss=4.1237 last lr=5.42E-07
+2024-01-17 03:16:53,995 INFO -- trainer: Epoch 006 / update 00350: train loss=4.0980 last lr=5.35E-07
+2024-01-17 03:16:54,690 INFO -- trainer: Run evaluation
+2024-01-17 03:17:08,073 INFO -- trainer: Eval after 350 updates: loss=4.4463 best_loss=4.4463 patience_steps_left=10
+2024-01-17 03:17:08,074 INFO -- trainer: Saving model
 ...
 ```

+ 48 - 6
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -13,11 +13,11 @@ from typing import Any, Dict, Iterable, List, Optional
 import numpy as np
 import torch
 import torchaudio
-import torchaudio.compliance.kaldi as ta_kaldi
 from datasets import Dataset
 from datasets.distributed import split_dataset_by_node
 from fairseq2.data.text import TextTokenEncoder
 from fairseq2.models.nllb import NllbTokenizer
+from fairseq2.data.audio import WaveformToFbankConverter
 from torch import Tensor
 from torch.nn.functional import pad as pad_tensor
 from torch.utils.data import DataLoader
@@ -69,6 +69,10 @@ class BatchingConfig:
     """The pad index to use in fbanks batching."""
 
     batch_size: int = 5
+    """Fixed batch size to use"""
+
+    max_audio_length_sec: float = 15.0
+    """ Drop samples with source audio sample length above the threshold."""
 
     rank: int = 0
     """The rank of this worker in the process group."""
@@ -83,11 +87,13 @@ class BatchingConfig:
     """Select between fp16/fp32 for float tensors """
 
 
-def worker_init_fn(worker_id):
-    np.random.seed(np.random.get_state()[1][0] + worker_id)
+def worker_init_fn(worker_id: int) -> None:
+    np.random.seed(np.random.get_state()[1][0] + worker_id)  # type: ignore
 
 
 class UnitYDataLoader:
+    SAMPLE_RATE = 16_000
+
     def __init__(
         self,
         text_tokenizer: NllbTokenizer,
@@ -100,9 +106,17 @@ class UnitYDataLoader:
         self.unit_tokenizer = unit_tokenizer
         self.unit_encoders_per_lang: Dict[str, UnitTokenEncoder] = {}
         self.batching_config = batching_config
+        self._fbank_extract_params = {
+            "num_mel_bins": 80,
+            "waveform_scale": 32768,
+            "channel_last": True,
+            "standardize": True,
+            "device": torch.device("cpu"),
+            "dtype": self.batching_config.float_dtype,
+        }
         self.dataset = self._load_manifest(dataset_manifest_path)
 
-    def get_dataloader(self) -> DataLoader:
+    def get_dataloader(self) -> DataLoader[SeqsBatch]:
         subset = split_dataset_by_node(
             self.dataset,
             rank=self.batching_config.rank,
@@ -122,8 +136,21 @@ class UnitYDataLoader:
         return self.get_dataloader().__iter__()
 
     def _get_source_fbank(self, sample: LangPairSample) -> Tensor:
-        audio_input = torchaudio.load(sample.source.audio_local_path)[0]
-        return ta_kaldi.fbank(audio_input, num_mel_bins=80)
+        wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+        assert (
+            int(sample_rate) == self.SAMPLE_RATE
+        ), f"sample != {self.SAMPLE_RATE}, please resample"
+        assert len(wav.shape) in (1, 2)
+        if len(wav.shape) == 1:
+            wav = wav.unsqueeze(-1)
+        elif wav.shape[0] <= 2:  # channel is first, should be second
+            wav = wav.transpose(0, 1)
+        return WaveformToFbankConverter(**self._fbank_extract_params)(  # type: ignore
+            {
+                "waveform": wav,
+                "sample_rate": self.SAMPLE_RATE,
+            }
+        )["fbank"]
 
     def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
         """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
@@ -163,10 +190,25 @@ class UnitYDataLoader:
             padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
         return torch.stack([tensor for tensor in padded_tensors], dim=0)
 
+    def _is_long_src_audio(self, sample: LangPairSample) -> bool:
+        wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+        length_s: float = max(wav.shape) / sample_rate
+        return length_s > self.batching_config.max_audio_length_sec
+
     def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
         samples = [LangPairSample.from_json(sample) for sample in raw_samples]
         # input speech
+        #  - filter long audio samples
+        filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
+        samples = filtered_samples if filtered_samples else [samples[0]]  # keep at least one sample
         src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
+        #  - filter NaNs in fbanks
+        with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
+        samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
+        assert len(samples) > 0
+        src_tokens_list = [
+            src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
+        ]
         src_tokens = self._batch_tensors(
             src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
         ).to(self.batching_config.float_dtype)

+ 16 - 4
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -126,6 +126,7 @@ def main() -> None:
     args = init_parser().parse_args()
     dist_utils.init_distributed([logger, trainer.logger])
     device = torch.device("cuda")
+    float_dtype = torch.float16
     text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
     unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
     finetune_params = trainer.FinetuneParams(
@@ -143,12 +144,19 @@ def main() -> None:
     )
     logger.info(f"Finetune params: {finetune_params}")
     model: UnitYModel = load_unity_model(
-        args.model_name, device=finetune_params.device, dtype=torch.float16
+        args.model_name, device=torch.device("cpu"), dtype=float_dtype
     )
-    logger.info(f"Model {model}")
     assert model.target_vocab_info == text_tokenizer.vocab_info
-    assert model.t2u_model is not None
-    assert model.t2u_model.target_vocab_info == unit_tokenizer.vocab_info
+    # (optional) delete unused params to reduce GPU memory consumption
+    if (
+        finetune_params.finetune_mode == trainer.FinetuneMode.SPEECH_TO_TEXT
+        and model.t2u_model is not None
+    ):
+        model.t2u_model = None
+    if model.text_encoder is not None:
+        model.text_encoder = None
+    model = model.to(finetune_params.device)
+    logger.info(f"Model {model}")
 
     train_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
@@ -157,6 +165,8 @@ def main() -> None:
             batch_size=finetune_params.train_batch_size,
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
+            max_audio_length_sec=15.0,
+            float_dtype=float_dtype,
         ),
         dataset_manifest_path=args.train_dataset,
     )
@@ -167,6 +177,8 @@ def main() -> None:
             batch_size=finetune_params.eval_batch_size,
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
+            max_audio_length_sec=100.0,
+            float_dtype=float_dtype,
         ),
         dataset_manifest_path=args.eval_dataset,
     )

+ 44 - 21
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -20,10 +20,13 @@ from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
-from torch.optim import Adam
+from torch.optim import AdamW
 
 from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
-from seamless_communication.models.unity import UnitYModel
+from seamless_communication.models.unity import (
+    UnitYModel,
+    UnitYT2UModel,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -80,26 +83,27 @@ class UnitYFinetuneWrapper(nn.Module):
 
     def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
         super().__init__()
-        assert model.t2u_model is not None
         self.model: UnitYModel = model
         self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
         self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
+        logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}")
         self.device = device
 
     def forward(
         self, batch: dataloader.MultimodalSeqsBatch
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
-        assert self.model.t2u_model is not None
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_s2t else dummy_context:  # type:ignore
             assert batch.speech_to_text.src_tokens is not None
             seqs = batch.speech_to_text.src_tokens.to(self.device)
+            assert batch.speech_to_text.src_lengths is not None
             seq_lens = batch.speech_to_text.src_lengths.to(self.device)
             speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
                 seqs=seqs, padding_mask=PaddingMask(seq_lens, seqs.size(1))
             )
             assert batch.speech_to_text.prev_output_tokens is not None
             seqs = batch.speech_to_text.prev_output_tokens.to(self.device)
+            assert batch.speech_to_text.target_lengths is not None
             seq_lens = batch.speech_to_text.target_lengths.to(self.device)
             text_decoder_out, text_decoder_padding_mask = self.model.decode(
                 seqs=seqs,
@@ -107,19 +111,27 @@ class UnitYFinetuneWrapper(nn.Module):
                 encoder_output=speech_encoder_out,
                 encoder_padding_mask=speech_encoder_padding_mask,
             )
+            assert self.model.final_proj is not None
             text_logits = self.model.final_proj(text_decoder_out)
-        if batch.text_to_units.prev_output_tokens is None:
+        if self.freeze_t2u:
             return (text_logits, None)
+        assert self.model.t2u_model is not None
+        assert batch.text_to_units.prev_output_tokens is not None
         dummy_context = contextmanager(lambda: iter([None]))()
         with torch.no_grad() if self.freeze_t2u else dummy_context:  # type:ignore
+            if not isinstance(self.model.t2u_model, UnitYT2UModel):
+                raise NotImplementedError(
+                    "T2U finetuning implemented only for UnitYT2UModel"
+                )
             (
                 unit_encoder_out,
                 unit_encoder_padding_mask,
             ) = self.model.t2u_model.encode(
-                text_decoder_output=text_decoder_out,
-                text_decoder_padding_mask=text_decoder_padding_mask,
+                seqs=text_decoder_out,
+                padding_mask=text_decoder_padding_mask,
             )
             seqs = batch.text_to_units.prev_output_tokens.to(self.device)
+            assert batch.text_to_units.target_lengths is not None
             seq_lens = batch.text_to_units.target_lengths.to(self.device)
             unit_decoder_out, _ = self.model.t2u_model.decode(
                 seqs=seqs,
@@ -139,7 +151,7 @@ class CalcLoss:
         self,
         label_smoothing: float,
         s2t_vocab_info: VocabularyInfo,
-        t2u_vocab_info: VocabularyInfo,
+        t2u_vocab_info: Optional[VocabularyInfo],
     ):
         self.label_smoothing = label_smoothing
         self.s2t_vocab_info = s2t_vocab_info
@@ -152,25 +164,31 @@ class CalcLoss:
         unit_logits: Optional[torch.Tensor],
     ) -> torch.Tensor:
         assert batch.speech_to_text.target_lengths is not None
-        s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
+        prefix_skip_len = 1  # language tokens to skip
+        s2t_numel = torch.sum(batch.speech_to_text.target_lengths - prefix_skip_len).to(
             text_logits.device
         )
+        assert batch.speech_to_text.target_tokens is not None
         s2t_loss = SequenceModelOutput(
             logits=text_logits, vocab_info=self.s2t_vocab_info
         ).compute_loss(
             targets=batch.speech_to_text.target_tokens.to(text_logits.device),
-            ignore_prefix_size=1,
+            ignore_prefix_size=prefix_skip_len,
             label_smoothing=self.label_smoothing,
         )
         if unit_logits is None:
             return s2t_loss / s2t_numel
         assert batch.text_to_units.target_lengths is not None
-        s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
+        s2u_numel = torch.sum(batch.text_to_units.target_lengths - prefix_skip_len).to(
+            unit_logits.device
+        )
+        assert batch.text_to_units.target_tokens is not None
+        assert self.t2u_vocab_info is not None
         s2u_loss = SequenceModelOutput(
             logits=unit_logits, vocab_info=self.t2u_vocab_info
         ).compute_loss(
             targets=batch.text_to_units.target_tokens.to(unit_logits.device),
-            ignore_prefix_size=1,
+            ignore_prefix_size=prefix_skip_len,
             label_smoothing=self.label_smoothing,
         )
         return s2t_loss / s2t_numel + s2u_loss / s2u_numel
@@ -225,17 +243,17 @@ class UnitYFinetune:
         eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
     ):
         self.params = params
-
-        assert model.t2u_model is not None
         self.calc_loss = CalcLoss(
             label_smoothing=self.params.label_smoothing,
             s2t_vocab_info=model.target_vocab_info,
-            t2u_vocab_info=model.t2u_model.target_vocab_info,
+            t2u_vocab_info=model.t2u_model.target_vocab_info
+            if model.t2u_model is not None
+            else None,
         )
         self.model = self._wrap_model_for_trainining(model=model)
         self.train_data_loader = train_data_loader
         self.eval_data_loader = eval_data_loader
-        self.optimizer = Adam(
+        self.optimizer = AdamW(
             params=self.model.parameters(),
             lr=self.params.learning_rate,
             betas=(0.9, 0.98),
@@ -244,7 +262,7 @@ class UnitYFinetune:
             weight_decay=0.0,
             fused=True,
         )
-        self.grad_scaler = torch.cuda.amp.GradScaler()
+        self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.lr_scheduler = MyleLR(
             optimizer=self.optimizer,
             num_warmup_steps=self.params.warmup_steps,
@@ -257,6 +275,7 @@ class UnitYFinetune:
         self.patience_left: int = self.params.patience
         self.best_eval_loss: Optional[float] = None
         self.is_best_state: bool = False
+        torch.set_float32_matmul_precision("high")
 
     def _reset_stats(self) -> None:
         self.train_loss_hist.reset()
@@ -272,10 +291,11 @@ class UnitYFinetune:
         )
         if not dist_utils.is_dist_initialized():
             return wrapped_model
+        find_unused = self.params.finetune_mode == FinetuneMode.TEXT_TO_SPEECH
         return nn.parallel.DistributedDataParallel(
             wrapped_model,
             device_ids=[dist_utils.get_local_rank()],
-            find_unused_parameters=True,
+            find_unused_parameters=find_unused,
         )
 
     def _update_eval_stats(self, eval_loss: float) -> None:
@@ -314,7 +334,7 @@ class UnitYFinetune:
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
 
-    def _train_step_log(self):
+    def _train_step_log(self) -> None:
         """Log train stats"""
         if (self.update_idx + 1) % self.params.log_steps == 0:
             avg_loss = self.train_loss_hist.reduce()
@@ -332,6 +352,9 @@ class UnitYFinetune:
         self.optimizer.zero_grad()
         tokens, units = self.model(batch)
         loss = self.calc_loss(batch, tokens, units)
+        if loss.isnan().any().item():
+            logger.error(batch.speech_to_text)
+            raise RuntimeError("Loss is Nan. Terminating.")
         self.grad_scaler.scale(loss).backward()
         self.grad_scaler.step(self.optimizer)
         self.grad_scaler.update()
@@ -340,7 +363,7 @@ class UnitYFinetune:
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
 
-    def _save_model(self):
+    def _save_model(self) -> None:
         logger.info("Saving model")
         if dist_utils.is_main_process():
             state_dict = {
@@ -351,7 +374,7 @@ class UnitYFinetune:
         if dist_utils.is_dist_initialized():
             dist.barrier()
 
-    def run(self):
+    def run(self) -> None:
         logger.info("Start finetuning")
         self._reset_stats()
         self._eval_model()

+ 1 - 0
src/seamless_communication/datasets/huggingface.py

@@ -96,6 +96,7 @@ class Speech2SpeechFleursDatasetBuilder:
             split=self.split,
             cache_dir=self.dataset_cache_dir,
             streaming=False,
+            trust_remote_code=True,
         )
         for item in ds:
             audio_path = os.path.join(