|
@@ -25,23 +25,35 @@ extern "C" void fairseq2_kv_cache_alloc(const fairseq2_model& model, int beam_si
|
|
|
// Note: we only allocate the cache for the decoder attention.
|
|
|
// For encoder attention since we compute it all at once,
|
|
|
// the allocation is delayed to the first forward pass, to not over allocate.
|
|
|
- auto layer_glob_c = "*decoder.*attn.k_proj.weight";
|
|
|
+ auto attn_glob = "*decoder.*_attn.k_proj.weight";
|
|
|
+ auto self_attn_glob = "*decoder.*self_attn.k_proj.weight";
|
|
|
ggml_tensor* self_attn_mask = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, max_seq_len, max_seq_len);
|
|
|
- self_attn_mask = ggml_diag_mask_inf(model.ctx, self_attn_mask, 0);
|
|
|
+ self_attn_mask = ggml_diag_mask_inf_inplace(model.ctx, self_attn_mask, 0);
|
|
|
+ ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
|
|
|
|
|
|
for (auto named_tensor : model.tensors) {
|
|
|
const std::string& name = named_tensor.first;
|
|
|
- if (::fnmatch(layer_glob_c, name.c_str(), 0) == FNM_NOMATCH)
|
|
|
+ if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
|
|
|
continue;
|
|
|
+ // create a cache entry without the ".k_proj.weight" suffix
|
|
|
+ const std::string& shortname = name.substr(0, name.size() - 14);
|
|
|
+ KeyValueTensor& kv = model.kv_cache[shortname];
|
|
|
+ kv.step_nr = 0;
|
|
|
+
|
|
|
+ if (::fnmatch(self_attn_glob, name.c_str(), 0) == FNM_NOMATCH) {
|
|
|
+ // enc_dec_attn
|
|
|
+ // the tensors will be allocated during the first forward
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // self_attn
|
|
|
ggml_tensor* k_proj = named_tensor.second;
|
|
|
int model_dim = k_proj->ne[0];
|
|
|
- // remove the ".k_proj.weight" suffix
|
|
|
- model.kv_cache[name.substr(0, name.size() - 14)] = KeyValueTensor {
|
|
|
- ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
|
|
|
- ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size),
|
|
|
- self_attn_mask,
|
|
|
- 0,
|
|
|
- };
|
|
|
+ kv.full_k = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
|
|
|
+ kv.full_v = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
|
|
|
+ kv.self_attn_mask = self_attn_mask;
|
|
|
+ ggml_format_name(kv.full_k, "%s.k_cache", shortname.c_str());
|
|
|
+ ggml_format_name(kv.full_v, "%s.v_cache", shortname.c_str());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -54,6 +66,7 @@ bool has_kv_cache(const fairseq2_model& model) {
|
|
|
// kv.full_v[step_nr] = v;
|
|
|
void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
|
|
|
KeyValueTensor& kv = model.kv_cache[prefix];
|
|
|
+ GGML_ASSERT(kv.full_k != nullptr); // key not found !
|
|
|
int step_nr = kv.step_nr;
|
|
|
|
|
|
ggml_tensor* full_k = kv.full_k;
|
|
@@ -66,6 +79,8 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
|
|
|
|
|
|
*k = ggml_slice(model.ctx, updated_k, 1, 0, step_nr + 1);
|
|
|
*v = ggml_slice(model.ctx, updated_v, 1, 0, step_nr + 1);
|
|
|
+ ggml_format_name(*k, "%s (step=%d)", full_k->name, step_nr);
|
|
|
+ ggml_format_name(*v, "%s (step=%d)", full_v->name, step_nr);
|
|
|
|
|
|
// qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
|
|
|
// we return the Sq slice of the (Sq, Sk) attention mask
|
|
@@ -97,13 +112,17 @@ ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
|
|
|
|
|
|
|
|
|
void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
|
|
|
- ggml_detach(kv.full_k);
|
|
|
- kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
|
|
|
- ggml_build_forward_expand(gf, kv.full_k);
|
|
|
+ if (kv.full_k != nullptr) {
|
|
|
+ ggml_detach(kv.full_k);
|
|
|
+ kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
|
|
|
+ ggml_build_forward_expand(gf, kv.full_k);
|
|
|
+ }
|
|
|
|
|
|
- ggml_detach(kv.full_v);
|
|
|
- kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
|
|
|
- ggml_build_forward_expand(gf, kv.full_v);
|
|
|
+ if (kv.full_v != nullptr) {
|
|
|
+ ggml_detach(kv.full_v);
|
|
|
+ kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
|
|
|
+ ggml_build_forward_expand(gf, kv.full_v);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
|
|
@@ -333,19 +352,27 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
KeyValueTensor& kv_cache = model.kv_cache[prefix];
|
|
|
if (kv_cache.step_nr == 0) {
|
|
|
k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
|
- ggml_set_name(k, "k");
|
|
|
+ ggml_format_name(k, "%s.k_cache", prefix.c_str());
|
|
|
v = Linear_forward(model, prefix + ".v_proj", values);
|
|
|
- ggml_set_name(v, "v");
|
|
|
- model.kv_cache[prefix] = KeyValueTensor{k, v, nullptr, 1};
|
|
|
+ ggml_format_name(v, "%s.v_cache", prefix.c_str());
|
|
|
+ // TODO: encoder_padding_mask
|
|
|
+ kv_cache.full_k = k;
|
|
|
+ kv_cache.full_v = v;
|
|
|
+ kv_cache.step_nr = keys->ne[1];
|
|
|
} else {
|
|
|
k = kv_cache.full_k;
|
|
|
v = kv_cache.full_v;
|
|
|
+ // This is a cache collision. TODO: fairseq2_kv_cache_reset
|
|
|
+ GGML_ASSERT(keys->ne[1] == k->ne[1]);
|
|
|
+ GGML_ASSERT(values->ne[1] == v->ne[1]);
|
|
|
}
|
|
|
} else { // self attention
|
|
|
// (1, K) -> (N, 1, K_proj)
|
|
|
k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
|
+ ggml_set_name(k, "k");
|
|
|
// (1, V) -> (N, 1, V_proj)
|
|
|
v = Linear_forward(model, prefix + ".v_proj", values);
|
|
|
+ ggml_set_name(v, "v");
|
|
|
|
|
|
append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
|
|
|
}
|
|
@@ -776,11 +803,13 @@ struct ggml_tensor * ggml_slice(
|
|
|
GGML_ASSERT(start <= end);
|
|
|
GGML_ASSERT(end <= ne[axis]);
|
|
|
|
|
|
+
|
|
|
ne[axis] = end - start;
|
|
|
size_t offset = a->nb[axis] * start;
|
|
|
|
|
|
size_t* nb = a->nb;
|
|
|
ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
|
|
|
+ ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
|
|
|
result->n_dims = a->n_dims;
|
|
|
return result;
|
|
|
}
|