|
@@ -196,7 +196,7 @@ ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int hea
|
|
|
|
|
|
|
|
|
// flash_attn doesn't work for cross attention because it assumes Q <= K
|
|
|
-// TODO: enable flash_attn only for the encoder
|
|
|
+// and it seems to yield slightly different scores than expected, and thus a different beam search
|
|
|
# define UNITY_FLASH_ATTN 0
|
|
|
|
|
|
extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
@@ -208,7 +208,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
ggml_tensor* attn_mask // (klen, slen)
|
|
|
) {
|
|
|
int model_dim = queries->ne[0];
|
|
|
- int num_heads = 16; // TODO: read from hparams
|
|
|
+ int num_heads = model.layer_config.at(prefix + ".num_heads");
|
|
|
int head_dim = model_dim / num_heads;
|
|
|
GGML_ASSERT(model_dim % num_heads == 0);
|
|
|
|
|
@@ -229,7 +229,6 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
|
ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
|
|
|
ggml_set_name(attn, "attn");
|
|
|
- // TODO test !
|
|
|
attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
|
|
|
attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
|
|
|
#else
|
|
@@ -270,8 +269,7 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
|
|
|
ggml_tensor* padding_mask
|
|
|
) {
|
|
|
ggml_context* ctx = model.ctx;
|
|
|
- // TODO: read norm_order from model
|
|
|
- auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
|
|
|
+ auto norm_order = model.layer_config.at(prefix + ".norm_order");
|
|
|
|
|
|
// _forward_self_attn(seqs, padding_mask)
|
|
|
auto residual = seqs;
|
|
@@ -745,8 +743,7 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
|
|
|
ggml_tensor* encoder_padding_mask
|
|
|
) {
|
|
|
ggml_context* ctx = model.ctx;
|
|
|
- // TODO: read norm_order from model
|
|
|
- auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
|
|
|
+ auto norm_order = model.layer_config.at(prefix + ".norm_order");
|
|
|
|
|
|
// _forward_self_attn(seqs, padding_mask)
|
|
|
auto residual = seqs;
|