Browse Source

fix TransformerEmbeddingFrontend

Guillaume Wenzek 1 year ago
parent
commit
78e7c9a311
2 changed files with 10 additions and 9 deletions
  1. 3 9
      ggml/examples/unity/fairseq2.cpp
  2. 7 0
      ggml/examples/unity/fairseq2.h

+ 3 - 9
ggml/examples/unity/fairseq2.cpp

@@ -253,6 +253,7 @@ extern "C" ggml_tensor* PositionalEmbedding_forward(
     const std::string& prefix,
     ggml_tensor* embeds
 ) {
+    // This only work with the simple pos encoders
     int encoding_dim = embeds->ne[0];
     int seq_len = embeds->ne[1];
     ggml_tensor* full_pos_embeds = model.tensors[prefix];
@@ -271,24 +272,17 @@ extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
     GGML_ASSERT(embed_weights != nullptr);
     ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
 
+    // padding mask ?
     // padding_mask = to_padding_mask(embeds, seq_lens)
 
-    // TODO: scale when saving the model weights
-    // embeds = ggml_scale embeds * self.scale
-
     if (has_layer(model, prefix + ".pos_encoder")) {
-        // This only work with the simple pos encoders
-        int encoding_dim = embeds->ne[0];
-        int seq_len = embeds->ne[1];
-       ggml_tensor* pos_embeds = ggml_view_2d(ctx, model.tensors[prefix + ".pos_encoder"], encoding_dim, seq_len, 0, 0);
-        embeds = ggml_add(ctx, embeds, pos_embeds);
+        embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
     }
 
     if (has_layer(model, prefix + ".layer_norm")) {
         embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
     }
 
-    // padding mask ?
     return embeds;
 }
 

+ 7 - 0
ggml/examples/unity/fairseq2.h

@@ -55,6 +55,13 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* _ // (klen, slen)  TODO: do we need to pass mask here ?
 );
 
+
+extern "C" ggml_tensor* PositionalEmbedding_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* embeds
+);
+
 extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
     fairseq2_model& model,
     const std::string& prefix,