cndn před 1 rokem
rodič
revize
e1bb22b13f
1 změnil soubory, kde provedl 33 přidání a 7 odebrání
  1. 33 7
      ggml/examples/unity/fairseq2.cpp

+ 33 - 7
ggml/examples/unity/fairseq2.cpp

@@ -114,6 +114,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
     bool no_alloc_save = ggml_get_no_alloc(ctx);
     ggml_set_no_alloc(ctx, false);
     int n_steps = (*k)->ne[1];
+    // printf("Prefix: %s   n_steps: %d\n", prefix.c_str(), n_steps);
     int k_proj, batch_size;
 
     if (kv.full_k != nullptr) {
@@ -136,6 +137,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
     ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
     ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
     step_nr += n_steps;
+    // printf("Prefix: %s   step_nr: %d\n", prefix.c_str(), step_nr);
 
     GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
 
@@ -147,7 +149,7 @@ void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, g
             1, step_nr - 1, step_nr
         );
     }
-
+    
     kv.step_nr = step_nr;
     ggml_set_no_alloc(ctx, no_alloc_save);
 }
@@ -191,8 +193,9 @@ void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, g
 void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
     auto self_attn_glob = "*.self_attn";
     for (auto& named_kv : model.kv_cache) {
-        if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH)
+        if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH) {
             continue;
+        }
 
         _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
     }
@@ -438,12 +441,14 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
                 ggml_set_name(v, "v");
                 // Note we are only storing a pointer to the buffer, not the full graph
                 kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
+                printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
                 ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
                 kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
                 ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
                 kv_cache.step_nr = keys->ne[1];
                 model.ctx = ctx;
             } else {
+                printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
                 k = kv_cache.full_k;
                 v = kv_cache.full_v;
                 GGML_ASSERT(keys->ne[1] == k->ne[1]);  // cache content doesn't match the input sequence
@@ -451,19 +456,40 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
             }
         } else { // self attention
             // (1, K) -> (N, 1, K_proj)
+            for (auto& named_kv : model.kv_cache) {
+                auto enc_dec_attn_glob = "*.encoder_decoder_attn";
+                if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
+                    printf("HERE BEFORE CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
+                    if(named_kv.second.full_k != nullptr)
+                    printf("HERE BEFORE CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
+                    }
+            }
+
             k = Linear_forward(model, prefix + ".k_proj", keys);
+
+            for (auto& named_kv : model.kv_cache) {
+                auto enc_dec_attn_glob = "*.encoder_decoder_attn";
+                if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
+                    printf("HERE AFTER CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
+                    if(named_kv.second.full_k != nullptr)
+                    printf("HERE AFTER CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
+                    }
+            }
             ggml_set_name(k, "k");
             // (1, V) -> (N, 1, V_proj)
             v = Linear_forward(model, prefix + ".v_proj", values);
             ggml_set_name(v, "v");
 
+
             append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
+
         }
     }
     k = _reshape_num_head(ctx, k, head_dim);  // (B * H, Sk, H_dim)
     v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
     v = ggml_cont(ctx, v);
 
+
 #if UNITY_FLASH_ATTN
     // For flash_attn, we assume either no masks, or triangular masks.
     ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr);  // (B * H, S, H_dim)
@@ -496,6 +522,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
     ggml_set_name(out, "out");
 
+
     return out;
 }
 
@@ -1025,7 +1052,6 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
     if (norm_order != TRANSFORMER_NORM_ORDER_POST)
         seqs =  LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
 
-
     seqs = MultiheadAttention_forward(
         model,
         prefix + ".encoder_decoder_attn",
@@ -1443,8 +1469,8 @@ extern "C" Hypothesis* generate_sequence(
 
     int prefix_seq_len = job.prefix_seq->ne[0];
     int start_step = prefix_seq_len - 1;
-    ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[0]);
-    ggml_context* step_ctx = ctx_from_buffer(local_bufs[1]);
+    ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[1]);
+    ggml_context* step_ctx = ctx_from_buffer(local_bufs[0]);
     GGML_ASSERT(step_ctx != search_ctx);
     GGML_ASSERT(prev_step_ctx != step_ctx);
     model.ctx = prev_step_ctx;
@@ -1525,7 +1551,7 @@ extern "C" Hypothesis* generate_sequence(
         ggml_detach(lprobs);
         ggml_allocr_reset(step_alloc);
 #if DEBUG_MEM_USAGE
-        printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf.n_nodes);
+        printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf->n_nodes);
         printf("  Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
         std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
 #endif
@@ -1631,7 +1657,7 @@ end_of_beam_search:
     );
 
     printf_mem_usage(search_ctx, "search_ctx");
-    fairseq2_kv_cache_reset(model);
+    // fairseq2_kv_cache_reset(model);
     model.ctx = original_ctx;
     return finished_searches_begin;
 }