fairseq2.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. /// allocate the fairseq2 model and hyperparameters
  5. extern "C" fairseq2_model* fairseq2_model_alloc() {
  6. // pre-allocate some memory to write hyperparameters and tensors pointers
  7. auto* model = new fairseq2_model;
  8. model->hparams = new std::uint8_t[8 * 1024];
  9. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  10. model->tensors_ctx = nullptr;
  11. return model;
  12. };
  13. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  14. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  15. delete (std::uint64_t*)(model->arch);
  16. delete (std::uint8_t*)model->hparams;
  17. delete model;
  18. };
  19. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  20. model->ctx = ctx;
  21. }
  22. extern "C" std::string* std_string_alloc(char* c_str) {
  23. return new std::string(c_str);
  24. }
  25. extern "C" void std_string_free(std::string* str) {
  26. delete str;
  27. }
  28. extern "C" ggml_tensor* Linear_forward(
  29. fairseq2_model& model,
  30. const std::string &prefix,
  31. ggml_tensor* input // (d_in)
  32. ) {
  33. // Note: for now we assumed un-batched input
  34. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  35. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  36. return ggml_add(
  37. model.ctx,
  38. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  39. bias
  40. );
  41. }
  42. extern "C" ggml_tensor* LayerNorm_forward(
  43. fairseq2_model& model,
  44. const std::string &prefix,
  45. ggml_tensor* input) {
  46. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  47. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  48. auto ctx = model.ctx;
  49. // TODO: should `eps` be part of unity hparams ?
  50. input = ggml_norm(ctx, input, /*eps*/1e-5);
  51. return ggml_add(
  52. ctx,
  53. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  54. ggml_repeat(ctx, bias, input)
  55. );
  56. }
  57. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  58. fairseq2_model& model,
  59. const std::string& prefix,
  60. ggml_tensor* seqs
  61. ) {
  62. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  63. // inner_activation = ReLu // TODO: allow other activation
  64. seqs = ggml_relu(model.ctx, seqs);
  65. if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
  66. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  67. }
  68. // TODO: inference dropout
  69. // if self.inner_dropout is not None:
  70. // seqs = self.inner_dropout(seqs)
  71. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  72. return seqs;
  73. }
  74. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  75. int slen = x->ne[1];
  76. int model_dim = x->ne[0];
  77. // (S, dim) -> (S, H, H_dim)
  78. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  79. // (S, H, H_dim) -> (H, S, H_dim)
  80. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  81. return x;
  82. }
  83. # define UNITY_FLASH_ATTN
  84. extern "C" ggml_tensor* MultiheadAttention_forward(
  85. fairseq2_model& model,
  86. const std::string &prefix,
  87. ggml_tensor* queries, // (slen, d_in)
  88. ggml_tensor* keys, // (klen, d_in)
  89. ggml_tensor* values, // (klen, d_out)
  90. ggml_tensor* mask // (klen, slen)
  91. ) {
  92. int slen = queries->ne[1];
  93. int slenk = keys->ne[1];
  94. int num_heads = 16;
  95. int head_dim = queries->ne[0] / num_heads;
  96. ggml_context* ctx = model.ctx;
  97. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  98. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  99. ggml_set_name(q, "q");
  100. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  101. k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
  102. ggml_set_name(k, "k");
  103. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  104. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
  105. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
  106. v = ggml_cont(ctx, v);
  107. ggml_set_name(v, "v");
  108. #ifdef UNITY_FLASH_ATTN
  109. // For flash_attn, we assume either no masks, or triangular masks.
  110. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  111. ggml_set_name(attn, "attn");
  112. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  113. attn = ggml_cont(ctx, attn);
  114. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
  115. #else
  116. // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
  117. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  118. ggml_set_name(qk, "qk");
  119. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  120. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  121. qk = ggml_scale(ctx, qk, qk_scale);
  122. ggml_set_name(qk, "qk_scaled");
  123. if (mask) qk = ggml_add(ctx, qk, mask);
  124. // TODO: upgrade qk to float32 if needed
  125. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
  126. ggml_set_name(attn_weights, "attn_weights");
  127. // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
  128. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  129. ggml_set_name(attn, "attn");
  130. attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
  131. attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
  132. // // I'm not sure why this one is needed ...
  133. attn = ggml_cont(ctx, attn);
  134. #endif // UNITY_FLASH_ATTN
  135. // out -> (S, d_out)
  136. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  137. ggml_set_name(out, "out");
  138. return out;
  139. }
  140. // ggml_tensor* attn_weights = ggml_mul_mat(ctx, q, k); // (H, S, S)
  141. // attn_weights = ggm_mul * (q.size(-1) ** -0.5)
  142. // if mask is not None:
  143. // attn_weights = attn_weights + mask
  144. // # For numerical stability run in single precision.
  145. // attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)
  146. // attn_weights = attn_weights.type_as(q)
  147. // if training and dropout_p > 0.0:
  148. // attn_weights = dropout(attn_weights, dropout_p)
  149. // # (*, S, S_kv) @ (*, S_kv, V) = (*, S, V)
  150. // attn = torch.matmul(attn_weights, values)
  151. // return attn, attn_weights if needs_weights else None
  152. // extern "C" ggml_tensor* // (d_out, seq_len)
  153. // SDPA_forward(
  154. // fairseq2_model& model,
  155. // const std::string &prefix,
  156. // ggml_tensor* queries, // (d_in, len_q)
  157. // ggml_tensor* keys, // (d_in, len_k)
  158. // ggml_tensor* values, // (d_out, len_k)
  159. // ggml_tensor* mask // (seq_len, len_q)
  160. // ) {
  161. // return queries;
  162. // }