finetune.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. import os
  9. from pathlib import Path
  10. import dataloader
  11. import dist_utils
  12. import torch
  13. import trainer
  14. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  15. from seamless_communication.models.unity import (
  16. UnitTokenizer,
  17. UnitYModel,
  18. load_unity_model,
  19. load_unity_text_tokenizer,
  20. load_unity_unit_tokenizer,
  21. )
  22. logging.basicConfig(
  23. level=logging.INFO,
  24. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  25. )
  26. logger = logging.getLogger("finetune")
  27. def init_parser() -> argparse.ArgumentParser:
  28. parser = argparse.ArgumentParser(
  29. description="Example finetuning script for M4T models"
  30. )
  31. parser.add_argument(
  32. "--train_dataset",
  33. type=Path,
  34. required=True,
  35. help="Path to manifest with train samples",
  36. )
  37. parser.add_argument(
  38. "--eval_dataset",
  39. type=Path,
  40. required=True,
  41. help="Path to manifest with eval samples",
  42. )
  43. parser.add_argument(
  44. "--model_name",
  45. type=str,
  46. default="seamlessM4T_medium",
  47. help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
  48. )
  49. parser.add_argument(
  50. "--save_model_to",
  51. type=Path,
  52. required=True,
  53. help="Path to save best finetuned model",
  54. )
  55. parser.add_argument(
  56. "--seed",
  57. type=int,
  58. default=2343,
  59. help="Randomizer seed value",
  60. )
  61. parser.add_argument(
  62. "--batch_size",
  63. type=int,
  64. default=5,
  65. help="Batch size for training and evaluation",
  66. )
  67. parser.add_argument(
  68. "--patience",
  69. type=int,
  70. default=3,
  71. help=(
  72. "Set early termination after `patience` number of evaluations "
  73. "without eval loss improvements"
  74. ),
  75. )
  76. parser.add_argument(
  77. "--max_epochs",
  78. type=int,
  79. default=10,
  80. help=("Max number of training epochs"),
  81. )
  82. parser.add_argument(
  83. "--learning_rate",
  84. type=float,
  85. default=1e-7,
  86. help=("Finetuning learning rate"),
  87. )
  88. parser.add_argument(
  89. "--warmup_steps",
  90. type=int,
  91. default=100,
  92. help=("Number of steps with linearly increasing learning rate"),
  93. )
  94. parser.add_argument(
  95. "--eval_steps",
  96. type=int,
  97. default=50,
  98. help=("Get eval loss after each `eval_steps` training steps "),
  99. )
  100. parser.add_argument(
  101. "--log_steps",
  102. type=int,
  103. default=10,
  104. help=("Log inner loss after each `log_steps` training steps"),
  105. )
  106. parser.add_argument(
  107. "--mode",
  108. type=trainer.FinetuneMode,
  109. choices=list(trainer.FinetuneMode),
  110. default=trainer.FinetuneMode.TEXT_TO_SPEECH,
  111. help=(
  112. "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
  113. "* `TEXT_TO_SPEECH` -- finetune only T2U; "
  114. "* `SPEECH_TO_TEXT` -- finetune only S2T"
  115. ),
  116. )
  117. return parser
  118. def main() -> None:
  119. args = init_parser().parse_args()
  120. dist_utils.init_distributed([logger, trainer.logger])
  121. device = torch.device("cuda")
  122. text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
  123. unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
  124. finetune_params = trainer.FinetuneParams(
  125. finetune_mode=args.mode,
  126. save_model_path=args.save_model_to,
  127. device=device,
  128. train_batch_size=args.batch_size,
  129. eval_batch_size=args.batch_size,
  130. patience=args.patience,
  131. max_epochs=args.max_epochs,
  132. learning_rate=args.learning_rate,
  133. warmup_steps=args.warmup_steps,
  134. eval_steps=args.eval_steps,
  135. log_steps=args.log_steps,
  136. )
  137. logger.info(f"Finetune params: {finetune_params}")
  138. model: UnitYModel = load_unity_model(
  139. args.model_name, device=finetune_params.device, dtype=torch.float16
  140. )
  141. logger.info(f"Model {model}")
  142. assert model.pad_idx == text_tokenizer.vocab_info.pad_idx
  143. assert model.t2u_model is not None
  144. assert model.t2u_model.pad_idx == unit_tokenizer.vocab_info.pad_idx
  145. train_dataloader = dataloader.UnitYDataLoader(
  146. text_tokenizer=text_tokenizer,
  147. unit_tokenizer=unit_tokenizer,
  148. batching_config=dataloader.BatchingConfig(
  149. batch_size=finetune_params.train_batch_size,
  150. rank=dist_utils.get_rank(),
  151. world_size=dist_utils.get_world_size(),
  152. ),
  153. dataset_manifest_path=args.train_dataset,
  154. )
  155. eval_dataloader = dataloader.UnitYDataLoader(
  156. text_tokenizer=text_tokenizer,
  157. unit_tokenizer=unit_tokenizer,
  158. batching_config=dataloader.BatchingConfig(
  159. batch_size=finetune_params.eval_batch_size,
  160. rank=dist_utils.get_rank(),
  161. world_size=dist_utils.get_world_size(),
  162. ),
  163. dataset_manifest_path=args.eval_dataset,
  164. )
  165. finetune = trainer.UnitYFinetune(
  166. model=model,
  167. params=finetune_params,
  168. train_data_loader=train_dataloader,
  169. eval_data_loader=eval_dataloader,
  170. )
  171. finetune.run()
  172. if __name__ == "__main__":
  173. main()