Sfoglia il codice sorgente

disable flash attn because of cross attention

Guillaume Wenzek 1 anno fa
parent
commit
c7b89f32f4
1 ha cambiato i file con 79 aggiunte e 149 eliminazioni
  1. 79 149
      ggml/examples/unity/fairseq2.cpp

+ 79 - 149
ggml/examples/unity/fairseq2.cpp

@@ -103,7 +103,8 @@ ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads)
     return x;
 }
 
-# define UNITY_FLASH_ATTN
+// TODO: flash_attn doesn't seem to work for cross attention because it assumes Q <= K
+# define UNITY_FLASH_ATTN 0
 
 extern "C" ggml_tensor* MultiheadAttention_forward(
     fairseq2_model& model,
@@ -131,7 +132,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     v = ggml_cont(ctx, v);
     ggml_set_name(v, "v");
 
-#ifdef UNITY_FLASH_ATTN
+#if UNITY_FLASH_ATTN
     // For flash_attn, we assume either no masks, or triangular masks.
     ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr);  // (H, S, H_dim)
     ggml_set_name(attn, "attn");
@@ -255,7 +256,6 @@ extern "C" ggml_tensor* PositionalEmbedding_forward(
     ggml_tensor* embeds
 ) {
     // This only work with the simple pos encoders
-    int encoding_dim = embeds->ne[0];
     int seq_len = embeds->ne[1];
     ggml_tensor* full_pos_embeds = model.tensors[prefix];
     ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
@@ -398,7 +398,7 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
 ggml_tensor* causal_mask_cache = nullptr;
 
 extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
-    auto seq_len = seqs->ne[0];
+    auto seq_len = seqs->ne[1];
     auto mask = causal_mask_cache;
     // TODO: this cache only works as long as we don't change the size/device too often
     // TODO: allow other ggml_type
@@ -612,14 +612,16 @@ int StandardBeamSearch_step(
         // The first step always indicates the beginning of the sequence and
         // has no score.
         if (step_nr > 0) {
-            lprobs = ggml_add(ctx, lprobs, last_scores);
+            lprobs = ggml_add_inplace(ctx, lprobs, last_scores);
         }
     } else {
         // Make probabilities contain cumulative scores for each hypothesis.
-        lprobs = ggml_add(ctx, lprobs, last_scores);
+        lprobs = ggml_add_inplace(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
     }
 
+    // Note this is where we will actually do the model inference.
     ggml_cgraph gf = ggml_build_forward(lprobs);
+    printf("StandardBeamSearch_step.graph.n_nodes: %d\n", gf.n_nodes);
     ggml_graph_compute_with_ctx(ctx, &gf, 1);
 
     // Take the best 2 x `beam_size` predictions. We'll choose the first
@@ -629,7 +631,7 @@ int StandardBeamSearch_step(
     int topk = std::min(2 * beam_size, vocab_size - 1);
 
     auto comp = [lprobs](std::int32_t a, std::int32_t b) {
-        return ggml_get_f32_1d(lprobs, a) < ggml_get_f32_1d(lprobs, b);
+        return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
     };
     auto cand = (std::int32_t*)candidate_indices->data;
     std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
@@ -637,19 +639,19 @@ int StandardBeamSearch_step(
     return topk;
 }
 
-bool _finalize_hypothesis(
+int _finalize_hypothesis(
     const SequenceGeneratorJob& job,
     ggml_context* ctx,
     int step_nr,
+    int vocab_size,
     std::int32_t candidate,
+    float tok_score,
     ggml_tensor* seqs, // (beam_size, seq_len)
     ggml_tensor* scores, // (beam_size, seq_len)
     std::vector<Hypothesis>& hypotheses
 ) {
-    int vocab_size = scores->ne[0];
     std::int32_t beam = candidate / vocab_size;
     std::int32_t token = candidate % vocab_size;
-    float tok_score = ggml_get_f32_1d(scores, candidate);
 
     // Detect beams that reached the minimum length and that end with an EOS.
     bool eos = token == job.eos_idx;
@@ -658,34 +660,33 @@ bool _finalize_hypothesis(
     // eos &= ggml_get_i32_1d(ignored_beam_mask, beam);
     // ggml_set_i32_1d(eos_mask, beam, eos);
 
-    if (!eos) return false;
+    if (!eos) return 0;
 
     // If the candidate beam is "finished", let's copy the score and sequence
     ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
     ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
 
     auto tok = (std::int32_t*)tokens->data;
-    auto sc = (float*)step_scores->data;
-    ggml_set_f32_1d(scores, scores->ne[0] * beam + step_nr + 1, tok_score);
     for (int i = 0; i < step_nr + 1; ++i) {
         tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
     }
     tok[step_nr + 1] = token;
 
+    // Convert from cumulative to per-step scores.
+    auto sc = (float*)step_scores->data;
     float last_score = tok_score;
     for (int i = step_nr; i >= 0; --i) {
-        // Convert from cumulative to per-step scores.
-        float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i + 0);
+        float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
         sc[i] = last_score - sc0;
         last_score = sc0;
     }
 
-    // Skip first EOS since it is always 0 and skews normalization.
     if (job.opts.normalize_scores)
-        tok_score /= std::pow((step_nr + 1), job.opts.len_penalty);
+        // Skip first EOS since it is always 0 and skews normalization.
+        tok_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
 
     hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
-    return true;
+    return 1;
 }
 
 /// Generates a translation for a single sequence
@@ -704,8 +705,8 @@ extern "C" float generate_sequence(
     ggml_tensor* output_seq
 ) {
     ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
-    int vocab_size = embed->ne[0];
-    int beam_size = job.opts.beam_size;
+    int vocab_size = embed->ne[1];
+    std::size_t beam_size = job.opts.beam_size;
     int source_seq_len = encoder_output->ne[1];
     int max_seq_len = _determine_max_seq_len(job, source_seq_len);
     ggml_context* ctx = model.ctx;
@@ -713,7 +714,8 @@ extern "C" float generate_sequence(
     // (S_enc, M) -> (B, S_enc, M)
     _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
 
-    std::vector<Hypothesis> finished_searches(beam_size);
+    std::vector<Hypothesis> finished_searches;
+    finished_searches.reserve(beam_size);
 
     // Initialize buffers. (B, S)
     ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
@@ -730,16 +732,13 @@ extern "C" float generate_sequence(
 
     // Holds the indices of beams (a beam can occur more than once) that we
     // should continue with in the next step.
-    ggml_tensor* beam_indices = nullptr;
+    ggml_tensor* beam_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
+    ggml_tensor* next_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
+    ggml_tensor* next_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, beam_size);
 
-    // Indices of next token
+    // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
     ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
-    for (int i = 0; i < vocab_size * beam_size; ++i) ggml_set_i32_1d(candidate_indices, i, i);
-
-    // Holds the indices of searches that we should continue with in the next
-    // step. If not `None`, it means we finalized one or more searches in the
-    // last step.
-    ggml_tensor* search_indices = nullptr;
+    for (std::size_t i = 0; i < vocab_size * beam_size; ++i) ggml_set_i32_1d(candidate_indices, i, i);
 
     // TODO: memory management
     // there should be a per-step ggml_context for intermediary results
@@ -813,133 +812,64 @@ extern "C" float generate_sequence(
             candidate_indices
         );
 
-        int ongoing_beams = 0;
-        for (std::int32_t c = 0; c < topk; ++c) {
-            bool finished = _finalize_hypothesis(job, ctx, step_nr, c, seqs, scores, finished_searches);
-            if (!finished) ongoing_beams += 1;
-
+        std::size_t ongoing_beams = 0;
+        int new_num_searches = 0;
+        for (std::int32_t i = 0; i < topk; ++i) {
+            int c = ggml_get_f32_1d(candidate_indices, i);
+            float tok_score = ggml_get_f32_1d(lprobs, c);
+            int finished = _finalize_hypothesis(job, ctx, step_nr, vocab_size, c, tok_score, seqs, scores, finished_searches);
+            new_num_searches += finished;
+            if (!finished){
+                std::int32_t beam = c / vocab_size;
+                std::int32_t token = c % vocab_size;
+
+                ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
+                ggml_set_f32_1d(next_tokens, ongoing_beams, token);
+                ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
+                ongoing_beams += 1 - finished;
+            }
             if (ongoing_beams >= beam_size) break;
+            if (finished_searches.size() >= beam_size) break;
+        }
+        if (finished_searches.size() >= beam_size) break;
+
+        // Reorder beams in the `seq` and `score` buffers. The same beam can
+        // be selected more than once.
+        ggml_tensor* new_seqs = seqs;
+        ggml_tensor* new_scores = scores;
+        if (step_nr > start_step) {
+            // (B, S), (B) -> (B, S)
+            // ggml_get_rows only work with floats ...
+            new_seqs->type = GGML_TYPE_F32;
+            new_seqs = ggml_get_rows(ctx, seqs, beam_indices);
+            new_scores = ggml_get_rows(ctx, new_scores, beam_indices);
         }
-        if (finished_searches.size() == beam_size) break;
-
-        // TODO: recreate scores and seqs with the best beams
-
-        // Remove finished searches (ones for which `beam_size` finalized
-        // beams have been generated) from the batch.
-        ggml_tensor* search_indices = nullptr;
-        // if (newly_finished_searches) {
-        //     new_num_searches = num_searches - len(newly_finished_searches)
-
-        //     // Construct `search_indices` which holds indices of searches
-        //     // to keep for the next step.
-        //     search_mask = torch.full((num_searches,), True, device=device)
-
-        //     search_mask[newly_finished_searches] = False
-
-        //     search_indices = torch.arange(num_searches, device=device)
-
-        //     search_indices = search_indices.masked_select(search_mask)
-
-        //     // Filter out removed batches from state variables.
-        //     // (N, B) -> (N - F, B)
-        //     ignored_beam_mask = ignored_beam_mask[search_indices]
-
-        //     // (N, 2 x B) -> (N - F, 2 x B)
-        //     cand_scores       = cand_scores      [search_indices]
-        //     cand_indices      = cand_indices     [search_indices]
-        //     cand_beam_indices = cand_beam_indices[search_indices]
-
-        //     // (N) -> (N - F)
-        //     search_offsets.resize_(new_num_searches, 1)
-
-        //     // (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
-        //     global_cand_beam_indices = cand_beam_indices + search_offsets
-
-        //     // (N, 2 x B) -> (N - F, 2 x B)
-        //     eos_mask = eos_mask[search_indices]
-
-        //     // (N x B, S) -> (N, B, S)
-        //     seqs   = seqs  .view(num_searches, -1)
-        //     scores = scores.view(num_searches, -1)
-
-        //     // (N, B, S + 1) -> ((N - F) x B, S)
-        //     seqs   = seqs  [search_indices].view(new_num_searches * beam_size, -1)
-        //     scores = scores[search_indices].view(new_num_searches * beam_size, -1)
-
-        //     // (N x B, S_enc, M) -> (N, B, S_enc, M)
-        //     encoder_output = encoder_output.unflatten(0, (num_searches, -1))
-
-        //     // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
-        //     encoder_output = encoder_output[search_indices].flatten(0, 1)
-
-        //     if encoder_padding_mask is not None:
-        //         // (N x B, S_enc, M) -> (N, B, S_enc, M)
-        //         padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
-
-        //         // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
-        //         encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
-
-        //     num_searches = new_num_searches
-        // }
-
-        // eos_mask[:, :beam_size][ignored_beam_mask] = True
-
-        // // Set `beam_weights` so that values greater than or equal to 2 x
-        // // `beam_size` indicate finished beams (i.e. end with EOS) and values
-        // // less than 2 x `beam_size` indicate active beams.
-        // // (N, 2 x B)
-        // beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
-
-        // // Get the top `beam_size` active beams, which are the beams with the
-        // // smallest weights in `active_beam_weights`.
-        // // (N, B)
-        // active_beam_weights, active_beams = torch.topk(
-        //     beam_weights, k=beam_size, dim=1, largest=False
-        // )
-
-        // // Update to ignore finalized beams in the next step.
-        // // (N, B)
-        // ignored_beam_mask = active_beam_weights >= 2 * beam_size
-
-        // // We should always have at least one active beam in each search.
-        // assert (~ignored_beam_mask).any(dim=1).all()
-
-        // // Denotes which beams are continued for each new hypothesis (a beam
-        // // can be selected more than once).
-        // // (N, B)
-        // beam_indices = torch.gather(
-        //     global_cand_beam_indices, dim=1, index=active_beams
-        // )
-
-        // // (N, B) -> (N x B)
-        // beam_indices = beam_indices.view(-1)
-
-        // // Reorder beams in the `seq` and `score` buffers. The same beam can
-        // // be selected more than once.
-        // if (step_nr > start_step) {
-        //     // seqs  [:, : step_nr + 1] = torch.index_select(
-        //     //     seqs  [:, : step_nr + 1], dim=0, index=beam_indices
-        //     // )
-        //     // scores[:, : step_nr + 1] = torch.index_select(
-        //     //     scores[:, : step_nr + 1], dim=0, index=beam_indices
-        //     // )
-        // }
 
-        // // (N x B, S) -> (N, B, S)
-        // // seqs_view   = seqs  .view(num_searches, beam_size, -1)
-        // // scores_view = scores.view(num_searches, beam_size, -1)
+        // new_seqs[:, step_nr + 1] = next_tokens
+        ggml_set_1d_inplace(ctx, new_seqs, next_tokens, new_seqs->nb[0] * (step_nr + 1));
+        ggml_set_1d_inplace(ctx, new_scores, next_scores, new_scores->nb[0] * (step_nr + 1));
 
-        // // seqs_view  [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
-        // // scores_view[:, :, step_nr + 1] = torch.gather(cand_scores,  dim=1, index=active_beams)
+        ggml_cgraph gf = ggml_build_forward(new_seqs);
+        ggml_graph_compute_with_ctx(ctx, &gf, 1);
+        new_seqs->type = GGML_TYPE_I32;
+        gf = ggml_build_forward(new_scores);
+        ggml_graph_compute_with_ctx(ctx, &gf, 1);
 
+        // TODO the old seqs and score buffers could be reused for next step
+        seqs = new_seqs;
+        scores = new_scores;
     }
-    // Ensure that hypotheses are sorted by their scores before returning.
-    // for batch in finished_searches:
-    //     batch.sort(key=lambda b: b.score, reverse=True)  # type: ignore[arg-type, return-value]
 
-    // return SequenceGeneratorOutput(
-    //     results=finished_searches, device=device, pad_idx=self.pad_idx
-    // )
+    // Ensure that hypotheses are sorted by decreasing scores before returning.
+    std::sort(
+        finished_searches.begin(),
+        finished_searches.end(),
+        [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
+    );
+
+    // For now just return the best sequence
+    // TODO: return structured output
+    *output_seq = *(finished_searches[0].seq);
 
     return 0.0f;
 }