|
@@ -332,10 +332,19 @@ extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
|
|
|
ggml_tensor* seqs
|
|
|
// TODO: state_bag
|
|
|
) {
|
|
|
+ GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
|
|
|
ggml_context* ctx = model.ctx;
|
|
|
ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
|
|
|
GGML_ASSERT(embed_weights != nullptr);
|
|
|
- ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
|
|
|
+ ggml_tensor* embeds;
|
|
|
+ if (seqs->n_dims == 1) {
|
|
|
+ embeds = ggml_get_rows(ctx, embed_weights, seqs);
|
|
|
+ } else {
|
|
|
+ // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
|
|
|
+ embeds = ggml_get_rows(ctx, embed_weights, ggml_reshape_1d(ctx, seqs, ggml_nelements(seqs)));
|
|
|
+ embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
|
|
|
+ embeds->n_dims = seqs->n_dims + 1;
|
|
|
+ }
|
|
|
|
|
|
// padding mask ?
|
|
|
// padding_mask = to_padding_mask(embeds, seq_lens)
|
|
@@ -583,14 +592,13 @@ void _bootstrap_seqs_and_scores(
|
|
|
// full_seqs[:, : prefix_seq_len] = job.prefix_seq;
|
|
|
full_seqs->type = GGML_TYPE_F32;
|
|
|
job.prefix_seq->type = GGML_TYPE_F32;
|
|
|
- ggml_tensor* seqs = ggml_cpy(ctx, job.prefix_seq, ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len));
|
|
|
+ ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
|
|
|
+ seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
|
|
|
|
|
|
// We have to bootstrap the model with the already fanned-out encoder
|
|
|
// output to correctly initialize its incremental state.
|
|
|
- // (S_pfx) -> (N x B, S_pfx - 1)
|
|
|
- // prefix_seq[:-1].expand(beam_size, -1)
|
|
|
- seqs = ggml_expand_2d(ctx, ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1), prefix_seq_len - 1, beam_size);
|
|
|
- seqs->type = GGML_TYPE_I32;
|
|
|
+ // Note: we don't start decoding the last prefix token just yet.
|
|
|
+ seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
|
|
|
|
|
|
// Bootstrap the model state with prefix sequence.
|
|
|
seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
|