|  | @@ -20,53 +20,11 @@ from typing import List
 | 
	
		
			
				|  |  |  from fairseq2.assets import AssetCard
 | 
	
		
			
				|  |  |  from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 | 
	
		
			
				|  |  |  from fairseq2.nn import SinusoidalPositionEncoder
 | 
	
		
			
				|  |  | +from fairseq2.nn.transformer import RelativePositionalEncoding
 | 
	
		
			
				|  |  |  from seamless_communication.models.unity import load_unity_config, load_unity_model
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  Preprocessor = Callable[[Any], Any]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def pos_enc(max_seq_len=4096, encoding_dim=1024):
 | 
	
		
			
				|  |  | -    weight = torch.empty(
 | 
	
		
			
				|  |  | -        ((max_seq_len * 2) - 1, encoding_dim), dtype=torch.float32
 | 
	
		
			
				|  |  | -    )
 | 
	
		
			
				|  |  | -    # copied from https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L22
 | 
	
		
			
				|  |  | -    dtype = torch.float32
 | 
	
		
			
				|  |  | -    weight = weight.to(dtype)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    positive_w = weight[: max_seq_len]
 | 
	
		
			
				|  |  | -    negative_w = weight[max_seq_len :]
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    device = weight.device
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # (E / 2)
 | 
	
		
			
				|  |  | -    indices = torch.arange(0, encoding_dim, step=2, device=device, dtype=dtype)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # (1, E / 2)
 | 
	
		
			
				|  |  | -    indices = indices.unsqueeze(0)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # (S)
 | 
	
		
			
				|  |  | -    steps = torch.arange(max_seq_len, device=device, dtype=dtype)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # (S, 1)
 | 
	
		
			
				|  |  | -    steps = steps.unsqueeze(1)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    factors = torch.exp(indices * -math.log(10000) / encoding_dim)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # (S, 1) x (1, E / 2) -> (S, E / 2)
 | 
	
		
			
				|  |  | -    factors = torch.matmul(steps, factors)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    flipped_factors = factors.flip([0])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # A mirrored matrix of sinusoidal positive and negative positional
 | 
	
		
			
				|  |  | -    # encodings to use in shift trick.
 | 
	
		
			
				|  |  | -    #
 | 
	
		
			
				|  |  | -    # [max, ...,  3,  2,  1,  0, -1, -2, -3, ..., min]
 | 
	
		
			
				|  |  | -    torch.sin(flipped_factors, out=positive_w[:, 0::2])
 | 
	
		
			
				|  |  | -    torch.cos(flipped_factors, out=positive_w[:, 1::2])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    torch.sin(-1 * factors[1:], out=negative_w[:, 0::2])
 | 
	
		
			
				|  |  | -    torch.cos(-1 * factors[1:], out=negative_w[:, 1::2])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    return weight
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def convert_model(
 | 
	
		
			
				|  |  |      model_name: Union[str, torch.nn.Module],
 | 
	
	
		
			
				|  | @@ -82,14 +40,18 @@ def convert_model(
 | 
	
		
			
				|  |  |          if "unity" in model_name or "seamlessM4T" in model_name:
 | 
	
		
			
				|  |  |              if hparams is None:
 | 
	
		
			
				|  |  |                  model_config = load_unity_config(model_name)
 | 
	
		
			
				|  |  | -                hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
 | 
	
		
			
				|  |  | +                hparams = flatten_config(
 | 
	
		
			
				|  |  | +                    dataclasses.asdict(model_config), separator="__"
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  |                  print(hparams)
 | 
	
		
			
				|  |  |              model = load_unity_model(model_name)
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              raise ValueError(f"Unsupported model type: {model_name}")
 | 
	
		
			
				|  |  |      else:
 | 
	
		
			
				|  |  |          # Use the model passed explicitly
 | 
	
		
			
				|  |  | -        assert out is not None, "output path is required when explicitly passing a module"
 | 
	
		
			
				|  |  | +        assert (
 | 
	
		
			
				|  |  | +            out is not None
 | 
	
		
			
				|  |  | +        ), "output path is required when explicitly passing a module"
 | 
	
		
			
				|  |  |          hparams = hparams or {}
 | 
	
		
			
				|  |  |          model = model_name
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -149,7 +111,16 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
 | 
	
		
			
				|  |  |          assert name not in state_dict
 | 
	
		
			
				|  |  |          state_dict[name] = pos_encoder.weight
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    state_dict["speech_encoder.pos_enc"] = pos_enc()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    relative_pos_encs = find_children(model, RelativePositionalEncoding)
 | 
	
		
			
				|  |  | +    # speech_encoder has several copies of the relative_pos_enc module.
 | 
	
		
			
				|  |  | +    # For efficiency reasons we only make one copy of it to GGML.
 | 
	
		
			
				|  |  | +    if relative_pos_encs:
 | 
	
		
			
				|  |  | +        print("Merging all speech_encoder RelativePositionalEncoding into one.")
 | 
	
		
			
				|  |  | +        _, rel_pos_enc = relative_pos_encs[0]
 | 
	
		
			
				|  |  | +        assert isinstance(rel_pos_enc.weight, torch.Tensor)
 | 
	
		
			
				|  |  | +        state_dict["speech_encoder.pos_enc"] = rel_pos_enc.weight
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def write_ggml_file(
 | 
	
		
			
				|  |  |      out: BufferedWriter,
 | 
	
	
		
			
				|  | @@ -213,7 +184,7 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
 | 
	
		
			
				|  |  |          if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
 | 
	
		
			
				|  |  |              # GGML broadcasting isn't as strong as numpy
 | 
	
		
			
				|  |  |              value = value.reshape(1, -1)
 | 
	
		
			
				|  |  | -        if "pointwise_conv" in key: # pointwise_conv / depthwise_conv
 | 
	
		
			
				|  |  | +        if "pointwise_conv" in key:  # pointwise_conv / depthwise_conv
 | 
	
		
			
				|  |  |              value = value.squeeze(-1)
 | 
	
		
			
				|  |  |          if "depthwise_conv" in key:
 | 
	
		
			
				|  |  |              value = value.squeeze(1)
 |