finetune.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. import os
  9. from argparse import Namespace
  10. from pathlib import Path
  11. import dataloader
  12. import dist_utils
  13. import torch
  14. import trainer
  15. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  16. from seamless_communication.models.unity import (
  17. UnitTokenizer,
  18. UnitYModel,
  19. load_unity_model,
  20. load_unity_text_tokenizer,
  21. load_unity_unit_tokenizer,
  22. )
  23. logging.basicConfig(
  24. level=logging.INFO,
  25. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  26. )
  27. logger = logging.getLogger("finetune")
  28. def init_parser() -> argparse.ArgumentParser:
  29. parser = argparse.ArgumentParser(
  30. description="Example finetuning script for M4T models"
  31. )
  32. parser.add_argument(
  33. "--train_dataset",
  34. type=Path,
  35. required=True,
  36. help="Path to manifest with train samples",
  37. )
  38. parser.add_argument(
  39. "--eval_dataset",
  40. type=Path,
  41. required=True,
  42. help="Path to manifest with eval samples",
  43. )
  44. parser.add_argument(
  45. "--model_name",
  46. type=str,
  47. default="seamlessM4T_medium",
  48. help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
  49. )
  50. parser.add_argument(
  51. "--save_model_to",
  52. type=Path,
  53. required=True,
  54. help="Path to save best finetuned model",
  55. )
  56. parser.add_argument(
  57. "--seed",
  58. type=int,
  59. default=2343,
  60. help="Randomizer seed value",
  61. )
  62. parser.add_argument(
  63. "--batch_size",
  64. type=int,
  65. default=5,
  66. help="Batch size for training and evaluation",
  67. )
  68. parser.add_argument(
  69. "--patience",
  70. type=int,
  71. default=3,
  72. help=(
  73. "Set early termination after `patience` number of evaluations "
  74. "without eval loss improvements"
  75. ),
  76. )
  77. parser.add_argument(
  78. "--max_epochs",
  79. type=int,
  80. default=10,
  81. help=("Max number of training epochs"),
  82. )
  83. parser.add_argument(
  84. "--learning_rate",
  85. type=float,
  86. default=1e-7,
  87. help=("Finetuning learning rate"),
  88. )
  89. parser.add_argument(
  90. "--warmup_steps",
  91. type=int,
  92. default=100,
  93. help=("Number of steps with linearly increasing learning rate"),
  94. )
  95. parser.add_argument(
  96. "--eval_steps",
  97. type=int,
  98. default=50,
  99. help=("Get eval loss after each `eval_steps` training steps "),
  100. )
  101. parser.add_argument(
  102. "--log_steps",
  103. type=int,
  104. default=10,
  105. help=("Log inner loss after each `log_steps` training steps"),
  106. )
  107. parser.add_argument(
  108. "--mode",
  109. type=trainer.FinetuneMode,
  110. choices=list(trainer.FinetuneMode),
  111. default=trainer.FinetuneMode.TEXT_TO_SPEECH,
  112. help=(
  113. "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
  114. "* `TEXT_TO_SPEECH` -- finetune only T2U; "
  115. "* `SPEECH_TO_TEXT` -- finetune only S2T"
  116. ),
  117. )
  118. return parser
  119. def run_finetune(args: Namespace) -> None:
  120. dist_utils.init_distributed([logger, trainer.logger])
  121. device = torch.device("cuda")
  122. text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
  123. unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
  124. finetune_params = trainer.FinetuneParams(
  125. finetune_mode=args.mode,
  126. save_model_path=args.save_model_to,
  127. device=device,
  128. train_batch_size=args.batch_size,
  129. eval_batch_size=args.batch_size,
  130. patience=args.patience,
  131. max_epochs=args.max_epochs,
  132. learning_rate=args.learning_rate,
  133. warmup_steps=args.warmup_steps,
  134. eval_steps=args.eval_steps,
  135. log_steps=args.log_steps,
  136. )
  137. logger.info(f"Finetune params: {finetune_params}")
  138. model: UnitYModel = load_unity_model(
  139. args.model_name, device=finetune_params.device, dtype=torch.float16
  140. )
  141. logger.info(f"Model {model}")
  142. assert model.pad_idx == text_tokenizer.vocab_info.pad_idx
  143. assert model.t2u_model is not None
  144. assert model.t2u_model.pad_idx == unit_tokenizer.vocab_info.pad_idx
  145. train_dataloader = dataloader.UnitYDataLoader(
  146. text_tokenizer=text_tokenizer,
  147. unit_tokenizer=unit_tokenizer,
  148. batching_config=dataloader.BatchingConfig(
  149. batch_size=finetune_params.train_batch_size,
  150. rank=dist_utils.get_rank(),
  151. world_size=dist_utils.get_world_size(),
  152. ),
  153. dataset_manifest_path=args.train_dataset,
  154. )
  155. eval_dataloader = dataloader.UnitYDataLoader(
  156. text_tokenizer=text_tokenizer,
  157. unit_tokenizer=unit_tokenizer,
  158. batching_config=dataloader.BatchingConfig(
  159. batch_size=finetune_params.eval_batch_size,
  160. rank=dist_utils.get_rank(),
  161. world_size=dist_utils.get_world_size(),
  162. ),
  163. dataset_manifest_path=args.eval_dataset,
  164. )
  165. finetune = trainer.UnitYFinetune(
  166. model=model,
  167. params=finetune_params,
  168. train_data_loader=train_dataloader,
  169. eval_data_loader=eval_dataloader,
  170. )
  171. finetune.run()
  172. if __name__ == "__main__":
  173. parser = init_parser()
  174. run_finetune(parser.parse_args())