ソースを参照

generate_sequence return full results

Guillaume Wenzek 2 年 前
コミット
6fbb465f2b

+ 9 - 4
ggml/ctypes_utils.py

@@ -12,12 +12,13 @@ class Ptr(Generic[T]):
     contents: T
 
     def __new__(cls):
+        breakpoint()
         return ctypes.pointer()
 
 
 def c_struct(cls):
     struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
-    struct.__module__ = "ctypes"
+    struct.__module__ = cls.__module__
     struct._fields_ = [
         (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
     ]
@@ -33,8 +34,11 @@ def _py_type_to_ctype(t: type):
         )
     if t.__module__ == "ctypes":
         return t
-    if isinstance(t, type) and issubclass(t, ctypes.Structure):
-        return t
+    if isinstance(t, type):
+        if issubclass(t, ctypes.Structure):
+            return t
+        if issubclass(t, ctypes._Pointer):
+            return t
     if t is int:
         return ctypes.c_int
     if t is float:
@@ -66,7 +70,8 @@ def _c_fn(module, fn):
 
     @functools.wraps(fn)
     def actual_fn(*args, **kwargs):
-        return c_fn(*args, **kwargs)
+        raw_res = c_fn(*args, **kwargs)
+        return raw_res
 
     return actual_fn
 

+ 38 - 56
ggml/examples/unity/fairseq2.cpp

@@ -617,7 +617,8 @@ void _bootstrap_seqs_and_scores(
         seqs,
         /*padding_mask*/ nullptr,
         encoder_output,
-        /*encoder_padding_mask*/ nullptr // TODO: do we need padding for encoder ?
+        // we assume there is only one input, and therefore we don't need padding.
+        /*encoder_padding_mask*/ nullptr
         // TODO: state_bag
     );
     // TODO state_bag.increment_step(prefix_seq_len - 1)
@@ -645,22 +646,8 @@ void _bootstrap_seqs_and_scores(
     }
 }
 
-/// Represents a hypothesis produced by a sequence generator.
-struct Hypothesis {
-    /// The generated sequence.
-    ggml_tensor* seq;
-
-    /// The score of the hypothesis.
-    float score;
-
-    /// The score of each individual sequence step.
-    ggml_tensor* step_scores;
-};
-
-
-/// Finds the topk indices
+/// Finds the topk indices, and write the winning indices in "candidate_indices" array.
 int topk(
-    ggml_context* ctx,
     ggml_tensor* lprobs,  // (B, V)
     std::int64_t k,
     ggml_tensor* candidate_indices
@@ -687,6 +674,7 @@ void ggml_detach(ggml_tensor* a) {
 }
 
 
+/// Copies the sequence and scores of a given candidate beam.
 void _finalize_hypothesis(
     const SequenceGeneratorJob& job,
     ggml_context* ctx,
@@ -698,7 +686,6 @@ void _finalize_hypothesis(
     ggml_tensor* scores, // (beam_size, seq_len)
     std::vector<Hypothesis>& hypotheses
 ) {
-    // If the candidate beam is "finished", let's copy the score and sequence
     ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
     ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
 
@@ -711,12 +698,12 @@ void _finalize_hypothesis(
     // Convert from cumulative to per-step scores.
     auto sc = (float*)step_scores->data;
     float last_score = eos_score;
-    sc[step_nr + 1] = last_score;
     for (int i = step_nr; i >= 0; --i) {
         float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
-        sc[i] = last_score - sc0;
+        sc[i + 1] = last_score - sc0;
         last_score = sc0;
     }
+    sc[0] = 0;
 
     if (job.opts.normalize_scores)
         // Skip first EOS since it is always 0 and skews normalization.
@@ -725,21 +712,21 @@ void _finalize_hypothesis(
     hypotheses.emplace_back(Hypothesis{tokens, eos_score, step_scores});
 }
 
+// Uses ggml_context to store any object.
+#define GGML_CTX_ALLOC(ctx, Type, n) \
+    (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
+
+
 /// Generates a translation for a single sequence
-// TODO: finish this for beam_size=1
-// * find out why score is different (seq is the same though)
 // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
-// TODO: support beam_size > 1:
-// * most layers assume un-batched input, but we want to handle several beams at once
-// * need to port "reorder_state_dict"
-// TODO: clean up
-// * replace manual tensor tweaking with ggml_set_*d (ggml_set_slice could be useful)
-extern "C" float generate_sequence(
+// TODO: clean ups
+// * replace manual tensor tweaking with ggml_set_*d (a ggml_set_slice could be useful)
+extern "C" Hypothesis* generate_sequence(
     fairseq2_model& model,
     const SequenceGeneratorJob& job,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
-    ggml_tensor* output_seq
+    ggml_context* result_ctx
 ) {
     ggml_context* ctx = model.ctx;
     size_t eos_idx = job.eos_idx;
@@ -787,25 +774,6 @@ extern "C" float generate_sequence(
     // there should be a per-step ggml_context for intermediary results
     // start of beam search:
     for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
-        // if (beam_indices != nullptr) {
-        //     // If not `None`, it means in the last step we finalized one or
-        //     // more searches. We should ensure that we adjust `beam_indices`
-        //     // before reordering `decoder`'s incremental state.
-        //     if (search_indices != nullptr) {
-        //         num_searches = search_indices->ne[0];
-
-        //         // (N)
-        //         delta = search_indices - torch.arange(num_searches, device=device)
-
-        //         // (N) -> (N, 1)
-        //         delta.unsqueeze_(-1)
-
-        //         // Adjust indices to take into account removed searches.
-        //         beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
-        //     }
-
-        //     // state_bag.reorder(beam_indices)
-        // }
         // because of no IncrementalStateBag we pass input from the start
         // decoder_input = seqs[:, 0 : step_nr + 1]
         ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
@@ -845,7 +813,6 @@ extern "C" float generate_sequence(
         }
 
         // If we have reached the maximum length, force the last step to be EOS.
-        // TODO: should this be done in an adhoc loop ? how often does that happen anyway ?
         if (step_nr == max_seq_len - 2) {
             // lprobs[:, :, : self.eos_idx]       = -torch.inf
             // lprobs[:, :,   self.eos_idx + 1 :] = -torch.inf
@@ -856,7 +823,6 @@ extern "C" float generate_sequence(
                 for (t = eos_idx + 1; t < vocab_size; ++t)
                     ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
             }
-
         }
 
         // Never allow PAD.
@@ -890,10 +856,10 @@ extern "C" float generate_sequence(
         gf = ggml_build_forward(lprobs);
         ggml_graph_compute_with_ctx(ctx, &gf, 1);
 
-        // Determine candidates for the next step.
+        // Determine (beam, token) candidates for the next step.
         // (N, 2 x B)
         std::int64_t K = topk(
-            ctx, lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
+            lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
         );
 
         std::size_t ongoing_beams = 0;
@@ -907,7 +873,7 @@ extern "C" float generate_sequence(
             bool eos = token == job.eos_idx;
             eos &= tok_score != -INFINITY;
             if (eos) {
-                _finalize_hypothesis(job, ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches);
+                _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches);
                 if (finished_searches.size() >= beam_size)
                     goto end_of_beam_search;
                 continue;
@@ -960,9 +926,25 @@ end_of_beam_search:
         [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
     );
 
-    // For now just return the best sequence
-    // TODO: return structured output
-    *output_seq = *(finished_searches[0].seq);
+    // Copy the scores to an object in the result_ctx.
+    GGML_ASSERT(finished_searches.size() <= beam_size);
+    Hypothesis* result = GGML_CTX_ALLOC(result_ctx, struct Hypothesis, beam_size);
+    std::copy(finished_searches.begin(), finished_searches.end(), result);
+    // In case we have less searches than expected, still make sure to initialize the memory.
+    for (std::size_t i = finished_searches.size(); i < beam_size; ++i)
+        result[i] = Hypothesis{nullptr, -INFINITY, nullptr};
+
+    return result;
+}
+
+extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
+    Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
+
+    result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
+    ggml_set_i32_1d(result[0].seq, 0, 314);
 
-    return finished_searches[0].score;
+    result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
+    ggml_set_i32_1d(result[1].seq, 0, 421);
+
+    return result;
 }

+ 14 - 2
ggml/examples/unity/fairseq2.h

@@ -142,11 +142,23 @@ struct SequenceGeneratorJob {
     std::int32_t eos_idx;
 };
 
+/// Represents a hypothesis produced by a sequence generator.
+struct Hypothesis {
+    /// The generated sequence.
+    ggml_tensor* seq;
 
-extern "C" float generate_sequence(
+    /// The score of the hypothesis.
+    float score;
+
+    /// The score of each individual sequence step.
+    ggml_tensor* step_scores;
+};
+
+
+extern "C" Hypothesis* generate_sequence(
     fairseq2_model& model,
     const SequenceGeneratorJob& opts,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
-    ggml_tensor* output_seq
+    ggml_context* result_ctx
 );

+ 18 - 2
ggml/ggml.py

@@ -424,12 +424,28 @@ class SequenceGeneratorJob:
     eos_idx: int
 
 
+@c_struct
+class Hypothesis:
+    seq: Ptr[ggml_tensor]
+    """The generated sequence."""
+
+    score: float
+    """The score of the hypothesis."""
+
+    step_scores: Ptr[ggml_tensor]
+    """The score of each individual sequence step."""
+
+
 @c_fn(lib)
 def generate_sequence(
     model: ctypes.c_void_p,
     job: Ptr[SequenceGeneratorJob],
     encoder_output: Ptr[ggml_tensor],
     encoder_padding_mask: Ptr[ggml_tensor],
-    output_seq: Ptr[ggml_tensor],
-) -> float:
+    result_ctx: ggml_context_p,
+) -> Ptr[Hypothesis]:
     ...
+
+@c_fn(lib)
+def _testing_return_hypothesis_ptr(ctx: ggml_context_p) -> Ptr[Hypothesis]:
+    return Ptr()

+ 13 - 2
ggml/test_ggml_integration.py

@@ -23,12 +23,12 @@ from seamless_communication.models.inference.translator import Translator, Modal
 Ctx = ggml.ggml_context_p
 
 UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
-CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
+CTX_PARAMS = ggml.ggml_init_params(mem_size=16 * 1024 * 1024, mem_buffer=None)
 
 
 @pytest.fixture(name="ctx")
 def _ctx() -> Iterator[Ctx]:
-    """Allocate a new context with 1024 MB of memory"""
+    """Allocate a new context with 16 MB of memory"""
     try:
         ctx = ggml.ggml_init(params=CTX_PARAMS)
         yield ctx
@@ -353,3 +353,14 @@ def test_ggml_softmax_vs_torch(ctx: Ctx, shape: Tuple[int, ...]) -> None:
     y = ggml.to_numpy(gy)
     assert np.allclose(y_exp, y, rtol=1e-3)
     assert np.allclose(np.argmax(y_exp, axis=-1), np.argmax(y, axis=-1))
+
+
+def test_can_return_hypothesis_ptr(ctx: Ctx) -> None:
+    hyp_ptr = ggml._testing_return_hypothesis_ptr(ctx)
+
+    hyp0, hyp1 = hyp_ptr[0], hyp_ptr[1]
+    assert ggml.to_numpy(hyp0.seq).tolist() == [314]
+    assert hyp0.score == pytest.approx(3.14)
+
+    assert ggml.to_numpy(hyp1.seq).tolist() == [421]
+    assert hyp1.score == pytest.approx(4.21)

+ 91 - 70
ggml/test_unity_cpp.py

@@ -8,6 +8,7 @@ import fairseq2.nn
 import fairseq2.nn.transformer
 import logging
 import sys
+import functools
 from pathlib import Path
 from ctypes_utils import Ptr
 from ctypes import c_void_p
@@ -32,40 +33,34 @@ def _ctx() -> Iterator[Ctx]:
     """Allocate a new context with 1024 MB of memory"""
     try:
         ctx = ggml.ggml_init(params=CTX_PARAMS)
-        yield ctx
+        with torch.inference_mode():
+            yield ctx
     finally:
         ggml.ggml_free(ctx)
 
 
-@pytest.fixture(scope="module")
-def g_model_once() -> Iterator[c_void_p]:
+@functools.lru_cache()
+def _load_g_model_once() -> NativeObj:
     model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
     if not model_file.exists():
         convert_model("seamlessM4T_medium", model_file)
-    with ggml.load_unity_ggml_file(model_file) as model:
-        yield model
-
+    return ggml.load_unity_ggml_file(model_file)
 
 @pytest.fixture()
-def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
-    ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
-    return g_model_once
+def g_model(ctx: Ctx) -> c_void_p:
+    model = _load_g_model_once()
+    ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
+    return model.ptr
 
 
-@pytest.fixture(scope="module")
-def translator() -> Iterator[Any]:
-    tr = Translator(
+@functools.lru_cache(maxsize=1)
+def load_translator() -> Translator:
+    return Translator(
         "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
     )
-    with torch.inference_mode():
-        yield tr
-
 
-@pytest.fixture(scope="module")
-def pt_model(translator: Translator) -> Any:
-    model = translator.model
-    print(model)
-    return model
+def load_pt_model() -> Any:
+    return load_translator().model
 
 
 @pytest.mark.xfail(reason="TODO")
@@ -108,10 +103,11 @@ def test_causal_attention_mask(ctx: Ctx):
     assert np.all(mask == mask_exp)
 
 
-def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
+def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
     torch.nn.init.uniform_(x, -1, 1)
 
+    pt_model = load_pt_model()
     y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
     gx = ggml.from_numpy(ctx, x)
     gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
@@ -121,10 +117,11 @@ def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     assert np.allclose(y_exp, y, atol=1e-5)
 
 
-def test_Linear_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
+def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
     torch.nn.init.uniform_(x, -1, 1)
 
+    pt_model = load_pt_model()
     y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
     gx = ggml.from_numpy(ctx, x)
     gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
@@ -134,11 +131,12 @@ def test_Linear_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     assert np.allclose(y_exp, y, atol=1e-5)
 
 
-def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
+def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))  # (bs, seq_len, model_dim)
     torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
 
     # Test FFN without LayerNorm
+    pt_model = load_pt_model()
     y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
     gx = ggml.from_numpy(ctx, x)
     gy = ggml.forward(
@@ -157,11 +155,12 @@ def _name(tensor: ggml.ggml_tensor_p) -> bytes:
         return b"???"
 
 
-def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
+def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
+    pt_model = load_pt_model()
     self_attn = pt_model.text_encoder.layers[0].self_attn
 
     # Note: we use different lengths for queries and keys,
@@ -222,13 +221,14 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any)
 
 
 def test_StandardTransformerEncoderLayer_forward(
-    ctx: Ctx, g_model: c_void_p, pt_model: Any
+    ctx: Ctx, g_model: c_void_p
 ) -> None:
     x = torch.empty((2, 21, 1024))
     padding_mask = torch.ones((2, 21))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
+    pt_model = load_pt_model()
     layer = pt_model.text_encoder.layers[0]
 
     gx = ggml.from_numpy(ctx, x)
@@ -255,7 +255,7 @@ def test_StandardTransformerEncoderLayer_forward(
 
 
 def test_StandardTransformerEncoder_forward(
-    ctx: Ctx, g_model: c_void_p, pt_model: Any
+    ctx: Ctx, g_model: c_void_p
 ) -> None:
     x = torch.empty((2, 21, 1024))
     padding_mask = torch.ones((2, 21))
@@ -278,6 +278,7 @@ def test_StandardTransformerEncoder_forward(
 
     y = ggml.to_numpy(gy)
 
+    pt_model = load_pt_model()
     y_exp, _ = pt_model.text_encoder(x, padding_mask)
     y_exp = y_exp.numpy()
 
@@ -306,7 +307,7 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
 
 
 def test_TransformerEmbeddingFrontend_forward(
-    ctx: Ctx, g_model: c_void_p, pt_model: Any
+    ctx: Ctx, g_model: c_void_p
 ) -> None:
     seq = torch.arange(2 * 20).reshape(2, 20)
     seq[1, 15:] = 0  # padding for second sentence
@@ -320,6 +321,7 @@ def test_TransformerEmbeddingFrontend_forward(
     ggml.build_and_compute(ctx, gy)
     y = ggml.to_numpy(gy)
 
+    pt_model = load_pt_model()
     y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
     y_exp = y_exp.numpy()
 
@@ -328,7 +330,7 @@ def test_TransformerEmbeddingFrontend_forward(
 
 
 def test_StandardTransformerDecoder_forward(
-    ctx: Ctx, g_model: c_void_p, pt_model: Any
+    ctx: Ctx, g_model: c_void_p
 ) -> None:
     x = torch.empty((2, 13, 1024))
     encoder_out = torch.empty((2, 21, 1024))
@@ -353,6 +355,7 @@ def test_StandardTransformerDecoder_forward(
     ggml.build_and_compute(ctx, gy)
     y = ggml.to_numpy(gy)
 
+    pt_model = load_pt_model()
     y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
     y_exp = y_exp.numpy()
 
@@ -361,64 +364,82 @@ def test_StandardTransformerDecoder_forward(
 
 
 def test_t2tt(ctx: Ctx, g_model: c_void_p):
-    # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
-    # device = translator.device
     src_lang = "eng"
     src_text = "We are all in a yellow submarine."
     tgt_lang = "fra"
-    # token_encoder = translator.text_tokenizer.create_encoder(
-    #     task="translation", lang=src_lang, mode="source", device=device
-    # )
-    # src = translator.collate(token_encoder(src_text))
-
-    # text_out, _ = translator.get_prediction(
-    #     translator.model,
-    #     translator.text_tokenizer,
-    #     translator.unit_tokenizer,
-    #     src,
-    #     input_modality=Modality.TEXT,
-    #     output_modality=Modality.TEXT,
-    #     tgt_lang=tgt_lang,
-    # )
-
-    # tgt_text = str(text_out.sentences[0])
-    # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
-    # tgt_tokens = text_out.generator_output.results[0][0].seq
-    # score = text_out.generator_output.results[0][0].score.item()
-    # np.savez(
-    #     Path(__file__).parent / "sample_input.npz",
-    #     score=score,
-    #     encoder_output=text_out.encoder_output.squeeze(0).numpy(),
-    #     encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
-    #     tgt_tokens=tgt_tokens.numpy(),
-    # )
-
-    text_out = np.load(Path(__file__).parent / "sample_input.npz")
-    score = text_out["score"].item()
-
-    tgt_tokens = list(text_out["tgt_tokens"])
+    sample_file = Path(__file__).parent / "sample_input.npz"
+    beam_size = 2
+
+    if not sample_file.exists():
+        translator = load_translator()
+        device = translator.device
+        token_encoder = translator.text_tokenizer.create_encoder(
+            task="translation", lang=src_lang, mode="source", device=device
+        )
+        src = translator.collate(token_encoder(src_text))
+
+        text_out, _ = translator.get_prediction(
+            translator.model,
+            translator.text_tokenizer,
+            translator.unit_tokenizer,
+            src,
+            input_modality=Modality.TEXT,
+            output_modality=Modality.TEXT,
+            tgt_lang=tgt_lang,
+            beam_size=beam_size,
+        )
+
+        tgt_text = str(text_out.sentences[0])
+        assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
+        hypotheses = [
+            {
+                "seq": h.seq.tolist(),
+                "score": h.score.item(),
+                "step_scores": h.step_scores.numpy(),
+            }
+            for h in text_out.generator_output.results[0]
+        ]
+        np.savez(
+            sample_file,
+            encoder_output=text_out.encoder_output.numpy(),
+            encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
+            hypotheses=hypotheses,
+        )
+
+    # allow_pickle to load the hyp dicts
+    text_out = np.load(sample_file, allow_pickle=True)
     encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
     encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
+    prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
+    max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
 
     job = ggml.SequenceGeneratorJob()
-    job.opts.beam_size = 2
+    job.opts.beam_size = beam_size
     job.opts.min_seq_len = 1
     job.opts.soft_max_seq_len_a = 1
     job.opts.soft_max_seq_len_b = 200
-    job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
+    job.opts.hard_max_seq_len = int(max_seq_len * 1.5)
     job.opts.len_penalty = 1.0
     job.opts.unk_penalty = 0.0
     job.opts.normalize_scores = True
-    job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
+
+    job.prefix_seq = ggml.from_numpy(ctx, prefix_seq)
     job.pad_idx = 0
     job.unk_idx = 1
     job.bos_idx = 2
     job.eos_idx = 3
 
-    result = ggml.ggml_tensor()
-    g_score = ggml.generate_sequence(
-        g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
+    result_ptr = ggml.generate_sequence(
+        g_model, job, encoder_out, encoder_padding_mask, ctx
     )
-    tokens = list(ggml.to_numpy(ctypes.pointer(result)))
-    assert tokens == tgt_tokens
-    assert g_score == pytest.approx(score, rel=1e-2)
+    results = [result_ptr[i] for i in range(beam_size)]
+
+    assert len(results) == len(text_out["hypotheses"])
+    for g_hyp, exp in zip(results, text_out["hypotheses"]):
+        g_tokens = list(ggml.to_numpy(g_hyp.seq))
+        g_step_scores = ggml.to_numpy(g_hyp.step_scores)
+        assert g_tokens == exp["seq"]
+        assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
+        # The score error is big, this may negatively impact the beam search.
+        assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
+