fairseq2.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #include "ggml.h"
  2. #include "fairseq2.h"
  3. /// allocate the fairseq2 model and hyperparameters
  4. extern "C" fairseq2_model* fairseq2_model_alloc() {
  5. // pre-allocate some memory to write hyperparameters and tensors pointers
  6. auto* model = new fairseq2_model;
  7. model->hparams = new std::uint8_t[8 * 1024];
  8. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  9. model->tensors_ctx = nullptr;
  10. return model;
  11. };
  12. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  13. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  14. delete (std::uint64_t*)(model->arch);
  15. delete (std::uint8_t*)model->hparams;
  16. delete model;
  17. };
  18. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  19. model->ctx = ctx;
  20. }
  21. extern "C" std::string* std_string_alloc(char* c_str) {
  22. return new std::string(c_str);
  23. }
  24. extern "C" void std_string_free(std::string* str) {
  25. delete str;
  26. }
  27. // Linear
  28. std::size_t Linear_size(int32_t input_dim, int32_t output_dim)
  29. {
  30. return (input_dim * output_dim * ggml_type_size(GGML_TYPE_F32)) // weight
  31. + (output_dim * ggml_type_size(GGML_TYPE_F32)); // bias
  32. };
  33. void Linear_init(
  34. Linear& self,
  35. fairseq2_model& model,
  36. const std::string &prefix,
  37. int input_dim,
  38. int output_dim,
  39. bool bias
  40. ) {
  41. self.weight = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, output_dim, input_dim);
  42. model.tensors[prefix + ".weight"] = self.weight;
  43. if (bias) {
  44. self.bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, output_dim);
  45. model.tensors[prefix + ".inner_proj.bias"] = self.bias;
  46. }
  47. }
  48. extern "C" ggml_tensor*
  49. Linear_forward(
  50. fairseq2_model& model,
  51. const std::string &prefix,
  52. ggml_tensor* input // (d_in)
  53. ) {
  54. // Note: for now we assumed un-batched input
  55. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  56. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  57. return ggml_add(
  58. model.ctx,
  59. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  60. bias
  61. );
  62. }
  63. // LayerNorm
  64. std::size_t LayerNorm_size(int32_t dim)
  65. {
  66. return 2 * dim * ggml_type_size(GGML_TYPE_F32); // weight and bias
  67. };
  68. void LayerNorm_init(
  69. LayerNorm& self,
  70. fairseq2_model& model,
  71. const std::string &prefix,
  72. int dim
  73. ) {
  74. self.weight = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
  75. model.tensors[prefix + ".weight"] = self.weight;
  76. self.bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
  77. model.tensors[prefix + ".bias"] = self.bias;
  78. }
  79. extern "C" ggml_tensor* LayerNorm_forward(
  80. fairseq2_model& model,
  81. const std::string &prefix,
  82. ggml_tensor* input) {
  83. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  84. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  85. auto ctx = model.ctx;
  86. // TODO: should `eps` be part of unity hparams ?
  87. input = ggml_norm(ctx, input, /*eps*/1e-5);
  88. return ggml_add(
  89. ctx,
  90. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  91. ggml_repeat(ctx, bias, input)
  92. );
  93. }
  94. std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim)
  95. {
  96. return LayerNorm_size(dim) + Linear_size(dim, inner_dim) + Linear_size(inner_dim, dim);
  97. };
  98. void StandardFeedForwardNetwork_init(
  99. StandardFeedForwardNetwork& self,
  100. fairseq2_model& model,
  101. const std::string &prefix,
  102. int model_dim,
  103. int inner_dim
  104. ) {
  105. Linear_init(self.inner_proj, model, prefix + ".inner_proj", model_dim, inner_dim, true);
  106. LayerNorm_init(self.inner_layer_norm, model, prefix + ".inner_layer_norm", inner_dim);
  107. Linear_init(self.output_proj, model, prefix + ".output_proj", inner_dim, model_dim, true);
  108. }
  109. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  110. fairseq2_model& model,
  111. const std::string& prefix,
  112. ggml_tensor* seqs
  113. ) {
  114. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  115. // inner_activation = ReLu // TODO: allow other activation
  116. seqs = ggml_relu(model.ctx, seqs);
  117. if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
  118. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  119. }
  120. // TODO: inference dropout
  121. // if self.inner_dropout is not None:
  122. // seqs = self.inner_dropout(seqs)
  123. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  124. return seqs;
  125. }
  126. void MultiheadAttention_init(
  127. MultiheadAttention& self,
  128. fairseq2_model& model,
  129. const std::string &prefix,
  130. int model_dim,
  131. int num_heads
  132. ) {
  133. int bias = true;
  134. int num_key_value_heads = num_heads;
  135. int head_dim = model_dim / num_heads;
  136. Linear_init(self.q_proj, model, prefix + ".q_proj", model_dim, model_dim, bias);
  137. Linear_init(self.k_proj, model, prefix + ".k_proj", model_dim, head_dim * num_key_value_heads, bias);
  138. Linear_init(self.v_proj, model, prefix + ".v_proj", model_dim, model_dim, bias);
  139. // (H, 1, K_h)
  140. self.bias_k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, num_heads, 1, head_dim * num_key_value_heads/ num_heads);
  141. // (H, 1, V_h)
  142. self.bias_v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, num_heads, 1, model_dim / num_heads);
  143. }
  144. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  145. int slen = x->ne[1];
  146. int model_dim = x->ne[0];
  147. // (S, dim) -> (S, H, H_dim)
  148. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  149. // (S, H, H_dim) -> (H, S, H_dim)
  150. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  151. return x;
  152. }
  153. extern "C" ggml_tensor* // (slen, d_in)
  154. MultiheadAttention_forward(
  155. fairseq2_model& model,
  156. const std::string &prefix,
  157. ggml_tensor* queries, // (slen, d_in)
  158. ggml_tensor* keys, // (klen, d_in)
  159. ggml_tensor* values, // (klen, d_out)
  160. ggml_tensor* _ // (klen, slen) TODO: do we need to pass mask here ?
  161. ) {
  162. int slen = queries->ne[1];
  163. int num_heads = 16;
  164. int head_dim = queries->ne[0] / num_heads;
  165. ggml_context* ctx = model.ctx;
  166. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  167. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  168. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  169. k = reshape_num_head(ctx, k, num_heads); // (H, S, H_dim)
  170. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  171. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slen); // (S, H, H_dim)
  172. // v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, S)
  173. v = ggml_permute(ctx, v, 1, 0, 2, 3); // (S, H_dim, H)
  174. v = ggml_cont(ctx, v);
  175. // ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/false); // (H, S, H_dim)
  176. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  177. attn = ggml_cont(ctx, attn);
  178. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * V_h)
  179. attn = Linear_forward(model, prefix + ".output_proj", attn); // (S, d_out)
  180. return attn;
  181. }
  182. // ggml_tensor* attn_weights = ggml_mul_mat(ctx, q, k); // (H, S, S)
  183. // attn_weights = ggm_mul * (q.size(-1) ** -0.5)
  184. // if mask is not None:
  185. // attn_weights = attn_weights + mask
  186. // # For numerical stability run in single precision.
  187. // attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)
  188. // attn_weights = attn_weights.type_as(q)
  189. // if training and dropout_p > 0.0:
  190. // attn_weights = dropout(attn_weights, dropout_p)
  191. // # (*, S, S_kv) @ (*, S_kv, V) = (*, S, V)
  192. // attn = torch.matmul(attn_weights, values)
  193. // return attn, attn_weights if needs_weights else None
  194. // extern "C" ggml_tensor* // (d_out, seq_len)
  195. // SDPA_forward(
  196. // fairseq2_model& model,
  197. // const std::string &prefix,
  198. // ggml_tensor* queries, // (d_in, len_q)
  199. // ggml_tensor* keys, // (d_in, len_k)
  200. // ggml_tensor* values, // (d_out, len_k)
  201. // ggml_tensor* mask // (seq_len, len_q)
  202. // ) {
  203. // return queries;
  204. // }