Browse Source

fix PositionalEmbedding with cache

Guillaume Wenzek 1 year ago
parent
commit
d5b035f230
3 changed files with 151 additions and 111 deletions
  1. 40 4
      ggml/examples/unity/fairseq2.cpp
  2. 18 0
      ggml/ggml.py
  3. 93 107
      ggml/test_unity_cpp.py

+ 40 - 4
ggml/examples/unity/fairseq2.cpp

@@ -57,6 +57,12 @@ extern "C" void fairseq2_kv_cache_alloc(const fairseq2_model& model, int beam_si
     }
 }
 
+extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
+    // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
+    model.kv_cache.clear();
+}
+
+
 bool has_kv_cache(const fairseq2_model& model) {
     return model.kv_cache.size() > 0;
 }
@@ -814,6 +820,31 @@ struct ggml_tensor * ggml_slice(
     return result;
 }
 
+struct ggml_tensor * ggml_select(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int axis,
+    int64_t index
+) {
+    int64_t ne[GGML_MAX_DIMS];
+    std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
+
+    if (axis < 0) axis = a->n_dims + axis;
+    if (index < 0) index = ne[axis] + index;
+    GGML_ASSERT(0 <= index);
+    GGML_ASSERT(index < ne[axis]);
+
+    std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
+
+    size_t offset = a->nb[axis] * index;
+    size_t* nb = a->nb;
+    GGML_ASSERT(GGML_MAX_DIMS == 4);
+    ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
+    ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
+    result->n_dims = a->n_dims - 1;
+    return result;
+}
+
 
 extern "C" ggml_tensor* PositionalEmbedding_forward(
     fairseq2_model& model,
@@ -823,7 +854,12 @@ extern "C" ggml_tensor* PositionalEmbedding_forward(
     // This only work with the simple pos encoders
     int seq_len = embeds->ne[1];
     ggml_tensor* full_pos_embeds = model.tensors[prefix];
-    ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
+
+    int start_step = 0;
+    if (has_kv_cache(model)) {
+        start_step = model.kv_cache[prefix].step_nr++;
+    }
+    ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
     return ggml_add(model.ctx, embeds, pos_embeds);
 }
 
@@ -831,7 +867,6 @@ extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
     fairseq2_model& model,
     const std::string& prefix,
     ggml_tensor* seqs
-    // TODO: state_bag
 ) {
     GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
     ggml_context* ctx = model.ctx;
@@ -1264,8 +1299,8 @@ extern "C" Hypothesis* generate_sequence(
 
     // TODO: memory management, there should be a per-step ggml_context for intermediary results
     for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
-        ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, step_nr, step_nr + 1);
-        decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
+        ggml_tensor* prev_token = ggml_slice(ctx, seqs, 0, step_nr, step_nr + 1);
+        ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
         ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
             model,
             "text_decoder",
@@ -1408,6 +1443,7 @@ end_of_beam_search:
         [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
     );
 
+    fairseq2_kv_cache_reset(model);
     return finished_searches_begin;
 }
 

+ 18 - 0
ggml/ggml.py

@@ -9,6 +9,8 @@ import torch
 import functools
 import logging
 import dataclasses
+import contextlib
+from typing import Iterator
 from typing import NamedTuple
 from pathlib import Path
 from typing import Dict
@@ -489,3 +491,19 @@ def fairseq2_kv_cache_alloc(
     model: ctypes.c_void_p, beam_size: int, max_seq_len: int
 ) -> None:
     pass
+
+
+@c_fn(lib)
+def fairseq2_kv_cache_reset(model: ctypes.c_void_p) -> None:
+    pass
+
+
+@contextlib.contextmanager
+def model_kv_cache_alloc(
+    model: ctypes.c_void_p, beam_size: int, max_seq_len: int
+) -> Iterator[None]:
+    fairseq2_kv_cache_alloc(model, beam_size, max_seq_len)
+    try:
+        yield
+    finally:
+        fairseq2_kv_cache_reset(model)

+ 93 - 107
ggml/test_unity_cpp.py

@@ -245,43 +245,43 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
 
     state_bag = fairseq2.nn.IncrementalStateBag()
 
-    ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
-    # Incremental decoding
-    for t in range(3):
-        xq = x[:, t : t + 1]
-        y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
-        assert y_exp.shape == (2, 1, 1024)
-
-        gxq = ggml.from_numpy(ctx, xq.contiguous())
-        ggml.ggml_set_name(gxq, b"xq")
-        gy = ggml.forward(
-            "MultiheadAttention",
-            g_model,
-            "text_decoder.layers.0.self_attn",
-            gxq,
-            gxq,
-            gxq,
-            None,  # type: ignore
-        )
-        gf = ggml.ggml_build_forward(gy)
-        ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    with ggml.model_kv_cache_alloc(g_model, 2, 21):
+        # Incremental decoding
+        for t in range(3):
+            xq = x[:, t : t + 1]
+            y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
+            assert y_exp.shape == (2, 1, 1024)
+
+            gxq = ggml.from_numpy(ctx, xq.contiguous())
+            ggml.ggml_set_name(gxq, b"xq")
+            gy = ggml.forward(
+                "MultiheadAttention",
+                g_model,
+                "text_decoder.layers.0.self_attn",
+                gxq,
+                gxq,
+                gxq,
+                None,  # type: ignore
+            )
+            gf = ggml.ggml_build_forward(gy)
+            ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-        nodes = ggml.nodes(gf)
-        state = state_bag.get_state(
-            attn, fairseq2.nn.transformer.MultiheadAttentionState
-        )
-        state_bag.increment_step()
-        assert state is not None
-        assert np.allclose(
-            state.prev_k.numpy(),
-            ggml.to_numpy(
-                nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
-            ),
-            atol=1e-3,
-        )
+            nodes = ggml.nodes(gf)
+            state = state_bag.get_state(
+                attn, fairseq2.nn.transformer.MultiheadAttentionState
+            )
+            state_bag.increment_step()
+            assert state is not None
+            assert np.allclose(
+                state.prev_k.numpy(),
+                ggml.to_numpy(
+                    nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
+                ),
+                atol=1e-3,
+            )
 
-        y = ggml.to_numpy(gy)
-        assert np.allclose(y, y_exp, atol=1e-2)
+            y = ggml.to_numpy(gy)
+            assert np.allclose(y, y_exp, atol=1e-2)
 
 
 def test_MultiheadAttention_forward_cross_attn_with_cache(
@@ -296,49 +296,49 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
 
     state_bag = fairseq2.nn.IncrementalStateBag()
 
-    ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
-    # Incremental decoding, the keys come from the encoder, and don't change during decoding
-    xk = x[:, :11]
-    gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
-
-    for t in range(3):
-        xq = x[:, t : t + 1]
-
-        gxq = ggml.from_numpy(ctx, xq.contiguous())
-        ggml.ggml_set_name(gxq, b"xq")
-        gy = ggml.forward(
-            "MultiheadAttention",
-            g_model,
-            "text_decoder.layers.0.encoder_decoder_attn",
-            gxq,
-            gxk,
-            gxk,
-            None,  # type: ignore
-        )
-        gf = ggml.ggml_build_forward(gy)
-        ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
-        y = ggml.to_numpy(gy)
-        nodes = ggml.nodes(gf)
-        leaves = ggml.leafs(gf)
-
-        if t > 0:
-            # the cache only appear in the graph during the second call
-            state = state_bag.get_state(
-                attn, fairseq2.nn.transformer.MultiheadAttentionState
-            )
-            assert state is not None
-            assert np.allclose(
-                state.prev_k.numpy(),
-                ggml.to_numpy(
-                    nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
-                ),
-                atol=1e-3,
+    with ggml.model_kv_cache_alloc(g_model, 2, 21):
+        # Incremental decoding, the keys come from the encoder, and don't change during decoding
+        xk = x[:, :11]
+        gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
+
+        for t in range(3):
+            xq = x[:, t : t + 1]
+
+            gxq = ggml.from_numpy(ctx, xq.contiguous())
+            ggml.ggml_set_name(gxq, b"xq")
+            gy = ggml.forward(
+                "MultiheadAttention",
+                g_model,
+                "text_decoder.layers.0.encoder_decoder_attn",
+                gxq,
+                gxk,
+                gxk,
+                None,  # type: ignore
             )
-
-        state_bag.increment_step()
-        y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
-        assert y_exp.shape == (2, 1, 1024)
-        assert np.allclose(y, y_exp, atol=1e-2)
+            gf = ggml.ggml_build_forward(gy)
+            ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+            y = ggml.to_numpy(gy)
+            nodes = ggml.nodes(gf)
+            leaves = ggml.leafs(gf)
+
+            if t > 0:
+                # the cache only appear in the graph during the second call
+                state = state_bag.get_state(
+                    attn, fairseq2.nn.transformer.MultiheadAttentionState
+                )
+                assert state is not None
+                assert np.allclose(
+                    state.prev_k.numpy(),
+                    ggml.to_numpy(
+                        nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
+                    ),
+                    atol=1e-3,
+                )
+
+            state_bag.increment_step()
+            y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
+            assert y_exp.shape == (2, 1, 1024)
+            assert np.allclose(y, y_exp, atol=1e-2)
 
 
 def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
@@ -534,20 +534,6 @@ def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
     assert np.allclose(y_exp, y, atol=4e-3)  # reduce? error is from standardization
 
 
-def test_causal_attention_mask(ctx: Ctx):
-    x = torch.zeros((5, 10))
-    generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
-    mask_exp = generator(x)
-
-    gx = ggml.from_numpy(ctx, x)
-    gmask = ggml.causal_attention_mask(ctx, gx)
-    mask = ggml.to_numpy(gmask)
-
-    assert mask_exp.shape == (10, 10)
-    assert mask.shape == (10, 10)
-    assert np.allclose(mask, mask_exp)
-
-
 def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
     seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
     # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
@@ -574,22 +560,22 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
     pos_encoder.eval()
     state_bag = fairseq2.nn.IncrementalStateBag()
 
-    ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
-    # Incremental decoding
-    for t in range(20):
-        gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
-        ggml.ggml_set_name(gseq, b"seq")
-        gy = ggml.forward(
-            "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
-        )
-        gf = ggml.ggml_build_forward(gy)
-        ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
-        y = ggml.to_numpy(gy)
-
-        y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
-        state_bag.increment_step()
-        assert y.shape == y_exp.shape
-        assert np.allclose(y_exp, y, atol=1e-6)
+    with ggml.model_kv_cache_alloc(g_model, 2, 21):
+        # Incremental decoding
+        for t in range(20):
+            gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
+            ggml.ggml_set_name(gseq, b"seq")
+            gy = ggml.forward(
+                "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
+            )
+            gf = ggml.ggml_build_forward(gy)
+            ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+            y = ggml.to_numpy(gy)
+
+            y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
+            state_bag.increment_step()
+            assert y.shape == y_exp.shape
+            assert np.allclose(y_exp, y, atol=1e-6)
 
 
 def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None: