Browse Source

MultiheadAttention_forward

Guillaume Wenzek 1 năm trước cách đây
mục cha
commit
b07a08102a
3 tập tin đã thay đổi với 98 bổ sung27 xóa
  1. 35 11
      ggml/examples/unity/fairseq2.cpp
  2. 1 0
      ggml/examples/unity/model_loader.cpp
  3. 62 16
      ggml/test_unity_cpp.py

+ 35 - 11
ggml/examples/unity/fairseq2.cpp

@@ -1,6 +1,8 @@
+#include <math.h>
 #include "ggml.h"
 #include "fairseq2.h"
 
+
 /// allocate the fairseq2 model and hyperparameters
 extern "C" fairseq2_model* fairseq2_model_alloc() {
     // pre-allocate some memory to write hyperparameters and tensors pointers
@@ -104,29 +106,51 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* queries,  // (slen, d_in)
     ggml_tensor* keys,  // (klen, d_in)
     ggml_tensor* values,  // (klen, d_out)
-    ggml_tensor* _ // (klen, slen)  TODO: do we need to pass mask here ?
+    ggml_tensor* mask // (klen, slen)  TODO: do we need to pass mask here ?
 ) {
     int slen = queries->ne[1];
+    int slenk = keys->ne[1];
     int num_heads = 16;
     int head_dim = queries->ne[0] / num_heads;
     ggml_context* ctx = model.ctx;
     ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
     q = reshape_num_head(ctx, q, num_heads);  // (H, S, H_dim)
+    ggml_set_name(q, "q");
     ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
-    k = reshape_num_head(ctx, k, num_heads);  // (H, S, H_dim)
+    k = reshape_num_head(ctx, k, num_heads);  // (H, Sk, H_dim)
+    ggml_set_name(k, "k");
+
     ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
-    v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slen); // (S, H, H_dim)
-    // v = ggml_permute(ctx, v, 1, 2, 0, 3);  // (H, H_dim, S)
-    v = ggml_permute(ctx, v, 1, 0, 2, 3);  // (S, H_dim, H)
+    v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
+    v = ggml_permute(ctx, v, 1, 2, 0, 3);  // (H, H_dim, Sk)
     v = ggml_cont(ctx, v);
-
-    // ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/false);  // (H, S, H_dim)
-    attn = ggml_permute(ctx, attn, 0, 2, 1, 3);  // (S, H, H_dim)
+    ggml_set_name(v, "v");
+
+    // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
+    ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
+    ggml_set_name(qk, "qk");
+    ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
+    ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
+    qk = ggml_scale(ctx, qk, qk_scale);
+    ggml_set_name(qk, "qk_scaled");
+
+    if (mask) qk = ggml_add(ctx, qk, mask);
+    // TODO: upgrade qk to float32 if needed
+    ggml_tensor* attn_weights = ggml_soft_max(ctx, qk);  // (H, Sk, S)
+    ggml_set_name(attn_weights, "attn_weights");
+
+    // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
+    ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
+    ggml_set_name(attn, "attn");
+    attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
+    attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
+    // // I'm not sure why this one is needed ...
     attn = ggml_cont(ctx, attn);
-    attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen);   // (S, H * V_h)
-    attn = Linear_forward(model, prefix + ".output_proj", attn);              // (S, d_out)
+    // out -> (S, d_out)
+    ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
+    ggml_set_name(out, "out");
 
-    return attn;
+    return out;
 }
 
 // ggml_tensor* attn_weights = ggml_mul_mat(ctx, q, k);  // (H, S, S)

+ 1 - 0
ggml/examples/unity/model_loader.cpp

@@ -35,6 +35,7 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
             printf("Error while reading tensor %s\n", name.c_str() );
             return 1;
         }
+        ggml_set_name(tensor, name.c_str());
         model.tensors[name] = tensor;
         if (DEBUG_MODEL_LOAD) {
             printf("%s [%5ld, %5ld], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), tensor->ne[0], tensor->ne[1], ggml_type_name(tensor->type), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));

+ 62 - 16
ggml/test_unity_cpp.py

@@ -5,6 +5,7 @@ import pytest
 import numpy as np
 import torch
 import fairseq2.nn
+import fairseq2.nn.transformer
 from typing import Any
 from pathlib import Path
 from typing import Iterator
@@ -346,6 +347,21 @@ def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
     assert np.allclose(y_exp, y)
 
 
+def test_ggml_softmax_vs_torch(ctx: Ctx) -> None:
+    x = torch.empty((5, 8, 4))
+    torch.nn.init.uniform_(x, -1, 1)
+    y_exp = torch.softmax(x, dim=-1).numpy()
+
+    gx = ggml.from_numpy(ctx, x.numpy())
+    gy = ggml.ggml_soft_max(ctx, gx)
+    y = ggml.to_numpy(gy)
+
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    assert np.allclose(y_exp, y, rtol=1e-3)
+
+
 def test_forward_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     x = torch.empty((21, 1024))  # (seq_len, model_dim)
     torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
@@ -377,6 +393,13 @@ def test_forward_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None
     assert np.allclose(y_exp, y, rtol=1e-3, atol=1e-4)
 
 
+def _name(tensor: ggml.ggml_tensor_p) -> bytes:
+    try:
+        return tensor.contents.name
+    except ValueError:
+        return b"???"
+
+
 def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     x = torch.empty((1, 21, 1024))
     torch.random.manual_seed(0)
@@ -387,32 +410,55 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     # TODO: implement spda
     # self_attn.spda = lambda *qkv, **kwargs: qkv[0]
 
-
-    gx = ggml.from_numpy(ctx, x)
+    gx = ggml.from_numpy(ctx, x[0])
+    gxk = ggml.from_numpy(ctx, x[0, :11, :])
+    gxv = ggml.from_numpy(ctx, x[0, :11, :])
+    ggml.ggml_set_name(gx, b"x")
     gy = ggml.forward(
         "MultiheadAttention",
         g_model,
         "text_encoder.layers.0.self_attn",
         gx,
-        gx,
-        gx,
+        gxk,
+        gxv,
         None,
     )
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    q_exp = self_attn._project_q(x, None, None).squeeze(0).numpy()
+
     y = ggml.to_numpy(gy)
-    names = "ql,q,qt,qp,kl,k,kt,kp,vl,v,vt,vp,v_cont,attn,attn_p,attn_cont,attn_reshape,outl,out"
-    assert gf.n_nodes == len(names.split(","))
-    gf_nodes = {}
-    for i, name in enumerate(names.split(",")):
-        mid = ggml.to_numpy(gf.nodes[i])
-        # print(name, mid.shape, mid)
-        gf_nodes[name] = mid
-
-    breakpoint()
-    y_exp = self_attn(x, None, x, x).numpy()
+    nodes = {}
+
+    for i in range(gf.n_nodes):
+        name = _name(gf.nodes[i])
+        children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
+        print(name, f"op({gf.nodes[i].contents.op})", children)
+        nodes[name] = ggml.to_numpy(gf.nodes[i])
+
+    attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
+    self_attn.register_attn_weight_hook(attn_weights_hook)
+
+    y_exp = self_attn(x, None, x[:, :11, :], x[:, :11, :]).numpy()
     y_exp = y_exp.squeeze(0)  # remove batch dimension
 
+    q = nodes[b"q"]
+    assert q.shape == q_exp.shape
+    assert np.allclose(q_exp, q, atol=1e-5)
+
+    attn_exp, attn_weights_exp = map(lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0])
+    attn_weights = nodes[b"attn_weights"]
+    assert attn_weights_exp.shape == attn_weights.shape
+    # GGML is very agressively reducing small softmax weights to 0.
+    # Not sure to what this is due.
+    assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
+
+    attn_exp = attn_exp.transpose(0, 2, 1)
+    attn = nodes[b"attn"]
+    assert attn_exp.shape == attn.shape
+    # Because of rounding errors in softmax, it's even worse here.
+    assert np.allclose(attn_exp, attn, atol=1e-2)
+
     assert y.shape == y_exp.shape
-    abs_diff = np.max(np.abs(y - y_exp))
-    assert np.allclose(y_exp, y)
+    assert np.allclose(y_exp, y, atol=1e-2)