fairseq2.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. /// allocate the fairseq2 model and hyperparameters
  5. extern "C" fairseq2_model* fairseq2_model_alloc() {
  6. // pre-allocate some memory to write hyperparameters and tensors pointers
  7. auto* model = new fairseq2_model;
  8. model->hparams = new std::uint8_t[8 * 1024];
  9. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  10. model->tensors_ctx = nullptr;
  11. return model;
  12. };
  13. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  14. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  15. delete (std::uint64_t*)(model->arch);
  16. delete (std::uint8_t*)model->hparams;
  17. delete model;
  18. };
  19. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  20. model->ctx = ctx;
  21. }
  22. extern "C" std::string* std_string_alloc(char* c_str) {
  23. return new std::string(c_str);
  24. }
  25. extern "C" void std_string_free(std::string* str) {
  26. delete str;
  27. }
  28. extern "C" ggml_tensor* Linear_forward(
  29. fairseq2_model& model,
  30. const std::string &prefix,
  31. ggml_tensor* input // (d_in)
  32. ) {
  33. // Note: for now we assumed un-batched input
  34. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  35. GGML_ASSERT(weight != nullptr);
  36. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  37. GGML_ASSERT(bias != nullptr);
  38. return ggml_add(
  39. model.ctx,
  40. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  41. bias
  42. );
  43. }
  44. extern "C" ggml_tensor* LayerNorm_forward(
  45. fairseq2_model& model,
  46. const std::string &prefix,
  47. ggml_tensor* input) {
  48. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  49. GGML_ASSERT(weight != nullptr);
  50. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  51. GGML_ASSERT(bias != nullptr);
  52. auto ctx = model.ctx;
  53. // TODO: should `eps` be part of unity hparams ?
  54. input = ggml_norm(ctx, input, /*eps*/1e-5);
  55. return ggml_add(
  56. ctx,
  57. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  58. ggml_repeat(ctx, bias, input)
  59. );
  60. }
  61. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  62. fairseq2_model& model,
  63. const std::string& prefix,
  64. ggml_tensor* seqs
  65. ) {
  66. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  67. // inner_activation = ReLu // TODO: allow other activation
  68. seqs = ggml_relu(model.ctx, seqs);
  69. if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
  70. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  71. }
  72. // TODO: inference dropout
  73. // if self.inner_dropout is not None:
  74. // seqs = self.inner_dropout(seqs)
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  79. int slen = x->ne[1];
  80. int model_dim = x->ne[0];
  81. // (S, dim) -> (S, H, H_dim)
  82. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  83. // (S, H, H_dim) -> (H, S, H_dim)
  84. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  85. return x;
  86. }
  87. # define UNITY_FLASH_ATTN
  88. extern "C" ggml_tensor* MultiheadAttention_forward(
  89. fairseq2_model& model,
  90. const std::string &prefix,
  91. ggml_tensor* queries, // (slen, d_in)
  92. ggml_tensor* keys, // (klen, d_in)
  93. ggml_tensor* values, // (klen, d_out)
  94. ggml_tensor* mask // (klen, slen)
  95. ) {
  96. int slen = queries->ne[1];
  97. int slenk = keys->ne[1];
  98. int num_heads = 16;
  99. int head_dim = queries->ne[0] / num_heads;
  100. ggml_context* ctx = model.ctx;
  101. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  102. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  103. ggml_set_name(q, "q");
  104. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  105. k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
  106. ggml_set_name(k, "k");
  107. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  108. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
  109. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
  110. v = ggml_cont(ctx, v);
  111. ggml_set_name(v, "v");
  112. #ifdef UNITY_FLASH_ATTN
  113. // For flash_attn, we assume either no masks, or triangular masks.
  114. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  115. ggml_set_name(attn, "attn");
  116. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  117. attn = ggml_cont(ctx, attn);
  118. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
  119. #else
  120. // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
  121. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  122. ggml_set_name(qk, "qk");
  123. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  124. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  125. qk = ggml_scale(ctx, qk, qk_scale);
  126. ggml_set_name(qk, "qk_scaled");
  127. if (mask) qk = ggml_add(ctx, qk, mask);
  128. // TODO: upgrade qk to float32 if needed
  129. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
  130. ggml_set_name(attn_weights, "attn_weights");
  131. // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
  132. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  133. ggml_set_name(attn, "attn");
  134. attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
  135. attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
  136. // // I'm not sure why this one is needed ...
  137. attn = ggml_cont(ctx, attn);
  138. #endif // UNITY_FLASH_ATTN
  139. // out -> (S, d_out)
  140. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  141. ggml_set_name(out, "out");
  142. return out;
  143. }
  144. bool has_layer(fairseq2_model& model, const std::string& name) {
  145. return model.tensors.find(name) != model.tensors.end();
  146. }
  147. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  148. fairseq2_model& model,
  149. const std::string& prefix,
  150. ggml_tensor* seqs,
  151. ggml_tensor* padding_mask
  152. ) {
  153. ggml_context* ctx = model.ctx;
  154. // TODO: read norm_order from model
  155. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  156. // _forward_self_attn(seqs, padding_mask)
  157. auto residual = seqs;
  158. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  159. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  160. // TODO: add padding_mask to MultiheadAttention_forward
  161. GGML_ASSERT(padding_mask == nullptr);
  162. seqs = MultiheadAttention_forward(
  163. model,
  164. prefix + ".self_attn",
  165. seqs,
  166. seqs,
  167. seqs,
  168. /*attention masks=*/nullptr
  169. );
  170. if (has_layer(model, prefix + ".self_attn_norm.weight"))
  171. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  172. // TODO: seqs = self.self_attn_dropout(seqs)
  173. seqs = ggml_add(ctx, seqs, residual);
  174. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  175. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  176. // _forward_ffn(seqs)
  177. residual = seqs;
  178. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  179. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  180. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  181. // TODO:
  182. // seqs = self.ffn_dropout(seqs)
  183. // if self.residual_scale is not None:
  184. // residual = self.residual_scale * residual
  185. seqs = ggml_add(ctx, seqs, residual);
  186. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  187. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  188. return seqs;
  189. }
  190. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  191. fairseq2_model& model,
  192. const std::string& prefix,
  193. ggml_tensor* seqs,
  194. ggml_tensor* padding_mask
  195. ) {
  196. int layer_idx = 0;
  197. // TODO: this isn't nice.
  198. // When loading model we should add nullptr for the module key to avoid those concatenation.
  199. while (has_layer(model, prefix + ".layers." + std::to_string(layer_idx) + ".self_attn_layer_norm.weight")) {
  200. seqs = StandardTransformerEncoderLayer_forward(
  201. model, prefix + ".layers." + std::to_string(layer_idx), seqs, padding_mask
  202. );
  203. ggml_set_name(seqs, ("x_" + std::to_string(layer_idx)).c_str());
  204. layer_idx += 1;
  205. }
  206. if (has_layer(model, prefix + ".layer_norm.weight"))
  207. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  208. return seqs;
  209. }