trainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. from typing import Any, Optional, Tuple, Dict, List
  8. import os
  9. import time
  10. import torch
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. from fairseq2.models.sequence import SequenceModelOutput
  14. from fairseq2.optim.lr_scheduler import MyleLR
  15. from m4t_scripts.train import dataloader, dist_utils
  16. from torch.optim import Adam
  17. from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
  18. from m4t_scripts.train.configs import TrainingParams
  19. logger = logging.getLogger(__name__)
  20. class UnitYTrainWrapper(nn.Module):
  21. """Convenience wrapper that does a forward pass
  22. and returns S2T and T2U logits"""
  23. def __init__(self, model: UnitYModel):
  24. super().__init__()
  25. self.model: UnitYModel = model
  26. if isinstance(self.model.t2u_model, UnitYT2UModel):
  27. self.t2u: UnitYT2UModel = self.model.t2u_model
  28. else:
  29. raise NotImplementedError("Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u")
  30. def forward(self, batch: dataloader.MultimodalSeqsBatch) -> Tuple[torch.Tensor, torch.Tensor]:
  31. """Forward pass, computes S2T and T2U losses"""
  32. assert self.model.t2u_model is not None
  33. assert batch.speech_to_text.src_tokens is not None
  34. # s2t
  35. speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
  36. seqs=batch.speech_to_text.src_tokens,
  37. seq_lens=batch.speech_to_text.src_lengths,
  38. )
  39. assert batch.speech_to_text.prev_output_tokens is not None
  40. text_decoder_out, text_decoder_padding_mask = self.model.decode(
  41. seqs=batch.speech_to_text.prev_output_tokens,
  42. seq_lens=batch.speech_to_text.target_lengths,
  43. encoder_output=speech_encoder_out,
  44. encoder_padding_mask=speech_encoder_padding_mask,
  45. )
  46. text_logits = self.model.final_proj(text_decoder_out)
  47. # t2u
  48. (
  49. unit_encoder_out,
  50. unit_encoder_padding_mask,
  51. ) = self.t2u.encode(
  52. text_decoder_output=text_decoder_out,
  53. text_decoder_padding_mask=text_decoder_padding_mask,
  54. )
  55. unit_decoder_out, _ = self.t2u.decode(
  56. seqs=batch.text_to_units.prev_output_tokens,
  57. seq_lens=batch.text_to_units.target_lengths,
  58. encoder_output=unit_encoder_out,
  59. encoder_padding_mask=unit_encoder_padding_mask,
  60. )
  61. unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
  62. return (text_logits, unit_logits)
  63. class CalcLoss:
  64. """Calculates per-token negative log likelihood loss for S2T and T2U"""
  65. def __init__(
  66. self,
  67. label_smoothing: float,
  68. s2t_pad_idx: Optional[int],
  69. t2u_pad_idx: Optional[int],
  70. s2t_skip_langtok_loss: bool = False,
  71. ):
  72. self.label_smoothing = label_smoothing
  73. self.s2t_pad_idx = s2t_pad_idx
  74. self.t2u_pad_idx = t2u_pad_idx
  75. self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
  76. self.t2u_ignore_prefix_size = 1
  77. def __call__(
  78. self,
  79. batch: dataloader.MultimodalSeqsBatch,
  80. text_logits: torch.Tensor,
  81. unit_logits: torch.Tensor,
  82. ) -> torch.Tensor:
  83. assert batch.speech_to_text.target_lengths is not None
  84. s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(text_logits.device)
  85. s2t_loss = SequenceModelOutput(logits=text_logits, pad_idx=self.s2t_pad_idx).compute_loss(
  86. targets=batch.speech_to_text.target_tokens.to(text_logits.device),
  87. ignore_prefix_size=self.s2t_ignore_prefix_size,
  88. label_smoothing=self.label_smoothing,
  89. )
  90. assert batch.text_to_units.target_lengths is not None
  91. s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
  92. s2u_loss = SequenceModelOutput(logits=unit_logits, pad_idx=self.t2u_pad_idx).compute_loss(
  93. targets=batch.text_to_units.target_tokens.to(unit_logits.device),
  94. ignore_prefix_size=1,
  95. label_smoothing=self.label_smoothing,
  96. )
  97. return s2t_loss / s2t_numel + s2u_loss / s2u_numel
  98. class LossCollector:
  99. """Aggregrates loss history across nodes"""
  100. def __init__(self, device: Optional[torch.device] = None, reduce_op: str = "avg"):
  101. self.n_samples: float = 0
  102. self.val_sum: float = 0.0
  103. self.reduce_op = reduce_op
  104. self.device = device
  105. self.is_distributed = dist_utils.is_dist_initialized()
  106. def reset(self) -> None:
  107. self.n_samples = 0
  108. self.val_sum = 0.0
  109. def update(self, n_samples: int, batch_loss: float) -> None:
  110. self.n_samples += n_samples
  111. self.val_sum += batch_loss
  112. def reduce(self) -> float:
  113. n_samples, val_sum = self._collect()
  114. if self.reduce_op == "avg":
  115. return val_sum / (n_samples + 1)
  116. if self.reduce_op == "sum":
  117. return val_sum
  118. raise ValueError()
  119. def _collect(self) -> Tuple[float, float]:
  120. if not self.is_distributed:
  121. return self.n_samples, self.val_sum
  122. local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
  123. all_vals = [torch.zeros((1, 2), device=self.device) for _ in range(dist_utils.get_world_size())]
  124. dist.all_gather(all_vals, local_val)
  125. losses = torch.concat(all_vals, dim=0)
  126. reduced = torch.sum(losses, dim=0).reshape(2).cpu()
  127. return reduced[0].item(), reduced[1].item()
  128. class UnitYTrainer:
  129. CHECKPOINT_BEST = "checkpoint_best.pt"
  130. def __init__(
  131. self,
  132. model: UnitYModel,
  133. params: TrainingParams,
  134. train_data_loader: dataloader.UnityDataLoader,
  135. eval_data_loader: Optional[dataloader.UnityDataLoader],
  136. chck_save_dir: str,
  137. device: torch.device,
  138. ):
  139. self.params = params
  140. self.device = device
  141. self.float_dtype = self._get_float_dtype(self.params.float_dtype)
  142. self.train_data_loader = train_data_loader
  143. self.eval_data_loader = eval_data_loader
  144. self.chck_save_dir = chck_save_dir
  145. assert model.t2u_model is not None
  146. self.calc_loss = CalcLoss(
  147. label_smoothing=self.params.label_smoothing,
  148. s2t_pad_idx=model.pad_idx,
  149. t2u_pad_idx=model.t2u_model.pad_idx,
  150. )
  151. self._try_load_checkpoint(model=model)
  152. self.model = self._wrap_model_for_trainining(model=model)
  153. # TODO: make tweakable
  154. self.optimizer = Adam(
  155. params=self.model.parameters(),
  156. lr=self.params.learning_rate,
  157. betas=(0.9, 0.98),
  158. eps=1e-08,
  159. maximize=False,
  160. weight_decay=0.0,
  161. fused=True,
  162. )
  163. self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None # type: ignore
  164. # TODO: allow scheduler selection
  165. self.lr_scheduler = MyleLR(
  166. optimizer=self.optimizer,
  167. num_warmup_steps=self.params.warmup_steps,
  168. start_lr=self.params.start_learning_rate,
  169. )
  170. self.train_loss_hist = LossCollector(device=self.device)
  171. self.epoch_idx: int = 0
  172. self.update_idx: int = 0
  173. self.patience_left: int = self.params.patience
  174. self.last_eval_loss: Optional[float] = None
  175. self.best_eval_loss: Optional[float] = None
  176. self.is_best_state: bool = False
  177. self.batch_sizes: List[int] = []
  178. self.gpu_usage: List[float] = []
  179. def _try_load_checkpoint(self, model: torch.nn.Module):
  180. chck_path = self.get_best_checkpoint_path()
  181. if os.path.exists(chck_path):
  182. logger.info(f"Loading state dict from {chck_path}")
  183. state_dict = torch.load(chck_path)
  184. model.load_state_dict(state_dict)
  185. @classmethod
  186. def _get_float_dtype(cls, float_dtype: str) -> torch.dtype:
  187. if float_dtype == "fp16":
  188. return torch.float16
  189. elif float_dtype == "fp32":
  190. return torch.float32
  191. elif float_dtype == "bf16":
  192. return torch.bfloat16
  193. else:
  194. raise ValueError(f"Unkown dtype literal: {float_dtype}")
  195. def _reset_stats(self) -> None:
  196. self.train_loss_hist.reset()
  197. self.epoch_idx = 0
  198. self.update_idx = 0
  199. self.patience_left = self.params.patience
  200. self.last_eval_loss = None
  201. self.best_eval_loss = None
  202. self.is_best_state = False
  203. self._reset_log_stats()
  204. def _reset_log_stats(self) -> None:
  205. self.batch_sizes.clear()
  206. self.gpu_usage.clear()
  207. self.ts = time.time()
  208. self.last_update_idx = self.update_idx
  209. def _record_gpu_usage(self) -> None:
  210. gb = (torch.cuda.memory_reserved(self.device) >> 20) / 1024.0
  211. self.gpu_usage.append(gb)
  212. def _get_avg_bsz(self) -> float:
  213. """Avg training batch size"""
  214. return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
  215. def _get_ups(self) -> float:
  216. """Updates per second"""
  217. ts_delta = time.time() - self.ts
  218. return (self.update_idx - self.last_update_idx) / ts_delta
  219. def _get_avg_gpu_usage(self) -> float:
  220. return sum(self.gpu_usage) / len(self.gpu_usage) if self.gpu_usage else 0.0
  221. def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
  222. wrapped_model = UnitYTrainWrapper(model=model)
  223. if not dist_utils.is_dist_initialized():
  224. return wrapped_model
  225. return nn.parallel.DistributedDataParallel(
  226. wrapped_model,
  227. device_ids=[dist_utils.get_local_rank()],
  228. find_unused_parameters=True,
  229. )
  230. def _update_eval_stats(self, eval_loss: float) -> None:
  231. self.last_eval_loss = eval_loss
  232. self.is_best_state = self.best_eval_loss is None or eval_loss < self.best_eval_loss
  233. self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
  234. self.patience_left = self.params.patience if self.is_best_state else self.patience_left - 1
  235. logger.info(
  236. f"Eval after {self.update_idx} updates: "
  237. f"loss={eval_loss:.4f} "
  238. f"best_loss={self.best_eval_loss:.4f} "
  239. f"patience_steps_left={self.patience_left}"
  240. )
  241. def _eval_model(self) -> None:
  242. """Calc avg loss on eval dataset and update evaluation stats"""
  243. if self.eval_data_loader is None:
  244. return
  245. logger.info("Run evaluation")
  246. loss_hist = LossCollector(device=self.device)
  247. self.model.eval()
  248. with torch.no_grad():
  249. self.eval_data_loader.reset()
  250. for batch in self.eval_data_loader.iterate_batches():
  251. assert batch.speech_to_text.src_tokens is not None
  252. loss = self.calc_loss(batch, *self.model(batch))
  253. if loss.isnan():
  254. logger.warning("Eval loss value is NaN, setting to inf")
  255. loss_val = float("Inf")
  256. else:
  257. loss_val = loss.item()
  258. del batch # force memory release
  259. loss_hist.update(1, loss_val)
  260. eval_loss = loss_hist.reduce()
  261. self._update_eval_stats(eval_loss)
  262. def _train_step_log(self):
  263. """Log train stats"""
  264. if (self.update_idx + 1) % self.params.log_steps == 0:
  265. avg_loss = self.train_loss_hist.reduce()
  266. self.train_loss_hist.reset()
  267. logger.info(
  268. f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
  269. f"update {str(self.update_idx + 1).zfill(5)}: "
  270. f"train loss={avg_loss:.4f} "
  271. f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E} "
  272. f"bsz_avg={self._get_avg_bsz():.1f} "
  273. f"ups={self._get_ups():.2f} "
  274. f"gpu_avg={self._get_avg_gpu_usage():.2f}Gb"
  275. )
  276. self._reset_log_stats()
  277. def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  278. """Run one train step"""
  279. self.model.train()
  280. self.optimizer.zero_grad()
  281. tokens, units = self.model(batch)
  282. loss = self.calc_loss(batch, tokens, units)
  283. # peak of gpu usage
  284. self._record_gpu_usage()
  285. if self.grad_scaler is not None:
  286. self.grad_scaler.scale(loss).backward() # type: ignore
  287. self.grad_scaler.step(self.optimizer)
  288. self.grad_scaler.update()
  289. else:
  290. loss.backward()
  291. self.optimizer.step()
  292. self.lr_scheduler.step()
  293. assert batch.speech_to_text.src_tokens is not None
  294. self.train_loss_hist.update(1, loss.item())
  295. self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
  296. self._train_step_log()
  297. def _get_state(self) -> Dict[str, Any]:
  298. model_state_dict = self.model.state_dict()
  299. model_state_dict = {key.replace("module.model.", ""): value for key, value in model_state_dict.items()}
  300. return model_state_dict
  301. def _get_chck_path(self) -> str:
  302. ts = str(int(time.time()))
  303. epoch = str(self.epoch_idx).zfill(3)
  304. update = str(self.update_idx).zfill(6)
  305. eval_loss = f"{self.last_eval_loss:.4f}"
  306. name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
  307. return os.path.join(self.chck_save_dir, name)
  308. def _get_best_checkpoint_link_path(self) -> str:
  309. return os.path.join(self.chck_save_dir, self.CHECKPOINT_BEST)
  310. def get_best_checkpoint_path(self) -> str:
  311. return os.path.realpath(self._get_best_checkpoint_link_path())
  312. def _save_model(self):
  313. if dist_utils.is_main_process():
  314. state_dict = self._get_state()
  315. save_path = self._get_chck_path()
  316. logger.info(f"Saving checkpoint to {save_path}")
  317. torch.save(state_dict, save_path)
  318. if self.is_best_state:
  319. best_link_path = self._get_best_checkpoint_link_path()
  320. if os.path.exists(best_link_path):
  321. os.unlink(best_link_path)
  322. os.symlink(save_path, best_link_path)
  323. logger.info(f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}")
  324. if dist_utils.is_dist_initialized():
  325. dist.barrier()
  326. def run(self):
  327. logger.info("Start training")
  328. self._reset_stats()
  329. self._eval_model()
  330. while self.epoch_idx < self.params.max_epochs and self.patience_left:
  331. for train_batch in self.train_data_loader.iterate_batches():
  332. self._train_step(batch=train_batch)
  333. if self.update_idx and self.update_idx % self.params.eval_steps == 0:
  334. self._eval_model()
  335. if self.is_best_state:
  336. self._save_model()
  337. elif not self.patience_left:
  338. no_improve_steps = self.params.eval_steps * self.params.patience
  339. logger.info(
  340. f"Early termination, as eval loss did not improve over last {no_improve_steps} updates"
  341. )
  342. break
  343. self.update_idx += 1
  344. self.train_data_loader.reset()
  345. self.epoch_idx += 1