123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- from contextlib import contextmanager
- from dataclasses import dataclass
- from enum import Enum
- from pathlib import Path
- from typing import Optional, Tuple
- import dataloader
- import dist_utils
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from fairseq2.models.sequence import SequenceModelOutput
- from fairseq2.models.unity import UnitYModel
- from fairseq2.optim.lr_scheduler import MyleLR
- from fairseq2.typing import Device
- from torch.optim import Adam
- logger = logging.getLogger(__name__)
- class FinetuneMode(Enum):
- SPEECH_TO_SPEECH = "SPEECH_TO_SPEECH"
- SPEECH_TO_TEXT = "SPEECH_TO_TEXT"
- TEXT_TO_SPEECH = "TEXT_TO_SPEECH"
- @dataclass
- class FinetuneParams:
- save_model_path: Path
- """Path were to save finetuned model."""
- finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH
- """Allows to freeze S2T or T2U part of the model"""
- max_epochs: int = 10
- """ Maximum number of trainign epochs"""
- label_smoothing: float = 0.2
- """ Label smoothing coefficient for nll_loss """
- warmup_steps: int = 100
- """ Number of steps with linearly increasing LR"""
- log_steps: int = 10
- """ Log inner loss after each `log_steps` training steps"""
- eval_steps: int = 50
- """ Get eval loss after each `eval_steps` training steps """
- patience: int = 3
- """ Terminate if eval loss did not improve
- over the last `patience * eval_steps` training steps"""
- learning_rate: float = 1e-5
- """ Optimizer learining rate """
- train_batch_size: int = 5
- """The batch size during train steps"""
- eval_batch_size: int = 5
- """The batch size during evaluation."""
- device: Device = torch.device("cuda")
- """ Where to run computation"""
- class UnitYFinetuneWrapper(nn.Module):
- """Convenience wrapper that does a forward pass
- and returns S2T and T2U logits"""
- def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
- super().__init__()
- assert model.t2u_model is not None
- self.model: UnitYModel = model
- self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
- self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
- self.device = device
- def forward(
- self, batch: dataloader.MultimodalSeqsBatch
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- assert self.model.t2u_model is not None
- dummy_context = contextmanager(lambda: iter([None]))()
- with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
- assert batch.speech_to_text.src_tokens is not None
- speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
- seqs=batch.speech_to_text.src_tokens.to(self.device),
- seq_lens=batch.speech_to_text.src_lengths.to(self.device),
- )
- assert batch.speech_to_text.prev_output_tokens is not None
- text_decoder_out, text_decoder_padding_mask = self.model.decode(
- seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
- seq_lens=batch.speech_to_text.target_lengths.to(self.device),
- encoder_output=speech_encoder_out,
- encoder_padding_mask=speech_encoder_padding_mask,
- )
- text_logits = self.model.final_proj(text_decoder_out)
- if batch.text_to_units.prev_output_tokens is None:
- return (text_logits, None)
- dummy_context = contextmanager(lambda: iter([None]))()
- with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
- (
- unit_encoder_out,
- unit_encoder_padding_mask,
- ) = self.model.t2u_model.encode(
- text_decoder_output=text_decoder_out,
- text_decoder_padding_mask=text_decoder_padding_mask,
- )
- unit_decoder_out, _ = self.model.t2u_model.decode(
- seqs=batch.text_to_units.prev_output_tokens.to(self.device),
- seq_lens=batch.text_to_units.target_lengths.to(self.device),
- encoder_output=unit_encoder_out,
- encoder_padding_mask=unit_encoder_padding_mask,
- )
- unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
- return (text_logits, unit_logits)
- class CalcLoss:
- """Calculates negative log likelihood loss for S2T and T2U"""
- def __init__(
- self,
- label_smoothing: float,
- s2t_pad_idx: Optional[int],
- t2u_pad_idx: Optional[int],
- ):
- self.label_smoothing = label_smoothing
- self.s2t_pad_idx = s2t_pad_idx
- self.t2u_pad_idx = t2u_pad_idx
- def __call__(
- self,
- batch: dataloader.MultimodalSeqsBatch,
- text_logits: torch.Tensor,
- unit_logits: Optional[torch.Tensor],
- ) -> torch.Tensor:
- assert batch.speech_to_text.target_lengths is not None
- s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
- text_logits.device
- )
- s2t_loss = SequenceModelOutput(
- logits=text_logits, pad_idx=self.s2t_pad_idx
- ).compute_loss(
- targets=batch.speech_to_text.target_tokens.to(text_logits.device),
- ignore_prefix_size=1,
- label_smoothing=self.label_smoothing,
- )
- if unit_logits is None:
- return s2t_loss / s2t_numel
- assert batch.text_to_units.target_lengths is not None
- s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
- s2u_loss = SequenceModelOutput(
- logits=unit_logits, pad_idx=self.t2u_pad_idx
- ).compute_loss(
- targets=batch.text_to_units.target_tokens.to(unit_logits.device),
- ignore_prefix_size=1,
- label_smoothing=self.label_smoothing,
- )
- return s2t_loss / s2t_numel + s2u_loss / s2u_numel
- class LossCollector:
- """Aggregrates loss history across nodes"""
- def __init__(self, device: Optional[Device] = None, reduce_op: str = "avg"):
- self.n_samples: float = 0
- self.val_sum: float = 0.0
- self.reduce_op = reduce_op
- self.device = device
- self.is_distributed = dist_utils.is_dist_initialized()
- def reset(self) -> None:
- self.n_samples = 0
- self.val_sum = 0.0
- def update(self, n_samples: int, batch_loss: float) -> None:
- self.n_samples += n_samples
- self.val_sum += batch_loss
- def reduce(self) -> float:
- n_samples, val_sum = self._collect()
- if self.reduce_op == "avg":
- return val_sum / (n_samples + 1)
- if self.reduce_op == "sum":
- return val_sum
- raise ValueError()
- def _collect(self) -> Tuple[float, float]:
- if not self.is_distributed:
- return self.n_samples, self.val_sum
- local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
- all_vals = [
- torch.zeros((1, 2), device=self.device)
- for _ in range(dist_utils.get_world_size())
- ]
- dist.all_gather(all_vals, local_val)
- losses = torch.concat(all_vals, dim=0)
- reduced = torch.sum(losses, dim=0).reshape(2).cpu()
- return reduced[0].item(), reduced[1].item()
- class UnitYFinetune:
- def __init__(
- self,
- model: UnitYModel,
- params: FinetuneParams,
- train_data_loader: dataloader.UnitYDataLoader,
- eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
- ):
- self.params = params
- assert model.t2u_model is not None
- self.calc_loss = CalcLoss(
- label_smoothing=self.params.label_smoothing,
- s2t_pad_idx=model.pad_idx,
- t2u_pad_idx=model.t2u_model.pad_idx,
- )
- self.model = self._wrap_model_for_trainining(model=model)
- self.train_data_loader = train_data_loader
- self.eval_data_loader = eval_data_loader
- self.optimizer = Adam(
- params=self.model.parameters(),
- lr=self.params.learning_rate,
- betas=(0.9, 0.98),
- eps=1e-08,
- maximize=False,
- weight_decay=0.0,
- fused=True,
- )
- self.grad_scaler = torch.cuda.amp.GradScaler()
- self.lr_scheduler = MyleLR(
- optimizer=self.optimizer,
- num_warmup_steps=self.params.warmup_steps,
- start_lr=1e-9,
- )
- self.train_loss_hist = LossCollector(device=params.device)
- self.epoch_idx: int = 0
- self.update_idx: int = 0
- self.patience_left: int = self.params.patience
- self.best_eval_loss: Optional[float] = None
- self.is_best_state: bool = False
- def _reset_stats(self) -> None:
- self.train_loss_hist.reset()
- self.epoch_idx = 0
- self.update_idx = 0
- self.patience_left = self.params.patience
- self.best_eval_loss = None
- self.is_best_state = False
- def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
- wrapped_model = UnitYFinetuneWrapper(
- model=model, mode=self.params.finetune_mode, device=self.params.device
- )
- if not dist_utils.is_dist_initialized():
- return wrapped_model
- return nn.parallel.DistributedDataParallel(
- wrapped_model,
- device_ids=[dist_utils.get_local_rank()],
- find_unused_parameters=True,
- )
- def _update_eval_stats(self, eval_loss: float) -> None:
- self.is_best_state = (
- self.best_eval_loss is None or eval_loss < self.best_eval_loss
- )
- self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
- self.patience_left = (
- self.params.patience if self.is_best_state else self.patience_left - 1
- )
- logger.info(
- f"Eval after {self.update_idx} updates: "
- f"loss={eval_loss:.4f} "
- f"best_loss={self.best_eval_loss:.4f} "
- f"patience_steps_left={self.patience_left}"
- )
- def _eval_model(self) -> None:
- """Calc avg loss on eval dataset and update evaluation stats"""
- if self.eval_data_loader is None:
- return
- logger.info("Run evaluation")
- loss_hist = LossCollector(device=self.params.device)
- self.model.eval()
- with torch.no_grad():
- for batch in self.eval_data_loader.get_dataloader():
- assert batch.speech_to_text.src_tokens is not None
- loss = self.calc_loss(batch, *self.model(batch))
- if loss.isnan():
- logger.warning("Eval loss value is NaN, setting to inf")
- loss_val = float("Inf")
- else:
- loss_val = loss.item()
- del batch # force memory release
- loss_hist.update(1, loss_val)
- eval_loss = loss_hist.reduce()
- self._update_eval_stats(eval_loss)
- def _train_step_log(self):
- """Log train stats"""
- if (self.update_idx + 1) % self.params.log_steps == 0:
- avg_loss = self.train_loss_hist.reduce()
- self.train_loss_hist.reset()
- logger.info(
- f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
- f"update {str(self.update_idx + 1).zfill(5)}: "
- f"train loss={avg_loss:.4f} "
- f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
- )
- def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
- """Run one train step"""
- self.model.train()
- self.optimizer.zero_grad()
- tokens, units = self.model(batch)
- loss = self.calc_loss(batch, tokens, units)
- self.grad_scaler.scale(loss).backward()
- self.grad_scaler.step(self.optimizer)
- self.grad_scaler.update()
- self.lr_scheduler.step()
- assert batch.speech_to_text.src_tokens is not None
- self.train_loss_hist.update(1, loss.item())
- self._train_step_log()
- def _save_model(self):
- logger.info("Saving model")
- if dist_utils.is_main_process():
- state_dict = {
- key.replace("module.model.", ""): value
- for key, value in self.model.state_dict().items()
- }
- torch.save(state_dict, self.params.save_model_path)
- if dist_utils.is_dist_initialized():
- dist.barrier()
- def run(self):
- logger.info("Start finetuning")
- self._reset_stats()
- self._eval_model()
- batch_itr = self.train_data_loader.get_dataloader()
- while self.epoch_idx < self.params.max_epochs and self.patience_left:
- for train_batch in batch_itr:
- self._train_step(batch=train_batch)
- if self.update_idx and self.update_idx % self.params.eval_steps == 0:
- self._eval_model()
- if self.is_best_state:
- self._save_model()
- elif not self.patience_left:
- no_improve_steps = self.params.eval_steps * self.params.patience
- logger.info(
- "Early termination, as eval loss did not improve "
- f"over last {no_improve_steps} updates"
- )
- break
- self.update_idx += 1
- self.epoch_idx += 1
|