|
@@ -119,20 +119,25 @@ def init_parser() -> argparse.ArgumentParser:
|
|
|
"* `SPEECH_TO_TEXT` -- finetune only S2T"
|
|
|
),
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ default="cuda",
|
|
|
+ help=("Device to fine-tune on. See `torch.device`."),
|
|
|
+ )
|
|
|
return parser
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
args = init_parser().parse_args()
|
|
|
dist_utils.init_distributed([logger, trainer.logger])
|
|
|
- device = torch.device("cuda")
|
|
|
- float_dtype = torch.float16
|
|
|
text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
|
|
|
unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
|
|
|
finetune_params = trainer.FinetuneParams(
|
|
|
finetune_mode=args.mode,
|
|
|
save_model_path=args.save_model_to,
|
|
|
- device=device,
|
|
|
+ device=torch.device(args.device),
|
|
|
+ float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
|
|
|
train_batch_size=args.batch_size,
|
|
|
eval_batch_size=args.batch_size,
|
|
|
patience=args.patience,
|
|
@@ -156,7 +161,7 @@ def main() -> None:
|
|
|
if model.text_encoder is not None:
|
|
|
model.text_encoder = None
|
|
|
model = model.to(finetune_params.device)
|
|
|
- logger.info(f"Model {model}")
|
|
|
+ logger.info(f"<{args.model_name}> {model}")
|
|
|
|
|
|
train_dataloader = dataloader.UnitYDataLoader(
|
|
|
text_tokenizer=text_tokenizer,
|
|
@@ -166,7 +171,7 @@ def main() -> None:
|
|
|
rank=dist_utils.get_rank(),
|
|
|
world_size=dist_utils.get_world_size(),
|
|
|
max_audio_length_sec=15.0,
|
|
|
- float_dtype=float_dtype,
|
|
|
+ float_dtype=finetune_params.float_dtype,
|
|
|
),
|
|
|
dataset_manifest_path=args.train_dataset,
|
|
|
)
|
|
@@ -178,7 +183,7 @@ def main() -> None:
|
|
|
rank=dist_utils.get_rank(),
|
|
|
world_size=dist_utils.get_world_size(),
|
|
|
max_audio_length_sec=100.0,
|
|
|
- float_dtype=float_dtype,
|
|
|
+ float_dtype=finetune_params.float_dtype,
|
|
|
),
|
|
|
dataset_manifest_path=args.eval_dataset,
|
|
|
)
|