fairseq2.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. #include <unordered_map>
  5. #include <algorithm>
  6. /// allocate the fairseq2 model and hyperparameters
  7. extern "C" fairseq2_model* fairseq2_model_alloc() {
  8. // pre-allocate some memory to write hyperparameters and tensors pointers
  9. auto* model = new fairseq2_model;
  10. model->hparams = new std::uint8_t[8 * 1024];
  11. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  12. model->tensors_ctx = nullptr;
  13. return model;
  14. };
  15. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  16. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  17. delete (std::uint64_t*)(model->arch);
  18. delete (std::uint8_t*)model->hparams;
  19. delete model;
  20. };
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  22. model->ctx = ctx;
  23. }
  24. extern "C" std::string* std_string_alloc(char* c_str) {
  25. return new std::string(c_str);
  26. }
  27. extern "C" void std_string_free(std::string* str) {
  28. delete str;
  29. }
  30. bool has_layer(fairseq2_model& model, const std::string& name) {
  31. return model.tensors.find(name) != model.tensors.end();
  32. }
  33. extern "C" ggml_tensor* Linear_forward(
  34. fairseq2_model& model,
  35. const std::string &prefix,
  36. ggml_tensor* input // (d_in)
  37. ) {
  38. // Note: for now we assumed un-batched input
  39. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  40. GGML_ASSERT(weight != nullptr);
  41. ggml_tensor* out = ggml_mul_mat(model.ctx, weight, input); // (d_out)
  42. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  43. if (bias == nullptr) return out;
  44. return ggml_add_inplace(model.ctx, out, bias);
  45. }
  46. extern "C" ggml_tensor* LayerNorm_forward(
  47. fairseq2_model& model,
  48. const std::string &prefix,
  49. ggml_tensor* input
  50. ) {
  51. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  52. GGML_ASSERT(weight != nullptr);
  53. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  54. GGML_ASSERT(bias != nullptr);
  55. auto ctx = model.ctx;
  56. // TODO: should `eps` be part of unity hparams ?
  57. input = ggml_norm(ctx, input, /*eps*/1e-5);
  58. return ggml_add_inplace(
  59. ctx,
  60. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  61. ggml_repeat(ctx, bias, input)
  62. );
  63. }
  64. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  65. fairseq2_model& model,
  66. const std::string& prefix,
  67. ggml_tensor* seqs
  68. ) {
  69. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  70. // inner_activation = ReLu // TODO: allow other activation
  71. seqs = ggml_relu_inplace(model.ctx, seqs);
  72. if (has_layer(model, prefix + ".inner_layer_norm")) {
  73. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  74. }
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  79. int n_dims = x->n_dims;
  80. GGML_ASSERT(dim >= 0);
  81. GGML_ASSERT(dim < n_dims);
  82. GGML_ASSERT(ggml_is_contiguous(x));
  83. // Nothing to do
  84. if (dim == n_dims - 1) return x;
  85. if (n_dims == 2) {
  86. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  87. } else if (n_dims == 3) {
  88. if (dim == 0) {
  89. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  90. } else { // dim == 1
  91. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  92. }
  93. } else { // n_dims == 4
  94. if (dim == 0) {
  95. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  96. } else if (dim == 1) {
  97. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  98. } else { // dim == 2
  99. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  100. }
  101. }
  102. }
  103. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  104. int n_dims = x->n_dims;
  105. GGML_ASSERT(dim >= 0);
  106. GGML_ASSERT(dim < n_dims);
  107. GGML_ASSERT(n_dims < 4);
  108. if (n_dims == 1) {
  109. return ggml_reshape_2d(ctx, x, num_el, x->ne[0] / num_el);
  110. } else if (n_dims == 2) {
  111. if (dim == 0) {
  112. return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
  113. } else { // dim == 1
  114. return ggml_reshape_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el);
  115. }
  116. } else { // (n_dims == 3)
  117. if (dim == 0) {
  118. return ggml_reshape_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2]);
  119. } else if (dim == 1) {
  120. return ggml_reshape_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2]);
  121. } else { // dim == 2
  122. return ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el);
  123. }
  124. }
  125. }
  126. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  127. // (B, S, dim) -> (B, S, H, H_dim)
  128. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  129. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  130. x = ggml_cont(ctx, x);
  131. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  132. return x;
  133. }
  134. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  135. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  136. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  137. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  138. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  139. v = ggml_cont(ctx, v);
  140. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  141. return v;
  142. }
  143. // flash_attn doesn't work for cross attention because it assumes Q <= K
  144. // TODO: enable flash_attn only for the encoder
  145. # define UNITY_FLASH_ATTN 0
  146. extern "C" ggml_tensor* MultiheadAttention_forward(
  147. fairseq2_model& model,
  148. const std::string &prefix,
  149. ggml_tensor* queries, // (slen, d_in)
  150. ggml_tensor* keys, // (klen, d_in)
  151. ggml_tensor* values, // (klen, d_out)
  152. ggml_tensor* mask // (klen, slen)
  153. ) {
  154. int model_dim = queries->ne[0];
  155. int num_heads = 16; // TODO: read from hparams
  156. int head_dim = model_dim / num_heads;
  157. GGML_ASSERT(model_dim % num_heads == 0);
  158. ggml_context* ctx = model.ctx;
  159. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  160. ggml_set_name(q, "q");
  161. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  162. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  163. ggml_set_name(k, "k");
  164. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  165. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  166. ggml_set_name(v, "v");
  167. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  168. v = ggml_cont(ctx, v);
  169. #if UNITY_FLASH_ATTN
  170. // For flash_attn, we assume either no masks, or triangular masks.
  171. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (B * H, S, H_dim)
  172. ggml_set_name(attn, "attn");
  173. // TODO test !
  174. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  175. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  176. #else
  177. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  178. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  179. ggml_set_name(qk, "qk");
  180. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  181. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  182. qk = ggml_scale(ctx, qk, qk_scale);
  183. ggml_set_name(qk, "qk_scaled");
  184. // TODO: Should we replace this by ggml_diag_mask_inf ?
  185. // TODO: masks have the wrong shape to be added here
  186. if (mask) qk = ggml_add(ctx, qk, mask);
  187. // TODO: upgrade qk to float32 if needed
  188. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  189. ggml_set_name(attn_weights, "attn_weights");
  190. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  191. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  192. ggml_set_name(attn, "attn");
  193. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  194. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  195. #endif // UNITY_FLASH_ATTN
  196. attn = ggml_cont(ctx, attn);
  197. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  198. // out -> (B, S, d_out)
  199. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  200. ggml_set_name(out, "out");
  201. return out;
  202. }
  203. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  204. fairseq2_model& model,
  205. const std::string& prefix,
  206. ggml_tensor* seqs,
  207. ggml_tensor* padding_mask
  208. ) {
  209. ggml_context* ctx = model.ctx;
  210. // TODO: read norm_order from model
  211. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  212. // _forward_self_attn(seqs, padding_mask)
  213. auto residual = seqs;
  214. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  215. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  216. // TODO: add padding_mask to MultiheadAttention_forward
  217. GGML_ASSERT(padding_mask == nullptr);
  218. seqs = MultiheadAttention_forward(
  219. model,
  220. prefix + ".self_attn",
  221. seqs,
  222. seqs,
  223. seqs,
  224. /*attention masks=*/nullptr
  225. );
  226. if (has_layer(model, prefix + ".self_attn_norm"))
  227. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  228. seqs = ggml_add(ctx, seqs, residual);
  229. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  230. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  231. // _forward_ffn(seqs)
  232. residual = seqs;
  233. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  234. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  235. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  236. // TODO: if self.residual_scale is not None:
  237. // residual = self.residual_scale * residual
  238. seqs = ggml_add(ctx, seqs, residual);
  239. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  240. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  241. return seqs;
  242. }
  243. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  244. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  245. struct ggml_tensor * ggml_slice(
  246. struct ggml_context * ctx,
  247. struct ggml_tensor * a,
  248. int axis,
  249. int64_t start,
  250. int64_t end
  251. ) {
  252. int64_t ne[4];
  253. std::copy(a->ne, a->ne + 4, ne);
  254. if (axis < 0) axis = a->n_dims + axis;
  255. if (start < 0) start = ne[axis] + start;
  256. if (end < 0) end = ne[axis] + end;
  257. GGML_ASSERT(0 <= start);
  258. GGML_ASSERT(start <= end);
  259. GGML_ASSERT(end <= ne[axis]);
  260. ne[axis] = end - start;
  261. size_t offset = a->nb[axis] * start;
  262. size_t* nb = a->nb;
  263. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  264. result->n_dims = a->n_dims;
  265. return result;
  266. }
  267. extern "C" ggml_tensor* PositionalEmbedding_forward(
  268. fairseq2_model& model,
  269. const std::string& prefix,
  270. ggml_tensor* embeds
  271. ) {
  272. // This only work with the simple pos encoders
  273. int seq_len = embeds->ne[1];
  274. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  275. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
  276. return ggml_add(model.ctx, embeds, pos_embeds);
  277. }
  278. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  279. fairseq2_model& model,
  280. const std::string& prefix,
  281. ggml_tensor* seqs
  282. // TODO: state_bag
  283. ) {
  284. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  285. ggml_context* ctx = model.ctx;
  286. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  287. GGML_ASSERT(embed_weights != nullptr);
  288. ggml_tensor* embeds;
  289. if (seqs->n_dims == 1) {
  290. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  291. } else {
  292. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  293. ggml_tensor* flat_seqs = seqs;
  294. if (!ggml_is_contiguous(seqs)) {
  295. flat_seqs->type = GGML_TYPE_F32;
  296. flat_seqs = ggml_cont(ctx, flat_seqs);
  297. }
  298. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  299. flat_seqs->type = GGML_TYPE_I32;
  300. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  301. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  302. embeds->n_dims = seqs->n_dims + 1;
  303. }
  304. // padding mask ?
  305. // padding_mask = to_padding_mask(embeds, seq_lens)
  306. if (has_layer(model, prefix + ".pos_encoder")) {
  307. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  308. }
  309. if (has_layer(model, prefix + ".layer_norm")) {
  310. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  311. }
  312. return embeds;
  313. }
  314. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  315. fairseq2_model& model,
  316. const std::string& prefix,
  317. ggml_tensor* seqs,
  318. ggml_tensor* padding_mask
  319. ) {
  320. int layer_idx = 0;
  321. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  322. while (has_layer(model, layer_name)) {
  323. seqs = StandardTransformerEncoderLayer_forward(
  324. model, layer_name, seqs, padding_mask
  325. );
  326. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  327. layer_idx += 1;
  328. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  329. }
  330. if (has_layer(model, prefix + ".layer_norm"))
  331. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  332. return seqs;
  333. }
  334. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  335. fairseq2_model& model,
  336. const std::string& prefix,
  337. ggml_tensor* seqs,
  338. ggml_tensor* self_attn_mask,
  339. ggml_tensor* encoder_output,
  340. ggml_tensor* encoder_padding_mask
  341. ) {
  342. ggml_context* ctx = model.ctx;
  343. // TODO: read norm_order from model
  344. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  345. // _forward_self_attn(seqs, padding_mask)
  346. auto residual = seqs;
  347. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  348. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  349. seqs = MultiheadAttention_forward(
  350. model,
  351. prefix + ".self_attn",
  352. seqs,
  353. seqs,
  354. seqs,
  355. /*attention masks=*/self_attn_mask
  356. );
  357. if (has_layer(model, prefix + ".self_attn_norm"))
  358. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  359. seqs = ggml_add(ctx, seqs, residual);
  360. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  361. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  362. // _forward_encoder_decoder_attn
  363. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  364. // `encoder_output` must be `None` for decoder-only attention.
  365. GGML_ASSERT(encoder_output == nullptr);
  366. return seqs;
  367. }
  368. // `encoder_output` must not be `None` for encoder-decoder attention.
  369. GGML_ASSERT(encoder_output != nullptr);
  370. residual = seqs;
  371. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  372. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  373. seqs = MultiheadAttention_forward(
  374. model,
  375. prefix + ".encoder_decoder_attn",
  376. seqs,
  377. encoder_output,
  378. encoder_output,
  379. /*attention masks=*/encoder_padding_mask
  380. );
  381. seqs = ggml_add(ctx, seqs, residual);
  382. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  383. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  384. // _forward_ffn(seqs)
  385. residual = seqs;
  386. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  387. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  388. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  389. // TODO:
  390. // if self.residual_scale is not None:
  391. // residual = self.residual_scale * residual
  392. seqs = ggml_add(ctx, seqs, residual);
  393. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  394. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  395. return seqs;
  396. }
  397. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  398. auto seq_len = seqs->ne[1];
  399. // TODO: allow other ggml_type
  400. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  401. return ggml_diag_mask_inf(ctx, mask, 0);
  402. }
  403. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  404. fairseq2_model& model,
  405. const std::string& prefix,
  406. ggml_tensor* seqs,
  407. ggml_tensor* padding_mask,
  408. ggml_tensor* encoder_output,
  409. ggml_tensor* encoder_padding_mask
  410. ) {
  411. int layer_idx = 0;
  412. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  413. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  414. while (has_layer(model, layer_name)) {
  415. seqs = StandardTransformerDecoderLayer_forward(
  416. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  417. );
  418. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  419. layer_idx += 1;
  420. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  421. }
  422. if (has_layer(model, prefix + ".layer_norm"))
  423. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  424. return seqs;
  425. }
  426. using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
  427. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  428. auto opts = job.opts;
  429. int max_seq_len = -1;
  430. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  431. max_seq_len = opts.hard_max_seq_len;
  432. } else {
  433. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  434. }
  435. if (opts.min_seq_len > max_seq_len) {
  436. printf(
  437. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  438. opts.min_seq_len,
  439. max_seq_len
  440. );
  441. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  442. }
  443. int prefix_seq_len = job.prefix_seq->ne[0];
  444. if (prefix_seq_len >= max_seq_len) {
  445. printf(
  446. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  447. prefix_seq_len,
  448. max_seq_len
  449. );
  450. GGML_ASSERT(prefix_seq_len < max_seq_len);
  451. }
  452. return max_seq_len;
  453. }
  454. void _fan_out_encoder_output(
  455. ggml_context* ctx,
  456. ggml_tensor** encoder_output_out,
  457. ggml_tensor** encoder_padding_mask_out,
  458. int beam_size
  459. ) {
  460. // (S_enc, M)
  461. ggml_tensor* encoder_output = *encoder_output_out;
  462. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  463. // (B, S_enc, M)
  464. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  465. // (S_enc, M) -> (B, S_enc, M)
  466. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  467. // (S_enc) -> (B, S_enc)
  468. if (encoder_padding_mask != nullptr) {
  469. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  470. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  471. }
  472. }
  473. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  474. // TODO: this isn't the most precise way of doing this
  475. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  476. }
  477. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  478. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  479. ggml_type true_type = x->type;
  480. x->type = GGML_TYPE_F32;
  481. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  482. y->type = true_type;
  483. return y;
  484. }
  485. void _bootstrap_seqs_and_scores(
  486. fairseq2_model& model,
  487. const SequenceGeneratorJob& job,
  488. ggml_tensor* full_seqs,
  489. ggml_tensor* scores,
  490. ggml_tensor* encoder_output,
  491. ggml_tensor* encoder_padding_mask,
  492. IncrementalStateBag state_bag
  493. ) {
  494. int prefix_seq_len = job.prefix_seq->ne[0];
  495. int max_seq_len = scores->ne[0];
  496. int beam_size = scores->ne[1];
  497. GGML_ASSERT(prefix_seq_len > 0);
  498. if (prefix_seq_len == 1)
  499. return;
  500. ggml_context* ctx = model.ctx;
  501. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  502. full_seqs->type = GGML_TYPE_F32;
  503. job.prefix_seq->type = GGML_TYPE_F32;
  504. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  505. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  506. // We have to bootstrap the model with the already fanned-out encoder
  507. // output to correctly initialize its incremental state.
  508. // Note: we don't start decoding the last prefix token just yet.
  509. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  510. seqs->type = GGML_TYPE_I32;
  511. // Bootstrap the model state with prefix sequence.
  512. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  513. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  514. model,
  515. "text_decoder",
  516. seqs,
  517. /*padding_mask*/ nullptr,
  518. encoder_output,
  519. // we assume there is only one input, and therefore we don't need padding.
  520. /*encoder_padding_mask*/ nullptr
  521. // TODO: state_bag
  522. );
  523. // TODO state_bag.increment_step(prefix_seq_len - 1)
  524. // logits, lprobs: (N, S_pfx - 1, V)
  525. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  526. int vocab_size = logits->ne[0];
  527. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  528. ggml_cgraph gf = ggml_build_forward(lprobs);
  529. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  530. full_seqs->type = GGML_TYPE_I32;
  531. job.prefix_seq->type = GGML_TYPE_I32;
  532. // Fetch scores of next steps from "lprobs"
  533. float p_score = 0;
  534. for (int i = 1; i < prefix_seq_len; ++i) {
  535. int p = ggml_get_i32_1d(job.prefix_seq, i);
  536. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  537. for (int b = 0; b < beam_size; ++b) {
  538. // scores: (N, S)
  539. // Note: First step (e.g. BOS)'s score is always 0.
  540. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  541. }
  542. }
  543. }
  544. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  545. int topk(
  546. ggml_tensor* lprobs, // (B, V)
  547. std::int64_t k,
  548. ggml_tensor* candidate_indices
  549. ) {
  550. // Take the best 2 x `beam_size` predictions. We'll choose the first
  551. // `beam_size` of these which don't predict EOS to continue with.
  552. // (N, 2 x B)
  553. // `vocab_size` - 1 to never select PAD.
  554. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  555. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  556. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  557. };
  558. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  559. auto cand = (std::int32_t*)candidate_indices->data;
  560. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  561. return K;
  562. }
  563. void ggml_detach(ggml_tensor* a) {
  564. a->op = GGML_OP_NONE;
  565. a->src[0] = nullptr;
  566. }
  567. /// Copies the sequence and scores of a given candidate beam.
  568. void _finalize_hypothesis(
  569. const SequenceGeneratorJob& job,
  570. ggml_context* ctx,
  571. int step_nr,
  572. std::int32_t beam,
  573. std::int32_t token,
  574. float eos_score,
  575. ggml_tensor* seqs, // (beam_size, seq_len)
  576. ggml_tensor* scores, // (beam_size, seq_len)
  577. std::vector<Hypothesis>& hypotheses
  578. ) {
  579. ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  580. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  581. auto tok = (std::int32_t*)tokens->data;
  582. for (int i = 0; i < step_nr + 1; ++i) {
  583. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  584. }
  585. tok[step_nr + 1] = token;
  586. // Convert from cumulative to per-step scores.
  587. auto sc = (float*)step_scores->data;
  588. float last_score = eos_score;
  589. for (int i = step_nr; i >= 0; --i) {
  590. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  591. sc[i + 1] = last_score - sc0;
  592. last_score = sc0;
  593. }
  594. sc[0] = 0;
  595. if (job.opts.normalize_scores)
  596. // Skip first EOS since it is always 0 and skews normalization.
  597. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  598. hypotheses.emplace_back(Hypothesis{tokens, eos_score, step_scores});
  599. }
  600. // Uses ggml_context to store any object.
  601. #define GGML_CTX_ALLOC(ctx, Type, n) \
  602. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  603. /// Generates a translation for a single sequence
  604. // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
  605. // TODO: clean ups
  606. // * replace manual tensor tweaking with ggml_set_*d (a ggml_set_slice could be useful)
  607. extern "C" Hypothesis* generate_sequence(
  608. fairseq2_model& model,
  609. const SequenceGeneratorJob& job,
  610. ggml_tensor* encoder_output,
  611. ggml_tensor* encoder_padding_mask,
  612. ggml_context* result_ctx
  613. ) {
  614. ggml_context* ctx = model.ctx;
  615. size_t eos_idx = job.eos_idx;
  616. auto pad_idx = job.pad_idx;
  617. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  618. size_t vocab_size = embed->ne[1];
  619. std::size_t beam_size = job.opts.beam_size;
  620. int source_seq_len = encoder_output->ne[1];
  621. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  622. // (S_enc, M) -> (B, S_enc, M)
  623. _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
  624. std::vector<Hypothesis> finished_searches;
  625. finished_searches.reserve(beam_size);
  626. // Initialize buffers. (B, S)
  627. ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  628. ggml_set_i32(seqs, 0);
  629. ggml_set_name(seqs, "seqs_0");
  630. ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  631. ggml_set_name(scores, "scores_0");
  632. ggml_set_f32(scores, 0.0);
  633. IncrementalStateBag state_bag = {};
  634. _bootstrap_seqs_and_scores(
  635. model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
  636. );
  637. int prefix_seq_len = job.prefix_seq->ne[0];
  638. int start_step = prefix_seq_len - 1;
  639. // Holds the indices of beams (a beam can occur more than once) that we
  640. // should continue with in the next step.
  641. ggml_tensor* beam_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  642. ggml_tensor* next_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  643. ggml_tensor* next_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, beam_size);
  644. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  645. ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
  646. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  647. ((int32_t *)(candidate_indices->data))[i] = i;
  648. // TODO: memory management
  649. // there should be a per-step ggml_context for intermediary results
  650. // start of beam search:
  651. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  652. // because of no IncrementalStateBag we pass input from the start
  653. // decoder_input = seqs[:, 0 : step_nr + 1]
  654. ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
  655. decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
  656. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  657. model,
  658. "text_decoder",
  659. decoder_input,
  660. nullptr, // We never generate PAD.
  661. encoder_output,
  662. /*encoder_padding_mask*/ nullptr // TODO: do we need padding for encoder ?
  663. // state_bag=state_bag,
  664. ); // (B, S, D)
  665. // state_bag.increment_step()
  666. // Because of no IncrementalStateBag decoder_output here is of shape (B, S, D)
  667. // Just look at the last token.
  668. decoder_output = ggml_slice(ctx, decoder_output, 1, step_nr, step_nr+1);
  669. decoder_output = ggml_cont(ctx, decoder_output);
  670. decoder_output = ggml_flatten_1d(ctx, decoder_output, 0); // (B, model_dim)
  671. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  672. ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
  673. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  674. // TODO: use ggml properly compute the tweaks
  675. ggml_cgraph gf = ggml_build_forward(lprobs);
  676. printf("beam search step %d. Graph.n_nodes: %d\n", step_nr, gf.n_nodes);
  677. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  678. ggml_detach(lprobs);
  679. // // Do not allow EOS before reaching the minimum sequence length.
  680. if (step_nr < job.opts.min_seq_len) {
  681. // lprobs[:, :, self.eos_idx] = -INFINITY;
  682. for (size_t i = 0; i < beam_size; ++i)
  683. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  684. }
  685. // If we have reached the maximum length, force the last step to be EOS.
  686. if (step_nr == max_seq_len - 2) {
  687. // lprobs[:, :, : self.eos_idx] = -torch.inf
  688. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  689. for (size_t b = 0; b < beam_size; ++b) {
  690. size_t t = 0;
  691. for (t = 0; t < eos_idx; ++t)
  692. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  693. for (t = eos_idx + 1; t < vocab_size; ++t)
  694. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  695. }
  696. }
  697. // Never allow PAD.
  698. for (size_t i = 0; i < beam_size; ++i)
  699. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  700. // Apply UNK penalty.
  701. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  702. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  703. auto lprobs_raw = ggml_get_data_f32(lprobs);
  704. for (size_t i = 0; i < beam_size; ++i)
  705. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  706. }
  707. ggml_tensor* last_scores = ggml_slice(ctx, scores, 0, step_nr, step_nr+1);
  708. if (step_nr == start_step) {
  709. // At the initial step, all hypotheses are equally likely, so we use
  710. // only the first beam.
  711. lprobs = ggml_slice(ctx, lprobs, 1, 0, 1);
  712. lprobs = ggml_cont(ctx, lprobs);
  713. // The first step always indicates the beginning of the sequence and has no score.
  714. if (step_nr > 0) {
  715. last_scores = ggml_slice(ctx, last_scores, 1, 0, 1);
  716. lprobs = ggml_add_inplace(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  717. }
  718. } else {
  719. // Make probabilities contain cumulative scores for each hypothesis.
  720. lprobs = ggml_add(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  721. }
  722. gf = ggml_build_forward(lprobs);
  723. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  724. // Determine (beam, token) candidates for the next step.
  725. // (N, 2 x B)
  726. std::int64_t K = topk(
  727. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  728. );
  729. std::size_t ongoing_beams = 0;
  730. for (std::int32_t i = 0; i < K; ++i) {
  731. int c = ggml_get_f32_1d(candidate_indices, i);
  732. std::int32_t beam = c / vocab_size;
  733. std::int32_t token = c % vocab_size;
  734. float tok_score = ggml_get_f32_1d(lprobs, c);
  735. // Detect beams that reached the minimum length and that end with an EOS.
  736. bool eos = token == job.eos_idx;
  737. eos &= tok_score != -INFINITY;
  738. if (eos) {
  739. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches);
  740. if (finished_searches.size() >= beam_size)
  741. goto end_of_beam_search;
  742. continue;
  743. }
  744. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  745. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  746. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  747. ongoing_beams += 1;
  748. if (ongoing_beams >= beam_size) break;
  749. }
  750. // Reorder beams in the `seq` and `score` buffers. The same beam can
  751. // be selected more than once.
  752. ggml_tensor* new_seqs = seqs;
  753. ggml_tensor* new_scores = scores;
  754. if (step_nr > start_step) {
  755. // (B, S), (B) -> (B, S)
  756. // ggml_get_rows and ggml_set only work with floats ...
  757. new_seqs->type = GGML_TYPE_F32;
  758. new_seqs = ggml_get_rows(ctx, seqs, beam_indices);
  759. new_scores = ggml_get_rows(ctx, scores, beam_indices);
  760. gf = ggml_build_forward(new_seqs);
  761. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  762. ggml_detach(new_seqs);
  763. new_seqs->type = GGML_TYPE_I32;
  764. gf = ggml_build_forward(new_scores);
  765. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  766. ggml_detach(new_scores);
  767. }
  768. // new_seqs[:, step_nr + 1] = next_tokens
  769. // new_scores[:, step_nr + 1] = next_scores
  770. for (std::size_t i = 0; i < beam_size; ++i) {
  771. ((std::int32_t*)new_seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  772. ((float*)new_scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  773. }
  774. // TODO the old seqs and score buffers could be reused for next step
  775. seqs = new_seqs;
  776. scores = new_scores;
  777. }
  778. end_of_beam_search:
  779. // Ensure that hypotheses are sorted by decreasing scores before returning.
  780. std::sort(
  781. finished_searches.begin(),
  782. finished_searches.end(),
  783. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  784. );
  785. // Copy the scores to an object in the result_ctx.
  786. GGML_ASSERT(finished_searches.size() <= beam_size);
  787. Hypothesis* result = GGML_CTX_ALLOC(result_ctx, struct Hypothesis, beam_size);
  788. std::copy(finished_searches.begin(), finished_searches.end(), result);
  789. // In case we have less searches than expected, still make sure to initialize the memory.
  790. for (std::size_t i = finished_searches.size(); i < beam_size; ++i)
  791. result[i] = Hypothesis{nullptr, -INFINITY, nullptr};
  792. return result;
  793. }
  794. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  795. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  796. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  797. ggml_set_i32_1d(result[0].seq, 0, 314);
  798. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  799. ggml_set_i32_1d(result[1].seq, 0, 421);
  800. return result;
  801. }