| 
					
				 | 
			
			
				@@ -39,12 +39,16 @@ class ModelBuilder: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         config: ModelConfig, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        override_s2t_vocabulary_info: Optional[VocabularyInfo] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        override_t2u_vocabulary_info: Optional[VocabularyInfo] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dtype: torch.dtype = torch.float16, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         device: torch.device = CPU_DEVICE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.config = config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.dtype = dtype 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.device = device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.override_s2t_vocabulary_info = override_s2t_vocabulary_info 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.override_t2u_vocabulary_info = override_t2u_vocabulary_info 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _sel_and_upd_prefix( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -287,6 +291,18 @@ class ModelBuilder: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             pp += nn 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return pp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _build_vocab_info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self, vocab_size: int, ref_vocab_info: Optional[VocabularyInfo] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> VocabularyInfo: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert ref_vocab_info is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return VocabularyInfo( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            size=vocab_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            unk_idx=ref_vocab_info.unk_idx, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            bos_idx=ref_vocab_info.bos_idx, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            eos_idx=ref_vocab_info.eos_idx, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pad_idx=ref_vocab_info.pad_idx, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _build_custom_model_config(self) -> UnitYConfig: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         assert self.config.custom_params is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         config: CustomModelParams = self.config.custom_params 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -341,12 +357,9 @@ class ModelBuilder: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             mt_model_config=NllbConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model_dim=config.model_embed_dim, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 max_seq_len=1024, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                vocab_info=VocabularyInfo( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    size=config.nllb_vocabulary_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    unk_idx=1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    bos_idx=2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    eos_idx=3, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    pad_idx=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                vocab_info=self._build_vocab_info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    vocab_size=config.nllb_vocabulary_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ref_vocab_info=self.override_s2t_vocabulary_info, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 num_encoder_layers=config.nllb_encoder_layers, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 num_decoder_layers=config.nllb_decoder_layers, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -362,12 +375,9 @@ class ModelBuilder: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 prosody_encoder_dim=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model_dim=config.model_embed_dim, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 unit_max_seq_len=2048, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                target_vocab_info=VocabularyInfo( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    size=config.unit_vocabulary_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    unk_idx=3, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    bos_idx=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    eos_idx=2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    pad_idx=1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                target_vocab_info=self._build_vocab_info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    vocab_size=config.unit_vocabulary_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ref_vocab_info=self.override_t2u_vocabulary_info, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 num_encoder_layers=config.t2u_encoder_layers, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 num_decoder_layers=config.t2u_decoder_layers, 
			 |