|  | @@ -39,12 +39,16 @@ class ModelBuilder:
 | 
	
		
			
				|  |  |      def __init__(
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          config: ModelConfig,
 | 
	
		
			
				|  |  | +        override_s2t_vocabulary_info: Optional[VocabularyInfo] = None,
 | 
	
		
			
				|  |  | +        override_t2u_vocabulary_info: Optional[VocabularyInfo] = None,
 | 
	
		
			
				|  |  |          dtype: torch.dtype = torch.float16,
 | 
	
		
			
				|  |  |          device: torch.device = CPU_DEVICE,
 | 
	
		
			
				|  |  |      ):
 | 
	
		
			
				|  |  |          self.config = config
 | 
	
		
			
				|  |  |          self.dtype = dtype
 | 
	
		
			
				|  |  |          self.device = device
 | 
	
		
			
				|  |  | +        self.override_s2t_vocabulary_info = override_s2t_vocabulary_info
 | 
	
		
			
				|  |  | +        self.override_t2u_vocabulary_info = override_t2u_vocabulary_info
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
		
			
				|  |  |      def _sel_and_upd_prefix(
 | 
	
	
		
			
				|  | @@ -287,6 +291,18 @@ class ModelBuilder:
 | 
	
		
			
				|  |  |              pp += nn
 | 
	
		
			
				|  |  |          return pp
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def _build_vocab_info(
 | 
	
		
			
				|  |  | +        self, vocab_size: int, ref_vocab_info: Optional[VocabularyInfo]
 | 
	
		
			
				|  |  | +    ) -> VocabularyInfo:
 | 
	
		
			
				|  |  | +        assert ref_vocab_info is not None
 | 
	
		
			
				|  |  | +        return VocabularyInfo(
 | 
	
		
			
				|  |  | +            size=vocab_size,
 | 
	
		
			
				|  |  | +            unk_idx=ref_vocab_info.unk_idx,
 | 
	
		
			
				|  |  | +            bos_idx=ref_vocab_info.bos_idx,
 | 
	
		
			
				|  |  | +            eos_idx=ref_vocab_info.eos_idx,
 | 
	
		
			
				|  |  | +            pad_idx=ref_vocab_info.pad_idx,
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _build_custom_model_config(self) -> UnitYConfig:
 | 
	
		
			
				|  |  |          assert self.config.custom_params is not None
 | 
	
		
			
				|  |  |          config: CustomModelParams = self.config.custom_params
 | 
	
	
		
			
				|  | @@ -341,12 +357,9 @@ class ModelBuilder:
 | 
	
		
			
				|  |  |              mt_model_config=NllbConfig(
 | 
	
		
			
				|  |  |                  model_dim=config.model_embed_dim,
 | 
	
		
			
				|  |  |                  max_seq_len=1024,
 | 
	
		
			
				|  |  | -                vocab_info=VocabularyInfo(
 | 
	
		
			
				|  |  | -                    size=config.nllb_vocabulary_size,
 | 
	
		
			
				|  |  | -                    unk_idx=1,
 | 
	
		
			
				|  |  | -                    bos_idx=2,
 | 
	
		
			
				|  |  | -                    eos_idx=3,
 | 
	
		
			
				|  |  | -                    pad_idx=0,
 | 
	
		
			
				|  |  | +                vocab_info=self._build_vocab_info(
 | 
	
		
			
				|  |  | +                    vocab_size=config.nllb_vocabulary_size,
 | 
	
		
			
				|  |  | +                    ref_vocab_info=self.override_s2t_vocabulary_info,
 | 
	
		
			
				|  |  |                  ),
 | 
	
		
			
				|  |  |                  num_encoder_layers=config.nllb_encoder_layers,
 | 
	
		
			
				|  |  |                  num_decoder_layers=config.nllb_decoder_layers,
 | 
	
	
		
			
				|  | @@ -362,12 +375,9 @@ class ModelBuilder:
 | 
	
		
			
				|  |  |                  prosody_encoder_dim=0,
 | 
	
		
			
				|  |  |                  model_dim=config.model_embed_dim,
 | 
	
		
			
				|  |  |                  unit_max_seq_len=2048,
 | 
	
		
			
				|  |  | -                target_vocab_info=VocabularyInfo(
 | 
	
		
			
				|  |  | -                    size=config.unit_vocabulary_size,
 | 
	
		
			
				|  |  | -                    unk_idx=3,
 | 
	
		
			
				|  |  | -                    bos_idx=0,
 | 
	
		
			
				|  |  | -                    eos_idx=2,
 | 
	
		
			
				|  |  | -                    pad_idx=1,
 | 
	
		
			
				|  |  | +                target_vocab_info=self._build_vocab_info(
 | 
	
		
			
				|  |  | +                    vocab_size=config.unit_vocabulary_size,
 | 
	
		
			
				|  |  | +                    ref_vocab_info=self.override_t2u_vocabulary_info,
 | 
	
		
			
				|  |  |                  ),
 | 
	
		
			
				|  |  |                  num_encoder_layers=config.t2u_encoder_layers,
 | 
	
		
			
				|  |  |                  num_decoder_layers=config.t2u_decoder_layers,
 |