trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. from contextlib import contextmanager
  8. from dataclasses import dataclass
  9. from enum import Enum
  10. from pathlib import Path
  11. from typing import Optional, Tuple
  12. import torch
  13. import torch.distributed as dist
  14. import torch.nn as nn
  15. from fairseq2.models.sequence import SequenceModelOutput
  16. from fairseq2.optim.lr_scheduler import MyleLR
  17. from fairseq2.typing import Device
  18. from m4t_scripts.finetune import dataloader, dist_utils
  19. from torch.optim import Adam
  20. from seamless_communication.models.unity import UnitYModel
  21. logger = logging.getLogger(__name__)
  22. class FinetuneMode(Enum):
  23. SPEECH_TO_SPEECH = "SPEECH_TO_SPEECH"
  24. SPEECH_TO_TEXT = "SPEECH_TO_TEXT"
  25. TEXT_TO_SPEECH = "TEXT_TO_SPEECH"
  26. @dataclass
  27. class FinetuneParams:
  28. save_model_path: Path
  29. """Path were to save finetuned model."""
  30. finetune_mode: FinetuneMode = FinetuneMode.TEXT_TO_SPEECH
  31. """Allows to freeze S2T or T2U part of the model"""
  32. max_epochs: int = 10
  33. """ Maximum number of trainign epochs"""
  34. label_smoothing: float = 0.2
  35. """ Label smoothing coefficient for nll_loss """
  36. warmup_steps: int = 100
  37. """ Number of steps with linearly increasing LR"""
  38. log_steps: int = 10
  39. """ Log inner loss after each `log_steps` training steps"""
  40. eval_steps: int = 50
  41. """ Get eval loss after each `eval_steps` training steps """
  42. patience: int = 3
  43. """ Terminate if eval loss did not improve
  44. over the last `patience * eval_steps` training steps"""
  45. learning_rate: float = 1e-5
  46. """ Optimizer learining rate """
  47. train_batch_size: int = 5
  48. """The batch size during train steps"""
  49. eval_batch_size: int = 5
  50. """The batch size during evaluation."""
  51. device: Device = torch.device("cuda")
  52. """ Where to run computation"""
  53. class UnitYFinetuneWrapper(nn.Module):
  54. """Convenience wrapper that does a forward pass
  55. and returns S2T and T2U logits"""
  56. def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
  57. super().__init__()
  58. assert model.t2u_model is not None
  59. self.model: UnitYModel = model
  60. self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH
  61. self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT
  62. self.device = device
  63. def forward(
  64. self, batch: dataloader.MultimodalSeqsBatch
  65. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  66. assert self.model.t2u_model is not None
  67. dummy_context = contextmanager(lambda: iter([None]))()
  68. with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
  69. assert batch.speech_to_text.src_tokens is not None
  70. speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
  71. seqs=batch.speech_to_text.src_tokens.to(self.device),
  72. seq_lens=batch.speech_to_text.src_lengths.to(self.device),
  73. )
  74. assert batch.speech_to_text.prev_output_tokens is not None
  75. text_decoder_out, text_decoder_padding_mask = self.model.decode(
  76. seqs=batch.speech_to_text.prev_output_tokens.to(self.device),
  77. seq_lens=batch.speech_to_text.target_lengths.to(self.device),
  78. encoder_output=speech_encoder_out,
  79. encoder_padding_mask=speech_encoder_padding_mask,
  80. )
  81. text_logits = self.model.final_proj(text_decoder_out)
  82. if batch.text_to_units.prev_output_tokens is None:
  83. return (text_logits, None)
  84. dummy_context = contextmanager(lambda: iter([None]))()
  85. with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
  86. (
  87. unit_encoder_out,
  88. unit_encoder_padding_mask,
  89. ) = self.model.t2u_model.encode(
  90. text_decoder_output=text_decoder_out,
  91. text_decoder_padding_mask=text_decoder_padding_mask,
  92. )
  93. unit_decoder_out, _ = self.model.t2u_model.decode(
  94. seqs=batch.text_to_units.prev_output_tokens.to(self.device),
  95. seq_lens=batch.text_to_units.target_lengths.to(self.device),
  96. encoder_output=unit_encoder_out,
  97. encoder_padding_mask=unit_encoder_padding_mask,
  98. )
  99. unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
  100. return (text_logits, unit_logits)
  101. class CalcLoss:
  102. """Calculates negative log likelihood loss for S2T and T2U"""
  103. def __init__(
  104. self,
  105. label_smoothing: float,
  106. s2t_pad_idx: Optional[int],
  107. t2u_pad_idx: Optional[int],
  108. ):
  109. self.label_smoothing = label_smoothing
  110. self.s2t_pad_idx = s2t_pad_idx
  111. self.t2u_pad_idx = t2u_pad_idx
  112. def __call__(
  113. self,
  114. batch: dataloader.MultimodalSeqsBatch,
  115. text_logits: torch.Tensor,
  116. unit_logits: Optional[torch.Tensor],
  117. ) -> torch.Tensor:
  118. assert batch.speech_to_text.target_lengths is not None
  119. s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(
  120. text_logits.device
  121. )
  122. s2t_loss = SequenceModelOutput(
  123. logits=text_logits, pad_idx=self.s2t_pad_idx
  124. ).compute_loss(
  125. targets=batch.speech_to_text.target_tokens.to(text_logits.device),
  126. ignore_prefix_size=1,
  127. label_smoothing=self.label_smoothing,
  128. )
  129. if unit_logits is None:
  130. return s2t_loss / s2t_numel
  131. assert batch.text_to_units.target_lengths is not None
  132. s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
  133. s2u_loss = SequenceModelOutput(
  134. logits=unit_logits, pad_idx=self.t2u_pad_idx
  135. ).compute_loss(
  136. targets=batch.text_to_units.target_tokens.to(unit_logits.device),
  137. ignore_prefix_size=1,
  138. label_smoothing=self.label_smoothing,
  139. )
  140. return s2t_loss / s2t_numel + s2u_loss / s2u_numel
  141. class LossCollector:
  142. """Aggregrates loss history across nodes"""
  143. def __init__(self, device: Optional[Device] = None, reduce_op: str = "avg"):
  144. self.n_samples: float = 0
  145. self.val_sum: float = 0.0
  146. self.reduce_op = reduce_op
  147. self.device = device
  148. self.is_distributed = dist_utils.is_dist_initialized()
  149. def reset(self) -> None:
  150. self.n_samples = 0
  151. self.val_sum = 0.0
  152. def update(self, n_samples: int, batch_loss: float) -> None:
  153. self.n_samples += n_samples
  154. self.val_sum += batch_loss
  155. def reduce(self) -> float:
  156. n_samples, val_sum = self._collect()
  157. if self.reduce_op == "avg":
  158. return val_sum / (n_samples + 1)
  159. if self.reduce_op == "sum":
  160. return val_sum
  161. raise ValueError()
  162. def _collect(self) -> Tuple[float, float]:
  163. if not self.is_distributed:
  164. return self.n_samples, self.val_sum
  165. local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
  166. all_vals = [
  167. torch.zeros((1, 2), device=self.device)
  168. for _ in range(dist_utils.get_world_size())
  169. ]
  170. dist.all_gather(all_vals, local_val)
  171. losses = torch.concat(all_vals, dim=0)
  172. reduced = torch.sum(losses, dim=0).reshape(2).cpu()
  173. return reduced[0].item(), reduced[1].item()
  174. class UnitYFinetune:
  175. def __init__(
  176. self,
  177. model: UnitYModel,
  178. params: FinetuneParams,
  179. train_data_loader: dataloader.UnitYDataLoader,
  180. eval_data_loader: Optional[dataloader.UnitYDataLoader] = None,
  181. ):
  182. self.params = params
  183. assert model.t2u_model is not None
  184. self.calc_loss = CalcLoss(
  185. label_smoothing=self.params.label_smoothing,
  186. s2t_pad_idx=model.pad_idx,
  187. t2u_pad_idx=model.t2u_model.pad_idx,
  188. )
  189. self.model = self._wrap_model_for_trainining(model=model)
  190. self.train_data_loader = train_data_loader
  191. self.eval_data_loader = eval_data_loader
  192. self.optimizer = Adam(
  193. params=self.model.parameters(),
  194. lr=self.params.learning_rate,
  195. betas=(0.9, 0.98),
  196. eps=1e-08,
  197. maximize=False,
  198. weight_decay=0.0,
  199. fused=True,
  200. )
  201. self.grad_scaler = torch.cuda.amp.GradScaler()
  202. self.lr_scheduler = MyleLR(
  203. optimizer=self.optimizer,
  204. num_warmup_steps=self.params.warmup_steps,
  205. start_lr=1e-9,
  206. )
  207. self.train_loss_hist = LossCollector(device=params.device)
  208. self.epoch_idx: int = 0
  209. self.update_idx: int = 0
  210. self.patience_left: int = self.params.patience
  211. self.best_eval_loss: Optional[float] = None
  212. self.is_best_state: bool = False
  213. def _reset_stats(self) -> None:
  214. self.train_loss_hist.reset()
  215. self.epoch_idx = 0
  216. self.update_idx = 0
  217. self.patience_left = self.params.patience
  218. self.best_eval_loss = None
  219. self.is_best_state = False
  220. def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
  221. wrapped_model = UnitYFinetuneWrapper(
  222. model=model, mode=self.params.finetune_mode, device=self.params.device
  223. )
  224. if not dist_utils.is_dist_initialized():
  225. return wrapped_model
  226. return nn.parallel.DistributedDataParallel(
  227. wrapped_model,
  228. device_ids=[dist_utils.get_local_rank()],
  229. find_unused_parameters=True,
  230. )
  231. def _update_eval_stats(self, eval_loss: float) -> None:
  232. self.is_best_state = (
  233. self.best_eval_loss is None or eval_loss < self.best_eval_loss
  234. )
  235. self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
  236. self.patience_left = (
  237. self.params.patience if self.is_best_state else self.patience_left - 1
  238. )
  239. logger.info(
  240. f"Eval after {self.update_idx} updates: "
  241. f"loss={eval_loss:.4f} "
  242. f"best_loss={self.best_eval_loss:.4f} "
  243. f"patience_steps_left={self.patience_left}"
  244. )
  245. def _eval_model(self) -> None:
  246. """Calc avg loss on eval dataset and update evaluation stats"""
  247. if self.eval_data_loader is None:
  248. return
  249. logger.info("Run evaluation")
  250. loss_hist = LossCollector(device=self.params.device)
  251. self.model.eval()
  252. with torch.no_grad():
  253. for batch in self.eval_data_loader.get_dataloader():
  254. assert batch.speech_to_text.src_tokens is not None
  255. loss = self.calc_loss(batch, *self.model(batch))
  256. if loss.isnan():
  257. logger.warning("Eval loss value is NaN, setting to inf")
  258. loss_val = float("Inf")
  259. else:
  260. loss_val = loss.item()
  261. del batch # force memory release
  262. loss_hist.update(1, loss_val)
  263. eval_loss = loss_hist.reduce()
  264. self._update_eval_stats(eval_loss)
  265. def _train_step_log(self):
  266. """Log train stats"""
  267. if (self.update_idx + 1) % self.params.log_steps == 0:
  268. avg_loss = self.train_loss_hist.reduce()
  269. self.train_loss_hist.reset()
  270. logger.info(
  271. f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
  272. f"update {str(self.update_idx + 1).zfill(5)}: "
  273. f"train loss={avg_loss:.4f} "
  274. f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}"
  275. )
  276. def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
  277. """Run one train step"""
  278. self.model.train()
  279. self.optimizer.zero_grad()
  280. tokens, units = self.model(batch)
  281. loss = self.calc_loss(batch, tokens, units)
  282. self.grad_scaler.scale(loss).backward()
  283. self.grad_scaler.step(self.optimizer)
  284. self.grad_scaler.update()
  285. self.lr_scheduler.step()
  286. assert batch.speech_to_text.src_tokens is not None
  287. self.train_loss_hist.update(1, loss.item())
  288. self._train_step_log()
  289. def _save_model(self):
  290. logger.info("Saving model")
  291. if dist_utils.is_main_process():
  292. state_dict = {
  293. key.replace("module.model.", ""): value
  294. for key, value in self.model.state_dict().items()
  295. }
  296. torch.save(state_dict, self.params.save_model_path)
  297. if dist_utils.is_dist_initialized():
  298. dist.barrier()
  299. def run(self):
  300. logger.info("Start finetuning")
  301. self._reset_stats()
  302. self._eval_model()
  303. batch_itr = self.train_data_loader.get_dataloader()
  304. while self.epoch_idx < self.params.max_epochs and self.patience_left:
  305. for train_batch in batch_itr:
  306. self._train_step(batch=train_batch)
  307. if self.update_idx and self.update_idx % self.params.eval_steps == 0:
  308. self._eval_model()
  309. if self.is_best_state:
  310. self._save_model()
  311. elif not self.patience_left:
  312. no_improve_steps = self.params.eval_steps * self.params.patience
  313. logger.info(
  314. "Early termination, as eval loss did not improve "
  315. f"over last {no_improve_steps} updates"
  316. )
  317. break
  318. self.update_idx += 1
  319. self.epoch_idx += 1