Tuan Tran 1 жил өмнө
parent
commit
86b10cb3e6

+ 7 - 2
ggml/examples/unity/fairseq2.cpp

@@ -24,7 +24,7 @@ ggml_tensor* ggml_detach(ggml_tensor* a) {
 // Enabling this flag allows to explictly reset memory buffers, making it more explicit
 // when we read garbage data.
 // It also prints memory usage information, which is useful to
-#define DEBUG_MEM_USAGE DEBUG
+#define DEBUG_MEM_USAGE 1
 size_t MB = 1024 * 1024;
 
 void printf_mem_usage(ggml_context* ctx, std::string name) {
@@ -1204,6 +1204,10 @@ void _bootstrap_seqs_and_scores(
     GGML_ASSERT(prefix_seq_len > 0);
     if (prefix_seq_len == 1) {
         // bootstrap all beams in full_seqs with EOS
+        // This is equivalent to:
+        // // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
+        // because in normal case: prefix_seq[0] = EOS
+        // 
         int eos_id = model.vocab.token_to_id["</s>"];
         if (model.tgt_vocab.id_to_token.size()) {
             eos_id = model.tgt_vocab.token_to_id["</s>"];
@@ -1467,7 +1471,8 @@ extern "C" Hypothesis* generate_sequence(
     ggml_set_name(scores, "scores_0");
     ggml_set_f32(scores, 0.0);
 
-    int start_step = 1;
+    int prefix_seq_len = job.prefix_seq->ne[0];
+    int start_step = prefix_seq_len - 1;
     ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step - 1) % 2]);
     ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);	    
     GGML_ASSERT(step_ctx != search_ctx);