Forráskód Böngészése

batching -> TransformerEmbeddingFrontend_forward

Guillaume Wenzek 1 éve
szülő
commit
b24dbe3030
2 módosított fájl, 21 hozzáadás és 12 törlés
  1. 14 6
      ggml/examples/unity/fairseq2.cpp
  2. 7 6
      ggml/test_unity_cpp.py

+ 14 - 6
ggml/examples/unity/fairseq2.cpp

@@ -332,10 +332,19 @@ extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
     ggml_tensor* seqs
     // TODO: state_bag
 ) {
+    GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
     ggml_context* ctx = model.ctx;
     ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
     GGML_ASSERT(embed_weights != nullptr);
-    ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
+    ggml_tensor* embeds;
+    if (seqs->n_dims == 1) {
+        embeds = ggml_get_rows(ctx, embed_weights, seqs);
+    } else {
+        // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
+        embeds = ggml_get_rows(ctx, embed_weights, ggml_reshape_1d(ctx, seqs, ggml_nelements(seqs)));
+        embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
+        embeds->n_dims = seqs->n_dims + 1;
+    }
 
     // padding mask ?
     // padding_mask = to_padding_mask(embeds, seq_lens)
@@ -583,14 +592,13 @@ void _bootstrap_seqs_and_scores(
     // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
     full_seqs->type = GGML_TYPE_F32;
     job.prefix_seq->type = GGML_TYPE_F32;
-    ggml_tensor* seqs = ggml_cpy(ctx, job.prefix_seq, ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len));
+    ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
+    seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
 
     // We have to bootstrap the model with the already fanned-out encoder
     // output to correctly initialize its incremental state.
-    // (S_pfx) -> (N x B, S_pfx - 1)
-    // prefix_seq[:-1].expand(beam_size, -1)
-    seqs = ggml_expand_2d(ctx, ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1), prefix_seq_len - 1, beam_size);
-    seqs->type = GGML_TYPE_I32;
+    // Note: we don't start decoding the last prefix token just yet.
+    seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
 
     // Bootstrap the model state with prefix sequence.
     seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);

+ 7 - 6
ggml/test_unity_cpp.py

@@ -308,19 +308,20 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
 def test_TransformerEmbeddingFrontend_forward(
     ctx: Ctx, g_model: c_void_p, pt_model: Any
 ) -> None:
-    seq = torch.arange(20).reshape(1, 20)
-    seq_len = torch.tensor([20])
-    gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
+    seq = torch.arange(2 * 20).reshape(2, 20)
+    seq[1, 15:] = 0  # padding for second sentence
+    seq_len = torch.tensor([20, 15])
+    gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
+
     ggml.ggml_set_name(gseq, b"seq")
     gy = ggml.forward(
         "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
     )
-    gf = ggml.ggml_build_forward(gy)
-    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    ggml.build_and_compute(ctx, gy)
     y = ggml.to_numpy(gy)
 
     y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
-    y_exp = y_exp.squeeze(0).numpy()  # remove batch dimension
+    y_exp = y_exp.numpy()
 
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=1e-6)