Pārlūkot izejas kodu

Fixed finetuning trainer and enable freezing layers (#449)

* Model saving

* Load a model from disk

* Refactored training loop

* Fix GPU OOM in eval

* Fix issues with eval

* Log exception

* Enable freezing

* Freeze instead of defreeze

* Formating

* drop overflow

* Logging

* Address reviews

* Don't continue on GPU OOM

* Remove OOM logic

---------

Co-authored-by: Ruslan Mavlyutov <mavlyutov@meta.com>
Alisamar Husain 1 gadu atpakaļ
vecāks
revīzija
ba7f6d0725

+ 69 - 25
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -32,13 +32,13 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     add_inference_arguments,
     set_generation_opts,
     set_generation_opts,
 )
 )
+from seamless_communication.models.unity import UnitYModel
 from seamless_communication.inference import (
 from seamless_communication.inference import (
     BatchedSpeechOutput,
     BatchedSpeechOutput,
     Modality,
     Modality,
     SequenceGeneratorOptions,
     SequenceGeneratorOptions,
     Translator,
     Translator,
 )
 )
-from seamless_communication.models.unity import load_unity_text_tokenizer
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -247,14 +247,14 @@ def adjust_output_for_corrupted_inputs(
 
 
 def run_eval(
 def run_eval(
     translator: Translator,
     translator: Translator,
-    text_tokenizer: TextTokenizer,
     ctx: EvalContext,
     ctx: EvalContext,
     whisper_model_name: str,
     whisper_model_name: str,
+    n_samples = None
 ) -> None:
 ) -> None:
-    pipeline = build_data_pipeline(ctx, text_tokenizer)
+    pipeline = build_data_pipeline(ctx, translator.text_tokenizer)
 
 
     total_steps = count_lines(ctx.data_file) - 1
     total_steps = count_lines(ctx.data_file) - 1
-    progress_bar = tqdm(total=total_steps)
+    progress_bar = tqdm(total=n_samples or total_steps)
 
 
     output_path = ctx.output_path / ctx.data_file.stem
     output_path = ctx.output_path / ctx.data_file.stem
     output_path.mkdir(parents=True, exist_ok=True)
     output_path.mkdir(parents=True, exist_ok=True)
@@ -294,15 +294,21 @@ def run_eval(
 
 
             # Skip performing inference when the input is entirely corrupted.
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
             if src["seqs"].numel() > 0:
-                (text_output, speech_output,) = translator.predict(
-                    src,
-                    ctx.task,
-                    ctx.target_lang,
-                    src_lang=ctx.source_lang,
-                    text_generation_opts=ctx.text_generation_opts,
-                    unit_generation_opts=ctx.unit_generation_opts,
-                    unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
-                )
+                # HACK:: Fix this bad handling
+                # RuntimeError: The sequence generator returned no hypothesis at index 2. Please file a bug report.
+                try:
+                    (text_output, speech_output,) = translator.predict(
+                        src,
+                        ctx.task,
+                        ctx.target_lang,
+                        src_lang=ctx.source_lang,
+                        text_generation_opts=ctx.text_generation_opts,
+                        unit_generation_opts=ctx.unit_generation_opts,
+                        unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
+                    )
+                except RuntimeError as e:
+                    logger.exception(f"Caught RuntimeError: {e}")
+                    continue
             else:
             else:
                 text_output = []
                 text_output = []
                 if ctx.output_modality == Modality.SPEECH:
                 if ctx.output_modality == Modality.SPEECH:
@@ -338,6 +344,10 @@ def run_eval(
 
 
                 sample_id += 1
                 sample_id += 1
                 progress_bar.update(1)
                 progress_bar.update(1)
+                if n_samples and progress_bar.n == n_samples:
+                    break
+            if n_samples and progress_bar.n == n_samples:
+                break
 
 
     progress_bar.close()
     progress_bar.close()
     logger.info(f"Processed {sample_id} samples")
     logger.info(f"Processed {sample_id} samples")
@@ -352,6 +362,26 @@ def run_eval(
     )
     )
 
 
 
 
+def load_checkpoint(model: UnitYModel, path: str, device = torch.device("cpu")) -> None:
+    saved_model = torch.load(path, map_location=device)["model"]
+    saved_model = { k.replace("model.", ""): v for k, v in saved_model.items() }
+
+    def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
+        return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}
+
+    model.speech_encoder_frontend.load_state_dict(_select_keys(saved_model, "model.speech_encoder_frontend."))
+    model.speech_encoder.load_state_dict(_select_keys(saved_model, "model.speech_encoder."))
+
+    assert model.text_decoder_frontend is not None
+    model.text_decoder_frontend.load_state_dict(_select_keys(saved_model, "model.text_decoder_frontend."))
+
+    assert model.text_decoder is not None
+    model.text_decoder.load_state_dict(_select_keys(saved_model, "model.text_decoder."))
+
+    assert model.final_proj is not None
+    model.final_proj.load_state_dict(_select_keys(saved_model, "model.final_proj."))
+
+
 def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
 def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="M4T evaluation for tasks supported by Translator."
         description="M4T evaluation for tasks supported by Translator."
@@ -362,8 +392,20 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         help="Data file to be evaluated, either TSV file or manifest JSON file."
         help="Data file to be evaluated, either TSV file or manifest JSON file."
         "Format of the manifest JSON file should be that as produced by `m4t_prepare_dataset`"
         "Format of the manifest JSON file should be that as produced by `m4t_prepare_dataset`"
     )
     )
+    parser.add_argument(
+        "--load_checkpoint", 
+        type=str,
+        help="Load a local Checkpoint",
+        default=None
+    )
 
 
     parser = add_inference_arguments(parser)
     parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--device",
+        type=str,
+        help="Device",
+        default="cuda" if torch.cuda.is_available() else "cpu",
+    )
     parser.add_argument(
     parser.add_argument(
         "--batch_size",
         "--batch_size",
         type=int,
         type=int,
@@ -388,7 +430,13 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         help="Whisper model to be used for ASR-BLEU scoring",
         help="Whisper model to be used for ASR-BLEU scoring",
         default="large",
         default="large",
     )
     )
-    args, unknown = parser.parse_known_args()
+    parser.add_argument(
+        "--n_samples",
+        type=int,
+        help="Number of Samples to run eval on. All if None.",
+        default=None,
+    )
+    args, _ = parser.parse_known_args()
     default_args = vars(args)
     default_args = vars(args)
     default_args.update(optional_args) if optional_args else default_args
     default_args.update(optional_args) if optional_args else default_args
     args = Namespace(**default_args)
     args = Namespace(**default_args)
@@ -412,15 +460,9 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         raise ValueError(
         raise ValueError(
             f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
             f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
         )
         )
-
-    if torch.cuda.is_available():
-        device = torch.device("cuda:0")
-        dtype = torch.float16
-    else:
-        device = torch.device("cpu")
-        dtype = torch.float32
-
-    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    
+    device = torch.device(args.device)
+    dtype = torch.float16 if device.type == "cuda" else torch.float32
 
 
     # TODO: Avoid loading the T2U model, vocoder when the output
     # TODO: Avoid loading the T2U model, vocoder when the output
     # modality is text.
     # modality is text.
@@ -428,11 +470,13 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         args.model_name,
         args.model_name,
         args.vocoder_name,
         args.vocoder_name,
         device,
         device,
-        text_tokenizer=text_tokenizer,
         dtype=dtype,
         dtype=dtype,
         input_modality=input_modality,
         input_modality=input_modality,
         output_modality=output_modality,
         output_modality=output_modality,
     )
     )
+    
+    if args.load_checkpoint:
+        load_checkpoint(translator.model, path=args.load_checkpoint, device=device)
 
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
 
 
@@ -465,7 +509,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     # fmt: on
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
 
 
-    run_eval(translator, text_tokenizer, ctx, args.whisper_model_name)
+    run_eval(translator, ctx, args.whisper_model_name, n_samples=args.n_samples)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 48 - 18
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -8,7 +8,7 @@
 import json
 import json
 import logging
 import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -100,6 +100,7 @@ class UnitYDataLoader:
         unit_tokenizer: UnitTokenizer,
         unit_tokenizer: UnitTokenizer,
         dataset_manifest_path: str,
         dataset_manifest_path: str,
         batching_config: BatchingConfig,
         batching_config: BatchingConfig,
+        max_src_tokens_per_batch: int = 100000
     ):
     ):
         self.text_tokenizer = text_tokenizer
         self.text_tokenizer = text_tokenizer
         self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
         self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
@@ -115,6 +116,7 @@ class UnitYDataLoader:
             "dtype": self.batching_config.float_dtype,
             "dtype": self.batching_config.float_dtype,
         }
         }
         self.dataset = self._load_manifest(dataset_manifest_path)
         self.dataset = self._load_manifest(dataset_manifest_path)
+        self.max_src_tokens_per_batch = max_src_tokens_per_batch
 
 
     def get_dataloader(self) -> DataLoader[SeqsBatch]:
     def get_dataloader(self) -> DataLoader[SeqsBatch]:
         subset = split_dataset_by_node(
         subset = split_dataset_by_node(
@@ -156,9 +158,9 @@ class UnitYDataLoader:
         """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
         """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
         target_lang = sample.target.lang
         target_lang = sample.target.lang
         if target_lang not in self.text_encoders_per_lang:
         if target_lang not in self.text_encoders_per_lang:
-            self.text_encoders_per_lang[
-                target_lang
-            ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            self.text_encoders_per_lang[target_lang] = (
+                self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            )
         tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
         tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
         eos_idx = self.text_tokenizer.vocab_info.eos_idx
         eos_idx = self.text_tokenizer.vocab_info.eos_idx
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
@@ -170,9 +172,9 @@ class UnitYDataLoader:
             return None
             return None
         target_lang = sample.target.lang
         target_lang = sample.target.lang
         if target_lang not in self.unit_encoders_per_lang:
         if target_lang not in self.unit_encoders_per_lang:
-            self.unit_encoders_per_lang[
-                target_lang
-            ] = self.unit_tokenizer.create_encoder(lang=target_lang)
+            self.unit_encoders_per_lang[target_lang] = (
+                self.unit_tokenizer.create_encoder(lang=target_lang)
+            )
         tokens = self.unit_encoders_per_lang[target_lang](
         tokens = self.unit_encoders_per_lang[target_lang](
             torch.LongTensor(sample.target.units).unsqueeze(0)
             torch.LongTensor(sample.target.units).unsqueeze(0)
         )
         )
@@ -191,30 +193,58 @@ class UnitYDataLoader:
         return torch.stack([tensor for tensor in padded_tensors], dim=0)
         return torch.stack([tensor for tensor in padded_tensors], dim=0)
 
 
     def _is_long_src_audio(self, sample: LangPairSample) -> bool:
     def _is_long_src_audio(self, sample: LangPairSample) -> bool:
-        wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
-        length_s: float = max(wav.shape) / sample_rate
-        return length_s > self.batching_config.max_audio_length_sec
+        # HACK:: causes errored audios to be excluded but this is difficult to follow
+        try:
+            wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+            length_s: float = max(wav.shape) / sample_rate
+            return length_s > self.batching_config.max_audio_length_sec
+        except:
+            logger.exception(f"Failed to load sample path: {sample.source.audio_local_path}")
+            return True
+
+    def _drop_overflow_samples(
+        self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
+    ) -> List[Tuple[LangPairSample, torch.Tensor]]:
+        # filter by src_tokens length (reverse)
+        samples_with_fbanks = sorted(
+            samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
+        )
+        bwd = samples_with_fbanks[0][1].shape[0]
+        max_samples_for_batch = max(1, self.max_src_tokens_per_batch // bwd)
+        if max_samples_for_batch < len(samples_with_fbanks):
+            samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
+        return samples_with_fbanks
 
 
     def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
     def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
         samples = [LangPairSample.from_json(sample) for sample in raw_samples]
         samples = [LangPairSample.from_json(sample) for sample in raw_samples]
         # input speech
         # input speech
+        
         #  - filter long audio samples
         #  - filter long audio samples
-        filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
-        samples = filtered_samples if filtered_samples else [samples[0]]  # keep at least one sample
-        src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
+        filtered_samples = [
+            sample for sample in samples if not self._is_long_src_audio(sample)
+        ]
+        samples = (
+            filtered_samples if filtered_samples else [samples[0]]
+        )  # keep at least one sample
+        with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
         #  - filter NaNs in fbanks
         #  - filter NaNs in fbanks
-        with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
-        samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
-        assert len(samples) > 0
-        src_tokens_list = [
-            src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
+        filtered = [
+            (sample, fbank)
+            for sample, fbank in with_fbanks
+            if not fbank.isnan().any().item()
         ]
         ]
+        filtered = self._drop_overflow_samples(filtered)
+
+        samples = [sample for sample, _ in filtered]
+        src_tokens_list = [src_tokens for _, src_tokens in filtered]
+        assert len(samples) > 0
         src_tokens = self._batch_tensors(
         src_tokens = self._batch_tensors(
             src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
             src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
         ).to(self.batching_config.float_dtype)
         ).to(self.batching_config.float_dtype)
         src_lengths = torch.LongTensor(
         src_lengths = torch.LongTensor(
             [src_tokens.shape[0] for src_tokens in src_tokens_list]
             [src_tokens.shape[0] for src_tokens in src_tokens_list]
         )
         )
+        
         # output text
         # output text
         text_tokens_list = [
         text_tokens_list = [
             self._get_tokenized_target_text(sample) for sample in samples
             self._get_tokenized_target_text(sample) for sample in samples

+ 38 - 17
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -10,12 +10,9 @@ import os
 from pathlib import Path
 from pathlib import Path
 
 
 import torch
 import torch
-from fairseq2.models.nllb.tokenizer import NllbTokenizer
 
 
 from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
 from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
-    UnitTokenizer,
-    UnitYModel,
     load_unity_model,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
@@ -108,6 +105,12 @@ def init_parser() -> argparse.ArgumentParser:
         default=10,
         default=10,
         help=("Log inner loss after each `log_steps` training steps"),
         help=("Log inner loss after each `log_steps` training steps"),
     )
     )
+    parser.add_argument(
+        "--max_src_tokens",
+        type=int,
+        default=7000,
+        help=("Maximum number of src_tokens per batch, used to avoid GPU OOM and maximize the effective batch size"),
+    )
     parser.add_argument(
     parser.add_argument(
         "--mode",
         "--mode",
         type=trainer.FinetuneMode,
         type=trainer.FinetuneMode,
@@ -119,6 +122,14 @@ def init_parser() -> argparse.ArgumentParser:
             "* `SPEECH_TO_TEXT` -- finetune only S2T"
             "* `SPEECH_TO_TEXT` -- finetune only S2T"
         ),
         ),
     )
     )
+    parser.add_argument(
+        "--freeze_layers",
+        nargs="*",
+        required=False,
+        default=None,
+        # TODO: better description
+        help=("A list of modules to freeze in the model. If empty, everything will be trained."),
+    )
     parser.add_argument(
     parser.add_argument(
         "--device",
         "--device",
         type=str,
         type=str,
@@ -130,14 +141,19 @@ def init_parser() -> argparse.ArgumentParser:
 
 
 def main() -> None:
 def main() -> None:
     args = init_parser().parse_args()
     args = init_parser().parse_args()
+    
     dist_utils.init_distributed([logger, trainer.logger])
     dist_utils.init_distributed([logger, trainer.logger])
-    text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
-    unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
+    float_dtype = torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16
+    
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+    
     finetune_params = trainer.FinetuneParams(
     finetune_params = trainer.FinetuneParams(
+        model_name=args.model_name,
         finetune_mode=args.mode,
         finetune_mode=args.mode,
         save_model_path=args.save_model_to,
         save_model_path=args.save_model_to,
         device=torch.device(args.device),
         device=torch.device(args.device),
-        float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
+        float_dtype=float_dtype,
         train_batch_size=args.batch_size,
         train_batch_size=args.batch_size,
         eval_batch_size=args.batch_size,
         eval_batch_size=args.batch_size,
         patience=args.patience,
         patience=args.patience,
@@ -147,22 +163,25 @@ def main() -> None:
         eval_steps=args.eval_steps,
         eval_steps=args.eval_steps,
         log_steps=args.log_steps,
         log_steps=args.log_steps,
     )
     )
-    logger.info(f"Finetune params: {finetune_params}")
-    model: UnitYModel = load_unity_model(
-        args.model_name, device=torch.device("cpu"), dtype=torch.float32
-    )
+    
+    logger.info(f"Finetune Params: {finetune_params}")
+    
+    model = load_unity_model(args.model_name, device=torch.device("cpu"), dtype=torch.float32)
     assert model.target_vocab_info == text_tokenizer.vocab_info
     assert model.target_vocab_info == text_tokenizer.vocab_info
-    # (optional) delete unused params to reduce GPU memory consumption
+    
     if (
     if (
         finetune_params.finetune_mode == trainer.FinetuneMode.SPEECH_TO_TEXT
         finetune_params.finetune_mode == trainer.FinetuneMode.SPEECH_TO_TEXT
         and model.t2u_model is not None
         and model.t2u_model is not None
     ):
     ):
         model.t2u_model = None
         model.t2u_model = None
+    
     if model.text_encoder is not None:
     if model.text_encoder is not None:
         model.text_encoder = None
         model.text_encoder = None
+    
+    # Put model on selected device
     model = model.to(finetune_params.device)
     model = model.to(finetune_params.device)
-    logger.info(f"<{args.model_name}> {model}")
 
 
+    # TODO: delete unused params to reduce GPU memory consumption
     train_dataloader = dataloader.UnitYDataLoader(
     train_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
         text_tokenizer=text_tokenizer,
         unit_tokenizer=unit_tokenizer,
         unit_tokenizer=unit_tokenizer,
@@ -174,7 +193,8 @@ def main() -> None:
             float_dtype=finetune_params.float_dtype,
             float_dtype=finetune_params.float_dtype,
         ),
         ),
         dataset_manifest_path=args.train_dataset,
         dataset_manifest_path=args.train_dataset,
-    )
+        max_src_tokens_per_batch=args.max_src_tokens)
+    
     eval_dataloader = dataloader.UnitYDataLoader(
     eval_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
         text_tokenizer=text_tokenizer,
         unit_tokenizer=unit_tokenizer,
         unit_tokenizer=unit_tokenizer,
@@ -182,17 +202,18 @@ def main() -> None:
             batch_size=finetune_params.eval_batch_size,
             batch_size=finetune_params.eval_batch_size,
             rank=dist_utils.get_rank(),
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
             world_size=dist_utils.get_world_size(),
-            max_audio_length_sec=100.0,
+            max_audio_length_sec=75.0,
             float_dtype=finetune_params.float_dtype,
             float_dtype=finetune_params.float_dtype,
         ),
         ),
-        dataset_manifest_path=args.eval_dataset,
-    )
+        dataset_manifest_path=args.eval_dataset)
+    
     finetune = trainer.UnitYFinetune(
     finetune = trainer.UnitYFinetune(
         model=model,
         model=model,
         params=finetune_params,
         params=finetune_params,
         train_data_loader=train_dataloader,
         train_data_loader=train_dataloader,
         eval_data_loader=eval_dataloader,
         eval_data_loader=eval_dataloader,
-    )
+        freeze_modules=args.freeze_layers)
+    
     finetune.run()
     finetune.run()
 
 
 
 

+ 76 - 40
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -6,12 +6,13 @@
 
 
 
 
 import logging
 import logging
+import time
 from contextlib import contextmanager
 from contextlib import contextmanager
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
 from tqdm import tqdm
 from tqdm import tqdm
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
 
 
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
@@ -40,6 +41,9 @@ class FinetuneMode(Enum):
 
 
 @dataclass
 @dataclass
 class FinetuneParams:
 class FinetuneParams:
+    model_name: str
+    """Model name of model being finetuned."""
+    
     save_model_path: Path
     save_model_path: Path
     """Path were to save finetuned model."""
     """Path were to save finetuned model."""
 
 
@@ -245,6 +249,7 @@ class UnitYFinetune:
         params: FinetuneParams,
         params: FinetuneParams,
         train_data_loader: dataloader.UnitYDataLoader,
         train_data_loader: dataloader.UnitYDataLoader,
         eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
         eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
+        freeze_modules: Optional[List[Union[str, torch.nn.Module]]] = None
     ):
     ):
         self.params = params
         self.params = params
         self.calc_loss = CalcLoss(
         self.calc_loss = CalcLoss(
@@ -254,9 +259,15 @@ class UnitYFinetune:
             if model.t2u_model is not None
             if model.t2u_model is not None
             else None,
             else None,
         )
         )
+        
         self.model = self._wrap_model_for_trainining(model=model)
         self.model = self._wrap_model_for_trainining(model=model)
+        if freeze_modules:
+            self._freeze_modules(freeze_modules)
+        
         self.train_data_loader = train_data_loader
         self.train_data_loader = train_data_loader
         self.eval_data_loader = eval_data_loader
         self.eval_data_loader = eval_data_loader
+        
+        self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.optimizer = AdamW(
         self.optimizer = AdamW(
             params=self.model.parameters(),
             params=self.model.parameters(),
             lr=self.params.learning_rate,
             lr=self.params.learning_rate,
@@ -266,7 +277,6 @@ class UnitYFinetune:
             weight_decay=0.0,
             weight_decay=0.0,
             fused=(self.params.device.type == "cuda"),
             fused=(self.params.device.type == "cuda"),
         )
         )
-        self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.lr_scheduler = MyleLR(
         self.lr_scheduler = MyleLR(
             optimizer=self.optimizer,
             optimizer=self.optimizer,
             num_warmup_steps=self.params.warmup_steps,
             num_warmup_steps=self.params.warmup_steps,
@@ -301,6 +311,14 @@ class UnitYFinetune:
             device_ids=[dist_utils.get_local_rank()],
             device_ids=[dist_utils.get_local_rank()],
             find_unused_parameters=find_unused,
             find_unused_parameters=find_unused,
         )
         )
+        
+    def _freeze_modules(self, frozen_modules: List[str] = []) -> None:
+        for icecube in frozen_modules:
+            for (name, module) in self.model.named_modules():
+                if name.startswith(icecube):
+                    logger.info(f"Freezing Module: {name}")
+                    for param in module.parameters():
+                        param.requires_grad = False
 
 
     def _update_eval_stats(self, eval_loss: float) -> None:
     def _update_eval_stats(self, eval_loss: float) -> None:
         self.is_best_state = (
         self.is_best_state = (
@@ -317,25 +335,26 @@ class UnitYFinetune:
             f"patience_steps_left={self.patience_left}"
             f"patience_steps_left={self.patience_left}"
         )
         )
 
 
-    def _eval_model(self) -> None:
+    @torch.no_grad()
+    def _eval_model(self, n_batches: int) -> None:
         """Calc avg loss on eval dataset and update evaluation stats"""
         """Calc avg loss on eval dataset and update evaluation stats"""
         if self.eval_data_loader is None:
         if self.eval_data_loader is None:
             return
             return
-        logger.info("Run evaluation")
+        logger.info(f"Evaluation Step {self.update_idx // self.params.eval_steps}...")
         loss_hist = LossCollector(device=self.params.device)
         loss_hist = LossCollector(device=self.params.device)
         self.model.eval()
         self.model.eval()
-        with torch.no_grad():
-            for batch in tqdm(self.eval_data_loader.get_dataloader()):
-                assert batch.speech_to_text.src_tokens is not None
-                with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
-                    loss = self.calc_loss(batch, *self.model(batch))
-                if loss.isnan():
-                    logger.warning("Eval loss value is NaN, setting to inf")
-                    loss_val = float("Inf")
-                else:
-                    loss_val = loss.item()
-                del batch  # force memory release
-                loss_hist.update(1, loss_val)
+        for batch in self.eval_data_loader.get_dataloader():
+            if n_batches == 0:
+                break
+            assert batch.speech_to_text.src_tokens is not None
+            with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
+                loss = self.calc_loss(batch, *self.model(batch))
+            if loss.isnan():
+                logger.warning("Eval batch loss value is NaN, skipping")
+                continue
+            del batch  # force memory release
+            loss_hist.update(1, loss.item())
+            n_batches -= 1
         eval_loss = loss_hist.reduce()
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
         self._update_eval_stats(eval_loss)
 
 
@@ -351,53 +370,70 @@ class UnitYFinetune:
                 f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
                 f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
             )
             )
 
 
-    def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
+    def _train_step(self, batch: List[dataloader.MultimodalSeqsBatch]) -> None:
         """Run one train step"""
         """Run one train step"""
         self.model.train()
         self.model.train()
         self.optimizer.zero_grad()
         self.optimizer.zero_grad()
         with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
         with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
             tokens, units = self.model(batch)
             tokens, units = self.model(batch)
+        
         loss = self.calc_loss(batch, tokens, units)
         loss = self.calc_loss(batch, tokens, units)
         if loss.isnan().any().item():
         if loss.isnan().any().item():
             logger.error(batch.speech_to_text)
             logger.error(batch.speech_to_text)
-            raise RuntimeError("Loss is Nan. Terminating.")
+            raise RuntimeError("Train loss is NaN! Something is wrong in the model!")
+        
         self.grad_scaler.scale(loss).backward()
         self.grad_scaler.scale(loss).backward()
         self.grad_scaler.step(self.optimizer)
         self.grad_scaler.step(self.optimizer)
         self.grad_scaler.update()
         self.grad_scaler.update()
         self.lr_scheduler.step()
         self.lr_scheduler.step()
+        
         assert batch.speech_to_text.src_tokens is not None
         assert batch.speech_to_text.src_tokens is not None
         self.train_loss_hist.update(1, loss.item())
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
         self._train_step_log()
+        self.update_idx += 1
 
 
     def _save_model(self) -> None:
     def _save_model(self) -> None:
         logger.info("Saving model")
         logger.info("Saving model")
         if dist_utils.is_main_process():
         if dist_utils.is_main_process():
-            state_dict = {
-                key.replace("module.model.", ""): value
-                for key, value in self.model.state_dict().items()
-            }
-            torch.save(state_dict, self.params.save_model_path)
+            torch.save({
+                "model_name": self.params.model_name,
+                "model": {
+                    key.replace("module.model.model.", ""): value
+                    for key, value in self.model.state_dict().items()
+                }
+            }, self.params.save_model_path)
         if dist_utils.is_dist_initialized():
         if dist_utils.is_dist_initialized():
             dist.barrier()
             dist.barrier()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        logger.info("Start finetuning")
+        logger.info("Start Finetuning")
         self._reset_stats()
         self._reset_stats()
         self._eval_model()
         self._eval_model()
-        batch_itr = self.train_data_loader.get_dataloader()
+        
+        train_dataloader = self.train_data_loader.get_dataloader()
+        
         while self.epoch_idx < self.params.max_epochs and self.patience_left:
         while self.epoch_idx < self.params.max_epochs and self.patience_left:
-            for train_batch in batch_itr:
-                self._train_step(batch=train_batch)
-                if self.update_idx and self.update_idx % self.params.eval_steps == 0:
-                    self._eval_model()
-                    if self.is_best_state:
-                        self._save_model()
-                    elif not self.patience_left:
-                        no_improve_steps = self.params.eval_steps * self.params.patience
-                        logger.info(
-                            "Early termination, as eval loss did not improve "
-                            f"over last {no_improve_steps} updates"
-                        )
-                        break
-                self.update_idx += 1
-            self.epoch_idx += 1
+            for train_batch in tqdm(train_dataloader, desc="Training Steps"):
+                # Run batch through train step
+                self._train_step(train_batch)
+                
+                # Perform eval if its time to eval
+                if not self.update_idx or self.update_idx % self.params.eval_steps != 0:
+                    continue
+                
+                # Clear GPU memory for eval
+                torch.cuda.empty_cache()
+                self._eval_model(n_batches=100)
+                    
+                # Save the current model if its the best we've ever had
+                if self.is_best_state:
+                    self._save_model()
+                elif not self.patience_left:
+                    no_improve_steps = self.params.eval_steps * self.params.patience
+                    logger.info(
+                        "Early termination, as eval loss did not improve "
+                        f"over last {no_improve_steps} updates"
+                    )
+                    break
+                
+            self.epoch_idx += 1

+ 108 - 2
src/seamless_communication/datasets/huggingface.py

@@ -28,7 +28,7 @@ class SpeechTokenizer:
 class Speech2SpeechFleursDatasetBuilder:
 class Speech2SpeechFleursDatasetBuilder:
     """Assembles speech2speech dataset from google/fleurs on HuggingFace"""
     """Assembles speech2speech dataset from google/fleurs on HuggingFace"""
 
 
-    HF_FLEURS_DATASET_NAME = "google/fleurs"
+    DATASET_NAME = "google/fleurs"
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -91,7 +91,113 @@ class Speech2SpeechFleursDatasetBuilder:
 
 
     def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
     def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
         ds = load_dataset(
         ds = load_dataset(
-            self.HF_FLEURS_DATASET_NAME,
+            self.DATASET_NAME,
+            lang,
+            split=self.split,
+            cache_dir=self.dataset_cache_dir,
+            streaming=False,
+            trust_remote_code=True,
+        )
+        for item in ds:
+            audio_path = os.path.join(
+                os.path.dirname(item["path"]), item["audio"]["path"]
+            )
+            (sample_id, audio_local_path, waveform, sampling_rate, text) = (
+                item["id"],
+                audio_path,
+                item["audio"]["array"],
+                item["audio"]["sampling_rate"],
+                item["transcription"],
+            )
+            yield self._prepare_sample(
+                sample_id=sample_id,
+                audio_local_path=audio_local_path,
+                waveform_npy=waveform,
+                sampling_rate=sampling_rate,
+                text=text,
+                lang=lang,
+            )
+
+    def __iter__(self) -> Iterable[LangPairSample]:
+        logger.info(f"Loading {self.target_lang} samples")
+        target_samples: Dict[int, MultimodalSample] = {}
+        for idx, sample in enumerate(
+            self.iterate_lang_audio_samples(lang=self.target_lang)
+        ):
+            if idx and idx % 100 == 0:
+                logger.info(f"..loaded {idx} target samples")
+            target_samples[sample.id] = sample
+
+        logger.info(f"Loading {self.source_lang} samples")
+        for idx, sample in enumerate(
+            self.iterate_lang_audio_samples(lang=self.source_lang)
+        ):
+            if idx and idx % 100 == 0:
+                logger.info(f"..loaded {idx} source samples")
+            if sample.id in target_samples:
+                yield LangPairSample(source=sample, target=target_samples[sample.id])
+
+
+class Speech2TextGigaspeechDatasetBuilder:
+    """ Assembles speech2speech dataset from google/fleurs on HuggingFace.
+        This dataset requires signing an license agreement and using an auth token.
+    """
+
+    DATASET_NAME = "speechcolab/gigaspeech"
+
+    def __init__(
+        self,
+        auth_token: str,
+        split: str = "test",
+        skip_source_audio: bool = True,
+        skip_target_audio: bool = True,
+        audio_dtype: torch.dtype = torch.float32,
+        dataset_cache_dir: Optional[str] = None,
+        speech_tokenizer: Optional[SpeechTokenizer] = None,
+    ):
+        self.auth_token = auth_token
+        self.split = split
+        self.dataset_cache_dir = dataset_cache_dir
+        self.audio_dtype = audio_dtype
+        self.skip_source_audio = skip_source_audio
+        self.skip_target_audio = skip_target_audio
+        self.speech_tokenizer = speech_tokenizer
+
+    def _prepare_sample(
+        self,
+        sample_id: int,
+        lang: str,
+        text: str,
+        audio_local_path: Optional[str] = None,
+        waveform_npy: Optional[np.ndarray] = None,
+        sampling_rate: Optional[int] = None,
+    ) -> MultimodalSample:
+        if waveform_npy is not None:
+            waveform = torch.from_numpy(waveform_npy).to(self.audio_dtype)
+        else:
+            waveform = None
+        if self.speech_tokenizer is not None and waveform_npy is not None:
+            assert waveform is not None
+            assert sampling_rate is not None
+            units_tensor = self.speech_tokenizer.encode(
+                waveform, sampling_rate
+            ).reshape(-1)
+            units = units_tensor.tolist()
+        else:
+            units = None
+        return MultimodalSample(
+            id=sample_id,
+            lang=lang,
+            text=text.strip(),
+            audio_local_path=audio_local_path,
+            waveform=waveform,
+            sampling_rate=sampling_rate,
+            units=units,
+        )
+
+    def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
+        ds = load_dataset(
+            self.DATASET_NAME,
             lang,
             lang,
             split=self.split,
             split=self.split,
             cache_dir=self.dataset_cache_dir,
             cache_dir=self.dataset_cache_dir,