fairseq2.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. /// allocate the fairseq2 model and hyperparameters
  5. extern "C" fairseq2_model* fairseq2_model_alloc() {
  6. // pre-allocate some memory to write hyperparameters and tensors pointers
  7. auto* model = new fairseq2_model;
  8. model->hparams = new std::uint8_t[8 * 1024];
  9. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  10. model->tensors_ctx = nullptr;
  11. return model;
  12. };
  13. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  14. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  15. delete (std::uint64_t*)(model->arch);
  16. delete (std::uint8_t*)model->hparams;
  17. delete model;
  18. };
  19. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  20. model->ctx = ctx;
  21. }
  22. extern "C" std::string* std_string_alloc(char* c_str) {
  23. return new std::string(c_str);
  24. }
  25. extern "C" void std_string_free(std::string* str) {
  26. delete str;
  27. }
  28. bool has_layer(fairseq2_model& model, const std::string& name) {
  29. return model.tensors.find(name) != model.tensors.end();
  30. }
  31. extern "C" ggml_tensor* Linear_forward(
  32. fairseq2_model& model,
  33. const std::string &prefix,
  34. ggml_tensor* input // (d_in)
  35. ) {
  36. // Note: for now we assumed un-batched input
  37. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  38. GGML_ASSERT(weight != nullptr);
  39. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  40. GGML_ASSERT(bias != nullptr);
  41. return ggml_add(
  42. model.ctx,
  43. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  44. bias
  45. );
  46. }
  47. extern "C" ggml_tensor* LayerNorm_forward(
  48. fairseq2_model& model,
  49. const std::string &prefix,
  50. ggml_tensor* input) {
  51. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  52. GGML_ASSERT(weight != nullptr);
  53. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  54. GGML_ASSERT(bias != nullptr);
  55. auto ctx = model.ctx;
  56. // TODO: should `eps` be part of unity hparams ?
  57. input = ggml_norm(ctx, input, /*eps*/1e-5);
  58. return ggml_add(
  59. ctx,
  60. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  61. ggml_repeat(ctx, bias, input)
  62. );
  63. }
  64. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  65. fairseq2_model& model,
  66. const std::string& prefix,
  67. ggml_tensor* seqs
  68. ) {
  69. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  70. // inner_activation = ReLu // TODO: allow other activation
  71. seqs = ggml_relu(model.ctx, seqs);
  72. if (has_layer(model, prefix + ".inner_layer_norm")) {
  73. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  74. }
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  79. int slen = x->ne[1];
  80. int model_dim = x->ne[0];
  81. // (S, dim) -> (S, H, H_dim)
  82. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  83. // (S, H, H_dim) -> (H, S, H_dim)
  84. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  85. return x;
  86. }
  87. # define UNITY_FLASH_ATTN
  88. extern "C" ggml_tensor* MultiheadAttention_forward(
  89. fairseq2_model& model,
  90. const std::string &prefix,
  91. ggml_tensor* queries, // (slen, d_in)
  92. ggml_tensor* keys, // (klen, d_in)
  93. ggml_tensor* values, // (klen, d_out)
  94. ggml_tensor* mask // (klen, slen)
  95. ) {
  96. int slen = queries->ne[1];
  97. int slenk = keys->ne[1];
  98. int num_heads = 16;
  99. int head_dim = queries->ne[0] / num_heads;
  100. ggml_context* ctx = model.ctx;
  101. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  102. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  103. ggml_set_name(q, "q");
  104. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  105. k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
  106. ggml_set_name(k, "k");
  107. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  108. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
  109. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
  110. v = ggml_cont(ctx, v);
  111. ggml_set_name(v, "v");
  112. #ifdef UNITY_FLASH_ATTN
  113. // For flash_attn, we assume either no masks, or triangular masks.
  114. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  115. ggml_set_name(attn, "attn");
  116. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  117. attn = ggml_cont(ctx, attn);
  118. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
  119. #else
  120. // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
  121. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  122. ggml_set_name(qk, "qk");
  123. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  124. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  125. qk = ggml_scale(ctx, qk, qk_scale);
  126. ggml_set_name(qk, "qk_scaled");
  127. if (mask) qk = ggml_add(ctx, qk, mask);
  128. // TODO: upgrade qk to float32 if needed
  129. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
  130. ggml_set_name(attn_weights, "attn_weights");
  131. // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
  132. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  133. ggml_set_name(attn, "attn");
  134. attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
  135. attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
  136. // // I'm not sure why this one is needed ...
  137. attn = ggml_cont(ctx, attn);
  138. #endif // UNITY_FLASH_ATTN
  139. // out -> (S, d_out)
  140. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  141. ggml_set_name(out, "out");
  142. return out;
  143. }
  144. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  145. fairseq2_model& model,
  146. const std::string& prefix,
  147. ggml_tensor* seqs,
  148. ggml_tensor* padding_mask
  149. ) {
  150. ggml_context* ctx = model.ctx;
  151. // TODO: read norm_order from model
  152. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  153. // _forward_self_attn(seqs, padding_mask)
  154. auto residual = seqs;
  155. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  156. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  157. // TODO: add padding_mask to MultiheadAttention_forward
  158. GGML_ASSERT(padding_mask == nullptr);
  159. seqs = MultiheadAttention_forward(
  160. model,
  161. prefix + ".self_attn",
  162. seqs,
  163. seqs,
  164. seqs,
  165. /*attention masks=*/nullptr
  166. );
  167. if (has_layer(model, prefix + ".self_attn_norm"))
  168. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  169. seqs = ggml_add(ctx, seqs, residual);
  170. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  171. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  172. // _forward_ffn(seqs)
  173. residual = seqs;
  174. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  175. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  176. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  177. // TODO: if self.residual_scale is not None:
  178. // residual = self.residual_scale * residual
  179. seqs = ggml_add(ctx, seqs, residual);
  180. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  181. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  182. return seqs;
  183. }
  184. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  185. fairseq2_model& model,
  186. const std::string& prefix,
  187. ggml_tensor* seqs,
  188. ggml_tensor* padding_mask
  189. ) {
  190. int layer_idx = 0;
  191. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  192. while (has_layer(model, layer_name)) {
  193. seqs = StandardTransformerEncoderLayer_forward(
  194. model, layer_name, seqs, padding_mask
  195. );
  196. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  197. layer_idx += 1;
  198. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  199. }
  200. if (has_layer(model, prefix + ".layer_norm"))
  201. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  202. return seqs;
  203. }
  204. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  205. fairseq2_model& model,
  206. const std::string& prefix,
  207. ggml_tensor* seqs,
  208. ggml_tensor* self_attn_mask,
  209. ggml_tensor* encoder_output,
  210. ggml_tensor* encoder_padding_mask
  211. ) {
  212. ggml_context* ctx = model.ctx;
  213. // TODO: read norm_order from model
  214. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  215. // _forward_self_attn(seqs, padding_mask)
  216. auto residual = seqs;
  217. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  218. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  219. seqs = MultiheadAttention_forward(
  220. model,
  221. prefix + ".self_attn",
  222. seqs,
  223. seqs,
  224. seqs,
  225. /*attention masks=*/self_attn_mask
  226. );
  227. if (has_layer(model, prefix + ".self_attn_norm"))
  228. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  229. seqs = ggml_add(ctx, seqs, residual);
  230. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  231. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  232. // _forward_encoder_decoder_attn
  233. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  234. // `encoder_output` must be `None` for decoder-only attention.
  235. GGML_ASSERT(encoder_output == nullptr);
  236. return seqs;
  237. }
  238. // `encoder_output` must not be `None` for encoder-decoder attention.
  239. GGML_ASSERT(encoder_output != nullptr);
  240. residual = seqs;
  241. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  242. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  243. seqs = MultiheadAttention_forward(
  244. model,
  245. prefix + ".encoder_decoder_attn",
  246. seqs,
  247. encoder_output,
  248. encoder_output,
  249. /*attention masks=*/encoder_padding_mask
  250. );
  251. seqs = ggml_add(ctx, seqs, residual);
  252. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  253. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  254. // _forward_ffn(seqs)
  255. residual = seqs;
  256. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  257. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  258. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  259. // TODO:
  260. // if self.residual_scale is not None:
  261. // residual = self.residual_scale * residual
  262. seqs = ggml_add(ctx, seqs, residual);
  263. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  264. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  265. return seqs;
  266. }
  267. ggml_tensor* causal_mask_cache = nullptr;
  268. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  269. auto seq_len = seqs->ne[0];
  270. auto mask = causal_mask_cache;
  271. // TODO: this cache only works as long as we don't change the size/device too often
  272. // TODO: allow other ggml_type
  273. if (mask == nullptr || mask->backend != seqs->backend || mask->ne[0] < seq_len) {
  274. printf("new causal_mask (%ld, %ld) created\n", seq_len, seq_len);
  275. mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  276. char* data = (char*)mask->data;
  277. // tensor([[0., -inf, -inf, -inf],
  278. // [0., 0., -inf, -inf],
  279. // [0., 0., 0., -inf],
  280. // [0., 0., 0., 0.]])
  281. for (int i = 0; i < seq_len; ++i) {
  282. char* row = data + i * mask->nb[1];
  283. for (int j = 0; j <= i; ++j) {*(float*)(row + j * mask->nb[0]) = 0;}
  284. for (int j = i + 1; j < seq_len; ++j) {*(float*)(row + j * mask->nb[0]) = -INFINITY;}
  285. }
  286. causal_mask_cache = mask;
  287. }
  288. return ggml_view_2d(ctx, mask, seq_len, seq_len, mask->nb[1], 0);
  289. }
  290. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  291. fairseq2_model& model,
  292. const std::string& prefix,
  293. ggml_tensor* seqs,
  294. ggml_tensor* padding_mask,
  295. ggml_tensor* encoder_output,
  296. ggml_tensor* encoder_padding_mask
  297. ) {
  298. int layer_idx = 0;
  299. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  300. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  301. while (has_layer(model, layer_name)) {
  302. seqs = StandardTransformerDecoderLayer_forward(
  303. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  304. );
  305. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  306. layer_idx += 1;
  307. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  308. }
  309. if (has_layer(model, prefix + ".layer_norm"))
  310. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  311. return seqs;
  312. }