fairseq2.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #include "ggml.h"
  2. #include "fairseq2.h"
  3. /// allocate the fairseq2 model and hyperparameters
  4. extern "C" fairseq2_model* fairseq2_model_alloc() {
  5. // pre-allocate some memory to write hyperparameters and tensors pointers
  6. auto* model = new fairseq2_model;
  7. model->hparams = new std::uint8_t[8 * 1024];
  8. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  9. model->tensors_ctx = nullptr;
  10. return model;
  11. };
  12. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  13. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  14. delete (std::uint64_t*)(model->arch);
  15. delete (std::uint8_t*)model->hparams;
  16. delete model;
  17. };
  18. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  19. model->ctx = ctx;
  20. }
  21. extern "C" std::string* std_string_alloc(char* c_str) {
  22. return new std::string(c_str);
  23. }
  24. extern "C" void std_string_free(std::string* str) {
  25. delete str;
  26. }
  27. extern "C" ggml_tensor* Linear_forward(
  28. fairseq2_model& model,
  29. const std::string &prefix,
  30. ggml_tensor* input // (d_in)
  31. ) {
  32. // Note: for now we assumed un-batched input
  33. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  34. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  35. return ggml_add(
  36. model.ctx,
  37. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  38. bias
  39. );
  40. }
  41. extern "C" ggml_tensor* LayerNorm_forward(
  42. fairseq2_model& model,
  43. const std::string &prefix,
  44. ggml_tensor* input) {
  45. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  46. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  47. auto ctx = model.ctx;
  48. // TODO: should `eps` be part of unity hparams ?
  49. input = ggml_norm(ctx, input, /*eps*/1e-5);
  50. return ggml_add(
  51. ctx,
  52. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  53. ggml_repeat(ctx, bias, input)
  54. );
  55. }
  56. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  57. fairseq2_model& model,
  58. const std::string& prefix,
  59. ggml_tensor* seqs
  60. ) {
  61. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  62. // inner_activation = ReLu // TODO: allow other activation
  63. seqs = ggml_relu(model.ctx, seqs);
  64. if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
  65. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  66. }
  67. // TODO: inference dropout
  68. // if self.inner_dropout is not None:
  69. // seqs = self.inner_dropout(seqs)
  70. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  71. return seqs;
  72. }
  73. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  74. int slen = x->ne[1];
  75. int model_dim = x->ne[0];
  76. // (S, dim) -> (S, H, H_dim)
  77. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  78. // (S, H, H_dim) -> (H, S, H_dim)
  79. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  80. return x;
  81. }
  82. // TODO: borken
  83. extern "C" ggml_tensor* MultiheadAttention_forward(
  84. fairseq2_model& model,
  85. const std::string &prefix,
  86. ggml_tensor* queries, // (slen, d_in)
  87. ggml_tensor* keys, // (klen, d_in)
  88. ggml_tensor* values, // (klen, d_out)
  89. ggml_tensor* _ // (klen, slen) TODO: do we need to pass mask here ?
  90. ) {
  91. int slen = queries->ne[1];
  92. int num_heads = 16;
  93. int head_dim = queries->ne[0] / num_heads;
  94. ggml_context* ctx = model.ctx;
  95. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  96. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  97. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  98. k = reshape_num_head(ctx, k, num_heads); // (H, S, H_dim)
  99. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  100. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slen); // (S, H, H_dim)
  101. // v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, S)
  102. v = ggml_permute(ctx, v, 1, 0, 2, 3); // (S, H_dim, H)
  103. v = ggml_cont(ctx, v);
  104. // ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/false); // (H, S, H_dim)
  105. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  106. attn = ggml_cont(ctx, attn);
  107. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * V_h)
  108. attn = Linear_forward(model, prefix + ".output_proj", attn); // (S, d_out)
  109. return attn;
  110. }
  111. // ggml_tensor* attn_weights = ggml_mul_mat(ctx, q, k); // (H, S, S)
  112. // attn_weights = ggm_mul * (q.size(-1) ** -0.5)
  113. // if mask is not None:
  114. // attn_weights = attn_weights + mask
  115. // # For numerical stability run in single precision.
  116. // attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)
  117. // attn_weights = attn_weights.type_as(q)
  118. // if training and dropout_p > 0.0:
  119. // attn_weights = dropout(attn_weights, dropout_p)
  120. // # (*, S, S_kv) @ (*, S_kv, V) = (*, S, V)
  121. // attn = torch.matmul(attn_weights, values)
  122. // return attn, attn_weights if needs_weights else None
  123. // extern "C" ggml_tensor* // (d_out, seq_len)
  124. // SDPA_forward(
  125. // fairseq2_model& model,
  126. // const std::string &prefix,
  127. // ggml_tensor* queries, // (d_in, len_q)
  128. // ggml_tensor* keys, // (d_in, len_k)
  129. // ggml_tensor* values, // (d_out, len_k)
  130. // ggml_tensor* mask // (seq_len, len_q)
  131. // ) {
  132. // return queries;
  133. // }