Просмотр исходного кода

test m4t finetuning on gigaspeech

Ruslan Mavlyutov 1 год назад
Родитель
Сommit
09a7d5be2e

+ 23 - 9
src/seamless_communication/cli/m4t/finetune/dataset.py

@@ -16,6 +16,7 @@ import torch
 
 from seamless_communication.datasets.huggingface import (
     Speech2SpeechFleursDatasetBuilder,
+    Speech2SpeechGigaSpeechDatasetBuilder,
     SpeechTokenizer,
 )
 from seamless_communication.models.unit_extractor import UnitExtractor
@@ -123,6 +124,7 @@ def download_fleurs_dataset(
     target_lang: str,
     split: str,
     save_directory: str,
+    max_samples: int = 100_000,
 ) -> str:
     _check_lang_code_mapping(source_lang)
     _check_lang_code_mapping(target_lang)
@@ -130,18 +132,23 @@ def download_fleurs_dataset(
         torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
     )
     tokenizer = UnitSpeechTokenizer(device=device)
-    dataset_iterator = Speech2SpeechFleursDatasetBuilder(
-        source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
-        target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
-        dataset_cache_dir=save_directory,
-        speech_tokenizer=tokenizer,
-        skip_source_audio=True,  # don't extract units from source audio
-        skip_target_audio=False,
-        split=split,
-    )
+    if 1:
+        dataset_iterator = Speech2SpeechGigaSpeechDatasetBuilder(split=split, dataset_cache_dir=save_directory)
+    else:
+        dataset_iterator = Speech2SpeechFleursDatasetBuilder(
+            source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
+            target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
+            dataset_cache_dir=save_directory,
+            speech_tokenizer=tokenizer,
+            skip_source_audio=True,  # don't extract units from source audio
+            skip_target_audio=False,
+            split=split,
+        )
     manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
     with open(manifest_path, "w") as fp_out:
         for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
+            if idx >= max_samples:
+                break
             # correction as FleursDatasetBuilder return fleurs lang codes
             sample.source.lang = source_lang
             sample.target.lang = target_lang
@@ -183,6 +190,12 @@ def init_parser() -> argparse.ArgumentParser:
         required=True,
         help="Directory where the datastets will be stored with HuggingFace datasets cache files",
     )
+    parser.add_argument(
+        "--max_samples",
+        type=int,
+        default=100_000,
+        help="Max samples to use",
+    )
     return parser
 
 
@@ -193,6 +206,7 @@ def main() -> None:
         target_lang=args.target_lang,
         split=args.split,
         save_directory=args.save_dir,
+        max_samples=args.max_samples,
     )
     logger.info(f"Manifest saved to: {manifest_path}")
 

+ 3 - 1
src/seamless_communication/cli/m4t/finetune/finetune.py

@@ -133,11 +133,13 @@ def main() -> None:
     dist_utils.init_distributed([logger, trainer.logger])
     text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
     unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
+    float_dtype = torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16
+    logger.info(f"Training precision: {float_dtype}")
     finetune_params = trainer.FinetuneParams(
         finetune_mode=args.mode,
         save_model_path=args.save_model_to,
         device=torch.device(args.device),
-        float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
+        float_dtype=float_dtype,
         train_batch_size=args.batch_size,
         eval_batch_size=args.batch_size,
         patience=args.patience,

+ 86 - 0
src/seamless_communication/cli/m4t/finetune/mini_eval.py

@@ -0,0 +1,86 @@
+import torch
+from datasets import load_dataset
+from jiwer import wer
+import os
+from typing import Tuple, Iterable, Dict, Any
+import logging
+
+logging.basicConfig(level=logging.INFO)
+
+from seamless_communication.models.unity import UnitYModel
+from seamless_communication.inference import Translator
+
+log = logging.getLogger("l")
+
+TOKEN = "<YOU HF TOKEN HERE>"
+MAX_SAMPLES = 100
+CHCK_PATH = os.path.expanduser("~/tune_chck/chck.pt")
+
+
+def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
+    ds = load_dataset(
+        "speechcolab/gigaspeech",
+        "xs",
+        token=os.environ.get("HF_TOKEN", TOKEN),
+        split="test",
+        streaming=True,
+        trust_remote_code=True,
+    )
+    for idx, item in enumerate(ds):
+        if idx >= MAX_SAMPLES:
+            break
+        assert item["audio"]["sampling_rate"] == 16000
+        yield (torch.from_numpy(item["audio"]["array"]), item["text"])
+
+
+def _eval(translator: Translator) -> float:
+    references = []
+    predictions = []
+    for idx, (wav, text) in enumerate(_iterate_test_ds()):
+        references.append(text)
+        prediction = str(
+            translator.predict(
+                input=wav,
+                task_str="s2tt",
+                tgt_lang="eng",
+                src_lang="eng",
+            )[0][0]
+        )
+        log.info(idx)
+        log.info(f"REF: {text}")
+        log.info(f"PRE: {prediction}")
+        log.info("----")
+        predictions.append(prediction)
+    return wer(reference=references, hypothesis=predictions)
+
+
+def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
+    return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}
+
+
+def load_checkpoint(model: UnitYModel, chck_path: str) -> None:
+    state_dict = torch.load(chck_path, map_location="cpu")
+    model.speech_encoder_frontend.load_state_dict(_select_keys(state_dict, "model.speech_encoder_frontend."))
+    model.speech_encoder.load_state_dict(_select_keys(state_dict, "model.speech_encoder."))
+    assert model.text_decoder_frontend is not None
+    model.text_decoder_frontend.load_state_dict(_select_keys(state_dict, "model.text_decoder_frontend."))
+    assert model.text_decoder is not None
+    model.text_decoder.load_state_dict(_select_keys(state_dict, "model.text_decoder."))
+
+
+def main() -> None:
+    translator = Translator(
+        model_name_or_card="seamlessM4T_medium",
+        vocoder_name_or_card=None,
+        device=torch.device("cuda"),
+    )
+    non_tuned_wer = _eval(translator)
+
+    load_checkpoint(translator.model, CHCK_PATH)
+    tuned_wer = _eval(translator)
+
+    log.info(f"WER non-tuned: {non_tuned_wer:.3f}")
+    log.info(f"WER tuned: {tuned_wer:.3f}")
+
+if __name__ == "__main__":
+    main()

+ 40 - 24
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -11,7 +11,7 @@ from dataclasses import dataclass
 from enum import Enum
 from tqdm import tqdm
 from pathlib import Path
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
 import torch
 import torch.distributed as dist
@@ -21,7 +21,7 @@ from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
-from torch.optim import AdamW
+from torch.optim import AdamW, Adam
 
 from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
 from seamless_communication.models.unity import (
@@ -88,11 +88,17 @@ class UnitYFinetuneWrapper(nn.Module):
     def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
         super().__init__()
         self.model: UnitYModel = model
+        #self._freeze_module(self.model.speech_encoder_frontend)
+        #self._freeze_module(self.model.speech_encoder)
         self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
         self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
         logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}")
         self.device = device
 
+    def _freeze_module(self, module: torch.nn.Module) -> None:
+        for param in module.parameters():
+            param.requires_grad = False
+
     def forward(
         self, batch: dataloader.MultimodalSeqsBatch
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -329,12 +335,11 @@ class UnitYFinetune:
                 assert batch.speech_to_text.src_tokens is not None
                 with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
                     loss = self.calc_loss(batch, *self.model(batch))
-                if loss.isnan():
-                    logger.warning("Eval loss value is NaN, setting to inf")
-                    loss_val = float("Inf")
-                else:
-                    loss_val = loss.item()
                 del batch  # force memory release
+                if loss.isnan():
+                    logger.warning(".. batch loss value is NaN, skipping")
+                    continue
+                loss_val = loss.item()
                 loss_hist.update(1, loss_val)
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
@@ -351,13 +356,18 @@ class UnitYFinetune:
                 f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
             )
 
-    def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
+    def _train_step(self, batches: List[dataloader.MultimodalSeqsBatch]) -> None:
         """Run one train step"""
         self.model.train()
         self.optimizer.zero_grad()
-        with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
-            tokens, units = self.model(batch)
-        loss = self.calc_loss(batch, tokens, units)
+        # logger.info(f"forward start {torch.cuda.memory_allocated(0) >> 30}g")
+        losses = []
+        for batch in batches:
+            with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype):
+                tokens, units = self.model(batch)
+            # logger.info(f"forward done {torch.cuda.memory_allocated(0) >> 30}g")
+            losses.append(self.calc_loss(batch, tokens, units))
+        loss = sum(losses) / len(losses)
         if loss.isnan().any().item():
             logger.error(batch.speech_to_text)
             raise RuntimeError("Loss is Nan. Terminating.")
@@ -365,6 +375,7 @@ class UnitYFinetune:
         self.grad_scaler.step(self.optimizer)
         self.grad_scaler.update()
         self.lr_scheduler.step()
+        # logger.info(f"backward done {torch.cuda.memory_allocated(0) >> 30}g")
         assert batch.speech_to_text.src_tokens is not None
         self.train_loss_hist.update(1, loss.item())
         self._train_step_log()
@@ -385,19 +396,24 @@ class UnitYFinetune:
         self._reset_stats()
         self._eval_model()
         batch_itr = self.train_data_loader.get_dataloader()
+        batches_per_iter = 1
         while self.epoch_idx < self.params.max_epochs and self.patience_left:
+            train_batches = []
             for train_batch in batch_itr:
-                self._train_step(batch=train_batch)
-                if self.update_idx and self.update_idx % self.params.eval_steps == 0:
-                    self._eval_model()
-                    if self.is_best_state:
-                        self._save_model()
-                    elif not self.patience_left:
-                        no_improve_steps = self.params.eval_steps * self.params.patience
-                        logger.info(
-                            "Early termination, as eval loss did not improve "
-                            f"over last {no_improve_steps} updates"
-                        )
-                        break
-                self.update_idx += 1
+                train_batches.append(train_batch)
+                if len(train_batches) > batches_per_iter:
+                    self._train_step(batches=train_batches)
+                    train_batches = []
+                    if self.update_idx and self.update_idx % self.params.eval_steps == 0:
+                        self._eval_model()
+                        if self.is_best_state:
+                            self._save_model()
+                        elif not self.patience_left:
+                            no_improve_steps = self.params.eval_steps * self.params.patience
+                            logger.info(
+                                "Early termination, as eval loss did not improve "
+                                f"over last {no_improve_steps} updates"
+                            )
+                            break
+                    self.update_idx += 1
             self.epoch_idx += 1