Explorar el Código

fix n_threads propagation (#264)

Ning hace 1 año
padre
commit
e0403847a5

+ 9 - 7
ggml/examples/unity/fairseq2.cpp

@@ -1176,7 +1176,8 @@ extern "C" void _bootstrap_seqs_and_scores(
     ggml_tensor* full_seqs,
     ggml_tensor* scores,
     ggml_tensor* encoder_output,
-    ggml_tensor* encoder_padding_mask
+    ggml_tensor* encoder_padding_mask,
+    int n_threads
 ) {
     int prefix_seq_len = job.prefix_seq->ne[0];
     int max_seq_len = scores->ne[0];
@@ -1213,7 +1214,7 @@ extern "C" void _bootstrap_seqs_and_scores(
     ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
 
     ggml_cgraph gf = ggml_build_forward(lprobs);
-    ggml_graph_compute_with_ctx(ctx, &gf, 1);
+    ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
 
     // Fetch scores of next steps from "lprobs"
     float p_score = 0;
@@ -1354,7 +1355,8 @@ extern "C" Hypothesis* generate_sequence(
     const SequenceGeneratorJob& job,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
-    ggml_context* result_ctx
+    ggml_context* result_ctx,
+    int n_threads
 ) {
     // Pre allocate memory buffers.
     // * step_ctx: contains metadata for the model graph, as well as some explicit
@@ -1414,7 +1416,7 @@ extern "C" Hypothesis* generate_sequence(
     // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
     model.kv_cache_ctx = search_ctx;
     _bootstrap_seqs_and_scores(
-        model, job, seqs, scores, encoder_output, encoder_padding_mask
+        model, job, seqs, scores, encoder_output, encoder_padding_mask, n_threads
     );
 
     // Holds the indices of beams (a beam can occur more than once) that we
@@ -1456,7 +1458,7 @@ extern "C" Hypothesis* generate_sequence(
         ggml_cgraph gf = ggml_build_forward(lprobs);
         size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, &gf);
         GGML_UNUSED(fwd_mem);
-        ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
+        ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
         ggml_detach(lprobs);
         ggml_allocr_reset(step_alloc);
 #if DEBUG_MEM_USAGE
@@ -1483,7 +1485,7 @@ extern "C" Hypothesis* generate_sequence(
         }
 
         gf = ggml_build_forward(lprobs);
-        ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
+        ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
 
         // Determine (beam, token) candidates for the next step.
         // (N, 2 x B)
@@ -1523,7 +1525,7 @@ extern "C" Hypothesis* generate_sequence(
         ggml_cgraph gf_reorder = ggml_build_forward(new_seqs);
         ggml_build_forward_expand(&gf_reorder, new_scores);
         reorder_kv_cache(model, step_ctx, &gf_reorder, beam_indices);
-        ggml_graph_compute_with_ctx(step_ctx, &gf_reorder, 1);
+        ggml_graph_compute_with_ctx(step_ctx, &gf_reorder, n_threads);
         seqs = ggml_detach(new_seqs);
         scores = ggml_detach(new_scores);
 

+ 2 - 1
ggml/examples/unity/fairseq2.h

@@ -307,7 +307,8 @@ extern "C" Hypothesis* generate_sequence(
     const SequenceGeneratorJob& opts,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
-    ggml_context* result_ctx
+    ggml_context* result_ctx,
+    int n_threads
 );
 
 extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out);

+ 1 - 1
ggml/examples/unity/unity.cpp

@@ -116,7 +116,7 @@ Hypothesis* unity_decode(
     ((int *)prefix_seq->data)[0]  = job.eos_idx;
     ((int *)prefix_seq->data)[1]  = tgt_lang_idx;
     job.prefix_seq = prefix_seq;
-    return generate_sequence(model, job, encoder_output, nullptr, model.ctx);
+    return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
 }
 
 int main(int argc, char ** argv) {