Эх сурвалжийг харах

read model config from "layer_config"

Guillaume Wenzek 1 жил өмнө
parent
commit
44a4ca129a

+ 4 - 7
ggml/examples/unity/fairseq2.cpp

@@ -196,7 +196,7 @@ ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int hea
 
 
 // flash_attn doesn't work for cross attention because it assumes Q <= K
-// TODO: enable flash_attn only for the encoder
+// and it seems to yield slightly different scores than expected, and thus a different beam search
 # define UNITY_FLASH_ATTN 0
 
 extern "C" ggml_tensor* MultiheadAttention_forward(
@@ -208,7 +208,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     ggml_tensor* attn_mask // (klen, slen)
 ) {
     int model_dim = queries->ne[0];
-    int num_heads = 16;  // TODO: read from hparams
+    int num_heads = model.layer_config.at(prefix + ".num_heads");
     int head_dim = model_dim / num_heads;
     GGML_ASSERT(model_dim % num_heads == 0);
 
@@ -229,7 +229,6 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     // For flash_attn, we assume either no masks, or triangular masks.
     ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr);  // (B * H, S, H_dim)
     ggml_set_name(attn, "attn");
-    // TODO test !
     attn = ggml_unflatten_1d(ctx, attn, 2, num_heads);  // (B, H, H_dim, S)
     attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
 #else
@@ -270,8 +269,7 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
     ggml_tensor* padding_mask
 ) {
     ggml_context* ctx = model.ctx;
-    // TODO: read norm_order from model
-    auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
+    auto norm_order = model.layer_config.at(prefix + ".norm_order");
 
     // _forward_self_attn(seqs, padding_mask)
     auto residual = seqs;
@@ -745,8 +743,7 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
     ggml_tensor* encoder_padding_mask
 ) {
     ggml_context* ctx = model.ctx;
-    // TODO: read norm_order from model
-    auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
+    auto norm_order = model.layer_config.at(prefix + ".norm_order");
 
     // _forward_self_attn(seqs, padding_mask)
     auto residual = seqs;