trainer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. from typing import Any, Optional, Tuple, Dict, List
  8. import os
  9. import time
  10. import torch
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. from fairseq2.models.sequence import SequenceModelOutput
  14. from fairseq2.nn.padding import PaddingMask
  15. from fairseq2.optim.lr_scheduler import MyleLR
  16. from m4t_scripts.train import dataloader, dist_utils
  17. from torch.optim import Adam
  18. from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
  19. from m4t_scripts.train.configs import TrainingParams
  20. logger = logging.getLogger(__name__)
  21. class UnitYTrainWrapper(nn.Module):
  22. """Convenience wrapper that does a forward pass
  23. and returns S2T and T2U logits"""
  24. def __init__(self, model: UnitYModel):
  25. super().__init__()
  26. self.model: UnitYModel = model
  27. if isinstance(self.model.t2u_model, UnitYT2UModel):
  28. self.t2u: UnitYT2UModel = self.model.t2u_model
  29. else:
  30. raise NotImplementedError(
  31. "Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u"
  32. )
  33. def forward(
  34. self, batch: dataloader.MultimodalSeqsBatch
  35. ) -> Tuple[torch.Tensor, torch.Tensor]:
  36. """Forward pass, computes S2T and T2U losses"""
  37. assert self.model.t2u_model is not None
  38. assert batch.speech_to_text.src_tokens is not None
  39. # s2t
  40. speech_padding_mask = PaddingMask(
  41. seq_lens=batch.speech_to_text.src_lengths,
  42. batch_seq_len=int(torch.max(batch.speech_to_text.src_lengths).item()),
  43. )
  44. speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
  45. seqs=batch.speech_to_text.src_tokens,
  46. padding_mask=speech_padding_mask,
  47. )
  48. assert batch.speech_to_text.prev_output_tokens is not None
  49. s2t_prev_out_tokens_padding_mask = PaddingMask(
  50. seq_lens=batch.speech_to_text.target_lengths,
  51. batch_seq_len=int(torch.max(batch.speech_to_text.target_lengths).item()),
  52. )
  53. text_decoder_out, text_decoder_padding_mask = self.model.decode(
  54. seqs=batch.speech_to_text.prev_output_tokens,
  55. padding_mask=s2t_prev_out_tokens_padding_mask,
  56. encoder_output=speech_encoder_out,
  57. encoder_padding_mask=speech_encoder_padding_mask,
  58. )
  59. text_logits = self.model.final_proj(text_decoder_out)
  60. # t2u
  61. (unit_encoder_out, unit_encoder_padding_mask,) = self.t2u.encode(
  62. text_decoder_output=text_decoder_out,
  63. text_decoder_padding_mask=text_decoder_padding_mask,
  64. )
  65. t2u_prev_out_tokens_padding_mask = PaddingMask(
  66. seq_lens=batch.text_to_units.target_lengths,
  67. batch_seq_len=int(torch.max(batch.text_to_units.target_lengths).item()),
  68. )
  69. unit_decoder_out, _ = self.t2u.decode(
  70. seqs=batch.text_to_units.prev_output_tokens,
  71. padding_mask=t2u_prev_out_tokens_padding_mask,
  72. encoder_output=unit_encoder_out,
  73. encoder_padding_mask=unit_encoder_padding_mask,
  74. )
  75. unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
  76. return (text_logits, unit_logits)
  77. class CalcLoss:
  78. """Calculates per-token negative log likelihood loss for S2T and T2U"""
  79. def __init__(
  80. self,
  81. label_smoothing: float,
  82. s2t_pad_idx: Optional[int],
  83. t2u_pad_idx: Optional[int],
  84. s2t_skip_langtok_loss: bool = False,
  85. ):
  86. self.label_smoothing = label_smoothing
  87. self.s2t_pad_idx = s2t_pad_idx
  88. self.t2u_pad_idx = t2u_pad_idx
  89. self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
  90. self.t2u_ignore_prefix_size = 1
  91. def __call__(
  92. self,
  93. batch: dataloader.MultimodalSeqsBatch,
  94. text_logits: torch.Tensor,
  95. unit_logits: torch.Tensor,
  96. ) -> torch.Tensor:
  97. assert batch.speech_to_text.target_lengths is not None
  98. s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
  99. text_logits.device
  100. )
  101. s2t_loss = SequenceModelOutput(
  102. logits=text_logits, pad_idx=self.s2t_pad_idx
  103. ).compute_loss(
  104. targets=batch.speech_to_text.target_tokens.to(text_logits.device),
  105. ignore_prefix_size=self.s2t_ignore_prefix_size,
  106. label_smoothing=self.label_smoothing,
  107. )
  108. assert batch.text_to_units.target_lengths is not None
  109. s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
  110. s2u_loss = SequenceModelOutput(
  111. logits=unit_logits, pad_idx=self.t2u_pad_idx
  112. ).compute_loss(
  113. targets=batch.text_to_units.target_tokens.to(unit_logits.device),
  114. ignore_prefix_size=1,
  115. label_smoothing=self.label_smoothing,
  116. )
  117. return s2t_loss / s2t_numel + s2u_loss / s2u_numel
  118. class LossCollector:
  119. """Aggregrates loss history across nodes"""
  120. def __init__(self, device: Optional[torch.device] = None, reduce_op: str = "avg"):
  121. self.n_samples: float = 0
  122. self.val_sum: float = 0.0
  123. self.reduce_op = reduce_op
  124. self.device = device
  125. self.is_distributed = dist_utils.is_dist_initialized()
  126. def reset(self) -> None:
  127. self.n_samples = 0
  128. self.val_sum = 0.0
  129. def update(self, n_samples: int, batch_loss: float) -> None:
  130. self.n_samples += n_samples
  131. self.val_sum += batch_loss
  132. def reduce(self) -> float:
  133. n_samples, val_sum = self._collect()
  134. if self.reduce_op == "avg":
  135. return val_sum / (n_samples + 1)
  136. if self.reduce_op == "sum":
  137. return val_sum
  138. raise ValueError()
  139. def _collect(self) -> Tuple[float, float]:
  140. if not self.is_distributed:
  141. return self.n_samples, self.val_sum
  142. local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
  143. all_vals = [
  144. torch.zeros((1, 2), device=self.device)
  145. for _ in range(dist_utils.get_world_size())
  146. ]
  147. dist.all_gather(all_vals, local_val)
  148. losses = torch.concat(all_vals, dim=0)
  149. reduced = torch.sum(losses, dim=0).reshape(2).cpu()
  150. return reduced[0].item(), reduced[1].item()
  151. class UnitYTrainer:
  152. CHECKPOINT_BEST = "checkpoint_best.pt"
  153. def __init__(
  154. self,
  155. model: UnitYModel,
  156. params: TrainingParams,
  157. train_data_loader: dataloader.UnityDataLoader,
  158. eval_data_loader: Optional[dataloader.UnityDataLoader],
  159. chck_save_dir: str,
  160. device: torch.device,
  161. ):
  162. self.params = params
  163. self.device = device
  164. self.float_dtype = self._get_float_dtype(self.params.float_dtype)
  165. self.train_data_loader = train_data_loader
  166. self.eval_data_loader = eval_data_loader
  167. self.chck_save_dir = chck_save_dir
  168. assert model.t2u_model is not None
  169. self.calc_loss = CalcLoss(
  170. label_smoothing=self.params.label_smoothing,
  171. s2t_pad_idx=model.pad_idx,
  172. t2u_pad_idx=model.t2u_model.pad_idx,
  173. )
  174. self._try_load_checkpoint(model=model)
  175. self.model = self._wrap_model_for_trainining(model=model)
  176. # TODO: make tweakable
  177. self.optimizer = Adam(
  178. params=self.model.parameters(),
  179. lr=self.params.learning_rate,
  180. betas=(0.9, 0.98),
  181. eps=1e-08,
  182. maximize=False,
  183. weight_decay=0.0,
  184. fused=True,
  185. )
  186. self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None # type: ignore
  187. # TODO: allow scheduler selection
  188. self.lr_scheduler = MyleLR(
  189. optimizer=self.optimizer,
  190. num_warmup_steps=self.params.warmup_steps,
  191. start_lr=self.params.start_learning_rate,
  192. )
  193. self.train_loss_hist = LossCollector(device=self.device)
  194. self.epoch_idx: int = 0
  195. self.update_idx: int = 0
  196. self.patience_left: int = self.params.patience
  197. self.last_eval_loss: Optional[float] = None
  198. self.best_eval_loss: Optional[float] = None
  199. self.is_best_state: bool = False
  200. self.batch_sizes: List[int] = []
  201. self.gpu_usage: List[float] = []
  202. def _try_load_checkpoint(self, model: torch.nn.Module):
  203. chck_path = self.get_best_checkpoint_path()
  204. if os.path.exists(chck_path):
  205. logger.info(f"Loading state dict from {chck_path}")
  206. state_dict = torch.load(chck_path)
  207. model.load_state_dict(state_dict)
  208. @classmethod
  209. def _get_float_dtype(cls, float_dtype: str) -> torch.dtype:
  210. if float_dtype == "fp16":
  211. return torch.float16
  212. elif float_dtype == "fp32":
  213. return torch.float32
  214. elif float_dtype == "bf16":
  215. return torch.bfloat16
  216. else:
  217. raise ValueError(f"Unkown dtype literal: {float_dtype}")
  218. def _reset_stats(self) -> None:
  219. self.train_loss_hist.reset()
  220. self.epoch_idx = 0
  221. self.update_idx = 0
  222. self.patience_left = self.params.patience
  223. self.last_eval_loss = None
  224. self.best_eval_loss = None
  225. self.is_best_state = False
  226. self._reset_log_stats()
  227. def _reset_log_stats(self) -> None:
  228. self.batch_sizes.clear()
  229. self.gpu_usage.clear()
  230. self.ts = time.time()
  231. self.last_update_idx = self.update_idx
  232. def _record_gpu_usage(self) -> None:
  233. gb = (torch.cuda.memory_reserved(self.device) >> 20) / 1024.0
  234. self.gpu_usage.append(gb)
  235. def _get_avg_bsz(self) -> float:
  236. """Avg training batch size"""
  237. return (
  238. sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
  239. )
  240. def _get_ups(self) -> float:
  241. """Updates per second"""
  242. ts_delta = time.time() - self.ts
  243. return (self.update_idx - self.last_update_idx) / ts_delta
  244. def _get_avg_gpu_usage(self) -> float:
  245. return sum(self.gpu_usage) / len(self.gpu_usage) if self.gpu_usage else 0.0
  246. def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
  247. wrapped_model = UnitYTrainWrapper(model=model)
  248. if not dist_utils.is_dist_initialized():
  249. return wrapped_model
  250. return nn.parallel.DistributedDataParallel(
  251. wrapped_model,
  252. device_ids=[dist_utils.get_local_rank()],
  253. find_unused_parameters=True,
  254. )
  255. def _update_eval_stats(self, eval_loss: float) -> None:
  256. self.last_eval_loss = eval_loss
  257. self.is_best_state = (
  258. self.best_eval_loss is None or eval_loss < self.best_eval_loss
  259. )
  260. self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
  261. self.patience_left = (
  262. self.params.patience if self.is_best_state else self.patience_left - 1
  263. )
  264. logger.info(
  265. f"Eval after {self.update_idx} updates: "
  266. f"loss={eval_loss:.4f} "
  267. f"best_loss={self.best_eval_loss:.4f} "
  268. f"patience_steps_left={self.patience_left}"
  269. )
  270. def _eval_model(self) -> None:
  271. """Calc avg loss on eval dataset and update evaluation stats"""
  272. if self.eval_data_loader is None:
  273. return
  274. logger.info("Run evaluation")
  275. loss_hist = LossCollector(device=self.device)
  276. self.model.eval()
  277. with torch.no_grad():
  278. self.eval_data_loader.reset()
  279. for batch in self.eval_data_loader.iterate_batches():
  280. assert batch.speech_to_text.src_tokens is not None
  281. loss = self.calc_loss(batch, *self.model(batch))
  282. if loss.isnan():
  283. logger.warning("Eval loss value is NaN, setting to inf")
  284. loss_val = float("Inf")
  285. else:
  286. loss_val = loss.item()
  287. self._release_memory(batch)
  288. loss_hist.update(1, loss_val)
  289. eval_loss = loss_hist.reduce()
  290. self._update_eval_stats(eval_loss)
  291. def _train_step_log(self):
  292. """Log train stats"""
  293. if (self.update_idx + 1) % self.params.log_steps == 0:
  294. avg_loss = self.train_loss_hist.reduce()
  295. self.train_loss_hist.reset()
  296. logger.info(
  297. f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
  298. f"update {str(self.update_idx + 1).zfill(5)}: "
  299. f"train loss={avg_loss:.4f} "
  300. f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E} "
  301. f"bsz_avg={self._get_avg_bsz():.1f} "
  302. f"ups={self._get_ups():.2f} "
  303. f"gpu_avg={self._get_avg_gpu_usage():.2f}Gb"
  304. )
  305. self._reset_log_stats()
  306. def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  307. """Run one train step"""
  308. self.model.train()
  309. self.optimizer.zero_grad()
  310. tokens, units = self.model(batch)
  311. loss = self.calc_loss(batch, tokens, units)
  312. # peak of gpu usage
  313. self._record_gpu_usage()
  314. if self.grad_scaler is not None:
  315. self.grad_scaler.scale(loss).backward() # type: ignore
  316. self.grad_scaler.step(self.optimizer)
  317. self.grad_scaler.update()
  318. else:
  319. loss.backward()
  320. self.optimizer.step()
  321. self.lr_scheduler.step()
  322. assert batch.speech_to_text.src_tokens is not None
  323. self.train_loss_hist.update(1, loss.item())
  324. self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
  325. self._train_step_log()
  326. self._release_memory(batch)
  327. def _release_memory(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  328. """Explicitly release large memory consumers"""
  329. del batch
  330. def _strip_state_key_prefixes(self, key: str) -> str:
  331. """Removes state_dict keys prefixes associated with model wrappers"""
  332. to_strip = ["module.", "model."]
  333. for prefix in to_strip:
  334. if key.startswith(prefix):
  335. key = key[len(prefix) :]
  336. return key
  337. def _get_state(self) -> Dict[str, Any]:
  338. model_state_dict = self.model.state_dict()
  339. model_state_dict = {
  340. self._strip_state_key_prefixes(key): value
  341. for key, value in model_state_dict.items()
  342. }
  343. return model_state_dict
  344. def _get_chck_path(self) -> str:
  345. ts = str(int(time.time()))
  346. epoch = str(self.epoch_idx).zfill(3)
  347. update = str(self.update_idx).zfill(6)
  348. eval_loss = f"{self.last_eval_loss:.4f}"
  349. name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
  350. return os.path.join(self.chck_save_dir, name)
  351. def _get_best_checkpoint_link_path(self) -> str:
  352. return os.path.join(self.chck_save_dir, self.CHECKPOINT_BEST)
  353. def get_best_checkpoint_path(self) -> str:
  354. return os.path.realpath(self._get_best_checkpoint_link_path())
  355. def _save_model(self):
  356. if dist_utils.is_main_process():
  357. state_dict = self._get_state()
  358. save_path = self._get_chck_path()
  359. logger.info(f"Saving checkpoint to {save_path}")
  360. torch.save(state_dict, save_path)
  361. if self.is_best_state:
  362. best_link_path = self._get_best_checkpoint_link_path()
  363. if os.path.exists(best_link_path):
  364. os.unlink(best_link_path)
  365. os.symlink(save_path, best_link_path)
  366. logger.info(
  367. f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}"
  368. )
  369. if dist_utils.is_dist_initialized():
  370. dist.barrier()
  371. def run(self):
  372. logger.info("Start training")
  373. self._reset_stats()
  374. self._eval_model()
  375. while self.epoch_idx < self.params.max_epochs and self.patience_left:
  376. for train_batch in self.train_data_loader.iterate_batches():
  377. self._train_step(batch=train_batch)
  378. if self.update_idx and self.update_idx % self.params.eval_steps == 0:
  379. self._eval_model()
  380. if self.is_best_state:
  381. self._save_model()
  382. elif not self.patience_left:
  383. no_improve_steps = self.params.eval_steps * self.params.patience
  384. logger.info(
  385. f"Early termination, as eval loss did not improve over last {no_improve_steps} updates"
  386. )
  387. break
  388. self.update_idx += 1
  389. self.train_data_loader.reset()
  390. self.epoch_idx += 1