Parcourir la source

StandardTransformerEncoderLayer_forward

Guillaume Wenzek il y a 1 an
Parent
commit
b81061704c
3 fichiers modifiés avec 117 ajouts et 34 suppressions
  1. 59 23
      ggml/examples/unity/fairseq2.cpp
  2. 14 0
      ggml/examples/unity/fairseq2.h
  3. 44 11
      ggml/test_unity_cpp.py

+ 59 - 23
ggml/examples/unity/fairseq2.cpp

@@ -40,7 +40,9 @@ extern "C" ggml_tensor* Linear_forward(
 ) {
     // Note: for now we assumed un-batched input
     ggml_tensor* weight = model.tensors[prefix + ".weight"];  // (d_in, d_out)
+    GGML_ASSERT(weight != nullptr);
     ggml_tensor* bias = model.tensors[prefix + ".bias"];  // (d_out)
+    GGML_ASSERT(bias != nullptr);
 
     return ggml_add(
         model.ctx,
@@ -54,7 +56,9 @@ extern "C" ggml_tensor* LayerNorm_forward(
     const std::string &prefix,
     ggml_tensor* input) {
     ggml_tensor* weight = model.tensors[prefix + ".weight"];
+    GGML_ASSERT(weight != nullptr);
     ggml_tensor* bias = model.tensors[prefix + ".bias"];
+    GGML_ASSERT(bias != nullptr);
 
     auto ctx = model.ctx;
     // TODO: should `eps` be part of unity hparams ?
@@ -162,33 +166,65 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     return out;
 }
 
-// ggml_tensor* attn_weights = ggml_mul_mat(ctx, q, k);  // (H, S, S)
-//     attn_weights = ggm_mul * (q.size(-1) ** -0.5)
 
-//     if mask is not None:
-//         attn_weights = attn_weights + mask
+bool has_layer(fairseq2_model& model, const std::string& name) {
+    return model.tensors.find(name) != model.tensors.end();
+}
+
+
+extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask
+) {
+    ggml_context* ctx = model.ctx;
+    // TODO: read norm_order from model
+    auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
+
+    // _forward_self_attn(seqs, padding_mask)
+    auto residual = seqs;
+    if (norm_order != TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
+
+    // TODO: add padding_mask to MultiheadAttention_forward
+    GGML_ASSERT(padding_mask == nullptr);
+    seqs = MultiheadAttention_forward(
+        model,
+        prefix + ".self_attn",
+        seqs,
+        seqs,
+        seqs,
+        /*attention masks=*/nullptr
+    );
 
-//     # For numerical stability run in single precision.
-//     attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)
+    if (has_layer(model, prefix + ".self_attn_norm.weight"))
+        seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
 
-//     attn_weights = attn_weights.type_as(q)
+    // TODO: seqs = self.self_attn_dropout(seqs)
 
-//     if training and dropout_p > 0.0:
-//         attn_weights = dropout(attn_weights, dropout_p)
+    seqs = ggml_add(ctx, seqs, residual);
 
-//     # (*, S, S_kv) @ (*, S_kv, V) = (*, S, V)
-//     attn = torch.matmul(attn_weights, values)
+    if (norm_order == TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
 
-//     return attn, attn_weights if needs_weights else None
+    // _forward_ffn(seqs)
+    residual = seqs;
 
-// extern "C" ggml_tensor* // (d_out, seq_len)
-// SDPA_forward(
-//     fairseq2_model& model,
-//     const std::string &prefix,
-//     ggml_tensor* queries,  // (d_in, len_q)
-//     ggml_tensor* keys,  // (d_in, len_k)
-//     ggml_tensor* values,  // (d_out, len_k)
-//     ggml_tensor* mask // (seq_len, len_q)
-// ) {
-//     return queries;
-// }
+    if (norm_order != TRANSFORMER_NORM_ORDER_POST)
+        seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
+
+    seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
+
+    // TODO:
+    // seqs = self.ffn_dropout(seqs)
+    // if self.residual_scale is not None:
+    // residual = self.residual_scale * residual
+
+    seqs = ggml_add(ctx, seqs, residual);
+
+    if (norm_order == TRANSFORMER_NORM_ORDER_POST)
+        seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
+
+    return seqs;
+}

+ 14 - 0
ggml/examples/unity/fairseq2.h

@@ -54,3 +54,17 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* values,  // (klen, d_out)
     ggml_tensor* _ // (klen, slen)  TODO: do we need to pass mask here ?
 );
+
+extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask
+);
+
+// Specifies the Layer Normalization order.
+enum TransformerNormOrder {
+    TRANSFORMER_NORM_ORDER_POST = 0,
+    TRANSFORMER_NORM_ORDER_PRE = 1,
+    TRANSFORMER_NORM_ORDER_PRE_WITH_NORMFORMER = 2
+};

+ 44 - 11
ggml/test_unity_cpp.py

@@ -6,6 +6,7 @@ import numpy as np
 import torch
 import fairseq2.nn
 import fairseq2.nn.transformer
+from ctypes import c_void_p
 from typing import Any
 from pathlib import Path
 from typing import Iterator
@@ -260,7 +261,7 @@ def test_ning_model_load(ctx: Ctx) -> None:
 
 
 @pytest.fixture(scope="module")
-def g_model_once() -> Iterator[ctypes.c_void_p]:
+def g_model_once() -> Iterator[c_void_p]:
     model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
     if not model_file.exists():
         convert_model("seamlessM4T_medium", model_file)
@@ -269,7 +270,7 @@ def g_model_once() -> Iterator[ctypes.c_void_p]:
 
 
 @pytest.fixture()
-def g_model(ctx: Ctx, g_model_once: ctypes.c_void_p) -> ctypes.c_void_p:
+def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
     ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
     return g_model_once
 
@@ -363,7 +364,7 @@ def test_ggml_softmax_vs_torch(ctx: Ctx) -> None:
     assert np.allclose(y_exp, y, rtol=1e-3)
 
 
-def test_forward_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     x = torch.empty((21, 1024))  # (seq_len, model_dim)
     torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
 
@@ -380,7 +381,7 @@ def test_forward_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     assert np.allclose(y_exp, y, rtol=2e-2, atol=1e-4)
 
 
-def test_forward_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+def test_forward_layer_norm(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     x = torch.empty((21, 1024))
     torch.nn.init.uniform_(x, -1, 1)
 
@@ -396,20 +397,17 @@ def test_forward_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None
 
 def _name(tensor: ggml.ggml_tensor_p) -> bytes:
     try:
-        return tensor.contents.name
+        return tensor.contents.name  # type: ignore[no-any-return]
     except ValueError:
         return b"???"
 
 
-def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     x = torch.empty((1, 21, 1024))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
     self_attn = pt_model.text_encoder.layers[0].self_attn
-    # Replace spda by just returning queries
-    # TODO: implement spda
-    # self_attn.spda = lambda *qkv, **kwargs: qkv[0]
 
     # Note: we use different lengths for queries and keys,
     # this tests the implementation in decoding context too.
@@ -424,7 +422,7 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
         gxq,
         gx,
         gx,
-        None,  # attention mask
+        ctypes.pointer(),  # TODO: tests with causal attention masks
     )
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
@@ -450,7 +448,9 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     assert q.shape == q_exp.shape
     assert np.allclose(q_exp, q, atol=1e-5)
 
-    attn_exp, attn_weights_exp = map(lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0])
+    attn_exp, attn_weights_exp = map(
+        lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0]
+    )
 
     # with flash_attn we don't have attn_weights
     flash_attn = b"attn_weights" not in nodes
@@ -471,3 +471,36 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
 
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
+
+
+def test_StandardTransformerEncoderLayer_forward(
+    ctx: Ctx, g_model: c_void_p, pt_model: Any
+) -> None:
+    x = torch.empty((1, 21, 1024))
+    padding_mask = torch.ones((1, 21))
+    torch.random.manual_seed(0)
+    torch.nn.init.uniform_(x, -1, 1)
+
+    layer = pt_model.text_encoder.layers[0]
+
+    gx = ggml.from_numpy(ctx, x[0])
+    ggml.ggml_set_name(gx, b"x")
+    gpad = ggml.from_numpy(ctx, padding_mask[0])
+    ggml.ggml_set_name(gpad, b"padding_mask")
+    gy = ggml.forward(
+        "StandardTransformerEncoderLayer",
+        g_model,
+        "text_encoder.layers.0",
+        gx,
+        None,  # TODO support padding mask
+    )
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    y = ggml.to_numpy(gy)
+
+    y_exp, _ = layer(x, padding_mask)
+    y_exp = y_exp.squeeze(0).numpy()  # remove batch dimension
+
+    assert y.shape == y_exp.shape
+    assert np.allclose(y_exp, y, atol=1e-4)