| 
					
				 | 
			
			
				@@ -20,53 +20,11 @@ from typing import List 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from fairseq2.assets import AssetCard 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from fairseq2.nn import SinusoidalPositionEncoder 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from fairseq2.nn.transformer import RelativePositionalEncoding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from seamless_communication.models.unity import load_unity_config, load_unity_model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 Preprocessor = Callable[[Any], Any] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def pos_enc(max_seq_len=4096, encoding_dim=1024): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    weight = torch.empty( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ((max_seq_len * 2) - 1, encoding_dim), dtype=torch.float32 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # copied from https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L22 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    dtype = torch.float32 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    weight = weight.to(dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    positive_w = weight[: max_seq_len] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    negative_w = weight[max_seq_len :] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    device = weight.device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # (E / 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    indices = torch.arange(0, encoding_dim, step=2, device=device, dtype=dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # (1, E / 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    indices = indices.unsqueeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # (S) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    steps = torch.arange(max_seq_len, device=device, dtype=dtype) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # (S, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    steps = steps.unsqueeze(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    factors = torch.exp(indices * -math.log(10000) / encoding_dim) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # (S, 1) x (1, E / 2) -> (S, E / 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    factors = torch.matmul(steps, factors) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    flipped_factors = factors.flip([0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # A mirrored matrix of sinusoidal positive and negative positional 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # encodings to use in shift trick. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # [max, ...,  3,  2,  1,  0, -1, -2, -3, ..., min] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    torch.sin(flipped_factors, out=positive_w[:, 0::2]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    torch.cos(flipped_factors, out=positive_w[:, 1::2]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    torch.sin(-1 * factors[1:], out=negative_w[:, 0::2]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    torch.cos(-1 * factors[1:], out=negative_w[:, 1::2]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    return weight 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def convert_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_name: Union[str, torch.nn.Module], 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -82,14 +40,18 @@ def convert_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if "unity" in model_name or "seamlessM4T" in model_name: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if hparams is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model_config = load_unity_config(model_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                hparams = flatten_config(dataclasses.asdict(model_config), separator="__") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                hparams = flatten_config( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    dataclasses.asdict(model_config), separator="__" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(hparams) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model = load_unity_model(model_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ValueError(f"Unsupported model type: {model_name}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Use the model passed explicitly 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        assert out is not None, "output path is required when explicitly passing a module" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            out is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ), "output path is required when explicitly passing a module" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         hparams = hparams or {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = model_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -149,7 +111,16 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         assert name not in state_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         state_dict[name] = pos_encoder.weight 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    state_dict["speech_encoder.pos_enc"] = pos_enc() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    relative_pos_encs = find_children(model, RelativePositionalEncoding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # speech_encoder has several copies of the relative_pos_enc module. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # For efficiency reasons we only make one copy of it to GGML. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if relative_pos_encs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print("Merging all speech_encoder RelativePositionalEncoding into one.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        _, rel_pos_enc = relative_pos_encs[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert isinstance(rel_pos_enc.weight, torch.Tensor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        state_dict["speech_encoder.pos_enc"] = rel_pos_enc.weight 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def write_ggml_file( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     out: BufferedWriter, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -213,7 +184,7 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) - 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # GGML broadcasting isn't as strong as numpy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             value = value.reshape(1, -1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if "pointwise_conv" in key: # pointwise_conv / depthwise_conv 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if "pointwise_conv" in key:  # pointwise_conv / depthwise_conv 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             value = value.squeeze(-1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if "depthwise_conv" in key: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             value = value.squeeze(1) 
			 |