123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- from typing import Any, Optional, Tuple, Dict, List
- import os
- import time
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from fairseq2.models.sequence import SequenceModelOutput
- from fairseq2.optim.lr_scheduler import MyleLR
- from m4t_scripts.train import dataloader, dist_utils
- from torch.optim import Adam
- from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
- from m4t_scripts.train.configs import TrainingParams
- logger = logging.getLogger(__name__)
- class UnitYTrainWrapper(nn.Module):
- """Convenience wrapper that does a forward pass
- and returns S2T and T2U logits"""
- def __init__(self, model: UnitYModel):
- super().__init__()
- self.model: UnitYModel = model
- if isinstance(self.model.t2u_model, UnitYT2UModel):
- self.t2u: UnitYT2UModel = self.model.t2u_model
- else:
- raise NotImplementedError("Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u")
- def forward(self, batch: dataloader.MultimodalSeqsBatch) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward pass, computes S2T and T2U losses"""
- assert self.model.t2u_model is not None
- assert batch.speech_to_text.src_tokens is not None
- # s2t
- speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
- seqs=batch.speech_to_text.src_tokens,
- seq_lens=batch.speech_to_text.src_lengths,
- )
- assert batch.speech_to_text.prev_output_tokens is not None
- text_decoder_out, text_decoder_padding_mask = self.model.decode(
- seqs=batch.speech_to_text.prev_output_tokens,
- seq_lens=batch.speech_to_text.target_lengths,
- encoder_output=speech_encoder_out,
- encoder_padding_mask=speech_encoder_padding_mask,
- )
- text_logits = self.model.final_proj(text_decoder_out)
- # t2u
- (
- unit_encoder_out,
- unit_encoder_padding_mask,
- ) = self.t2u.encode(
- text_decoder_output=text_decoder_out,
- text_decoder_padding_mask=text_decoder_padding_mask,
- )
- unit_decoder_out, _ = self.t2u.decode(
- seqs=batch.text_to_units.prev_output_tokens,
- seq_lens=batch.text_to_units.target_lengths,
- encoder_output=unit_encoder_out,
- encoder_padding_mask=unit_encoder_padding_mask,
- )
- unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
- return (text_logits, unit_logits)
- class CalcLoss:
- """Calculates per-token negative log likelihood loss for S2T and T2U"""
- def __init__(
- self,
- label_smoothing: float,
- s2t_pad_idx: Optional[int],
- t2u_pad_idx: Optional[int],
- s2t_skip_langtok_loss: bool = False,
- ):
- self.label_smoothing = label_smoothing
- self.s2t_pad_idx = s2t_pad_idx
- self.t2u_pad_idx = t2u_pad_idx
- self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
- self.t2u_ignore_prefix_size = 1
- def __call__(
- self,
- batch: dataloader.MultimodalSeqsBatch,
- text_logits: torch.Tensor,
- unit_logits: torch.Tensor,
- ) -> torch.Tensor:
- assert batch.speech_to_text.target_lengths is not None
- s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(text_logits.device)
- s2t_loss = SequenceModelOutput(logits=text_logits, pad_idx=self.s2t_pad_idx).compute_loss(
- targets=batch.speech_to_text.target_tokens.to(text_logits.device),
- ignore_prefix_size=self.s2t_ignore_prefix_size,
- label_smoothing=self.label_smoothing,
- )
- assert batch.text_to_units.target_lengths is not None
- s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
- s2u_loss = SequenceModelOutput(logits=unit_logits, pad_idx=self.t2u_pad_idx).compute_loss(
- targets=batch.text_to_units.target_tokens.to(unit_logits.device),
- ignore_prefix_size=1,
- label_smoothing=self.label_smoothing,
- )
- return s2t_loss / s2t_numel + s2u_loss / s2u_numel
- class LossCollector:
- """Aggregrates loss history across nodes"""
- def __init__(self, device: Optional[torch.device] = None, reduce_op: str = "avg"):
- self.n_samples: float = 0
- self.val_sum: float = 0.0
- self.reduce_op = reduce_op
- self.device = device
- self.is_distributed = dist_utils.is_dist_initialized()
- def reset(self) -> None:
- self.n_samples = 0
- self.val_sum = 0.0
- def update(self, n_samples: int, batch_loss: float) -> None:
- self.n_samples += n_samples
- self.val_sum += batch_loss
- def reduce(self) -> float:
- n_samples, val_sum = self._collect()
- if self.reduce_op == "avg":
- return val_sum / (n_samples + 1)
- if self.reduce_op == "sum":
- return val_sum
- raise ValueError()
- def _collect(self) -> Tuple[float, float]:
- if not self.is_distributed:
- return self.n_samples, self.val_sum
- local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
- all_vals = [torch.zeros((1, 2), device=self.device) for _ in range(dist_utils.get_world_size())]
- dist.all_gather(all_vals, local_val)
- losses = torch.concat(all_vals, dim=0)
- reduced = torch.sum(losses, dim=0).reshape(2).cpu()
- return reduced[0].item(), reduced[1].item()
- class UnitYTrainer:
- CHECKPOINT_BEST = "checkpoint_best.pt"
- def __init__(
- self,
- model: UnitYModel,
- params: TrainingParams,
- train_data_loader: dataloader.UnityDataLoader,
- eval_data_loader: Optional[dataloader.UnityDataLoader],
- chck_save_dir: str,
- device: torch.device,
- ):
- self.params = params
- self.device = device
- self.float_dtype = self._get_float_dtype(self.params.float_dtype)
- self.train_data_loader = train_data_loader
- self.eval_data_loader = eval_data_loader
- self.chck_save_dir = chck_save_dir
- assert model.t2u_model is not None
- self.calc_loss = CalcLoss(
- label_smoothing=self.params.label_smoothing,
- s2t_pad_idx=model.pad_idx,
- t2u_pad_idx=model.t2u_model.pad_idx,
- )
- self._try_load_checkpoint(model=model)
- self.model = self._wrap_model_for_trainining(model=model)
- # TODO: make tweakable
- self.optimizer = Adam(
- params=self.model.parameters(),
- lr=self.params.learning_rate,
- betas=(0.9, 0.98),
- eps=1e-08,
- maximize=False,
- weight_decay=0.0,
- fused=True,
- )
- self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None # type: ignore
- # TODO: allow scheduler selection
- self.lr_scheduler = MyleLR(
- optimizer=self.optimizer,
- num_warmup_steps=self.params.warmup_steps,
- start_lr=self.params.start_learning_rate,
- )
- self.train_loss_hist = LossCollector(device=self.device)
- self.epoch_idx: int = 0
- self.update_idx: int = 0
- self.patience_left: int = self.params.patience
- self.last_eval_loss: Optional[float] = None
- self.best_eval_loss: Optional[float] = None
- self.is_best_state: bool = False
- self.batch_sizes: List[int] = []
- self.gpu_usage: List[float] = []
- def _try_load_checkpoint(self, model: torch.nn.Module):
- chck_path = self.get_best_checkpoint_path()
- if os.path.exists(chck_path):
- logger.info(f"Loading state dict from {chck_path}")
- state_dict = torch.load(chck_path)
- model.load_state_dict(state_dict)
- @classmethod
- def _get_float_dtype(cls, float_dtype: str) -> torch.dtype:
- if float_dtype == "fp16":
- return torch.float16
- elif float_dtype == "fp32":
- return torch.float32
- elif float_dtype == "bf16":
- return torch.bfloat16
- else:
- raise ValueError(f"Unkown dtype literal: {float_dtype}")
- def _reset_stats(self) -> None:
- self.train_loss_hist.reset()
- self.epoch_idx = 0
- self.update_idx = 0
- self.patience_left = self.params.patience
- self.last_eval_loss = None
- self.best_eval_loss = None
- self.is_best_state = False
- self._reset_log_stats()
- def _reset_log_stats(self) -> None:
- self.batch_sizes.clear()
- self.gpu_usage.clear()
- self.ts = time.time()
- self.last_update_idx = self.update_idx
- def _record_gpu_usage(self) -> None:
- gb = (torch.cuda.memory_reserved(self.device) >> 20) / 1024.0
- self.gpu_usage.append(gb)
- def _get_avg_bsz(self) -> float:
- """Avg training batch size"""
- return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
- def _get_ups(self) -> float:
- """Updates per second"""
- ts_delta = time.time() - self.ts
- return (self.update_idx - self.last_update_idx) / ts_delta
- def _get_avg_gpu_usage(self) -> float:
- return sum(self.gpu_usage) / len(self.gpu_usage) if self.gpu_usage else 0.0
- def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
- wrapped_model = UnitYTrainWrapper(model=model)
- if not dist_utils.is_dist_initialized():
- return wrapped_model
- return nn.parallel.DistributedDataParallel(
- wrapped_model,
- device_ids=[dist_utils.get_local_rank()],
- find_unused_parameters=True,
- )
- def _update_eval_stats(self, eval_loss: float) -> None:
- self.last_eval_loss = eval_loss
- self.is_best_state = self.best_eval_loss is None or eval_loss < self.best_eval_loss
- self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
- self.patience_left = self.params.patience if self.is_best_state else self.patience_left - 1
- logger.info(
- f"Eval after {self.update_idx} updates: "
- f"loss={eval_loss:.4f} "
- f"best_loss={self.best_eval_loss:.4f} "
- f"patience_steps_left={self.patience_left}"
- )
- def _eval_model(self) -> None:
- """Calc avg loss on eval dataset and update evaluation stats"""
- if self.eval_data_loader is None:
- return
- logger.info("Run evaluation")
- loss_hist = LossCollector(device=self.device)
- self.model.eval()
- with torch.no_grad():
- self.eval_data_loader.reset()
- for batch in self.eval_data_loader.iterate_batches():
- assert batch.speech_to_text.src_tokens is not None
- loss = self.calc_loss(batch, *self.model(batch))
- if loss.isnan():
- logger.warning("Eval loss value is NaN, setting to inf")
- loss_val = float("Inf")
- else:
- loss_val = loss.item()
- del batch # force memory release
- loss_hist.update(1, loss_val)
- eval_loss = loss_hist.reduce()
- self._update_eval_stats(eval_loss)
- def _train_step_log(self):
- """Log train stats"""
- if (self.update_idx + 1) % self.params.log_steps == 0:
- avg_loss = self.train_loss_hist.reduce()
- self.train_loss_hist.reset()
- logger.info(
- f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
- f"update {str(self.update_idx + 1).zfill(5)}: "
- f"train loss={avg_loss:.4f} "
- f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E} "
- f"bsz_avg={self._get_avg_bsz():.1f} "
- f"ups={self._get_ups():.2f} "
- f"gpu_avg={self._get_avg_gpu_usage():.2f}Gb"
- )
- self._reset_log_stats()
- def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
- """Run one train step"""
- self.model.train()
- self.optimizer.zero_grad()
- tokens, units = self.model(batch)
- loss = self.calc_loss(batch, tokens, units)
- # peak of gpu usage
- self._record_gpu_usage()
- if self.grad_scaler is not None:
- self.grad_scaler.scale(loss).backward() # type: ignore
- self.grad_scaler.step(self.optimizer)
- self.grad_scaler.update()
- else:
- loss.backward()
- self.optimizer.step()
- self.lr_scheduler.step()
- assert batch.speech_to_text.src_tokens is not None
- self.train_loss_hist.update(1, loss.item())
- self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
- self._train_step_log()
- def _get_state(self) -> Dict[str, Any]:
- model_state_dict = self.model.state_dict()
- model_state_dict = {key.replace("module.model.", ""): value for key, value in model_state_dict.items()}
- return model_state_dict
- def _get_chck_path(self) -> str:
- ts = str(int(time.time()))
- epoch = str(self.epoch_idx).zfill(3)
- update = str(self.update_idx).zfill(6)
- eval_loss = f"{self.last_eval_loss:.4f}"
- name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
- return os.path.join(self.chck_save_dir, name)
- def _get_best_checkpoint_link_path(self) -> str:
- return os.path.join(self.chck_save_dir, self.CHECKPOINT_BEST)
- def get_best_checkpoint_path(self) -> str:
- return os.path.realpath(self._get_best_checkpoint_link_path())
- def _save_model(self):
- if dist_utils.is_main_process():
- state_dict = self._get_state()
- save_path = self._get_chck_path()
- logger.info(f"Saving checkpoint to {save_path}")
- torch.save(state_dict, save_path)
- if self.is_best_state:
- best_link_path = self._get_best_checkpoint_link_path()
- if os.path.exists(best_link_path):
- os.unlink(best_link_path)
- os.symlink(save_path, best_link_path)
- logger.info(f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}")
- if dist_utils.is_dist_initialized():
- dist.barrier()
- def run(self):
- logger.info("Start training")
- self._reset_stats()
- self._eval_model()
- while self.epoch_idx < self.params.max_epochs and self.patience_left:
- for train_batch in self.train_data_loader.iterate_batches():
- self._train_step(batch=train_batch)
- if self.update_idx and self.update_idx % self.params.eval_steps == 0:
- self._eval_model()
- if self.is_best_state:
- self._save_model()
- elif not self.patience_left:
- no_improve_steps = self.params.eval_steps * self.params.patience
- logger.info(
- f"Early termination, as eval loss did not improve over last {no_improve_steps} updates"
- )
- break
- self.update_idx += 1
- self.train_data_loader.reset()
- self.epoch_idx += 1
|