trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. from contextlib import contextmanager
  8. from dataclasses import dataclass
  9. from enum import Enum
  10. from pathlib import Path
  11. from typing import Optional, Tuple
  12. import dataloader
  13. import dist_utils
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. from fairseq2.models.sequence import SequenceModelOutput
  18. from fairseq2.models.unity import UnitYModel
  19. from fairseq2.optim.lr_scheduler import MyleLR
  20. from fairseq2.typing import Device
  21. from torch.optim import Adam
  22. logger = logging.getLogger(__name__)
  23. class FinetuneMode(Enum):
  24. SPEECH_TO_SPEECH = "SPEECH_TO_SPEECH"
  25. SPEECH_TO_TEXT = "SPEECH_TO_TEXT"
  26. TEXT_TO_SPEECH = "TEXT_TO_SPEECH"
  27. @dataclass
  28. class FinetuneParams:
  29. save_model_path: Path
  30. """Path were to save finetuned model."""
  31. finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH
  32. """Allows to freeze S2T or T2U part of the model"""
  33. max_epochs: int = 10
  34. """ Maximum number of trainign epochs"""
  35. label_smoothing: float = 0.2
  36. """ Label smoothing coefficient for nll_loss """
  37. warmup_steps: int = 100
  38. """ Number of steps with linearly increasing LR"""
  39. log_steps: int = 10
  40. """ Log inner loss after each `log_steps` training steps"""
  41. eval_steps: int = 50
  42. """ Get eval loss after each `eval_steps` training steps """
  43. patience: int = 3
  44. """ Terminate if eval loss did not improve
  45. over the last `patience * eval_steps` training steps"""
  46. learning_rate: float = 1e-5
  47. """ Optimizer learining rate """
  48. train_batch_size: int = 5
  49. """The batch size during train steps"""
  50. eval_batch_size: int = 5
  51. """The batch size during evaluation."""
  52. device: Device = torch.device("cuda")
  53. """ Where to run computation"""
  54. class UnitYFinetuneWrapper(nn.Module):
  55. """Convenience wrapper that does a forward pass
  56. and returns S2T and T2U logits"""
  57. def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
  58. super().__init__()
  59. assert model.t2u_model is not None
  60. self.model: UnitYModel = model
  61. self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
  62. self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
  63. self.device = device
  64. def forward(
  65. self, batch: dataloader.MultimodalSeqsBatch
  66. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  67. assert self.model.t2u_model is not None
  68. dummy_context = contextmanager(lambda: iter([None]))()
  69. with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
  70. assert batch.speech_to_text.src_tokens is not None
  71. speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
  72. seqs=batch.speech_to_text.src_tokens.to(self.device),
  73. seq_lens=batch.speech_to_text.src_lengths.to(self.device),
  74. )
  75. assert batch.speech_to_text.prev_output_tokens is not None
  76. text_decoder_out, text_decoder_padding_mask = self.model.decode(
  77. seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
  78. seq_lens=batch.speech_to_text.target_lengths.to(self.device),
  79. encoder_output=speech_encoder_out,
  80. encoder_padding_mask=speech_encoder_padding_mask,
  81. )
  82. text_logits = self.model.final_proj(text_decoder_out)
  83. if batch.text_to_units.prev_output_tokens is None:
  84. return (text_logits, None)
  85. dummy_context = contextmanager(lambda: iter([None]))()
  86. with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
  87. (
  88. unit_encoder_out,
  89. unit_encoder_padding_mask,
  90. ) = self.model.t2u_model.encode(
  91. text_decoder_output=text_decoder_out,
  92. text_decoder_padding_mask=text_decoder_padding_mask,
  93. )
  94. unit_decoder_out, _ = self.model.t2u_model.decode(
  95. seqs=batch.text_to_units.prev_output_tokens.to(self.device),
  96. seq_lens=batch.text_to_units.target_lengths.to(self.device),
  97. encoder_output=unit_encoder_out,
  98. encoder_padding_mask=unit_encoder_padding_mask,
  99. )
  100. unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
  101. return (text_logits, unit_logits)
  102. class CalcLoss:
  103. """Calculates negative log likelihood loss for S2T and T2U"""
  104. def __init__(
  105. self,
  106. label_smoothing: float,
  107. s2t_pad_idx: Optional[int],
  108. t2u_pad_idx: Optional[int],
  109. ):
  110. self.label_smoothing = label_smoothing
  111. self.s2t_pad_idx = s2t_pad_idx
  112. self.t2u_pad_idx = t2u_pad_idx
  113. def __call__(
  114. self,
  115. batch: dataloader.MultimodalSeqsBatch,
  116. text_logits: torch.Tensor,
  117. unit_logits: Optional[torch.Tensor],
  118. ) -> torch.Tensor:
  119. assert batch.speech_to_text.target_lengths is not None
  120. s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
  121. text_logits.device
  122. )
  123. s2t_loss = SequenceModelOutput(
  124. logits=text_logits, pad_idx=self.s2t_pad_idx
  125. ).compute_loss(
  126. targets=batch.speech_to_text.target_tokens.to(text_logits.device),
  127. ignore_prefix_size=1,
  128. label_smoothing=self.label_smoothing,
  129. )
  130. if unit_logits is None:
  131. return s2t_loss / s2t_numel
  132. assert batch.text_to_units.target_lengths is not None
  133. s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
  134. s2u_loss = SequenceModelOutput(
  135. logits=unit_logits, pad_idx=self.t2u_pad_idx
  136. ).compute_loss(
  137. targets=batch.text_to_units.target_tokens.to(unit_logits.device),
  138. ignore_prefix_size=1,
  139. label_smoothing=self.label_smoothing,
  140. )
  141. return s2t_loss / s2t_numel + s2u_loss / s2u_numel
  142. class LossCollector:
  143. """Aggregrates loss history across nodes"""
  144. def __init__(self, device: Optional[Device] = None, reduce_op: str = "avg"):
  145. self.n_samples: float = 0
  146. self.val_sum: float = 0.0
  147. self.reduce_op = reduce_op
  148. self.device = device
  149. self.is_distributed = dist_utils.is_dist_initialized()
  150. def reset(self) -> None:
  151. self.n_samples = 0
  152. self.val_sum = 0.0
  153. def update(self, n_samples: int, batch_loss: float) -> None:
  154. self.n_samples += n_samples
  155. self.val_sum += batch_loss
  156. def reduce(self) -> float:
  157. n_samples, val_sum = self._collect()
  158. if self.reduce_op == "avg":
  159. return val_sum / (n_samples + 1)
  160. if self.reduce_op == "sum":
  161. return val_sum
  162. raise ValueError()
  163. def _collect(self) -> Tuple[float, float]:
  164. if not self.is_distributed:
  165. return self.n_samples, self.val_sum
  166. local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
  167. all_vals = [
  168. torch.zeros((1, 2), device=self.device)
  169. for _ in range(dist_utils.get_world_size())
  170. ]
  171. dist.all_gather(all_vals, local_val)
  172. losses = torch.concat(all_vals, dim=0)
  173. reduced = torch.sum(losses, dim=0).reshape(2).cpu()
  174. return reduced[0].item(), reduced[1].item()
  175. class UnitYFinetune:
  176. def __init__(
  177. self,
  178. model: UnitYModel,
  179. params: FinetuneParams,
  180. train_data_loader: dataloader.UnitYDataLoader,
  181. eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
  182. ):
  183. self.params = params
  184. assert model.t2u_model is not None
  185. self.calc_loss = CalcLoss(
  186. label_smoothing=self.params.label_smoothing,
  187. s2t_pad_idx=model.pad_idx,
  188. t2u_pad_idx=model.t2u_model.pad_idx,
  189. )
  190. self.model = self._wrap_model_for_trainining(model=model)
  191. self.train_data_loader = train_data_loader
  192. self.eval_data_loader = eval_data_loader
  193. self.optimizer = Adam(
  194. params=self.model.parameters(),
  195. lr=self.params.learning_rate,
  196. betas=(0.9, 0.98),
  197. eps=1e-08,
  198. maximize=False,
  199. weight_decay=0.0,
  200. fused=True,
  201. )
  202. self.grad_scaler = torch.cuda.amp.GradScaler()
  203. self.lr_scheduler = MyleLR(
  204. optimizer=self.optimizer,
  205. num_warmup_steps=self.params.warmup_steps,
  206. start_lr=1e-9,
  207. )
  208. self.train_loss_hist = LossCollector(device=params.device)
  209. self.epoch_idx: int = 0
  210. self.update_idx: int = 0
  211. self.patience_left: int = self.params.patience
  212. self.best_eval_loss: Optional[float] = None
  213. self.is_best_state: bool = False
  214. def _reset_stats(self) -> None:
  215. self.train_loss_hist.reset()
  216. self.epoch_idx = 0
  217. self.update_idx = 0
  218. self.patience_left = self.params.patience
  219. self.best_eval_loss = None
  220. self.is_best_state = False
  221. def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
  222. wrapped_model = UnitYFinetuneWrapper(
  223. model=model, mode=self.params.finetune_mode, device=self.params.device
  224. )
  225. if not dist_utils.is_dist_initialized():
  226. return wrapped_model
  227. return nn.parallel.DistributedDataParallel(
  228. wrapped_model,
  229. device_ids=[dist_utils.get_local_rank()],
  230. find_unused_parameters=True,
  231. )
  232. def _update_eval_stats(self, eval_loss: float) -> None:
  233. self.is_best_state = (
  234. self.best_eval_loss is None or eval_loss < self.best_eval_loss
  235. )
  236. self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
  237. self.patience_left = (
  238. self.params.patience if self.is_best_state else self.patience_left - 1
  239. )
  240. logger.info(
  241. f"Eval after {self.update_idx} updates: "
  242. f"loss={eval_loss:.4f} "
  243. f"best_loss={self.best_eval_loss:.4f} "
  244. f"patience_steps_left={self.patience_left}"
  245. )
  246. def _eval_model(self) -> None:
  247. """Calc avg loss on eval dataset and update evaluation stats"""
  248. if self.eval_data_loader is None:
  249. return
  250. logger.info("Run evaluation")
  251. loss_hist = LossCollector(device=self.params.device)
  252. self.model.eval()
  253. with torch.no_grad():
  254. for batch in self.eval_data_loader.get_dataloader():
  255. assert batch.speech_to_text.src_tokens is not None
  256. loss = self.calc_loss(batch, *self.model(batch))
  257. if loss.isnan():
  258. logger.warning("Eval loss value is NaN, setting to inf")
  259. loss_val = float("Inf")
  260. else:
  261. loss_val = loss.item()
  262. del batch # force memory release
  263. loss_hist.update(1, loss_val)
  264. eval_loss = loss_hist.reduce()
  265. self._update_eval_stats(eval_loss)
  266. def _train_step_log(self):
  267. """Log train stats"""
  268. if (self.update_idx + 1) % self.params.log_steps == 0:
  269. avg_loss = self.train_loss_hist.reduce()
  270. self.train_loss_hist.reset()
  271. logger.info(
  272. f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
  273. f"update {str(self.update_idx + 1).zfill(5)}: "
  274. f"train loss={avg_loss:.4f} "
  275. f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
  276. )
  277. def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  278. """Run one train step"""
  279. self.model.train()
  280. self.optimizer.zero_grad()
  281. tokens, units = self.model(batch)
  282. loss = self.calc_loss(batch, tokens, units)
  283. self.grad_scaler.scale(loss).backward()
  284. self.grad_scaler.step(self.optimizer)
  285. self.grad_scaler.update()
  286. self.lr_scheduler.step()
  287. assert batch.speech_to_text.src_tokens is not None
  288. self.train_loss_hist.update(1, loss.item())
  289. self._train_step_log()
  290. def _save_model(self):
  291. logger.info("Saving model")
  292. if dist_utils.is_main_process():
  293. state_dict = {
  294. key.replace("module.model.", ""): value
  295. for key, value in self.model.state_dict().items()
  296. }
  297. torch.save(state_dict, self.params.save_model_path)
  298. if dist_utils.is_dist_initialized():
  299. dist.barrier()
  300. def run(self):
  301. logger.info("Start finetuning")
  302. self._reset_stats()
  303. self._eval_model()
  304. batch_itr = self.train_data_loader.get_dataloader()
  305. while self.epoch_idx < self.params.max_epochs and self.patience_left:
  306. for train_batch in batch_itr:
  307. self._train_step(batch=train_batch)
  308. if self.update_idx and self.update_idx % self.params.eval_steps == 0:
  309. self._eval_model()
  310. if self.is_best_state:
  311. self._save_model()
  312. elif not self.patience_left:
  313. no_improve_steps = self.params.eval_steps * self.params.patience
  314. logger.info(
  315. "Early termination, as eval loss did not improve "
  316. f"over last {no_improve_steps} updates"
  317. )
  318. break
  319. self.update_idx += 1
  320. self.epoch_idx += 1