Browse Source

add kv_cache to fairseq2_model

Guillaume Wenzek 1 year ago
parent
commit
2c543185e2
2 changed files with 106 additions and 32 deletions
  1. 96 31
      ggml/examples/unity/fairseq2.cpp
  2. 10 1
      ggml/examples/unity/fairseq2.h

+ 96 - 31
ggml/examples/unity/fairseq2.cpp

@@ -6,7 +6,7 @@
 #include <unordered_map>
 #include <algorithm>
 #include <iostream>
-
+#include <fnmatch.h>
 
 /// allocate the fairseq2 model and hyperparameters
 extern "C" fairseq2_model* fairseq2_model_alloc() {
@@ -16,10 +16,54 @@ extern "C" fairseq2_model* fairseq2_model_alloc() {
     return model;
 }
 
+void fairseq2_kv_cache_alloc(const fairseq2_model& model, std::size_t beam_size, std::size_t max_seq_len) {
+    // Note: we only allocate the cache for the decoder attention.
+    // For encoder attention since we compute it all at once,
+    // the allocation is delayed to the first forward pass, to not over allocate.
+    auto layer_glob_c = "*decoder.*attn.k_proj.weight";
+    for (auto named_tensor : model.tensors) {
+        const std::string& name = named_tensor.first;
+        if (::fnmatch(layer_glob_c, name.c_str(), 0) == FNM_NOMATCH)
+            continue;
+        ggml_tensor* k_proj = named_tensor.second;
+        int model_dim = k_proj->ne[0];
+        // remove the ".k_proj.weight" suffix
+        model.kv_cache[name.substr(0, name.size() - 14)] = KeyValueTensor {
+            ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
+            ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
+            0,
+        };
+    }
+}
+
+bool has_kv_cache(const fairseq2_model& model) {
+    return model.kv_cache.size() > 0;
+}
+
+// copy k and v to kv cache
+// kv.full_k[step_nr] = k;
+// kv.full_v[step_nr] = v;
+void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v) {
+    KeyValueTensor& kv = model.kv_cache[prefix];
+    int step_nr = kv.step_nr;
+
+    ggml_tensor* full_k = kv.full_k;
+    ggml_tensor* full_v = kv.full_v;
+
+    ggml_tensor* updated_k = ggml_set_2d_inplace(model.ctx, full_k, *k, full_k->nb[2], full_k->nb[1] * step_nr);
+    ggml_tensor* updated_v = ggml_set_2d_inplace(model.ctx, full_v, *v, full_v->nb[2], full_v->nb[1] * step_nr);
+
+    *k = ggml_slice(model.ctx, updated_k, 1, 0, step_nr + 1);
+    *v = ggml_slice(model.ctx, updated_v, 1, 0, step_nr + 1);
+    kv.step_nr = step_nr + 1;
+}
+
+
 
 inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
     const std::int64_t* data = &model.layer_config.at(name);
-    return *(double*)data;
+    double val = *(const double*)data;
+    return val;
 }
 
 extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
@@ -155,21 +199,23 @@ ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int n
     GGML_ASSERT(dim >= 0);
     GGML_ASSERT(dim < n_dims);
     GGML_ASSERT(n_dims < 4);
+    GGML_ASSERT(x->ne[dim] % num_el == 0);
+    GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]);  // `x` isn't contiguous along `dim`
     if (n_dims == 1) {
-        return ggml_reshape_2d(ctx, x, num_el, x->ne[0] / num_el);
+        return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
     } else if (n_dims == 2) {
         if (dim == 0) {
-            return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
+            return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
         } else { // dim == 1
-            return ggml_reshape_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el);
+            return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
         }
     } else { // (n_dims == 3)
         if (dim == 0) {
-            return ggml_reshape_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2]);
+            return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
         } else if (dim == 1) {
-            return ggml_reshape_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2]);
+            return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
         } else { // dim == 2
-            return ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el);
+            return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
         }
     }
 }
@@ -216,12 +262,40 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
     ggml_set_name(q, "q");
     q = _reshape_num_head(ctx, q, head_dim);  // (B * H, S, H_dim)
-    ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
-    ggml_set_name(k, "k");
-    k = _reshape_num_head(ctx, k, head_dim);  // (B * H, Sk, H_dim)
 
-    ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
-    ggml_set_name(v, "v");
+    ggml_tensor *k, *v;
+    if (!has_kv_cache(model)) {
+        k = Linear_forward(model, prefix + ".k_proj", keys);
+        ggml_set_name(k, "k");
+        v = Linear_forward(model, prefix + ".v_proj", values);
+        ggml_set_name(v, "v");
+    } else {
+        bool encoder_decoder_attn = keys == values && keys != queries;
+        if (encoder_decoder_attn) {
+            // The K and V tensors of an encoder-decoder attention (i.e. the
+            // projected encoder outputs) remain static during evaluation.
+
+            KeyValueTensor& kv_cache = model.kv_cache[prefix];
+            if (kv_cache.step_nr == 0) {
+                k = Linear_forward(model, prefix + ".k_proj", keys);
+                ggml_set_name(k, "k");
+                v = Linear_forward(model, prefix + ".v_proj", values);
+                ggml_set_name(v, "v");
+                model.kv_cache[prefix] = KeyValueTensor{k, v, 1};
+            } else {
+                k = kv_cache.full_k;
+                v = kv_cache.full_v;
+            }
+        } else {
+            // (1, K) -> (N, 1, K_proj)
+            k = Linear_forward(model, prefix + ".k_proj", keys);
+            // (1, V) -> (N, 1, V_proj)
+            v = Linear_forward(model, prefix + ".v_proj", values);
+
+            append_to_prev_kv(model, prefix, &k, &v);
+        }
+    }
+    k = _reshape_num_head(ctx, k, head_dim);  // (B * H, Sk, H_dim)
     v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
     v = ggml_cont(ctx, v);
 
@@ -851,8 +925,6 @@ extern "C" ggml_tensor* StandardTransformerDecoder_forward(
     return seqs;
 }
 
-using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
-
 
 int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
     auto opts = job.opts;
@@ -926,8 +998,7 @@ void _bootstrap_seqs_and_scores(
     ggml_tensor* full_seqs,
     ggml_tensor* scores,
     ggml_tensor* encoder_output,
-    ggml_tensor* encoder_padding_mask,
-    IncrementalStateBag state_bag
+    ggml_tensor* encoder_padding_mask
 ) {
     int prefix_seq_len = job.prefix_seq->ne[0];
     int max_seq_len = scores->ne[0];
@@ -959,7 +1030,6 @@ void _bootstrap_seqs_and_scores(
         /*padding_mask*/ nullptr,
         encoder_output,
         encoder_padding_mask
-        // TODO: state_bag
     );
     // TODO state_bag.increment_step(prefix_seq_len - 1)
 
@@ -1010,7 +1080,9 @@ int topk(
 
 void ggml_detach(ggml_tensor* a) {
     a->op = GGML_OP_NONE;
-    a->src[0] = nullptr;
+    for (int i = 0; i < GGML_MAX_SRC; ++i) {
+        a->src[i] = nullptr;
+    }
 }
 
 
@@ -1079,6 +1151,8 @@ extern "C" Hypothesis* generate_sequence(
     int source_seq_len = encoder_output->ne[1];
     int max_seq_len = _determine_max_seq_len(job, source_seq_len);
 
+    fairseq2_kv_cache_alloc(model, beam_size, max_seq_len);
+
     // (S_enc, M) -> (B, S_enc, M)
     _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
 
@@ -1096,9 +1170,8 @@ extern "C" Hypothesis* generate_sequence(
     ggml_set_name(scores, "scores_0");
     ggml_set_f32(scores, 0.0);
 
-    IncrementalStateBag state_bag = {};
     _bootstrap_seqs_and_scores(
-        model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
+        model, job, seqs, scores, encoder_output, encoder_padding_mask
     );
     int prefix_seq_len = job.prefix_seq->ne[0];
     int start_step = prefix_seq_len - 1;
@@ -1116,9 +1189,7 @@ extern "C" Hypothesis* generate_sequence(
 
     // TODO: memory management, there should be a per-step ggml_context for intermediary results
     for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
-        // because of no IncrementalStateBag we pass input from the start
-        // decoder_input = seqs[:, 0 : step_nr + 1]
-        ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
+        ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, step_nr, step_nr + 1);
         decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
         ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
             model,
@@ -1127,15 +1198,9 @@ extern "C" Hypothesis* generate_sequence(
             nullptr,  // We never generate PAD.
             encoder_output,
             encoder_padding_mask
-            // state_bag=state_bag,
-        ); // (B, S, D)
-
-        // state_bag.increment_step()
+        ); // (B, 1, D)
 
-        // Because of no IncrementalStateBag decoder_output here is of shape (B, S, D)
         // Just look at the last token.
-        decoder_output = ggml_slice(ctx, decoder_output, 1, step_nr, step_nr+1);
-        decoder_output = ggml_cont(ctx, decoder_output);
         decoder_output = ggml_flatten_1d(ctx, decoder_output, 0);  // (B, model_dim)
         ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);  // (B, vocab_size)
         ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);

+ 10 - 1
ggml/examples/unity/fairseq2.h

@@ -6,6 +6,12 @@
 #include "ggml.h"
 #include "kaldi-native-fbank/csrc/feature-fbank.h"
 
+struct KeyValueTensor {
+    ggml_tensor* full_k;
+    ggml_tensor* full_v;
+    int step_nr;
+    // ggml_tensor* key_padding_mask;
+};
 
 struct fairseq2_model {
     // Context containing all tensors memory
@@ -21,6 +27,9 @@ struct fairseq2_model {
     // Normally those can be inferred from hparams, but it avoids doing this logic in GGML
     std::unordered_map<std::string, std::int64_t> layer_config;
 
+    // KV cache for attention layers
+    mutable std::unordered_map<std::string, KeyValueTensor> kv_cache;
+
     // an inference context, not managed by this object
     // TODO: is this the best place to store this or should we also pass this to all forward methods ?
     ggml_context* ctx;
@@ -91,7 +100,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* queries,  // (slen, d_in)
     ggml_tensor* keys,  // (klen, d_in)
     ggml_tensor* values,  // (klen, d_out)
-    ggml_tensor* _ // (klen, slen)  TODO: do we need to pass mask here ?
+    ggml_tensor* attn_mask // (klen, slen)
 );