Pārlūkot izejas kodu

generate fairseq2.cpp

Guillaume Wenzek 1 gadu atpakaļ
vecāks
revīzija
2fb09f34fb
2 mainītis faili ar 529 papildinājumiem un 0 dzēšanām
  1. 480 0
      ggml/examples/unity/fairseq2.cpp
  2. 49 0
      ggml/examples/unity/fairseq2.h

+ 480 - 0
ggml/examples/unity/fairseq2.cpp

@@ -1,6 +1,8 @@
 #include <math.h>
 #include "ggml.h"
 #include "fairseq2.h"
+#include <unordered_map>
+#include <algorithm>
 
 
 /// allocate the fairseq2 model and hyperparameters
@@ -383,3 +385,481 @@ extern "C" ggml_tensor* StandardTransformerDecoder_forward(
 
     return seqs;
 }
+
+using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
+
+
+int _determine_max_seq_len(const SequenceGeneratorJob& job) {
+    auto opts = job.opts;
+    int max_seq_len = -1;
+    if (job.source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
+        max_seq_len = opts.hard_max_seq_len;
+    } else {
+        max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * job.source_seq_len + opts.soft_max_seq_len_b));
+    }
+
+    if (opts.min_seq_len > max_seq_len) {
+        printf(
+            "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
+            opts.min_seq_len,
+            max_seq_len
+        );
+        GGML_ASSERT(opts.min_seq_len <= max_seq_len);
+    }
+
+    int prefix_seq_len = job.prefix_seq->ne[0];
+    if (prefix_seq_len >= max_seq_len) {
+        printf(
+            "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
+            prefix_seq_len,
+            max_seq_len
+        );
+        GGML_ASSERT(prefix_seq_len < max_seq_len);
+    }
+
+    return max_seq_len;
+}
+
+void _fan_out_encoder_output(
+    ggml_context* ctx,
+    ggml_tensor** encoder_output_out,
+    ggml_tensor** encoder_padding_mask_out,
+    int beam_size
+) {
+    // (S_enc, M)
+    ggml_tensor* encoder_output = *encoder_output_out;
+    ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
+
+    // (B, S_enc, M)
+    ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
+
+    // (S_enc, M) -> (B, S_enc, M)
+    *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
+    if (encoder_padding_mask != nullptr) {
+        *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape);
+    }
+}
+
+ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
+    // TODO: this isn't the smartest way of doing this
+    return ggml_log(ctx, ggml_soft_max(ctx, logits));
+}
+
+void _bootstrap_seqs_and_scores(
+    fairseq2_model& model,
+    const SequenceGeneratorJob& job,
+    ggml_tensor* seqs,
+    ggml_tensor* scores,
+    ggml_tensor* encoder_output,
+    ggml_tensor* encoder_padding_mask,
+    IncrementalStateBag state_bag
+) {
+    int prefix_seq_len = job.prefix_seq->ne[0];
+    int max_seq_len = scores->ne[0];
+    int beam_size = scores->ne[1];
+    GGML_ASSERT(prefix_seq_len > 0);
+    if (prefix_seq_len == 1)
+        return;
+
+    ggml_context* ctx = model.ctx;
+
+    // seqs[:, : prefix_seq_len] = job.prefix_seq;
+    ggml_cpy(ctx, job.prefix_seq, ggml_view_2d(ctx, seqs, 0, prefix_seq_len, 0, 0));
+
+    // We have to bootstrap the model with the already fanned-out encoder
+    // output to correctly initialize its incremental state. This causes some
+    // redundancy as we have to expand `decoder_input` to match the shape of
+    // `encoder_output`.
+    // (S_pfx) -> (N x B, S_pfx - 1)
+    // prefix_seq[:-1].expand(encoder_output.size(0), -1)
+    ggml_tensor* decoder_input = ggml_repeat(ctx, ggml_view_1d(ctx, job.prefix_seq, prefix_seq_len - 1, 0), encoder_output);
+
+    // Bootstrap the model state with prefix sequence.
+    ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
+        model,
+        ".decoder",
+        seqs,
+        /*padding_mask*/ nullptr,
+        encoder_output,
+        encoder_padding_mask
+        // TODO: state_bag
+    );
+    // TODO state_bag.increment_step(prefix_seq_len - 1)
+
+    // logits, lprobs: (N, S_pfx - 1, V)
+    ggml_tensor* logits = Linear_forward(model, ".decoder.final_proj", decoder_output);
+    ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_view_3d(ctx, logits, logits->ne[0], logits->ne[1], 1, 0, 0, 0));
+    int vocab_size = logits->ne[0];
+
+    ggml_cgraph gf = ggml_build_forward(lprobs);
+    ggml_graph_compute_with_ctx(ctx, &gf, 1);
+
+    // Fetch scores of next steps from "lprobs"
+    float p_score = 0;
+    for (int i = 0; i < prefix_seq_len; ++i) {
+        int p = ggml_get_i32_1d(job.prefix_seq, i);
+        p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
+        for (int b = 0; b < beam_size; ++b) {
+            // scores: (N, S)
+            // Note: First step (e.g. BOS)'s score is always 0.
+            ggml_set_f32_1d(scores, b * max_seq_len + i + 1, p_score);
+        }
+    }
+}
+
+/// Represents a hypothesis produced by a sequence generator.
+struct Hypothesis {
+    /// The generated sequence.
+    ggml_tensor* seq;
+
+    /// The score of the hypothesis.
+    float score;
+
+    /// The score of each individual sequence step.
+    ggml_tensor* step_scores;
+};
+
+
+/// Represents a standard beam search algoritm.
+int StandardBeamSearch_step(
+    ggml_context* ctx,
+    int step_nr,
+    bool is_start_step,
+    ggml_tensor* lprobs,  // (N, S, V)
+    ggml_tensor* scores,  // (N, S)
+    ggml_tensor* candidate_indices
+) {
+    int vocab_size = lprobs->ne[0];
+    int sent_len = lprobs->ne[1];
+    int beam_size = lprobs->ne[2];
+    GGML_ASSERT(scores->ne[0] == sent_len);
+    GGML_ASSERT(scores->ne[1] == beam_size);
+
+    // should this be done by the caller ?
+    ggml_tensor* last_scores = ggml_view_2d(ctx, scores, beam_size, 1, 0, step_nr);
+    if (is_start_step) {
+        // At the initial step, all hypotheses are equally likely, so we use
+        // only the first beam.
+        lprobs = ggml_view_3d(ctx, lprobs, vocab_size, sent_len, 1, 0, 0, 0);
+        lprobs = ggml_cont(ctx, lprobs);
+        // The first step always indicates the beginning of the sequence and
+        // has no score.
+        if (step_nr > 0) {
+            lprobs = ggml_add(ctx, lprobs, last_scores);
+        }
+    } else {
+        // Make probabilities contain cumulative scores for each hypothesis.
+        lprobs = ggml_add(ctx, lprobs, last_scores);
+    }
+
+    ggml_cgraph gf = ggml_build_forward(lprobs);
+    ggml_graph_compute_with_ctx(ctx, &gf, 1);
+
+    // Take the best 2 x `beam_size` predictions. We'll choose the first
+    // `beam_size` of these which don't predict EOS to continue with.
+    // (N, 2 x B)
+    // `vocab_size` - 1 to never select PAD.
+    int topk = std::min(2 * beam_size, vocab_size - 1);
+
+    auto comp = [scores](std::int32_t a, std::int32_t b) {
+        return ggml_get_f32_1d(scores, a) < ggml_get_f32_1d(scores, b);
+    };
+    auto cand = (std::int32_t*)candidate_indices->data;
+    std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
+
+    return topk;
+}
+
+bool _finalize_hypothesis(
+    const SequenceGeneratorJob& job,
+    ggml_context* ctx,
+    int step_nr,
+    std::int32_t candidate,
+    ggml_tensor* seqs, // (beam_size, seq_len)
+    ggml_tensor* scores, // (beam_size, seq_len)
+    std::vector<Hypothesis>& hypotheses
+) {
+    int vocab_size = scores->ne[0];
+    std::int32_t beam = candidate / vocab_size;
+    std::int32_t token = candidate % vocab_size;
+    float tok_score = ggml_get_f32_1d(scores, candidate);
+
+    // Detect beams that reached the minimum length and that end with an EOS.
+    bool eos = token == job.eos_idx;
+    eos &= tok_score != -INFINITY;
+    // TODO ignored_beam_mask ?
+    // eos &= ggml_get_i32_1d(ignored_beam_mask, beam);
+    // ggml_set_i32_1d(eos_mask, beam, eos);
+
+    if (!eos) return false;
+
+    // If the candidate beam is "finished", let's copy the score and sequence
+    ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
+    ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
+
+    auto tok = (std::int32_t*)tokens->data;
+    auto sc = (float*)step_scores->data;
+    ggml_set_f32_1d(scores, scores->ne[0] * beam + step_nr + 1, tok_score);
+    for (int i = 0; i < step_nr + 1; ++i) {
+        tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
+    }
+    tok[step_nr + 1] = token;
+
+    float last_score = tok_score;
+    for (int i = step_nr; i >= 0; --i) {
+        // Convert from cumulative to per-step scores.
+        float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i + 0);
+        sc[i] = last_score - sc0;
+        last_score = sc0;
+    }
+
+    // Skip first EOS since it is always 0 and skews normalization.
+    if (job.opts.normalize_scores)
+        tok_score /= std::pow((step_nr + 1), job.opts.len_penalty);
+
+    hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
+    return true;
+}
+
+/// Generates a translation for a single sequence
+extern "C" float generate_sequence(
+    fairseq2_model& model,
+    const SequenceGeneratorJob& job,
+    ggml_tensor* encoder_output,
+    ggml_tensor* encoder_padding_mask,
+    ggml_tensor** output_seq
+) {
+    int input_seq_len = encoder_output->ne[1];
+    int vocab_size = encoder_output->ne[0];
+    int beam_size = job.opts.beam_size;
+    int max_seq_len = _determine_max_seq_len(job);
+    ggml_context* ctx = model.ctx;
+
+    // (S_enc, M) -> (B, S_enc, M)
+    _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
+
+    std::vector<Hypothesis> active_searches(beam_size);
+    std::vector<Hypothesis> finished_searches(beam_size);
+
+    // Initialize buffers. (B, S)
+    ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
+    ggml_set_i32(seqs, 0);
+    ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
+    ggml_set_f32(scores, 0.0);
+
+    IncrementalStateBag state_bag = {};
+    _bootstrap_seqs_and_scores(
+        model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
+    );
+    int prefix_seq_len = job.prefix_seq->ne[0];
+    int start_step = prefix_seq_len - 1;
+
+    // Holds the indices of beams (a beam can occur more than once) that we
+    // should continue with in the next step.
+    ggml_tensor* beam_indices = nullptr;
+
+    // Indices of next token
+    ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
+    for (int i = 0; i < vocab_size * beam_size; ++i) ggml_set_i32_1d(candidate_indices, i, i);
+
+    // Holds the indices of searches that we should continue with in the next
+    // step. If not `None`, it means we finalized one or more searches in the
+    // last step.
+    ggml_tensor* search_indices = nullptr;
+
+    for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
+        // if (beam_indices != nullptr) {
+        //     // If not `None`, it means in the last step we finalized one or
+        //     // more searches. We should ensure that we adjust `beam_indices`
+        //     // before reordering `decoder`'s incremental state.
+        //     if (search_indices != nullptr) {
+        //         num_searches = search_indices->ne[0];
+
+        //         // (N)
+        //         delta = search_indices - torch.arange(num_searches, device=device)
+
+        //         // (N) -> (N, 1)
+        //         delta.unsqueeze_(-1)
+
+        //         // Adjust indices to take into account removed searches.
+        //         beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
+        //     }
+
+        //     // state_bag.reorder(beam_indices)
+        // }
+
+        ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
+            model,
+            ".decoder",
+            // seqs[:, step_nr : step_nr + 1]
+            ggml_view_2d(ctx, seqs, 1, beam_size, step_nr * seqs->nb[0], 0),
+            nullptr,  // We never generate PAD.
+            encoder_output,
+            encoder_padding_mask
+            // state_bag=state_bag,
+        );
+
+        // state_bag.increment_step()
+
+        ggml_tensor* logits = Linear_forward(model, ".decoder.final_proj", decoder_output);
+        ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
+
+        // // Do not allow EOS before reaching the minimum sequence length.
+        // if step_nr < self.opts.min_seq_len:
+        //     lprobs[:, :, self.eos_idx] = -torch.inf
+
+        // // If we have reached the maximum length, force the last step to be
+        // // EOS.
+        // if step_nr == max_seq_len - 2:
+        //     lprobs[:, :, : self.eos_idx]       = -torch.inf
+        //     lprobs[:, :,   self.eos_idx + 1 :] = -torch.inf
+
+        // // Never allow PAD.
+        // lprobs[:, :, self.pad_idx] = -torch.inf
+
+        // // Apply UNK penalty.
+        // if self.unk_idx is not None:
+        //     lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
+
+        // Determine candidates for the next step.
+        // (N, 2 x B)
+        int topk = StandardBeamSearch_step(
+            ctx,
+            step_nr,
+            step_nr == start_step,
+            lprobs,
+            // TODO only pass scores for new tokens
+            ggml_view_2d(ctx, scores, step_nr + 1, beam_size, 0, 0),
+            candidate_indices
+        );
+
+        int ongoing_beams = 0;
+        for (std::int32_t c = 0; c < topk; ++c) {
+            bool finished = _finalize_hypothesis(job, ctx, step_nr, c, seqs, scores, finished_searches);
+            if (!finished) ongoing_beams += 1;
+
+            if (ongoing_beams >= beam_size) break;
+        }
+        if (finished_searches.size() == beam_size) break;
+
+        // TODO: recreate scores and seqs with the best beams
+
+        // Remove finished searches (ones for which `beam_size` finalized
+        // beams have been generated) from the batch.
+        ggml_tensor* search_indices = nullptr;
+        // if (newly_finished_searches) {
+        //     new_num_searches = num_searches - len(newly_finished_searches)
+
+        //     // Construct `search_indices` which holds indices of searches
+        //     // to keep for the next step.
+        //     search_mask = torch.full((num_searches,), True, device=device)
+
+        //     search_mask[newly_finished_searches] = False
+
+        //     search_indices = torch.arange(num_searches, device=device)
+
+        //     search_indices = search_indices.masked_select(search_mask)
+
+        //     // Filter out removed batches from state variables.
+        //     // (N, B) -> (N - F, B)
+        //     ignored_beam_mask = ignored_beam_mask[search_indices]
+
+        //     // (N, 2 x B) -> (N - F, 2 x B)
+        //     cand_scores       = cand_scores      [search_indices]
+        //     cand_indices      = cand_indices     [search_indices]
+        //     cand_beam_indices = cand_beam_indices[search_indices]
+
+        //     // (N) -> (N - F)
+        //     search_offsets.resize_(new_num_searches, 1)
+
+        //     // (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
+        //     global_cand_beam_indices = cand_beam_indices + search_offsets
+
+        //     // (N, 2 x B) -> (N - F, 2 x B)
+        //     eos_mask = eos_mask[search_indices]
+
+        //     // (N x B, S) -> (N, B, S)
+        //     seqs   = seqs  .view(num_searches, -1)
+        //     scores = scores.view(num_searches, -1)
+
+        //     // (N, B, S + 1) -> ((N - F) x B, S)
+        //     seqs   = seqs  [search_indices].view(new_num_searches * beam_size, -1)
+        //     scores = scores[search_indices].view(new_num_searches * beam_size, -1)
+
+        //     // (N x B, S_enc, M) -> (N, B, S_enc, M)
+        //     encoder_output = encoder_output.unflatten(0, (num_searches, -1))
+
+        //     // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
+        //     encoder_output = encoder_output[search_indices].flatten(0, 1)
+
+        //     if encoder_padding_mask is not None:
+        //         // (N x B, S_enc, M) -> (N, B, S_enc, M)
+        //         padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
+
+        //         // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
+        //         encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
+
+        //     num_searches = new_num_searches
+        // }
+
+        // eos_mask[:, :beam_size][ignored_beam_mask] = True
+
+        // // Set `beam_weights` so that values greater than or equal to 2 x
+        // // `beam_size` indicate finished beams (i.e. end with EOS) and values
+        // // less than 2 x `beam_size` indicate active beams.
+        // // (N, 2 x B)
+        // beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
+
+        // // Get the top `beam_size` active beams, which are the beams with the
+        // // smallest weights in `active_beam_weights`.
+        // // (N, B)
+        // active_beam_weights, active_beams = torch.topk(
+        //     beam_weights, k=beam_size, dim=1, largest=False
+        // )
+
+        // // Update to ignore finalized beams in the next step.
+        // // (N, B)
+        // ignored_beam_mask = active_beam_weights >= 2 * beam_size
+
+        // // We should always have at least one active beam in each search.
+        // assert (~ignored_beam_mask).any(dim=1).all()
+
+        // // Denotes which beams are continued for each new hypothesis (a beam
+        // // can be selected more than once).
+        // // (N, B)
+        // beam_indices = torch.gather(
+        //     global_cand_beam_indices, dim=1, index=active_beams
+        // )
+
+        // // (N, B) -> (N x B)
+        // beam_indices = beam_indices.view(-1)
+
+        // // Reorder beams in the `seq` and `score` buffers. The same beam can
+        // // be selected more than once.
+        // if (step_nr > start_step) {
+        //     // seqs  [:, : step_nr + 1] = torch.index_select(
+        //     //     seqs  [:, : step_nr + 1], dim=0, index=beam_indices
+        //     // )
+        //     // scores[:, : step_nr + 1] = torch.index_select(
+        //     //     scores[:, : step_nr + 1], dim=0, index=beam_indices
+        //     // )
+        // }
+
+        // // (N x B, S) -> (N, B, S)
+        // // seqs_view   = seqs  .view(num_searches, beam_size, -1)
+        // // scores_view = scores.view(num_searches, beam_size, -1)
+
+        // // seqs_view  [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
+        // // scores_view[:, :, step_nr + 1] = torch.gather(cand_scores,  dim=1, index=active_beams)
+
+    }
+    // Ensure that hypotheses are sorted by their scores before returning.
+    // for batch in finished_searches:
+    //     batch.sort(key=lambda b: b.score, reverse=True)  # type: ignore[arg-type, return-value]
+
+    // return SequenceGeneratorOutput(
+    //     results=finished_searches, device=device, pad_idx=self.pad_idx
+    // )
+
+    return 0.0f;
+}

+ 49 - 0
ggml/examples/unity/fairseq2.h

@@ -68,3 +68,52 @@ enum TransformerNormOrder {
     TRANSFORMER_NORM_ORDER_PRE = 1,
     TRANSFORMER_NORM_ORDER_PRE_WITH_NORMFORMER = 2
 };
+
+
+
+/// Holds the options to pass to a sequence generator.
+struct SequenceGeneratorOptions {
+    /// The beam size.
+    int beam_size = 5;
+
+    /// The minimum length of generated sequences (including prefix sequence).
+    int min_seq_len = 1;
+
+    /// The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
+    /// sequence length. The generated sequences (including prefix sequence) will
+    /// have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
+    /// ``hard_max_seq_len``.
+    int soft_max_seq_len_a = 1;
+    int soft_max_seq_len_b = 200;
+
+    /// The hard limit on maximum length of generated sequences.
+    int hard_max_seq_len = 1024;
+
+    /// The length penalty, where values less than 1.0 favor shorter, values
+    /// greater than 1.0 favor longer sequences.
+    float len_penalty = 1.0;
+
+    /// The unknown symbol penalty, where values less than 0 produce more UNKs,
+    /// values greater than 0 produce fewer UNKs.
+    float unk_penalty = 0.0;
+
+    /// If ``True``, normalizes scores by the length of generated sequences.
+    bool normalize_scores = true;
+};
+
+
+struct SequenceGeneratorJob {
+    SequenceGeneratorOptions opts;
+    ggml_tensor* prefix_seq;
+    int source_seq_len;
+    std::int32_t eos_idx;
+};
+
+
+extern "C" float generate_sequence(
+    fairseq2_model& model,
+    const SequenceGeneratorJob& opts,
+    ggml_tensor* encoder_output,
+    ggml_tensor* encoder_padding_mask,
+    ggml_tensor** output_seq
+);