|
@@ -93,16 +93,81 @@ extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
|
|
|
}
|
|
|
|
|
|
|
|
|
-ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
|
|
|
- int slen = x->ne[1];
|
|
|
- int model_dim = x->ne[0];
|
|
|
- // (S, dim) -> (S, H, H_dim)
|
|
|
- x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
|
|
|
- // (S, H, H_dim) -> (H, S, H_dim)
|
|
|
+/// Merge the given dimension and the previous one in the tensor.
|
|
|
+/// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
|
|
|
+/// dim is the position of the resulting merged dimension
|
|
|
+/// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d)
|
|
|
+ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
|
|
|
+ int n_dims = x->n_dims;
|
|
|
+ GGML_ASSERT(dim >= 0);
|
|
|
+ GGML_ASSERT(dim < n_dims);
|
|
|
+ // Nothing to do
|
|
|
+ if (dim == n_dims - 1) return x;
|
|
|
+
|
|
|
+ if (n_dims == 2) {
|
|
|
+ return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
|
|
|
+ } else if (n_dims == 3) {
|
|
|
+ if (dim == 0) {
|
|
|
+ return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
|
|
|
+ } else { // dim == 1
|
|
|
+ return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
|
|
|
+ }
|
|
|
+ } else { // n_dims == 4
|
|
|
+ if (dim == 0) {
|
|
|
+ return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
|
|
|
+ } else if (dim == 1) {
|
|
|
+ return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
|
|
|
+ } else { // dim == 2
|
|
|
+ return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Split the given dimension.
|
|
|
+/// (..., K * N, ...) -> (..., K, N, ...)
|
|
|
+/// dim is the position of the output dimension with the given number of element (N).
|
|
|
+ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
|
|
|
+ int n_dims = x->n_dims;
|
|
|
+ GGML_ASSERT(dim >= 0);
|
|
|
+ GGML_ASSERT(dim < n_dims);
|
|
|
+ GGML_ASSERT(n_dims < 4);
|
|
|
+ if (n_dims == 1) {
|
|
|
+ return ggml_reshape_2d(ctx, x, num_el, x->ne[0] / num_el);
|
|
|
+ } else if (n_dims == 2) {
|
|
|
+ if (dim == 0) {
|
|
|
+ return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
|
|
|
+ } else { // dim == 1
|
|
|
+ return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
|
|
|
+ }
|
|
|
+ } else { // (n_dims == 3)
|
|
|
+ if (dim == 0) {
|
|
|
+ return ggml_reshape_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2]);
|
|
|
+ } else if (dim == 1) {
|
|
|
+ return ggml_reshape_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2]);
|
|
|
+ } else { // dim == 2
|
|
|
+ return ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
|
|
|
+ // (B, S, dim) -> (B, S, H, H_dim)
|
|
|
+ x = ggml_unflatten_1d(ctx, x, 0, head_dim);
|
|
|
+ // (B?, S, H, H_dim) -> (B?, H, S, H_dim)
|
|
|
x = ggml_permute(ctx, x, 0, 2, 1, 3);
|
|
|
return x;
|
|
|
}
|
|
|
|
|
|
+/// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
|
|
|
+ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
|
|
|
+ // (B, Sk, dim) -> (B, Sk, H, H_dim)
|
|
|
+ v = ggml_unflatten_1d(ctx, v, 0, head_dim);
|
|
|
+ v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
|
|
|
+ return v;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
// flash_attn doesn't work for cross attention because it assumes Q <= K
|
|
|
// TODO: enable flash_attn only for the encoder
|
|
|
# define UNITY_FLASH_ATTN 0
|
|
@@ -115,21 +180,21 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
ggml_tensor* values, // (klen, d_out)
|
|
|
ggml_tensor* mask // (klen, slen)
|
|
|
) {
|
|
|
- int slen = queries->ne[1];
|
|
|
- int slenk = keys->ne[1];
|
|
|
- int num_heads = 16;
|
|
|
- int head_dim = queries->ne[0] / num_heads;
|
|
|
+ int model_dim = queries->ne[0];
|
|
|
+ int num_heads = 16; // TODO: read from hparams
|
|
|
+ int head_dim = model_dim / num_heads;
|
|
|
+ GGML_ASSERT(model_dim % num_heads == 0);
|
|
|
+
|
|
|
ggml_context* ctx = model.ctx;
|
|
|
ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
|
|
|
- q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
|
|
|
+ q = _reshape_num_head(ctx, q, head_dim); // (B, H, S, H_dim)
|
|
|
ggml_set_name(q, "q");
|
|
|
ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
|
- k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
|
|
|
+ k = _reshape_num_head(ctx, k, head_dim); // (B, H, Sk, H_dim)
|
|
|
ggml_set_name(k, "k");
|
|
|
|
|
|
ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
|
|
|
- v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
|
|
|
- v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
|
|
|
+ v = _reshape_num_head_values(ctx, v, head_dim); // (B, H, H_dim, Sk)
|
|
|
v = ggml_cont(ctx, v);
|
|
|
ggml_set_name(v, "v");
|
|
|
|
|
@@ -137,11 +202,11 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
|
ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
|
|
|
ggml_set_name(attn, "attn");
|
|
|
- attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
|
|
|
+ attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
|
|
|
attn = ggml_cont(ctx, attn);
|
|
|
- attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
|
|
|
+ attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
|
|
|
#else
|
|
|
- // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
|
|
|
+ // (B, H, Sk, H_dim) x (B, H, S, H_dim) -> (B, H, S, Sk)
|
|
|
ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
|
|
|
ggml_set_name(qk, "qk");
|
|
|
ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
|
|
@@ -149,20 +214,21 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
|
qk = ggml_scale(ctx, qk, qk_scale);
|
|
|
ggml_set_name(qk, "qk_scaled");
|
|
|
|
|
|
+ // TODO: Should we replace this by ggml_diag_mask_inf ?
|
|
|
if (mask) qk = ggml_add(ctx, qk, mask);
|
|
|
// TODO: upgrade qk to float32 if needed
|
|
|
- ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
|
|
|
+ ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B, H, S, Sk)
|
|
|
ggml_set_name(attn_weights, "attn_weights");
|
|
|
|
|
|
- // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
|
|
|
+ // (B, H, S, Sk) x (B, H, H_dim, Sk) -> (B, H, H_dim, S)
|
|
|
ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
|
|
|
ggml_set_name(attn, "attn");
|
|
|
- attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
|
|
|
- attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
|
|
|
+ attn = ggml_flatten_1d(ctx, attn, 1); // (B, H * H_dim, S)
|
|
|
+ attn = ggml_transpose(ctx, attn); // (B, S, H * H_dim)
|
|
|
// // I'm not sure why this one is needed ...
|
|
|
attn = ggml_cont(ctx, attn);
|
|
|
#endif // UNITY_FLASH_ATTN
|
|
|
- // out -> (S, d_out)
|
|
|
+ // out -> (B, S, d_out)
|
|
|
ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
|
|
|
ggml_set_name(out, "out");
|
|
|
|