瀏覽代碼

fix incremental decoding !

Guillaume Wenzek 1 年之前
父節點
當前提交
f6d810543d
共有 3 個文件被更改,包括 158 次插入49 次删除
  1. 48 19
      ggml/examples/unity/fairseq2.cpp
  2. 32 6
      ggml/ggml.py
  3. 78 24
      ggml/test_unity_cpp.py

+ 48 - 19
ggml/examples/unity/fairseq2.cpp

@@ -25,23 +25,35 @@ extern "C" void fairseq2_kv_cache_alloc(const fairseq2_model& model, int beam_si
     // Note: we only allocate the cache for the decoder attention.
     // For encoder attention since we compute it all at once,
     // the allocation is delayed to the first forward pass, to not over allocate.
-    auto layer_glob_c = "*decoder.*attn.k_proj.weight";
+    auto attn_glob = "*decoder.*_attn.k_proj.weight";
+    auto self_attn_glob = "*decoder.*self_attn.k_proj.weight";
     ggml_tensor* self_attn_mask = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, max_seq_len, max_seq_len);
-    self_attn_mask = ggml_diag_mask_inf(model.ctx, self_attn_mask, 0);
+    self_attn_mask = ggml_diag_mask_inf_inplace(model.ctx, self_attn_mask, 0);
+    ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
 
     for (auto named_tensor : model.tensors) {
         const std::string& name = named_tensor.first;
-        if (::fnmatch(layer_glob_c, name.c_str(), 0) == FNM_NOMATCH)
+        if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
             continue;
+        // create a cache entry without the ".k_proj.weight" suffix
+        const std::string& shortname = name.substr(0, name.size() - 14);
+        KeyValueTensor& kv = model.kv_cache[shortname];
+        kv.step_nr = 0;
+
+        if (::fnmatch(self_attn_glob, name.c_str(), 0) == FNM_NOMATCH) {
+            // enc_dec_attn
+            // the tensors will be allocated during the first forward
+            continue;
+        }
+
+        // self_attn
         ggml_tensor* k_proj = named_tensor.second;
         int model_dim = k_proj->ne[0];
-        // remove the ".k_proj.weight" suffix
-        model.kv_cache[name.substr(0, name.size() - 14)] = KeyValueTensor {
-            ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
-            ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
-            self_attn_mask,
-            0,
-        };
+        kv.full_k = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
+        kv.full_v = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
+        kv.self_attn_mask = self_attn_mask;
+        ggml_format_name(kv.full_k, "%s.k_cache", shortname.c_str());
+        ggml_format_name(kv.full_v, "%s.v_cache", shortname.c_str());
     }
 }
 
@@ -54,6 +66,7 @@ bool has_kv_cache(const fairseq2_model& model) {
 // kv.full_v[step_nr] = v;
 void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
     KeyValueTensor& kv = model.kv_cache[prefix];
+    GGML_ASSERT(kv.full_k != nullptr); // key not found !
     int step_nr = kv.step_nr;
 
     ggml_tensor* full_k = kv.full_k;
@@ -66,6 +79,8 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
 
     *k = ggml_slice(model.ctx, updated_k, 1, 0, step_nr + 1);
     *v = ggml_slice(model.ctx, updated_v, 1, 0, step_nr + 1);
+    ggml_format_name(*k, "%s (step=%d)", full_k->name, step_nr);
+    ggml_format_name(*v, "%s (step=%d)", full_v->name, step_nr);
 
     // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
     // we return the Sq slice of the (Sq, Sk) attention mask
@@ -97,13 +112,17 @@ ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
 
 
 void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
-    ggml_detach(kv.full_k);
-    kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
-    ggml_build_forward_expand(gf, kv.full_k);
+    if (kv.full_k != nullptr) {
+        ggml_detach(kv.full_k);
+        kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
+        ggml_build_forward_expand(gf, kv.full_k);
+    }
 
-    ggml_detach(kv.full_v);
-    kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
-    ggml_build_forward_expand(gf, kv.full_v);
+    if (kv.full_v != nullptr) {
+        ggml_detach(kv.full_v);
+        kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
+        ggml_build_forward_expand(gf, kv.full_v);
+    }
 }
 
 
@@ -333,19 +352,27 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
             KeyValueTensor& kv_cache = model.kv_cache[prefix];
             if (kv_cache.step_nr == 0) {
                 k = Linear_forward(model, prefix + ".k_proj", keys);
-                ggml_set_name(k, "k");
+                ggml_format_name(k, "%s.k_cache", prefix.c_str());
                 v = Linear_forward(model, prefix + ".v_proj", values);
-                ggml_set_name(v, "v");
-                model.kv_cache[prefix] = KeyValueTensor{k, v, nullptr, 1};
+                ggml_format_name(v, "%s.v_cache", prefix.c_str());
+                // TODO: encoder_padding_mask
+                kv_cache.full_k = k;
+                kv_cache.full_v = v;
+                kv_cache.step_nr = keys->ne[1];
             } else {
                 k = kv_cache.full_k;
                 v = kv_cache.full_v;
+                // This is a cache collision. TODO: fairseq2_kv_cache_reset
+                GGML_ASSERT(keys->ne[1] == k->ne[1]);
+                GGML_ASSERT(values->ne[1] == v->ne[1]);
             }
         } else { // self attention
             // (1, K) -> (N, 1, K_proj)
             k = Linear_forward(model, prefix + ".k_proj", keys);
+            ggml_set_name(k, "k");
             // (1, V) -> (N, 1, V_proj)
             v = Linear_forward(model, prefix + ".v_proj", values);
+            ggml_set_name(v, "v");
 
             append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
         }
@@ -776,11 +803,13 @@ struct ggml_tensor * ggml_slice(
     GGML_ASSERT(start <= end);
     GGML_ASSERT(end <= ne[axis]);
 
+
     ne[axis] = end - start;
     size_t offset = a->nb[axis] * start;
 
     size_t* nb = a->nb;
     ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
+    ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
     result->n_dims = a->n_dims;
     return result;
 }

+ 32 - 6
ggml/ggml.py

@@ -7,6 +7,7 @@ import numpy as np
 import ctypes
 import torch
 import functools
+import logging
 from pathlib import Path
 from typing import Dict
 from typing import Callable
@@ -129,10 +130,14 @@ def _strided_to_numpy(tensor_p: ggml_tensor_p) -> np.ndarray:
     # TODO make this work for transposed array
     n = 1
     total_elements = 1
-    for d in range(n_dim - 1):
-        n = num_bytes[d + 1] // type_size // n
-        full_shape.append(n)
-        total_elements *= n
+    try:
+        for d in range(n_dim - 1):
+            n = num_bytes[d + 1] // type_size // n
+            full_shape.append(n)
+            total_elements *= n
+    except ZeroDivisionError:
+        logging.warning("Can't convert permuted GGML tensor back to numpy")
+        return None
     # We don't need to guess for the first dimension, since this doesn't impact striding.
     full_shape.append(t_shape[0])
     total_elements *= t_shape[0]
@@ -193,7 +198,7 @@ def _compute_nbytes(
 
 
 def from_numpy(
-    ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"]
+    ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"], name: bytes = b""
 ) -> ggml_tensor_p:
     if type(array).__name__ == "Tensor":
         array = array.numpy()
@@ -212,6 +217,8 @@ def from_numpy(
 
     # prevent the underlying numpy array to be freed
     setattr(tensor_p, "__data", array)
+    if name:
+        ggml_set_name(tensor_p, name)
     return tensor_p
 
 
@@ -225,6 +232,22 @@ def ggml_can_mul_mat(t0: ggml_tensor_p, t1: ggml_tensor_p) -> bool:
     )
 
 
+def nodes(gf: ggml_cgraph) -> Dict[bytes, ggml_tensor_p]:
+    res = {}
+    for i in range(gf.n_nodes):
+        name = gf.nodes[i].contents.name
+        res[name] = gf.nodes[i]
+    return res
+
+
+def leafs(gf: ggml_cgraph) -> Dict[bytes, ggml_tensor_p]:
+    res = {}
+    for i in range(gf.n_leafs):
+        name = gf.leafs[i].contents.name
+        res[name] = gf.leafs[i]
+    return res
+
+
 class NativeObj:
     AllocFn = Callable[[], ctypes.c_void_p]
     FreeFn = Callable[[ctypes.c_void_p], None]
@@ -455,6 +478,9 @@ def _testing_return_hypothesis_ptr(ctx: ggml_context_p) -> Ptr[Hypothesis]:
 def fairseq2_model_layer_config_int(model: ctypes.c_void_p, name: str) -> int:
     return -1
 
+
 @c_fn(lib)
-def fairseq2_kv_cache_alloc(model: ctypes.c_void_p, beam_size: int, max_seq_len: int) -> None:
+def fairseq2_kv_cache_alloc(
+    model: ctypes.c_void_p, beam_size: int, max_seq_len: int
+) -> None:
     pass

+ 78 - 24
ggml/test_unity_cpp.py

@@ -160,13 +160,6 @@ def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
     assert np.allclose(y_exp, y, atol=1e-5)
 
 
-def _name(tensor: ggml.ggml_tensor_p) -> bytes:
-    try:
-        return tensor.contents.name  # type: ignore[no-any-return]
-    except ValueError:
-        return b"???"
-
-
 @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
 def test_MultiheadAttention_forward(
     ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
@@ -205,27 +198,21 @@ def test_MultiheadAttention_forward(
     q_exp = self_attn.q_proj(xq).numpy()
 
     y = ggml.to_numpy(gy)
-    nodes = {}
-
-    for i in range(gf.n_nodes):
-        name = _name(gf.nodes[i])
-        children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
-        print(name, f"op({gf.nodes[i].contents.op})", children)
-        nodes[name] = ggml.to_numpy(gf.nodes[i])
+    nodes = ggml.nodes(gf)
 
     attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
     self_attn.register_attn_weight_hook(attn_weights_hook)
 
     y_exp = self_attn(xq, None, xk, xk).numpy()
 
-    q = nodes[b"q"]
+    q = ggml.to_numpy(nodes[b"q"])
     assert q.shape == q_exp.shape
     assert np.allclose(q_exp, q, atol=1e-5)
 
     # with flash_attn we don't have attn_weights
     naive_attn = b"attn_weights" in nodes
     if naive_attn:
-        attn_weights = nodes[b"attn_weights"]
+        attn_weights = ggml.to_numpy(nodes[b"attn_weights"])
         [attn_weights_exp] = attn_weights_hook._storage
         attn_weights_exp = attn_weights_exp.numpy()
         assert attn_weights_exp.shape == attn_weights.shape
@@ -242,9 +229,11 @@ def test_MultiheadAttention_forward(
     assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
 
 
-def test_MultiheadAttention_forward_with_state_bag(ctx: Ctx, g_model: c_void_p) -> None:
+def test_MultiheadAttention_forward_self_attn_with_cache(
+    ctx: Ctx, g_model: c_void_p
+) -> None:
     pt_model = load_pt_model()
-    self_attn = pt_model.text_encoder.layers[0].self_attn
+    attn = pt_model.text_decoder.layers[0].self_attn
 
     x = torch.empty((2, 21, 1024))
     torch.random.manual_seed(0)
@@ -255,17 +244,65 @@ def test_MultiheadAttention_forward_with_state_bag(ctx: Ctx, g_model: c_void_p)
     ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
     # Incremental decoding
     for t in range(3):
-        xq, xk = x[:, t : t + 1], x[:, t : t + 1]
-        y_exp = self_attn(xq, None, xk, xk, state_bag=state_bag).numpy()
+        xq = x[:, t : t + 1]
+        y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
         assert y_exp.shape == (2, 1, 1024)
 
         gxq = ggml.from_numpy(ctx, xq.contiguous())
-        gxk = ggml.from_numpy(ctx, xk.contiguous())
-        ggml.ggml_set_name(gxk, b"xk")
+        ggml.ggml_set_name(gxq, b"xq")
+        gy = ggml.forward(
+            "MultiheadAttention",
+            g_model,
+            "text_decoder.layers.0.self_attn",
+            gxq,
+            gxq,
+            gxq,
+            None,  # type: ignore
+        )
+        gf = ggml.ggml_build_forward(gy)
+        ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+        nodes = ggml.nodes(gf)
+        state = state_bag.get_state(
+            attn, fairseq2.nn.transformer.MultiheadAttentionState
+        )
+        assert state is not None
+        assert np.allclose(
+            state.prev_k.numpy(),
+            ggml.to_numpy(nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]),
+            atol=1e-3,
+        )
+
+        y = ggml.to_numpy(gy)
+        assert np.allclose(y, y_exp, atol=1e-2)
+
+
+def test_MultiheadAttention_forward_cross_attn_with_cache(
+    ctx: Ctx, g_model: c_void_p
+) -> None:
+    pt_model = load_pt_model()
+    attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
+
+    x = torch.empty((2, 21, 1024))
+    torch.random.manual_seed(0)
+    torch.nn.init.uniform_(x, -1, 1)
+
+    state_bag = fairseq2.nn.IncrementalStateBag()
+
+    ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
+    # Incremental decoding, the keys come from the encoder, and don't change during decoding
+    xk = x[:, :11]
+    gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
+
+    for t in range(3):
+        xq = x[:, t : t + 1]
+
+        gxq = ggml.from_numpy(ctx, xq.contiguous())
+        ggml.ggml_set_name(gxq, b"xq")
         gy = ggml.forward(
             "MultiheadAttention",
             g_model,
-            "text_encoder.layers.0.self_attn",
+            "text_decoder.layers.0.encoder_decoder_attn",
             gxq,
             gxk,
             gxk,
@@ -273,8 +310,24 @@ def test_MultiheadAttention_forward_with_state_bag(ctx: Ctx, g_model: c_void_p)
         )
         gf = ggml.ggml_build_forward(gy)
         ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
-
         y = ggml.to_numpy(gy)
+        nodes = ggml.nodes(gf)
+        leaves = ggml.leafs(gf)
+
+        if t > 0:
+            # the cache only appear in the graph during the second call
+            state = state_bag.get_state(
+                attn, fairseq2.nn.transformer.MultiheadAttentionState
+            )
+            assert state is not None
+            assert np.allclose(
+                state.prev_k.numpy(),
+                ggml.to_numpy(nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]),
+                atol=1e-3,
+            )
+
+        y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
+        assert y_exp.shape == (2, 1, 1024)
         assert np.allclose(y, y_exp, atol=1e-2)
 
 
@@ -338,6 +391,7 @@ def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> N
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=2e-3)
 
+
 def test_StandardConformerEncoderAdaptorLayer_forward(
     ctx: Ctx, g_model: c_void_p
 ) -> None: