|
@@ -114,6 +114,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
|
|
bool no_alloc_save = ggml_get_no_alloc(ctx);
|
|
bool no_alloc_save = ggml_get_no_alloc(ctx);
|
|
ggml_set_no_alloc(ctx, false);
|
|
ggml_set_no_alloc(ctx, false);
|
|
int n_steps = (*k)->ne[1];
|
|
int n_steps = (*k)->ne[1];
|
|
|
|
+ // printf("Prefix: %s n_steps: %d\n", prefix.c_str(), n_steps);
|
|
int k_proj, batch_size;
|
|
int k_proj, batch_size;
|
|
|
|
|
|
if (kv.full_k != nullptr) {
|
|
if (kv.full_k != nullptr) {
|
|
@@ -136,6 +137,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
|
|
ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
|
|
ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
|
|
ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
|
|
ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
|
|
step_nr += n_steps;
|
|
step_nr += n_steps;
|
|
|
|
+ // printf("Prefix: %s step_nr: %d\n", prefix.c_str(), step_nr);
|
|
|
|
|
|
GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
|
|
GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
|
|
|
|
|
|
@@ -147,7 +149,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
|
|
1, step_nr - 1, step_nr
|
|
1, step_nr - 1, step_nr
|
|
);
|
|
);
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
kv.step_nr = step_nr;
|
|
kv.step_nr = step_nr;
|
|
ggml_set_no_alloc(ctx, no_alloc_save);
|
|
ggml_set_no_alloc(ctx, no_alloc_save);
|
|
}
|
|
}
|
|
@@ -191,8 +193,9 @@ void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, g
|
|
void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
|
|
void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
|
|
auto self_attn_glob = "*.self_attn";
|
|
auto self_attn_glob = "*.self_attn";
|
|
for (auto& named_kv : model.kv_cache) {
|
|
for (auto& named_kv : model.kv_cache) {
|
|
- if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH)
|
|
|
|
|
|
+ if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH) {
|
|
continue;
|
|
continue;
|
|
|
|
+ }
|
|
|
|
|
|
_reorder_kv_cache(ctx, gf, named_kv.second, new_order);
|
|
_reorder_kv_cache(ctx, gf, named_kv.second, new_order);
|
|
}
|
|
}
|
|
@@ -438,12 +441,14 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
ggml_set_name(v, "v");
|
|
ggml_set_name(v, "v");
|
|
// Note we are only storing a pointer to the buffer, not the full graph
|
|
// Note we are only storing a pointer to the buffer, not the full graph
|
|
kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
|
|
kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
|
|
|
|
+ printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
|
|
ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
|
|
ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
|
|
kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
|
|
kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
|
|
ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
|
|
ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
|
|
kv_cache.step_nr = keys->ne[1];
|
|
kv_cache.step_nr = keys->ne[1];
|
|
model.ctx = ctx;
|
|
model.ctx = ctx;
|
|
} else {
|
|
} else {
|
|
|
|
+ printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
|
|
k = kv_cache.full_k;
|
|
k = kv_cache.full_k;
|
|
v = kv_cache.full_v;
|
|
v = kv_cache.full_v;
|
|
GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
|
|
GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
|
|
@@ -451,19 +456,40 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
}
|
|
}
|
|
} else { // self attention
|
|
} else { // self attention
|
|
// (1, K) -> (N, 1, K_proj)
|
|
// (1, K) -> (N, 1, K_proj)
|
|
|
|
+ for (auto& named_kv : model.kv_cache) {
|
|
|
|
+ auto enc_dec_attn_glob = "*.encoder_decoder_attn";
|
|
|
|
+ if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
|
|
|
|
+ printf("HERE BEFORE CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
|
|
|
|
+ if(named_kv.second.full_k != nullptr)
|
|
|
|
+ printf("HERE BEFORE CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
|
|
+
|
|
|
|
+ for (auto& named_kv : model.kv_cache) {
|
|
|
|
+ auto enc_dec_attn_glob = "*.encoder_decoder_attn";
|
|
|
|
+ if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
|
|
|
|
+ printf("HERE AFTER CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
|
|
|
|
+ if(named_kv.second.full_k != nullptr)
|
|
|
|
+ printf("HERE AFTER CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
ggml_set_name(k, "k");
|
|
ggml_set_name(k, "k");
|
|
// (1, V) -> (N, 1, V_proj)
|
|
// (1, V) -> (N, 1, V_proj)
|
|
v = Linear_forward(model, prefix + ".v_proj", values);
|
|
v = Linear_forward(model, prefix + ".v_proj", values);
|
|
ggml_set_name(v, "v");
|
|
ggml_set_name(v, "v");
|
|
|
|
|
|
|
|
+
|
|
append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
|
|
append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
|
|
|
|
+
|
|
}
|
|
}
|
|
}
|
|
}
|
|
k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
|
|
k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
|
|
v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
|
|
v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
|
|
v = ggml_cont(ctx, v);
|
|
v = ggml_cont(ctx, v);
|
|
|
|
|
|
|
|
+
|
|
#if UNITY_FLASH_ATTN
|
|
#if UNITY_FLASH_ATTN
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
|
|
ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
|
|
@@ -496,6 +522,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
|
|
ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
|
|
ggml_set_name(out, "out");
|
|
ggml_set_name(out, "out");
|
|
|
|
|
|
|
|
+
|
|
return out;
|
|
return out;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1025,7 +1052,6 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
|
|
if (norm_order != TRANSFORMER_NORM_ORDER_POST)
|
|
if (norm_order != TRANSFORMER_NORM_ORDER_POST)
|
|
seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
|
|
seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
|
|
|
|
|
|
-
|
|
|
|
seqs = MultiheadAttention_forward(
|
|
seqs = MultiheadAttention_forward(
|
|
model,
|
|
model,
|
|
prefix + ".encoder_decoder_attn",
|
|
prefix + ".encoder_decoder_attn",
|
|
@@ -1443,8 +1469,8 @@ extern "C" Hypothesis* generate_sequence(
|
|
|
|
|
|
int prefix_seq_len = job.prefix_seq->ne[0];
|
|
int prefix_seq_len = job.prefix_seq->ne[0];
|
|
int start_step = prefix_seq_len - 1;
|
|
int start_step = prefix_seq_len - 1;
|
|
- ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[0]);
|
|
|
|
- ggml_context* step_ctx = ctx_from_buffer(local_bufs[1]);
|
|
|
|
|
|
+ ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[1]);
|
|
|
|
+ ggml_context* step_ctx = ctx_from_buffer(local_bufs[0]);
|
|
GGML_ASSERT(step_ctx != search_ctx);
|
|
GGML_ASSERT(step_ctx != search_ctx);
|
|
GGML_ASSERT(prev_step_ctx != step_ctx);
|
|
GGML_ASSERT(prev_step_ctx != step_ctx);
|
|
model.ctx = prev_step_ctx;
|
|
model.ctx = prev_step_ctx;
|
|
@@ -1525,7 +1551,7 @@ extern "C" Hypothesis* generate_sequence(
|
|
ggml_detach(lprobs);
|
|
ggml_detach(lprobs);
|
|
ggml_allocr_reset(step_alloc);
|
|
ggml_allocr_reset(step_alloc);
|
|
#if DEBUG_MEM_USAGE
|
|
#if DEBUG_MEM_USAGE
|
|
- printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf.n_nodes);
|
|
|
|
|
|
+ printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf->n_nodes);
|
|
printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
|
|
printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
|
|
std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
|
|
std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
|
|
#endif
|
|
#endif
|
|
@@ -1631,7 +1657,7 @@ end_of_beam_search:
|
|
);
|
|
);
|
|
|
|
|
|
printf_mem_usage(search_ctx, "search_ctx");
|
|
printf_mem_usage(search_ctx, "search_ctx");
|
|
- fairseq2_kv_cache_reset(model);
|
|
|
|
|
|
+ // fairseq2_kv_cache_reset(model);
|
|
model.ctx = original_ctx;
|
|
model.ctx = original_ctx;
|
|
return finished_searches_begin;
|
|
return finished_searches_begin;
|
|
}
|
|
}
|