Эх сурвалжийг харах

Fixed finetuning trainer and enable freezing layers (#449)

* Model saving

* Load a model from disk

* Refactored training loop

* Fix GPU OOM in eval

* Fix issues with eval

* Log exception

* Enable freezing

* Freeze instead of defreeze

* Formating

* drop overflow

* Logging

* Address reviews

* Don't continue on GPU OOM

* Remove OOM logic

---------

Co-authored-by: Ruslan Mavlyutov <mavlyutov@meta.com>
Alisamar Husain 1 жил өмнө
parent
commit
ba7f6d0725

+ 69 - 25
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -32,13 +32,13 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
+from seamless_communication.models.unity import UnitYModel
 from seamless_communication.inference import (
     BatchedSpeechOutput,
     Modality,
     SequenceGeneratorOptions,
     Translator,
 )
-from seamless_communication.models.unity import load_unity_text_tokenizer
 
 logging.basicConfig(
     level=logging.INFO,
@@ -247,14 +247,14 @@ def adjust_output_for_corrupted_inputs(
 
 def run_eval(
     translator: Translator,
-    text_tokenizer: TextTokenizer,
     ctx: EvalContext,
     whisper_model_name: str,
+    n_samples = None
 ) -> None:
-    pipeline = build_data_pipeline(ctx, text_tokenizer)
+    pipeline = build_data_pipeline(ctx, translator.text_tokenizer)
 
     total_steps = count_lines(ctx.data_file) - 1
-    progress_bar = tqdm(total=total_steps)
+    progress_bar = tqdm(total=n_samples or total_steps)
 
     output_path = ctx.output_path / ctx.data_file.stem
     output_path.mkdir(parents=True, exist_ok=True)
@@ -294,15 +294,21 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (text_output, speech_output,) = translator.predict(
-                    src,
-                    ctx.task,
-                    ctx.target_lang,
-                    src_lang=ctx.source_lang,
-                    text_generation_opts=ctx.text_generation_opts,
-                    unit_generation_opts=ctx.unit_generation_opts,
-                    unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
-                )
+                # HACK:: Fix this bad handling
+                # RuntimeError: The sequence generator returned no hypothesis at index 2. Please file a bug report.
+                try:
+                    (text_output, speech_output,) = translator.predict(
+                        src,
+                        ctx.task,
+                        ctx.target_lang,
+                        src_lang=ctx.source_lang,
+                        text_generation_opts=ctx.text_generation_opts,
+                        unit_generation_opts=ctx.unit_generation_opts,
+                        unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
+                    )
+                except RuntimeError as e:
+                    logger.exception(f"Caught RuntimeError: {e}")
+                    continue
             else:
                 text_output = []
                 if ctx.output_modality == Modality.SPEECH:
@@ -338,6 +344,10 @@ def run_eval(
 
                 sample_id += 1
                 progress_bar.update(1)
+                if n_samples and progress_bar.n == n_samples:
+                    break
+            if n_samples and progress_bar.n == n_samples:
+                break
 
     progress_bar.close()
     logger.info(f"Processed {sample_id} samples")
@@ -352,6 +362,26 @@ def run_eval(
     )
 
 
+def load_checkpoint(model: UnitYModel, path: str, device = torch.device("cpu")) -> None:
+    saved_model = torch.load(path, map_location=device)["model"]
+    saved_model = { k.replace("model.", ""): v for k, v in saved_model.items() }
+
+    def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
+        return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}
+
+    model.speech_encoder_frontend.load_state_dict(_select_keys(saved_model, "model.speech_encoder_frontend."))
+    model.speech_encoder.load_state_dict(_select_keys(saved_model, "model.speech_encoder."))
+
+    assert model.text_decoder_frontend is not None
+    model.text_decoder_frontend.load_state_dict(_select_keys(saved_model, "model.text_decoder_frontend."))
+
+    assert model.text_decoder is not None
+    model.text_decoder.load_state_dict(_select_keys(saved_model, "model.text_decoder."))
+
+    assert model.final_proj is not None
+    model.final_proj.load_state_dict(_select_keys(saved_model, "model.final_proj."))
+
+
 def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     parser = argparse.ArgumentParser(
         description="M4T evaluation for tasks supported by Translator."
@@ -362,8 +392,20 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         help="Data file to be evaluated, either TSV file or manifest JSON file."
         "Format of the manifest JSON file should be that as produced by `m4t_prepare_dataset`"
     )
+    parser.add_argument(
+        "--load_checkpoint", 
+        type=str,
+        help="Load a local Checkpoint",
+        default=None
+    )
 
     parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--device",
+        type=str,
+        help="Device",
+        default="cuda" if torch.cuda.is_available() else "cpu",
+    )
     parser.add_argument(
         "--batch_size",
         type=int,
@@ -388,7 +430,13 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         help="Whisper model to be used for ASR-BLEU scoring",
         default="large",
     )
-    args, unknown = parser.parse_known_args()
+    parser.add_argument(
+        "--n_samples",
+        type=int,
+        help="Number of Samples to run eval on. All if None.",
+        default=None,
+    )
+    args, _ = parser.parse_known_args()
     default_args = vars(args)
     default_args.update(optional_args) if optional_args else default_args
     args = Namespace(**default_args)
@@ -412,15 +460,9 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         raise ValueError(
             f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
         )
-
-    if torch.cuda.is_available():
-        device = torch.device("cuda:0")
-        dtype = torch.float16
-    else:
-        device = torch.device("cpu")
-        dtype = torch.float32
-
-    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    
+    device = torch.device(args.device)
+    dtype = torch.float16 if device.type == "cuda" else torch.float32
 
     # TODO: Avoid loading the T2U model, vocoder when the output
     # modality is text.
@@ -428,11 +470,13 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         args.model_name,
         args.vocoder_name,
         device,
-        text_tokenizer=text_tokenizer,
         dtype=dtype,
         input_modality=input_modality,
         output_modality=output_modality,
     )
+    
+    if args.load_checkpoint:
+        load_checkpoint(translator.model, path=args.load_checkpoint, device=device)
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
 
@@ -465,7 +509,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
 
-    run_eval(translator, text_tokenizer, ctx, args.whisper_model_name)
+    run_eval(translator, ctx, args.whisper_model_name, n_samples=args.n_samples)
 
 
 if __name__ == "__main__":

+ 48 - 18
src/seamless_communication/cli/m4t/finetune/dataloader.py

@@ -8,7 +8,7 @@
 import json
 import logging
 from dataclasses import dataclass
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import numpy as np
 import torch
@@ -100,6 +100,7 @@ class UnitYDataLoader:
         unit_tokenizer: UnitTokenizer,
         dataset_manifest_path: str,
         batching_config: BatchingConfig,
+        max_src_tokens_per_batch: int = 100000
     ):
         self.text_tokenizer = text_tokenizer
         self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
@@ -115,6 +116,7 @@ class UnitYDataLoader:
             "dtype": self.batching_config.float_dtype,
         }
         self.dataset = self._load_manifest(dataset_manifest_path)
+        self.max_src_tokens_per_batch = max_src_tokens_per_batch
 
     def get_dataloader(self) -> DataLoader[SeqsBatch]:
         subset = split_dataset_by_node(
@@ -156,9 +158,9 @@ class UnitYDataLoader:
         """Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
         target_lang = sample.target.lang
         if target_lang not in self.text_encoders_per_lang:
-            self.text_encoders_per_lang[
-                target_lang
-            ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            self.text_encoders_per_lang[target_lang] = (
+                self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
+            )
         tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
         eos_idx = self.text_tokenizer.vocab_info.eos_idx
         tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
@@ -170,9 +172,9 @@ class UnitYDataLoader:
             return None
         target_lang = sample.target.lang
         if target_lang not in self.unit_encoders_per_lang:
-            self.unit_encoders_per_lang[
-                target_lang
-            ] = self.unit_tokenizer.create_encoder(lang=target_lang)
+            self.unit_encoders_per_lang[target_lang] = (
+                self.unit_tokenizer.create_encoder(lang=target_lang)
+            )
         tokens = self.unit_encoders_per_lang[target_lang](
             torch.LongTensor(sample.target.units).unsqueeze(0)
         )
@@ -191,30 +193,58 @@ class UnitYDataLoader:
         return torch.stack([tensor for tensor in padded_tensors], dim=0)
 
     def _is_long_src_audio(self, sample: LangPairSample) -> bool:
-        wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
-        length_s: float = max(wav.shape) / sample_rate
-        return length_s > self.batching_config.max_audio_length_sec
+        # HACK:: causes errored audios to be excluded but this is difficult to follow
+        try:
+            wav, sample_rate = torchaudio.load(sample.source.audio_local_path)
+            length_s: float = max(wav.shape) / sample_rate
+            return length_s > self.batching_config.max_audio_length_sec
+        except:
+            logger.exception(f"Failed to load sample path: {sample.source.audio_local_path}")
+            return True
+
+    def _drop_overflow_samples(
+        self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
+    ) -> List[Tuple[LangPairSample, torch.Tensor]]:
+        # filter by src_tokens length (reverse)
+        samples_with_fbanks = sorted(
+            samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
+        )
+        bwd = samples_with_fbanks[0][1].shape[0]
+        max_samples_for_batch = max(1, self.max_src_tokens_per_batch // bwd)
+        if max_samples_for_batch < len(samples_with_fbanks):
+            samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
+        return samples_with_fbanks
 
     def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
         samples = [LangPairSample.from_json(sample) for sample in raw_samples]
         # input speech
+        
         #  - filter long audio samples
-        filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
-        samples = filtered_samples if filtered_samples else [samples[0]]  # keep at least one sample
-        src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
+        filtered_samples = [
+            sample for sample in samples if not self._is_long_src_audio(sample)
+        ]
+        samples = (
+            filtered_samples if filtered_samples else [samples[0]]
+        )  # keep at least one sample
+        with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
         #  - filter NaNs in fbanks
-        with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
-        samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
-        assert len(samples) > 0
-        src_tokens_list = [
-            src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
+        filtered = [
+            (sample, fbank)
+            for sample, fbank in with_fbanks
+            if not fbank.isnan().any().item()
         ]
+        filtered = self._drop_overflow_samples(filtered)
+
+        samples = [sample for sample, _ in filtered]
+        src_tokens_list = [src_tokens for _, src_tokens in filtered]
+        assert len(samples) > 0
         src_tokens = self._batch_tensors(
             src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
         ).to(self.batching_config.float_dtype)
         src_lengths = torch.LongTensor(
             [src_tokens.shape[0] for src_tokens in src_tokens_list]
         )
+        
         # output text
         text_tokens_list = [
             self._get_tokenized_target_text(sample) for sample in samples

+ 38 - 17
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -10,12 +10,9 @@ import os
 from pathlib import Path
 
 import torch
-from fairseq2.models.nllb.tokenizer import NllbTokenizer
 
 from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
 from seamless_communication.models.unity import (
-    UnitTokenizer,
-    UnitYModel,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
@@ -108,6 +105,12 @@ def init_parser() -> argparse.ArgumentParser:
         default=10,
         help=("Log inner loss after each `log_steps` training steps"),
     )
+    parser.add_argument(
+        "--max_src_tokens",
+        type=int,
+        default=7000,
+        help=("Maximum number of src_tokens per batch, used to avoid GPU OOM and maximize the effective batch size"),
+    )
     parser.add_argument(
         "--mode",
         type=trainer.FinetuneMode,
@@ -119,6 +122,14 @@ def init_parser() -> argparse.ArgumentParser:
             "* `SPEECH_TO_TEXT` -- finetune only S2T"
         ),
     )
+    parser.add_argument(
+        "--freeze_layers",
+        nargs="*",
+        required=False,
+        default=None,
+        # TODO: better description
+        help=("A list of modules to freeze in the model. If empty, everything will be trained."),
+    )
     parser.add_argument(
         "--device",
         type=str,
@@ -130,14 +141,19 @@ def init_parser() -> argparse.ArgumentParser:
 
 def main() -> None:
     args = init_parser().parse_args()
+    
     dist_utils.init_distributed([logger, trainer.logger])
-    text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
-    unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
+    float_dtype = torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16
+    
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+    
     finetune_params = trainer.FinetuneParams(
+        model_name=args.model_name,
         finetune_mode=args.mode,
         save_model_path=args.save_model_to,
         device=torch.device(args.device),
-        float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
+        float_dtype=float_dtype,
         train_batch_size=args.batch_size,
         eval_batch_size=args.batch_size,
         patience=args.patience,
@@ -147,22 +163,25 @@ def main() -> None:
         eval_steps=args.eval_steps,
         log_steps=args.log_steps,
     )
-    logger.info(f"Finetune params: {finetune_params}")
-    model: UnitYModel = load_unity_model(
-        args.model_name, device=torch.device("cpu"), dtype=torch.float32
-    )
+    
+    logger.info(f"Finetune Params: {finetune_params}")
+    
+    model = load_unity_model(args.model_name, device=torch.device("cpu"), dtype=torch.float32)
     assert model.target_vocab_info == text_tokenizer.vocab_info
-    # (optional) delete unused params to reduce GPU memory consumption
+    
     if (
         finetune_params.finetune_mode == trainer.FinetuneMode.SPEECH_TO_TEXT
         and model.t2u_model is not None
     ):
         model.t2u_model = None
+    
     if model.text_encoder is not None:
         model.text_encoder = None
+    
+    # Put model on selected device
     model = model.to(finetune_params.device)
-    logger.info(f"<{args.model_name}> {model}")
 
+    # TODO: delete unused params to reduce GPU memory consumption
     train_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
         unit_tokenizer=unit_tokenizer,
@@ -174,7 +193,8 @@ def main() -> None:
             float_dtype=finetune_params.float_dtype,
         ),
         dataset_manifest_path=args.train_dataset,
-    )
+        max_src_tokens_per_batch=args.max_src_tokens)
+    
     eval_dataloader = dataloader.UnitYDataLoader(
         text_tokenizer=text_tokenizer,
         unit_tokenizer=unit_tokenizer,
@@ -182,17 +202,18 @@ def main() -> None:
             batch_size=finetune_params.eval_batch_size,
             rank=dist_utils.get_rank(),
             world_size=dist_utils.get_world_size(),
-            max_audio_length_sec=100.0,
+            max_audio_length_sec=75.0,
             float_dtype=finetune_params.float_dtype,
         ),
-        dataset_manifest_path=args.eval_dataset,
-    )
+        dataset_manifest_path=args.eval_dataset)
+    
     finetune = trainer.UnitYFinetune(
         model=model,
         params=finetune_params,
         train_data_loader=train_dataloader,
         eval_data_loader=eval_dataloader,
-    )
+        freeze_modules=args.freeze_layers)
+    
     finetune.run()
 
 

+ 76 - 40
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -6,12 +6,13 @@
 
 
 import logging
+import time
 from contextlib import contextmanager
 from dataclasses import dataclass
 from enum import Enum
 from tqdm import tqdm
 from pathlib import Path
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.distributed as dist
@@ -40,6 +41,9 @@ class FinetuneMode(Enum):
 
 @dataclass
 class FinetuneParams:
+    model_name: str
+    """Model name of model being finetuned."""
+    
     save_model_path: Path
     """Path were to save finetuned model."""
 
@@ -245,6 +249,7 @@ class UnitYFinetune:
         params: FinetuneParams,
         train_data_loader: dataloader.UnitYDataLoader,
         eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
+        freeze_modules: Optional[List[Union[str, torch.nn.Module]]] = None
     ):
         self.params = params
         self.calc_loss = CalcLoss(
@@ -254,9 +259,15 @@ class UnitYFinetune:
             if model.t2u_model is not None
             else None,
         )
+        
         self.model = self._wrap_model_for_trainining(model=model)
+        if freeze_modules:
+            self._freeze_modules(freeze_modules)
+        
         self.train_data_loader = train_data_loader
         self.eval_data_loader = eval_data_loader
+        
+        self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.optimizer = AdamW(
             params=self.model.parameters(),
             lr=self.params.learning_rate,
@@ -266,7 +277,6 @@ class UnitYFinetune:
             weight_decay=0.0,
             fused=(self.params.device.type == "cuda"),
         )
-        self.grad_scaler = torch.cuda.amp.GradScaler()  # type: ignore
         self.lr_scheduler = MyleLR(
             optimizer=self.optimizer,
             num_warmup_steps=self.params.warmup_steps,
@@ -301,6 +311,14 @@ class UnitYFinetune:
             device_ids=[dist_utils.get_local_rank()],
             find_unused_parameters=find_unused,
         )
+        
+    def _freeze_modules(self, frozen_modules: List[str] = []) -> None:
+        for icecube in frozen_modules:
+            for (name, module) in self.model.named_modules():
+                if name.startswith(icecube):
+                    logger.info(f"Freezing Module: {name}")
+                    for param in module.parameters():
+                        param.requires_grad = False
 
     def _update_eval_stats(self, eval_loss: float) -> None:
         self.is_best_state = (
@@ -317,25 +335,26 @@ class UnitYFinetune:
             f"patience_steps_left={self.patience_left}"
         )
 
-    def _eval_model(self) -> None:
+    @torch.no_grad()
+    def _eval_model(self, n_batches: int) -> None:
         """Calc avg loss on eval dataset and update evaluation stats"""
         if self.eval_data_loader is None:
             return
-        logger.info("Run evaluation")
+        logger.info(f"Evaluation Step {self.update_idx // self.params.eval_steps}...")
         loss_hist = LossCollector(device=self.params.device)
         self.model.eval()
-        with torch.no_grad():
-            for batch in tqdm(self.eval_data_loader.get_dataloader()):
-                assert batch.speech_to_text.src_tokens is not None
-                with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
-                    loss = self.calc_loss(batch, *self.model(batch))
-                if loss.isnan():
-                    logger.warning("Eval loss value is NaN, setting to inf")
-                    loss_val = float("Inf")
-                else:
-                    loss_val = loss.item()
-                del batch  # force memory release
-                loss_hist.update(1, loss_val)
+        for batch in self.eval_data_loader.get_dataloader():
+            if n_batches == 0:
+                break
+            assert batch.speech_to_text.src_tokens is not None
+            with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
+                loss = self.calc_loss(batch, *self.model(batch))
+            if loss.isnan():
+                logger.warning("Eval batch loss value is NaN, skipping")
+                continue
+            del batch  # force memory release
+            loss_hist.update(1, loss.item())
+            n_batches -= 1
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
 
@@ -351,53 +370,70 @@ class UnitYFinetune:
                 f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
             )
 
-    def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
+    def _train_step(self, batch: List[dataloader.MultimodalSeqsBatch]) -> None:
         """Run one train step"""
         self.model.train()
         self.optimizer.zero_grad()
         with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
             tokens, units = self.model(batch)
+        
         loss = self.calc_loss(batch, tokens, units)
         if loss.isnan().any().item():
             logger.error(batch.speech_to_text)
-            raise RuntimeError("Loss is Nan. Terminating.")
+            raise RuntimeError("Train loss is NaN! Something is wrong in the model!")
+        
         self.grad_scaler.scale(loss).backward()
         self.grad_scaler.step(self.optimizer)
         self.grad_scaler.update()
         self.lr_scheduler.step()
+        
         assert batch.speech_to_text.src_tokens is not None
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
+        self.update_idx += 1
 
     def _save_model(self) -> None:
         logger.info("Saving model")
         if dist_utils.is_main_process():
-            state_dict = {
-                key.replace("module.model.", ""): value
-                for key, value in self.model.state_dict().items()
-            }
-            torch.save(state_dict, self.params.save_model_path)
+            torch.save({
+                "model_name": self.params.model_name,
+                "model": {
+                    key.replace("module.model.model.", ""): value
+                    for key, value in self.model.state_dict().items()
+                }
+            }, self.params.save_model_path)
         if dist_utils.is_dist_initialized():
             dist.barrier()
 
     def run(self) -> None:
-        logger.info("Start finetuning")
+        logger.info("Start Finetuning")
         self._reset_stats()
         self._eval_model()
-        batch_itr = self.train_data_loader.get_dataloader()
+        
+        train_dataloader = self.train_data_loader.get_dataloader()
+        
         while self.epoch_idx < self.params.max_epochs and self.patience_left:
-            for train_batch in batch_itr:
-                self._train_step(batch=train_batch)
-                if self.update_idx and self.update_idx % self.params.eval_steps == 0:
-                    self._eval_model()
-                    if self.is_best_state:
-                        self._save_model()
-                    elif not self.patience_left:
-                        no_improve_steps = self.params.eval_steps * self.params.patience
-                        logger.info(
-                            "Early termination, as eval loss did not improve "
-                            f"over last {no_improve_steps} updates"
-                        )
-                        break
-                self.update_idx += 1
-            self.epoch_idx += 1
+            for train_batch in tqdm(train_dataloader, desc="Training Steps"):
+                # Run batch through train step
+                self._train_step(train_batch)
+                
+                # Perform eval if its time to eval
+                if not self.update_idx or self.update_idx % self.params.eval_steps != 0:
+                    continue
+                
+                # Clear GPU memory for eval
+                torch.cuda.empty_cache()
+                self._eval_model(n_batches=100)
+                    
+                # Save the current model if its the best we've ever had
+                if self.is_best_state:
+                    self._save_model()
+                elif not self.patience_left:
+                    no_improve_steps = self.params.eval_steps * self.params.patience
+                    logger.info(
+                        "Early termination, as eval loss did not improve "
+                        f"over last {no_improve_steps} updates"
+                    )
+                    break
+                
+            self.epoch_idx += 1

+ 108 - 2
src/seamless_communication/datasets/huggingface.py

@@ -28,7 +28,7 @@ class SpeechTokenizer:
 class Speech2SpeechFleursDatasetBuilder:
     """Assembles speech2speech dataset from google/fleurs on HuggingFace"""
 
-    HF_FLEURS_DATASET_NAME = "google/fleurs"
+    DATASET_NAME = "google/fleurs"
 
     def __init__(
         self,
@@ -91,7 +91,113 @@ class Speech2SpeechFleursDatasetBuilder:
 
     def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
         ds = load_dataset(
-            self.HF_FLEURS_DATASET_NAME,
+            self.DATASET_NAME,
+            lang,
+            split=self.split,
+            cache_dir=self.dataset_cache_dir,
+            streaming=False,
+            trust_remote_code=True,
+        )
+        for item in ds:
+            audio_path = os.path.join(
+                os.path.dirname(item["path"]), item["audio"]["path"]
+            )
+            (sample_id, audio_local_path, waveform, sampling_rate, text) = (
+                item["id"],
+                audio_path,
+                item["audio"]["array"],
+                item["audio"]["sampling_rate"],
+                item["transcription"],
+            )
+            yield self._prepare_sample(
+                sample_id=sample_id,
+                audio_local_path=audio_local_path,
+                waveform_npy=waveform,
+                sampling_rate=sampling_rate,
+                text=text,
+                lang=lang,
+            )
+
+    def __iter__(self) -> Iterable[LangPairSample]:
+        logger.info(f"Loading {self.target_lang} samples")
+        target_samples: Dict[int, MultimodalSample] = {}
+        for idx, sample in enumerate(
+            self.iterate_lang_audio_samples(lang=self.target_lang)
+        ):
+            if idx and idx % 100 == 0:
+                logger.info(f"..loaded {idx} target samples")
+            target_samples[sample.id] = sample
+
+        logger.info(f"Loading {self.source_lang} samples")
+        for idx, sample in enumerate(
+            self.iterate_lang_audio_samples(lang=self.source_lang)
+        ):
+            if idx and idx % 100 == 0:
+                logger.info(f"..loaded {idx} source samples")
+            if sample.id in target_samples:
+                yield LangPairSample(source=sample, target=target_samples[sample.id])
+
+
+class Speech2TextGigaspeechDatasetBuilder:
+    """ Assembles speech2speech dataset from google/fleurs on HuggingFace.
+        This dataset requires signing an license agreement and using an auth token.
+    """
+
+    DATASET_NAME = "speechcolab/gigaspeech"
+
+    def __init__(
+        self,
+        auth_token: str,
+        split: str = "test",
+        skip_source_audio: bool = True,
+        skip_target_audio: bool = True,
+        audio_dtype: torch.dtype = torch.float32,
+        dataset_cache_dir: Optional[str] = None,
+        speech_tokenizer: Optional[SpeechTokenizer] = None,
+    ):
+        self.auth_token = auth_token
+        self.split = split
+        self.dataset_cache_dir = dataset_cache_dir
+        self.audio_dtype = audio_dtype
+        self.skip_source_audio = skip_source_audio
+        self.skip_target_audio = skip_target_audio
+        self.speech_tokenizer = speech_tokenizer
+
+    def _prepare_sample(
+        self,
+        sample_id: int,
+        lang: str,
+        text: str,
+        audio_local_path: Optional[str] = None,
+        waveform_npy: Optional[np.ndarray] = None,
+        sampling_rate: Optional[int] = None,
+    ) -> MultimodalSample:
+        if waveform_npy is not None:
+            waveform = torch.from_numpy(waveform_npy).to(self.audio_dtype)
+        else:
+            waveform = None
+        if self.speech_tokenizer is not None and waveform_npy is not None:
+            assert waveform is not None
+            assert sampling_rate is not None
+            units_tensor = self.speech_tokenizer.encode(
+                waveform, sampling_rate
+            ).reshape(-1)
+            units = units_tensor.tolist()
+        else:
+            units = None
+        return MultimodalSample(
+            id=sample_id,
+            lang=lang,
+            text=text.strip(),
+            audio_local_path=audio_local_path,
+            waveform=waveform,
+            sampling_rate=sampling_rate,
+            units=units,
+        )
+
+    def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
+        ds = load_dataset(
+            self.DATASET_NAME,
             lang,
             split=self.split,
             cache_dir=self.dataset_cache_dir,