Browse Source

remove duplicated code

Guillaume Wenzek 1 năm trước cách đây
mục cha
commit
4aee25223a
1 tập tin đã thay đổi với 18 bổ sung47 xóa
  1. 18 47
      ggml/ggml_convert.py

+ 18 - 47
ggml/ggml_convert.py

@@ -20,53 +20,11 @@ from typing import List
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.nn import SinusoidalPositionEncoder
+from fairseq2.nn.transformer import RelativePositionalEncoding
 from seamless_communication.models.unity import load_unity_config, load_unity_model
 
 Preprocessor = Callable[[Any], Any]
 
-def pos_enc(max_seq_len=4096, encoding_dim=1024):
-    weight = torch.empty(
-        ((max_seq_len * 2) - 1, encoding_dim), dtype=torch.float32
-    )
-    # copied from https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L22
-    dtype = torch.float32
-    weight = weight.to(dtype)
-
-    positive_w = weight[: max_seq_len]
-    negative_w = weight[max_seq_len :]
-
-    device = weight.device
-
-    # (E / 2)
-    indices = torch.arange(0, encoding_dim, step=2, device=device, dtype=dtype)
-
-    # (1, E / 2)
-    indices = indices.unsqueeze(0)
-
-    # (S)
-    steps = torch.arange(max_seq_len, device=device, dtype=dtype)
-
-    # (S, 1)
-    steps = steps.unsqueeze(1)
-
-    factors = torch.exp(indices * -math.log(10000) / encoding_dim)
-
-    # (S, 1) x (1, E / 2) -> (S, E / 2)
-    factors = torch.matmul(steps, factors)
-
-    flipped_factors = factors.flip([0])
-
-    # A mirrored matrix of sinusoidal positive and negative positional
-    # encodings to use in shift trick.
-    #
-    # [max, ...,  3,  2,  1,  0, -1, -2, -3, ..., min]
-    torch.sin(flipped_factors, out=positive_w[:, 0::2])
-    torch.cos(flipped_factors, out=positive_w[:, 1::2])
-
-    torch.sin(-1 * factors[1:], out=negative_w[:, 0::2])
-    torch.cos(-1 * factors[1:], out=negative_w[:, 1::2])
-
-    return weight
 
 def convert_model(
     model_name: Union[str, torch.nn.Module],
@@ -82,14 +40,18 @@ def convert_model(
         if "unity" in model_name or "seamlessM4T" in model_name:
             if hparams is None:
                 model_config = load_unity_config(model_name)
-                hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
+                hparams = flatten_config(
+                    dataclasses.asdict(model_config), separator="__"
+                )
                 print(hparams)
             model = load_unity_model(model_name)
         else:
             raise ValueError(f"Unsupported model type: {model_name}")
     else:
         # Use the model passed explicitly
-        assert out is not None, "output path is required when explicitly passing a module"
+        assert (
+            out is not None
+        ), "output path is required when explicitly passing a module"
         hparams = hparams or {}
         model = model_name
 
@@ -149,7 +111,16 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
         assert name not in state_dict
         state_dict[name] = pos_encoder.weight
 
-    state_dict["speech_encoder.pos_enc"] = pos_enc()
+
+    relative_pos_encs = find_children(model, RelativePositionalEncoding)
+    # speech_encoder has several copies of the relative_pos_enc module.
+    # For efficiency reasons we only make one copy of it to GGML.
+    if relative_pos_encs:
+        print("Merging all speech_encoder RelativePositionalEncoding into one.")
+        _, rel_pos_enc = relative_pos_encs[0]
+        assert isinstance(rel_pos_enc.weight, torch.Tensor)
+        state_dict["speech_encoder.pos_enc"] = rel_pos_enc.weight
+
 
 def write_ggml_file(
     out: BufferedWriter,
@@ -213,7 +184,7 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
             # GGML broadcasting isn't as strong as numpy
             value = value.reshape(1, -1)
-        if "pointwise_conv" in key: # pointwise_conv / depthwise_conv
+        if "pointwise_conv" in key:  # pointwise_conv / depthwise_conv
             value = value.squeeze(-1)
         if "depthwise_conv" in key:
             value = value.squeeze(1)