trainer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. from typing import Any, Optional, Tuple, Dict, List
  8. import os
  9. import time
  10. import torch
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. from fairseq2.models.sequence import SequenceModelOutput
  14. from fairseq2.nn.padding import PaddingMask
  15. from fairseq2.optim.lr_scheduler import MyleLR
  16. from m4t_scripts.train import dataloader, dist_utils
  17. from torch.optim import Adam
  18. from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
  19. from m4t_scripts.train.configs import TrainingParams
  20. logger = logging.getLogger(__name__)
  21. class UnitYTrainWrapper(nn.Module):
  22. """Convenience wrapper that does a forward pass
  23. and returns S2T and T2U logits"""
  24. def __init__(self, model: UnitYModel):
  25. super().__init__()
  26. self.model: UnitYModel = model
  27. if isinstance(self.model.t2u_model, UnitYT2UModel):
  28. self.t2u: UnitYT2UModel = self.model.t2u_model
  29. else:
  30. raise NotImplementedError(
  31. "Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u"
  32. )
  33. def forward(
  34. self, batch: dataloader.MultimodalSeqsBatch
  35. ) -> Tuple[torch.Tensor, torch.Tensor]:
  36. """Forward pass, computes S2T and T2U losses"""
  37. assert self.model.t2u_model is not None
  38. assert batch.speech_to_text.src_tokens is not None
  39. # s2t
  40. speech_padding_mask = PaddingMask(
  41. seq_lens=batch.speech_to_text.src_lengths,
  42. batch_seq_len=int(torch.max(batch.speech_to_text.src_lengths).item()),
  43. )
  44. speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
  45. seqs=batch.speech_to_text.src_tokens,
  46. padding_mask=speech_padding_mask,
  47. )
  48. assert batch.speech_to_text.prev_output_tokens is not None
  49. s2t_prev_out_tokens_padding_mask = PaddingMask(
  50. seq_lens=batch.speech_to_text.target_lengths,
  51. batch_seq_len=int(torch.max(batch.speech_to_text.target_lengths).item()),
  52. )
  53. text_decoder_out, text_decoder_padding_mask = self.model.decode(
  54. seqs=batch.speech_to_text.prev_output_tokens,
  55. padding_mask=s2t_prev_out_tokens_padding_mask,
  56. encoder_output=speech_encoder_out,
  57. encoder_padding_mask=speech_encoder_padding_mask,
  58. )
  59. text_logits = self.model.final_proj(text_decoder_out)
  60. # t2u
  61. (
  62. unit_encoder_out,
  63. unit_encoder_padding_mask,
  64. ) = self.t2u.encode(
  65. text_decoder_output=text_decoder_out,
  66. text_decoder_padding_mask=text_decoder_padding_mask,
  67. )
  68. t2u_prev_out_tokens_padding_mask = PaddingMask(
  69. seq_lens=batch.text_to_units.target_lengths,
  70. batch_seq_len=int(torch.max(batch.text_to_units.target_lengths).item()),
  71. )
  72. unit_decoder_out, _ = self.t2u.decode(
  73. seqs=batch.text_to_units.prev_output_tokens,
  74. padding_mask=t2u_prev_out_tokens_padding_mask,
  75. encoder_output=unit_encoder_out,
  76. encoder_padding_mask=unit_encoder_padding_mask,
  77. )
  78. unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
  79. return (text_logits, unit_logits)
  80. class CalcLoss:
  81. """Calculates per-token negative log likelihood loss for S2T and T2U"""
  82. def __init__(
  83. self,
  84. label_smoothing: float,
  85. s2t_pad_idx: Optional[int],
  86. t2u_pad_idx: Optional[int],
  87. s2t_skip_langtok_loss: bool = False,
  88. ):
  89. self.label_smoothing = label_smoothing
  90. self.s2t_pad_idx = s2t_pad_idx
  91. self.t2u_pad_idx = t2u_pad_idx
  92. self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
  93. self.t2u_ignore_prefix_size = 1
  94. def __call__(
  95. self,
  96. batch: dataloader.MultimodalSeqsBatch,
  97. text_logits: torch.Tensor,
  98. unit_logits: torch.Tensor,
  99. ) -> torch.Tensor:
  100. assert batch.speech_to_text.target_lengths is not None
  101. s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
  102. text_logits.device
  103. )
  104. s2t_loss = SequenceModelOutput(
  105. logits=text_logits, pad_idx=self.s2t_pad_idx
  106. ).compute_loss(
  107. targets=batch.speech_to_text.target_tokens.to(text_logits.device),
  108. ignore_prefix_size=self.s2t_ignore_prefix_size,
  109. label_smoothing=self.label_smoothing,
  110. )
  111. assert batch.text_to_units.target_lengths is not None
  112. s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
  113. s2u_loss = SequenceModelOutput(
  114. logits=unit_logits, pad_idx=self.t2u_pad_idx
  115. ).compute_loss(
  116. targets=batch.text_to_units.target_tokens.to(unit_logits.device),
  117. ignore_prefix_size=1,
  118. label_smoothing=self.label_smoothing,
  119. )
  120. return s2t_loss / s2t_numel + s2u_loss / s2u_numel
  121. class LossCollector:
  122. """Aggregrates loss history across nodes"""
  123. def __init__(self, device: Optional[torch.device] = None, reduce_op: str = "avg"):
  124. self.n_samples: float = 0
  125. self.val_sum: float = 0.0
  126. self.reduce_op = reduce_op
  127. self.device = device
  128. self.is_distributed = dist_utils.is_dist_initialized()
  129. def reset(self) -> None:
  130. self.n_samples = 0
  131. self.val_sum = 0.0
  132. def update(self, n_samples: int, batch_loss: float) -> None:
  133. self.n_samples += n_samples
  134. self.val_sum += batch_loss
  135. def reduce(self) -> float:
  136. n_samples, val_sum = self._collect()
  137. if self.reduce_op == "avg":
  138. return val_sum / (n_samples + 1)
  139. if self.reduce_op == "sum":
  140. return val_sum
  141. raise ValueError()
  142. def _collect(self) -> Tuple[float, float]:
  143. if not self.is_distributed:
  144. return self.n_samples, self.val_sum
  145. local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
  146. all_vals = [
  147. torch.zeros((1, 2), device=self.device)
  148. for _ in range(dist_utils.get_world_size())
  149. ]
  150. dist.all_gather(all_vals, local_val)
  151. losses = torch.concat(all_vals, dim=0)
  152. reduced = torch.sum(losses, dim=0).reshape(2).cpu()
  153. return reduced[0].item(), reduced[1].item()
  154. class UnitYTrainer:
  155. CHECKPOINT_BEST = "checkpoint_best.pt"
  156. def __init__(
  157. self,
  158. model: UnitYModel,
  159. params: TrainingParams,
  160. train_data_loader: dataloader.UnityDataLoader,
  161. eval_data_loader: Optional[dataloader.UnityDataLoader],
  162. chck_save_dir: str,
  163. device: torch.device,
  164. ):
  165. self.params = params
  166. self.device = device
  167. self.float_dtype = self._get_float_dtype(self.params.float_dtype)
  168. self.train_data_loader = train_data_loader
  169. self.eval_data_loader = eval_data_loader
  170. self.chck_save_dir = chck_save_dir
  171. assert model.t2u_model is not None
  172. self.calc_loss = CalcLoss(
  173. label_smoothing=self.params.label_smoothing,
  174. s2t_pad_idx=model.pad_idx,
  175. t2u_pad_idx=model.t2u_model.pad_idx,
  176. )
  177. self._try_load_checkpoint(model=model)
  178. self.model = self._wrap_model_for_trainining(model=model)
  179. # TODO: make tweakable
  180. self.optimizer = Adam(
  181. params=self.model.parameters(),
  182. lr=self.params.learning_rate,
  183. betas=(0.9, 0.98),
  184. eps=1e-08,
  185. maximize=False,
  186. weight_decay=0.0,
  187. fused=True,
  188. )
  189. self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None # type: ignore
  190. # TODO: allow scheduler selection
  191. self.lr_scheduler = MyleLR(
  192. optimizer=self.optimizer,
  193. num_warmup_steps=self.params.warmup_steps,
  194. start_lr=self.params.start_learning_rate,
  195. )
  196. self.train_loss_hist = LossCollector(device=self.device)
  197. self.epoch_idx: int = 0
  198. self.update_idx: int = 0
  199. self.patience_left: int = self.params.patience
  200. self.last_eval_loss: Optional[float] = None
  201. self.best_eval_loss: Optional[float] = None
  202. self.is_best_state: bool = False
  203. self.batch_sizes: List[int] = []
  204. self.gpu_usage: List[float] = []
  205. def _try_load_checkpoint(self, model: torch.nn.Module):
  206. chck_path = self.get_best_checkpoint_path()
  207. if os.path.exists(chck_path):
  208. logger.info(f"Loading state dict from {chck_path}")
  209. state_dict = torch.load(chck_path)
  210. model.load_state_dict(state_dict)
  211. @classmethod
  212. def _get_float_dtype(cls, float_dtype: str) -> torch.dtype:
  213. if float_dtype == "fp16":
  214. return torch.float16
  215. elif float_dtype == "fp32":
  216. return torch.float32
  217. elif float_dtype == "bf16":
  218. return torch.bfloat16
  219. else:
  220. raise ValueError(f"Unkown dtype literal: {float_dtype}")
  221. def _reset_stats(self) -> None:
  222. self.train_loss_hist.reset()
  223. self.epoch_idx = 0
  224. self.update_idx = 0
  225. self.patience_left = self.params.patience
  226. self.last_eval_loss = None
  227. self.best_eval_loss = None
  228. self.is_best_state = False
  229. self._reset_log_stats()
  230. def _reset_log_stats(self) -> None:
  231. self.batch_sizes.clear()
  232. self.gpu_usage.clear()
  233. self.ts = time.time()
  234. self.last_update_idx = self.update_idx
  235. def _record_gpu_usage(self) -> None:
  236. gb = (torch.cuda.memory_reserved(self.device) >> 20) / 1024.0
  237. self.gpu_usage.append(gb)
  238. def _get_avg_bsz(self) -> float:
  239. """Avg training batch size"""
  240. return (
  241. sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
  242. )
  243. def _get_ups(self) -> float:
  244. """Updates per second"""
  245. ts_delta = time.time() - self.ts
  246. return (self.update_idx - self.last_update_idx) / ts_delta
  247. def _get_avg_gpu_usage(self) -> float:
  248. return sum(self.gpu_usage) / len(self.gpu_usage) if self.gpu_usage else 0.0
  249. def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
  250. wrapped_model = UnitYTrainWrapper(model=model)
  251. if not dist_utils.is_dist_initialized():
  252. return wrapped_model
  253. return nn.parallel.DistributedDataParallel(
  254. wrapped_model,
  255. device_ids=[dist_utils.get_local_rank()],
  256. find_unused_parameters=True,
  257. )
  258. def _update_eval_stats(self, eval_loss: float) -> None:
  259. self.last_eval_loss = eval_loss
  260. self.is_best_state = (
  261. self.best_eval_loss is None or eval_loss < self.best_eval_loss
  262. )
  263. self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
  264. self.patience_left = (
  265. self.params.patience if self.is_best_state else self.patience_left - 1
  266. )
  267. logger.info(
  268. f"Eval after {self.update_idx} updates: "
  269. f"loss={eval_loss:.4f} "
  270. f"best_loss={self.best_eval_loss:.4f} "
  271. f"patience_steps_left={self.patience_left}"
  272. )
  273. def _eval_model(self) -> None:
  274. """Calc avg loss on eval dataset and update evaluation stats"""
  275. if self.eval_data_loader is None:
  276. return
  277. logger.info("Run evaluation")
  278. loss_hist = LossCollector(device=self.device)
  279. self.model.eval()
  280. with torch.no_grad():
  281. self.eval_data_loader.reset()
  282. for batch in self.eval_data_loader.iterate_batches():
  283. assert batch.speech_to_text.src_tokens is not None
  284. loss = self.calc_loss(batch, *self.model(batch))
  285. if loss.isnan():
  286. logger.warning("Eval loss value is NaN, setting to inf")
  287. loss_val = float("Inf")
  288. else:
  289. loss_val = loss.item()
  290. self._release_memory(batch)
  291. loss_hist.update(1, loss_val)
  292. eval_loss = loss_hist.reduce()
  293. self._update_eval_stats(eval_loss)
  294. def _train_step_log(self):
  295. """Log train stats"""
  296. if (self.update_idx + 1) % self.params.log_steps == 0:
  297. avg_loss = self.train_loss_hist.reduce()
  298. self.train_loss_hist.reset()
  299. logger.info(
  300. f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
  301. f"update {str(self.update_idx + 1).zfill(5)}: "
  302. f"train loss={avg_loss:.4f} "
  303. f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E} "
  304. f"bsz_avg={self._get_avg_bsz():.1f} "
  305. f"ups={self._get_ups():.2f} "
  306. f"gpu_avg={self._get_avg_gpu_usage():.2f}Gb"
  307. )
  308. self._reset_log_stats()
  309. def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  310. """Run one train step"""
  311. self.model.train()
  312. self.optimizer.zero_grad()
  313. tokens, units = self.model(batch)
  314. loss = self.calc_loss(batch, tokens, units)
  315. # peak of gpu usage
  316. self._record_gpu_usage()
  317. if self.grad_scaler is not None:
  318. self.grad_scaler.scale(loss).backward() # type: ignore
  319. self.grad_scaler.step(self.optimizer)
  320. self.grad_scaler.update()
  321. else:
  322. loss.backward()
  323. self.optimizer.step()
  324. self.lr_scheduler.step()
  325. assert batch.speech_to_text.src_tokens is not None
  326. self.train_loss_hist.update(1, loss.item())
  327. self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
  328. self._train_step_log()
  329. self._release_memory(batch)
  330. def _release_memory(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  331. """Explicitly release large memory consumers"""
  332. del batch
  333. def _strip_state_key_prefixes(self, key: str) -> str:
  334. """Removes state_dict keys prefixes associated with model wrappers"""
  335. to_strip = ["module.", "model."]
  336. for prefix in to_strip:
  337. if key.startswith(prefix):
  338. key = key[len(prefix):]
  339. return key
  340. def _get_state(self) -> Dict[str, Any]:
  341. model_state_dict = self.model.state_dict()
  342. model_state_dict = {
  343. self._strip_state_key_prefixes(key): value
  344. for key, value in model_state_dict.items()
  345. }
  346. return model_state_dict
  347. def _get_chck_path(self) -> str:
  348. ts = str(int(time.time()))
  349. epoch = str(self.epoch_idx).zfill(3)
  350. update = str(self.update_idx).zfill(6)
  351. eval_loss = f"{self.last_eval_loss:.4f}"
  352. name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
  353. return os.path.join(self.chck_save_dir, name)
  354. def _get_best_checkpoint_link_path(self) -> str:
  355. return os.path.join(self.chck_save_dir, self.CHECKPOINT_BEST)
  356. def get_best_checkpoint_path(self) -> str:
  357. return os.path.realpath(self._get_best_checkpoint_link_path())
  358. def _save_model(self):
  359. if dist_utils.is_main_process():
  360. state_dict = self._get_state()
  361. save_path = self._get_chck_path()
  362. logger.info(f"Saving checkpoint to {save_path}")
  363. torch.save(state_dict, save_path)
  364. if self.is_best_state:
  365. best_link_path = self._get_best_checkpoint_link_path()
  366. if os.path.exists(best_link_path):
  367. os.unlink(best_link_path)
  368. os.symlink(save_path, best_link_path)
  369. logger.info(
  370. f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}"
  371. )
  372. if dist_utils.is_dist_initialized():
  373. dist.barrier()
  374. def run(self):
  375. logger.info("Start training")
  376. self._reset_stats()
  377. self._eval_model()
  378. while self.epoch_idx < self.params.max_epochs and self.patience_left:
  379. for train_batch in self.train_data_loader.iterate_batches():
  380. self._train_step(batch=train_batch)
  381. if self.update_idx and self.update_idx % self.params.eval_steps == 0:
  382. self._eval_model()
  383. if self.is_best_state:
  384. self._save_model()
  385. elif not self.patience_left:
  386. no_improve_steps = self.params.eval_steps * self.params.patience
  387. logger.info(
  388. f"Early termination, as eval loss did not improve over last {no_improve_steps} updates"
  389. )
  390. break
  391. self.update_idx += 1
  392. self.train_data_loader.reset()
  393. self.epoch_idx += 1