Browse Source

simplify _finalize_hypothesis

Guillaume Wenzek 1 year ago
parent
commit
1756897d23
1 changed files with 23 additions and 31 deletions
  1. 23 31
      ggml/examples/unity/fairseq2.cpp

+ 23 - 31
ggml/examples/unity/fairseq2.cpp

@@ -687,26 +687,17 @@ void ggml_detach(ggml_tensor* a) {
 }
 
 
-int _finalize_hypothesis(
+void _finalize_hypothesis(
     const SequenceGeneratorJob& job,
     ggml_context* ctx,
     int step_nr,
-    int vocab_size,
-    std::int32_t candidate,
-    float tok_score,
+    std::int32_t beam,
+    std::int32_t token,
+    float eos_score,
     ggml_tensor* seqs, // (beam_size, seq_len)
     ggml_tensor* scores, // (beam_size, seq_len)
     std::vector<Hypothesis>& hypotheses
 ) {
-    std::int32_t beam = candidate / vocab_size;
-    std::int32_t token = candidate % vocab_size;
-
-    // Detect beams that reached the minimum length and that end with an EOS.
-    bool eos = token == job.eos_idx;
-    eos &= tok_score != -INFINITY;
-
-    if (!eos) return 0;
-
     // If the candidate beam is "finished", let's copy the score and sequence
     ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
     ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
@@ -719,7 +710,7 @@ int _finalize_hypothesis(
 
     // Convert from cumulative to per-step scores.
     auto sc = (float*)step_scores->data;
-    float last_score = tok_score;
+    float last_score = eos_score;
     sc[step_nr + 1] = last_score;
     for (int i = step_nr; i >= 0; --i) {
         float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
@@ -729,11 +720,9 @@ int _finalize_hypothesis(
 
     if (job.opts.normalize_scores)
         // Skip first EOS since it is always 0 and skews normalization.
-        tok_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
+        eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
 
-    // TODO the score computed here isn't the same than computed by fairseq2.
-    hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
-    return 1;
+    hypotheses.emplace_back(Hypothesis{tokens, eos_score, step_scores});
 }
 
 /// Generates a translation for a single sequence
@@ -908,24 +897,27 @@ extern "C" float generate_sequence(
         );
 
         std::size_t ongoing_beams = 0;
-        int new_num_searches = 0;
         for (std::int32_t i = 0; i < K; ++i) {
             int c = ggml_get_f32_1d(candidate_indices, i);
+            std::int32_t beam = c / vocab_size;
+            std::int32_t token = c % vocab_size;
             float tok_score = ggml_get_f32_1d(lprobs, c);
-            int finished = _finalize_hypothesis(job, ctx, step_nr, vocab_size, c, tok_score, seqs, scores, finished_searches);
-            new_num_searches += finished;
-            if (!finished){
-                std::int32_t beam = c / vocab_size;
-                std::int32_t token = c % vocab_size;
-
-                ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
-                ggml_set_f32_1d(next_tokens, ongoing_beams, token);
-                ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
-                ongoing_beams += 1 - finished;
+
+            // Detect beams that reached the minimum length and that end with an EOS.
+            bool eos = token == job.eos_idx;
+            eos &= tok_score != -INFINITY;
+            if (eos) {
+                _finalize_hypothesis(job, ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches);
+                if (finished_searches.size() >= beam_size)
+                    goto end_of_beam_search;
+                continue;
             }
+
+            ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
+            ggml_set_f32_1d(next_tokens, ongoing_beams, token);
+            ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
+            ongoing_beams += 1;
             if (ongoing_beams >= beam_size) break;
-            if (finished_searches.size() >= beam_size)
-                goto end_of_beam_search;
         }
 
         // Reorder beams in the `seq` and `score` buffers. The same beam can