|
@@ -39,12 +39,16 @@ class ModelBuilder:
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: ModelConfig,
|
|
|
+ override_s2t_vocabulary_info: Optional[VocabularyInfo] = None,
|
|
|
+ override_t2u_vocabulary_info: Optional[VocabularyInfo] = None,
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
device: torch.device = CPU_DEVICE,
|
|
|
):
|
|
|
self.config = config
|
|
|
self.dtype = dtype
|
|
|
self.device = device
|
|
|
+ self.override_s2t_vocabulary_info = override_s2t_vocabulary_info
|
|
|
+ self.override_t2u_vocabulary_info = override_t2u_vocabulary_info
|
|
|
|
|
|
@classmethod
|
|
|
def _sel_and_upd_prefix(
|
|
@@ -287,6 +291,18 @@ class ModelBuilder:
|
|
|
pp += nn
|
|
|
return pp
|
|
|
|
|
|
+ def _build_vocab_info(
|
|
|
+ self, vocab_size: int, ref_vocab_info: Optional[VocabularyInfo]
|
|
|
+ ) -> VocabularyInfo:
|
|
|
+ assert ref_vocab_info is not None
|
|
|
+ return VocabularyInfo(
|
|
|
+ size=vocab_size,
|
|
|
+ unk_idx=ref_vocab_info.unk_idx,
|
|
|
+ bos_idx=ref_vocab_info.bos_idx,
|
|
|
+ eos_idx=ref_vocab_info.eos_idx,
|
|
|
+ pad_idx=ref_vocab_info.pad_idx,
|
|
|
+ )
|
|
|
+
|
|
|
def _build_custom_model_config(self) -> UnitYConfig:
|
|
|
assert self.config.custom_params is not None
|
|
|
config: CustomModelParams = self.config.custom_params
|
|
@@ -341,12 +357,9 @@ class ModelBuilder:
|
|
|
mt_model_config=NllbConfig(
|
|
|
model_dim=config.model_embed_dim,
|
|
|
max_seq_len=1024,
|
|
|
- vocab_info=VocabularyInfo(
|
|
|
- size=config.nllb_vocabulary_size,
|
|
|
- unk_idx=1,
|
|
|
- bos_idx=2,
|
|
|
- eos_idx=3,
|
|
|
- pad_idx=0,
|
|
|
+ vocab_info=self._build_vocab_info(
|
|
|
+ vocab_size=config.nllb_vocabulary_size,
|
|
|
+ ref_vocab_info=self.override_s2t_vocabulary_info,
|
|
|
),
|
|
|
num_encoder_layers=config.nllb_encoder_layers,
|
|
|
num_decoder_layers=config.nllb_decoder_layers,
|
|
@@ -362,12 +375,9 @@ class ModelBuilder:
|
|
|
prosody_encoder_dim=0,
|
|
|
model_dim=config.model_embed_dim,
|
|
|
unit_max_seq_len=2048,
|
|
|
- target_vocab_info=VocabularyInfo(
|
|
|
- size=config.unit_vocabulary_size,
|
|
|
- unk_idx=3,
|
|
|
- bos_idx=0,
|
|
|
- eos_idx=2,
|
|
|
- pad_idx=1,
|
|
|
+ target_vocab_info=self._build_vocab_info(
|
|
|
+ vocab_size=config.unit_vocabulary_size,
|
|
|
+ ref_vocab_info=self.override_t2u_vocabulary_info,
|
|
|
),
|
|
|
num_encoder_layers=config.t2u_encoder_layers,
|
|
|
num_decoder_layers=config.t2u_decoder_layers,
|