소스 검색

fix MultiheadAttention_forward

Guillaume Wenzek 1 년 전
부모
커밋
28ed039370
4개의 변경된 파일56개의 추가작업 그리고 34개의 파일을 삭제
  1. 22 25
      ggml/examples/unity/fairseq2.cpp
  2. 11 0
      ggml/examples/unity/fairseq2.h
  3. 16 2
      ggml/ggml.py
  4. 7 7
      ggml/test_unity_cpp.py

+ 22 - 25
ggml/examples/unity/fairseq2.cpp

@@ -93,14 +93,11 @@ extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
 }
 
 
-/// Merge the given dimension and the previous one in the tensor.
-/// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
-/// dim is the position of the resulting merged dimension
-/// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d)
 ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
     int n_dims = x->n_dims;
     GGML_ASSERT(dim >= 0);
     GGML_ASSERT(dim < n_dims);
+    GGML_ASSERT(ggml_is_contiguous(x));
     // Nothing to do
     if (dim == n_dims - 1) return x;
 
@@ -123,9 +120,6 @@ ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
     }
 }
 
-/// Split the given dimension.
-/// (..., K * N, ...) -> (..., K, N, ...)
-/// dim is the position of the output dimension with the given number of element (N).
 ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
     int n_dims = x->n_dims;
     GGML_ASSERT(dim >= 0);
@@ -137,7 +131,7 @@ ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int n
         if (dim == 0) {
             return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
         } else { // dim == 1
-            return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
+            return ggml_reshape_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el);
         }
     } else { // (n_dims == 3)
         if (dim == 0) {
@@ -154,8 +148,9 @@ ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int n
 ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
     // (B, S, dim) -> (B, S, H, H_dim)
     x = ggml_unflatten_1d(ctx, x, 0, head_dim);
-    // (B?, S, H, H_dim) -> (B?, H, S, H_dim)
-    x = ggml_permute(ctx, x, 0, 2, 1, 3);
+    x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
+    x = ggml_cont(ctx, x);
+    x = ggml_flatten_1d(ctx, x, 2);  // (B * H, S, H_dim)
     return x;
 }
 
@@ -164,6 +159,8 @@ ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int hea
     // (B, Sk, dim) -> (B, Sk, H, H_dim)
     v = ggml_unflatten_1d(ctx, v, 0, head_dim);
     v = ggml_permute(ctx, v, 1, 2, 0, 3);  // (B?, H, H_dim, Sk)
+    v = ggml_cont(ctx, v);
+    v = ggml_flatten_1d(ctx, v, 2);  // (B * H, S, H_dim)
     return v;
 }
 
@@ -186,27 +183,27 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     GGML_ASSERT(model_dim % num_heads == 0);
 
     ggml_context* ctx = model.ctx;
-    ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
-    q = _reshape_num_head(ctx, q, head_dim);  // (B, H, S, H_dim)
+    ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
     ggml_set_name(q, "q");
+    q = _reshape_num_head(ctx, q, head_dim);  // (B * H, S, H_dim)
     ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
-    k = _reshape_num_head(ctx, k, head_dim);  // (B, H, Sk, H_dim)
     ggml_set_name(k, "k");
+    k = _reshape_num_head(ctx, k, head_dim);  // (B * H, Sk, H_dim)
 
     ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
-    v = _reshape_num_head_values(ctx, v, head_dim); // (B, H, H_dim, Sk)
-    v = ggml_cont(ctx, v);
     ggml_set_name(v, "v");
+    v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
+    v = ggml_cont(ctx, v);
 
 #if UNITY_FLASH_ATTN
     // For flash_attn, we assume either no masks, or triangular masks.
-    ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr);  // (H, S, H_dim)
+    ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr);  // (B * H, S, H_dim)
     ggml_set_name(attn, "attn");
+    // TODO test !
+    attn = ggml_unflatten_1d(ctx, attn, 2, num_heads);  // (B, H, H_dim, S)
     attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
-    attn = ggml_cont(ctx, attn);
-    attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
 #else
-    // (B, H, Sk, H_dim) x (B, H, S, H_dim) -> (B, H, S, Sk)
+    // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
     ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
     ggml_set_name(qk, "qk");
     ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
@@ -217,17 +214,17 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     // TODO: Should we replace this by ggml_diag_mask_inf ?
     if (mask) qk = ggml_add(ctx, qk, mask);
     // TODO: upgrade qk to float32 if needed
-    ggml_tensor* attn_weights = ggml_soft_max(ctx, qk);  // (B, H, S, Sk)
+    ggml_tensor* attn_weights = ggml_soft_max(ctx, qk);  // (B * H, S, Sk)
     ggml_set_name(attn_weights, "attn_weights");
 
-    // (B, H, S, Sk) x (B, H, H_dim, Sk) -> (B, H, H_dim, S)
+    // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
     ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
     ggml_set_name(attn, "attn");
-    attn = ggml_flatten_1d(ctx, attn, 1); // (B, H * H_dim, S)
-    attn = ggml_transpose(ctx, attn); // (B, S, H * H_dim)
-    // // I'm not sure why this one is needed ...
-    attn = ggml_cont(ctx, attn);
+    attn = ggml_unflatten_1d(ctx, attn, 2, num_heads);  // (B, H, H_dim, S)
+    attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
 #endif  // UNITY_FLASH_ATTN
+    attn = ggml_cont(ctx, attn);
+    attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
     // out -> (B, S, d_out)
     ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
     ggml_set_name(out, "out");

+ 11 - 0
ggml/examples/unity/fairseq2.h

@@ -35,6 +35,17 @@ extern "C" ggml_tensor* ggml_slice(
     int64_t end
 );
 
+/// Merge the given dimension and the previous one in the tensor.
+/// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
+/// dim is the position of the resulting merged dimension
+/// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d0
+extern "C" ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim);
+
+/// Split the given dimension.
+/// (..., K * N, ...) -> (..., K, N, ...)
+/// dim is the position of the output dimension with the given number of element (N).
+extern "C" ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el);
+
 extern "C" ggml_tensor* Linear_forward(
     fairseq2_model& model,
     const std::string &prefix,

+ 16 - 2
ggml/ggml.py

@@ -141,7 +141,7 @@ def _strided_to_numpy(tensor_p: ggml_tensor_p) -> np.ndarray:
     res = _void_p_to_np_array(tensor.data, tuple(full_shape), numpy_dtype(tensor.type))
 
     # Extract the correct slice
-    res = res.__getitem__(*[slice(0, n) for n in t_shape])
+    res = res.__getitem__(tuple(slice(0, n) for n in t_shape))
     # TODO: we could handle transposition here
 
     return res
@@ -175,7 +175,7 @@ def _shape_to_ne(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
     # in GGML ne[0] indicates the contiguous dimension, ie the last one in numpy and torch
     ne = shape[::-1]
     if len(ne) >= GGML_MAX_DIMS:
-        return ne # type: ignore
+        return ne  # type: ignore
 
     # ne is always of the same length
     padding = (1,) * (GGML_MAX_DIMS - len(ne))
@@ -388,6 +388,20 @@ def ggml_slice(
     ...
 
 
+@c_fn(lib)
+def ggml_flatten_1d(
+    ctx: ggml_context_p, a: Ptr[ggml_tensor], dim: int
+) -> Ptr[ggml_tensor]:
+    return a
+
+
+@c_fn(lib)
+def ggml_unflatten_1d(
+    ctx: ggml_context_p, a: Ptr[ggml_tensor], dim: int, num_el: int
+) -> Ptr[ggml_tensor]:
+    return a
+
+
 @c_struct
 class SequenceGeneratorOptions:
     beam_size: int

+ 7 - 7
ggml/test_unity_cpp.py

@@ -167,7 +167,7 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any)
     # Note: we use different lengths for queries and keys,
     # this tests the implementation in decoding context too.
     # Note2: ggml_flash_attn requires that we have more keys than queries
-    gxq = ggml.from_numpy(ctx, x[:, :11, :])
+    gxq = ggml.from_numpy(ctx, x[:, :11, :].contiguous())
     gx = ggml.from_numpy(ctx, x)
     ggml.ggml_set_name(gx, b"x")
     gy = ggml.forward(
@@ -182,7 +182,7 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any)
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-    q_exp = self_attn._project_q(x[:, :11, :], None, None).numpy()
+    q_exp = self_attn.q_proj(x[:, :11, :]).numpy()
 
     y = ggml.to_numpy(gy)
     nodes = {}
@@ -206,13 +206,13 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any)
     if not UNITY_FLASH_ATTN:
         attn_weights = nodes[b"attn_weights"]
         [attn_weights_exp] = attn_weights_hook._storage
-        # Fix the shape of attn_weights_exp
-        attn_weights_exp = attn_weights_exp.unflatten(0, (2, 16)).numpy()
+        attn_weights_exp = attn_weights_exp.numpy()
         assert attn_weights_exp.shape == attn_weights.shape
-        # GGML is very agressively reducing small softmax weights to 0.
-        # assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
+        # GGML is very agressively reducing small softmax weights to 0,
+        # so the error isn't that small
+        assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
         # But the sums should be close to 1
-        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, 11)))
+        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 11)))
         # And the maximum index should match the original ones.
         assert np.allclose(
             np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)