finetune.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. import os
  9. from pathlib import Path
  10. import torch
  11. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  12. from m4t_scripts.finetune import dataloader, dist_utils, trainer
  13. from seamless_communication.models.unity import (
  14. UnitTokenizer,
  15. UnitYModel,
  16. load_unity_model,
  17. load_unity_text_tokenizer,
  18. load_unity_unit_tokenizer,
  19. )
  20. logging.basicConfig(
  21. level=logging.INFO,
  22. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  23. )
  24. logger = logging.getLogger("finetune")
  25. def init_parser() -> argparse.ArgumentParser:
  26. parser = argparse.ArgumentParser(
  27. description="Example finetuning script for M4T models"
  28. )
  29. parser.add_argument(
  30. "--train_dataset",
  31. type=Path,
  32. required=True,
  33. help="Path to manifest with train samples",
  34. )
  35. parser.add_argument(
  36. "--eval_dataset",
  37. type=Path,
  38. required=True,
  39. help="Path to manifest with eval samples",
  40. )
  41. parser.add_argument(
  42. "--model_name",
  43. type=str,
  44. default="seamlessM4T_medium",
  45. help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
  46. )
  47. parser.add_argument(
  48. "--save_model_to",
  49. type=Path,
  50. required=True,
  51. help="Path to save best finetuned model",
  52. )
  53. parser.add_argument(
  54. "--seed",
  55. type=int,
  56. default=2343,
  57. help="Randomizer seed value",
  58. )
  59. parser.add_argument(
  60. "--batch_size",
  61. type=int,
  62. default=5,
  63. help="Batch size for training and evaluation",
  64. )
  65. parser.add_argument(
  66. "--patience",
  67. type=int,
  68. default=3,
  69. help=(
  70. "Set early termination after `patience` number of evaluations "
  71. "without eval loss improvements"
  72. ),
  73. )
  74. parser.add_argument(
  75. "--max_epochs",
  76. type=int,
  77. default=10,
  78. help=("Max number of training epochs"),
  79. )
  80. parser.add_argument(
  81. "--learning_rate",
  82. type=float,
  83. default=1e-7,
  84. help=("Finetuning learning rate"),
  85. )
  86. parser.add_argument(
  87. "--warmup_steps",
  88. type=int,
  89. default=100,
  90. help=("Number of steps with linearly increasing learning rate"),
  91. )
  92. parser.add_argument(
  93. "--eval_steps",
  94. type=int,
  95. default=50,
  96. help=("Get eval loss after each `eval_steps` training steps "),
  97. )
  98. parser.add_argument(
  99. "--log_steps",
  100. type=int,
  101. default=10,
  102. help=("Log inner loss after each `log_steps` training steps"),
  103. )
  104. parser.add_argument(
  105. "--mode",
  106. type=trainer.FinetuneMode,
  107. choices=list(trainer.FinetuneMode),
  108. default=trainer.FinetuneMode.SPEECH_TO_TEXT,
  109. help=(
  110. "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
  111. "* `TEXT_TO_SPEECH` -- finetune only T2U; "
  112. "* `SPEECH_TO_TEXT` -- finetune only S2T"
  113. ),
  114. )
  115. return parser
  116. def main() -> None:
  117. args = init_parser().parse_args()
  118. dist_utils.init_distributed([logger, trainer.logger])
  119. device = torch.device("cuda")
  120. text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
  121. unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
  122. finetune_params = trainer.FinetuneParams(
  123. finetune_mode=args.mode,
  124. save_model_path=args.save_model_to,
  125. device=device,
  126. train_batch_size=args.batch_size,
  127. eval_batch_size=args.batch_size,
  128. patience=args.patience,
  129. max_epochs=args.max_epochs,
  130. learning_rate=args.learning_rate,
  131. warmup_steps=args.warmup_steps,
  132. eval_steps=args.eval_steps,
  133. log_steps=args.log_steps,
  134. )
  135. logger.info(f"Finetune params: {finetune_params}")
  136. model: UnitYModel = load_unity_model(
  137. args.model_name, device=finetune_params.device, dtype=torch.float16
  138. )
  139. logger.info(f"Model {model}")
  140. assert model.pad_idx == text_tokenizer.vocab_info.pad_idx
  141. assert model.t2u_model is not None
  142. assert model.t2u_model.pad_idx == unit_tokenizer.vocab_info.pad_idx
  143. train_dataloader = dataloader.UnitYDataLoader(
  144. text_tokenizer=text_tokenizer,
  145. unit_tokenizer=unit_tokenizer,
  146. batching_config=dataloader.BatchingConfig(
  147. batch_size=finetune_params.train_batch_size,
  148. rank=dist_utils.get_rank(),
  149. world_size=dist_utils.get_world_size(),
  150. ),
  151. dataset_manifest_path=args.train_dataset,
  152. )
  153. eval_dataloader = dataloader.UnitYDataLoader(
  154. text_tokenizer=text_tokenizer,
  155. unit_tokenizer=unit_tokenizer,
  156. batching_config=dataloader.BatchingConfig(
  157. batch_size=finetune_params.eval_batch_size,
  158. rank=dist_utils.get_rank(),
  159. world_size=dist_utils.get_world_size(),
  160. ),
  161. dataset_manifest_path=args.eval_dataset,
  162. )
  163. finetune = trainer.UnitYFinetune(
  164. model=model,
  165. params=finetune_params,
  166. train_data_loader=train_dataloader,
  167. eval_data_loader=eval_dataloader,
  168. )
  169. finetune.run()
  170. if __name__ == "__main__":
  171. main()