|
@@ -93,14 +93,11 @@ extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-/// Merge the given dimension and the previous one in the tensor.
|
|
|
|
-/// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
|
|
|
|
-/// dim is the position of the resulting merged dimension
|
|
|
|
-/// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d)
|
|
|
|
ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
|
|
ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
|
|
int n_dims = x->n_dims;
|
|
int n_dims = x->n_dims;
|
|
GGML_ASSERT(dim >= 0);
|
|
GGML_ASSERT(dim >= 0);
|
|
GGML_ASSERT(dim < n_dims);
|
|
GGML_ASSERT(dim < n_dims);
|
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(x));
|
|
// Nothing to do
|
|
// Nothing to do
|
|
if (dim == n_dims - 1) return x;
|
|
if (dim == n_dims - 1) return x;
|
|
|
|
|
|
@@ -123,9 +120,6 @@ ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-/// Split the given dimension.
|
|
|
|
-/// (..., K * N, ...) -> (..., K, N, ...)
|
|
|
|
-/// dim is the position of the output dimension with the given number of element (N).
|
|
|
|
ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
|
|
ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
|
|
int n_dims = x->n_dims;
|
|
int n_dims = x->n_dims;
|
|
GGML_ASSERT(dim >= 0);
|
|
GGML_ASSERT(dim >= 0);
|
|
@@ -137,7 +131,7 @@ ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int n
|
|
if (dim == 0) {
|
|
if (dim == 0) {
|
|
return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
|
|
return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
|
|
} else { // dim == 1
|
|
} else { // dim == 1
|
|
- return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
|
|
|
|
|
|
+ return ggml_reshape_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el);
|
|
}
|
|
}
|
|
} else { // (n_dims == 3)
|
|
} else { // (n_dims == 3)
|
|
if (dim == 0) {
|
|
if (dim == 0) {
|
|
@@ -154,8 +148,9 @@ ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int n
|
|
ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
|
|
ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
|
|
// (B, S, dim) -> (B, S, H, H_dim)
|
|
// (B, S, dim) -> (B, S, H, H_dim)
|
|
x = ggml_unflatten_1d(ctx, x, 0, head_dim);
|
|
x = ggml_unflatten_1d(ctx, x, 0, head_dim);
|
|
- // (B?, S, H, H_dim) -> (B?, H, S, H_dim)
|
|
|
|
- x = ggml_permute(ctx, x, 0, 2, 1, 3);
|
|
|
|
|
|
+ x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
|
|
|
|
+ x = ggml_cont(ctx, x);
|
|
|
|
+ x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
|
|
return x;
|
|
return x;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -164,6 +159,8 @@ ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int hea
|
|
// (B, Sk, dim) -> (B, Sk, H, H_dim)
|
|
// (B, Sk, dim) -> (B, Sk, H, H_dim)
|
|
v = ggml_unflatten_1d(ctx, v, 0, head_dim);
|
|
v = ggml_unflatten_1d(ctx, v, 0, head_dim);
|
|
v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
|
|
v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
|
|
|
|
+ v = ggml_cont(ctx, v);
|
|
|
|
+ v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
|
|
return v;
|
|
return v;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -186,27 +183,27 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
GGML_ASSERT(model_dim % num_heads == 0);
|
|
GGML_ASSERT(model_dim % num_heads == 0);
|
|
|
|
|
|
ggml_context* ctx = model.ctx;
|
|
ggml_context* ctx = model.ctx;
|
|
- ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
|
|
|
|
- q = _reshape_num_head(ctx, q, head_dim); // (B, H, S, H_dim)
|
|
|
|
|
|
+ ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
|
|
ggml_set_name(q, "q");
|
|
ggml_set_name(q, "q");
|
|
|
|
+ q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
|
|
ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
|
|
- k = _reshape_num_head(ctx, k, head_dim); // (B, H, Sk, H_dim)
|
|
|
|
ggml_set_name(k, "k");
|
|
ggml_set_name(k, "k");
|
|
|
|
+ k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
|
|
|
|
|
|
ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
|
|
ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
|
|
- v = _reshape_num_head_values(ctx, v, head_dim); // (B, H, H_dim, Sk)
|
|
|
|
- v = ggml_cont(ctx, v);
|
|
|
|
ggml_set_name(v, "v");
|
|
ggml_set_name(v, "v");
|
|
|
|
+ v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
|
|
|
|
+ v = ggml_cont(ctx, v);
|
|
|
|
|
|
#if UNITY_FLASH_ATTN
|
|
#if UNITY_FLASH_ATTN
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
// For flash_attn, we assume either no masks, or triangular masks.
|
|
- ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
|
|
|
|
|
|
+ ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (B * H, S, H_dim)
|
|
ggml_set_name(attn, "attn");
|
|
ggml_set_name(attn, "attn");
|
|
|
|
+ // TODO test !
|
|
|
|
+ attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
|
|
attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
|
|
attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
|
|
- attn = ggml_cont(ctx, attn);
|
|
|
|
- attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
|
|
|
|
#else
|
|
#else
|
|
- // (B, H, Sk, H_dim) x (B, H, S, H_dim) -> (B, H, S, Sk)
|
|
|
|
|
|
+ // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
|
|
ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
|
|
ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
|
|
ggml_set_name(qk, "qk");
|
|
ggml_set_name(qk, "qk");
|
|
ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
|
|
ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
|
|
@@ -217,17 +214,17 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
|
|
// TODO: Should we replace this by ggml_diag_mask_inf ?
|
|
// TODO: Should we replace this by ggml_diag_mask_inf ?
|
|
if (mask) qk = ggml_add(ctx, qk, mask);
|
|
if (mask) qk = ggml_add(ctx, qk, mask);
|
|
// TODO: upgrade qk to float32 if needed
|
|
// TODO: upgrade qk to float32 if needed
|
|
- ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B, H, S, Sk)
|
|
|
|
|
|
+ ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
|
|
ggml_set_name(attn_weights, "attn_weights");
|
|
ggml_set_name(attn_weights, "attn_weights");
|
|
|
|
|
|
- // (B, H, S, Sk) x (B, H, H_dim, Sk) -> (B, H, H_dim, S)
|
|
|
|
|
|
+ // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
|
|
ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
|
|
ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
|
|
ggml_set_name(attn, "attn");
|
|
ggml_set_name(attn, "attn");
|
|
- attn = ggml_flatten_1d(ctx, attn, 1); // (B, H * H_dim, S)
|
|
|
|
- attn = ggml_transpose(ctx, attn); // (B, S, H * H_dim)
|
|
|
|
- // // I'm not sure why this one is needed ...
|
|
|
|
- attn = ggml_cont(ctx, attn);
|
|
|
|
|
|
+ attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
|
|
|
|
+ attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
|
|
#endif // UNITY_FLASH_ATTN
|
|
#endif // UNITY_FLASH_ATTN
|
|
|
|
+ attn = ggml_cont(ctx, attn);
|
|
|
|
+ attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
|
|
// out -> (B, S, d_out)
|
|
// out -> (B, S, d_out)
|
|
ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
|
|
ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
|
|
ggml_set_name(out, "out");
|
|
ggml_set_name(out, "out");
|