|
@@ -146,9 +146,9 @@ def main() -> None:
|
|
|
args.model_name, device=finetune_params.device, dtype=torch.float16
|
|
|
)
|
|
|
logger.info(f"Model {model}")
|
|
|
- assert model.pad_idx == text_tokenizer.vocab_info.pad_idx
|
|
|
+ assert model.target_vocab_info == text_tokenizer.vocab_info
|
|
|
assert model.t2u_model is not None
|
|
|
- assert model.t2u_model.pad_idx == unit_tokenizer.vocab_info.pad_idx
|
|
|
+ assert model.t2u_model.target_vocab_info == unit_tokenizer.vocab_info
|
|
|
|
|
|
train_dataloader = dataloader.UnitYDataLoader(
|
|
|
text_tokenizer=text_tokenizer,
|