fairseq2.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. #include <unordered_map>
  5. #include <algorithm>
  6. /// allocate the fairseq2 model and hyperparameters
  7. extern "C" fairseq2_model* fairseq2_model_alloc() {
  8. // pre-allocate some memory to write hyperparameters and tensors pointers
  9. auto* model = new fairseq2_model;
  10. model->hparams = new std::uint8_t[8 * 1024];
  11. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  12. model->tensors_ctx = nullptr;
  13. return model;
  14. };
  15. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  16. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  17. delete (std::uint64_t*)(model->arch);
  18. delete (std::uint8_t*)model->hparams;
  19. delete model;
  20. };
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  22. model->ctx = ctx;
  23. }
  24. extern "C" std::string* std_string_alloc(char* c_str) {
  25. return new std::string(c_str);
  26. }
  27. extern "C" void std_string_free(std::string* str) {
  28. delete str;
  29. }
  30. bool has_layer(fairseq2_model& model, const std::string& name) {
  31. return model.tensors.find(name) != model.tensors.end();
  32. }
  33. extern "C" ggml_tensor* Linear_forward(
  34. fairseq2_model& model,
  35. const std::string &prefix,
  36. ggml_tensor* input // (d_in)
  37. ) {
  38. // Note: for now we assumed un-batched input
  39. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  40. GGML_ASSERT(weight != nullptr);
  41. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  42. GGML_ASSERT(bias != nullptr);
  43. return ggml_add(
  44. model.ctx,
  45. ggml_mul_mat(model.ctx, weight, input), // (d_out)
  46. bias
  47. );
  48. }
  49. extern "C" ggml_tensor* LayerNorm_forward(
  50. fairseq2_model& model,
  51. const std::string &prefix,
  52. ggml_tensor* input
  53. ) {
  54. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  55. GGML_ASSERT(weight != nullptr);
  56. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  57. GGML_ASSERT(bias != nullptr);
  58. auto ctx = model.ctx;
  59. // TODO: should `eps` be part of unity hparams ?
  60. input = ggml_norm(ctx, input, /*eps*/1e-5);
  61. return ggml_add(
  62. ctx,
  63. ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
  64. ggml_repeat(ctx, bias, input)
  65. );
  66. }
  67. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  68. fairseq2_model& model,
  69. const std::string& prefix,
  70. ggml_tensor* seqs
  71. ) {
  72. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  73. // inner_activation = ReLu // TODO: allow other activation
  74. seqs = ggml_relu(model.ctx, seqs);
  75. if (has_layer(model, prefix + ".inner_layer_norm")) {
  76. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  77. }
  78. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  79. return seqs;
  80. }
  81. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  82. int slen = x->ne[1];
  83. int model_dim = x->ne[0];
  84. // (S, dim) -> (S, H, H_dim)
  85. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  86. // (S, H, H_dim) -> (H, S, H_dim)
  87. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  88. return x;
  89. }
  90. # define UNITY_FLASH_ATTN
  91. extern "C" ggml_tensor* MultiheadAttention_forward(
  92. fairseq2_model& model,
  93. const std::string &prefix,
  94. ggml_tensor* queries, // (slen, d_in)
  95. ggml_tensor* keys, // (klen, d_in)
  96. ggml_tensor* values, // (klen, d_out)
  97. ggml_tensor* mask // (klen, slen)
  98. ) {
  99. int slen = queries->ne[1];
  100. int slenk = keys->ne[1];
  101. int num_heads = 16;
  102. int head_dim = queries->ne[0] / num_heads;
  103. ggml_context* ctx = model.ctx;
  104. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  105. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  106. ggml_set_name(q, "q");
  107. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  108. k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
  109. ggml_set_name(k, "k");
  110. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  111. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
  112. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
  113. v = ggml_cont(ctx, v);
  114. ggml_set_name(v, "v");
  115. #ifdef UNITY_FLASH_ATTN
  116. // For flash_attn, we assume either no masks, or triangular masks.
  117. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  118. ggml_set_name(attn, "attn");
  119. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  120. attn = ggml_cont(ctx, attn);
  121. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
  122. #else
  123. // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
  124. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  125. ggml_set_name(qk, "qk");
  126. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  127. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  128. qk = ggml_scale(ctx, qk, qk_scale);
  129. ggml_set_name(qk, "qk_scaled");
  130. if (mask) qk = ggml_add(ctx, qk, mask);
  131. // TODO: upgrade qk to float32 if needed
  132. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
  133. ggml_set_name(attn_weights, "attn_weights");
  134. // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
  135. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  136. ggml_set_name(attn, "attn");
  137. attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
  138. attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
  139. // // I'm not sure why this one is needed ...
  140. attn = ggml_cont(ctx, attn);
  141. #endif // UNITY_FLASH_ATTN
  142. // out -> (S, d_out)
  143. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  144. ggml_set_name(out, "out");
  145. return out;
  146. }
  147. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  148. fairseq2_model& model,
  149. const std::string& prefix,
  150. ggml_tensor* seqs,
  151. ggml_tensor* padding_mask
  152. ) {
  153. ggml_context* ctx = model.ctx;
  154. // TODO: read norm_order from model
  155. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  156. // _forward_self_attn(seqs, padding_mask)
  157. auto residual = seqs;
  158. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  159. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  160. // TODO: add padding_mask to MultiheadAttention_forward
  161. GGML_ASSERT(padding_mask == nullptr);
  162. seqs = MultiheadAttention_forward(
  163. model,
  164. prefix + ".self_attn",
  165. seqs,
  166. seqs,
  167. seqs,
  168. /*attention masks=*/nullptr
  169. );
  170. if (has_layer(model, prefix + ".self_attn_norm"))
  171. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  172. seqs = ggml_add(ctx, seqs, residual);
  173. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  174. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  175. // _forward_ffn(seqs)
  176. residual = seqs;
  177. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  178. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  179. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  180. // TODO: if self.residual_scale is not None:
  181. // residual = self.residual_scale * residual
  182. seqs = ggml_add(ctx, seqs, residual);
  183. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  184. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  185. return seqs;
  186. }
  187. struct ggml_tensor * ggml_slice(
  188. struct ggml_context * ctx,
  189. struct ggml_tensor * a,
  190. int axis,
  191. int64_t start,
  192. int64_t end
  193. ) {
  194. int64_t ne[4];
  195. std::copy(a->ne, a->ne + 4, ne);
  196. if (start < 0) start = ne[axis] + start;
  197. if (end < 0) end = ne[axis] + end;
  198. GGML_ASSERT(0 <= start);
  199. GGML_ASSERT(start <= end);
  200. GGML_ASSERT(end <= ne[axis]);
  201. ne[axis] = end - start;
  202. size_t offset = a->nb[axis] * start;
  203. size_t* nb = a->nb;
  204. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  205. result->n_dims = a->n_dims;
  206. return result;
  207. }
  208. extern "C" ggml_tensor* PositionalEmbedding_forward(
  209. fairseq2_model& model,
  210. const std::string& prefix,
  211. ggml_tensor* embeds
  212. ) {
  213. int encoding_dim = embeds->ne[0];
  214. int seq_len = embeds->ne[1];
  215. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  216. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
  217. return ggml_add(model.ctx, embeds, pos_embeds);
  218. }
  219. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  220. fairseq2_model& model,
  221. const std::string& prefix,
  222. ggml_tensor* seqs
  223. // TODO: state_bag
  224. ) {
  225. ggml_context* ctx = model.ctx;
  226. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  227. GGML_ASSERT(embed_weights != nullptr);
  228. ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
  229. // padding_mask = to_padding_mask(embeds, seq_lens)
  230. // TODO: scale when saving the model weights
  231. // embeds = ggml_scale embeds * self.scale
  232. if (has_layer(model, prefix + ".pos_encoder")) {
  233. // This only work with the simple pos encoders
  234. int encoding_dim = embeds->ne[0];
  235. int seq_len = embeds->ne[1];
  236. ggml_tensor* pos_embeds = ggml_view_2d(ctx, model.tensors[prefix + ".pos_encoder"], encoding_dim, seq_len, 0, 0);
  237. embeds = ggml_add(ctx, embeds, pos_embeds);
  238. }
  239. if (has_layer(model, prefix + ".layer_norm")) {
  240. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  241. }
  242. // padding mask ?
  243. return embeds;
  244. }
  245. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  246. fairseq2_model& model,
  247. const std::string& prefix,
  248. ggml_tensor* seqs,
  249. ggml_tensor* padding_mask
  250. ) {
  251. int layer_idx = 0;
  252. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  253. while (has_layer(model, layer_name)) {
  254. seqs = StandardTransformerEncoderLayer_forward(
  255. model, layer_name, seqs, padding_mask
  256. );
  257. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  258. layer_idx += 1;
  259. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  260. }
  261. if (has_layer(model, prefix + ".layer_norm"))
  262. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  263. return seqs;
  264. }
  265. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  266. fairseq2_model& model,
  267. const std::string& prefix,
  268. ggml_tensor* seqs,
  269. ggml_tensor* self_attn_mask,
  270. ggml_tensor* encoder_output,
  271. ggml_tensor* encoder_padding_mask
  272. ) {
  273. ggml_context* ctx = model.ctx;
  274. // TODO: read norm_order from model
  275. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  276. // _forward_self_attn(seqs, padding_mask)
  277. auto residual = seqs;
  278. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  279. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  280. seqs = MultiheadAttention_forward(
  281. model,
  282. prefix + ".self_attn",
  283. seqs,
  284. seqs,
  285. seqs,
  286. /*attention masks=*/self_attn_mask
  287. );
  288. if (has_layer(model, prefix + ".self_attn_norm"))
  289. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  290. seqs = ggml_add(ctx, seqs, residual);
  291. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  292. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  293. // _forward_encoder_decoder_attn
  294. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  295. // `encoder_output` must be `None` for decoder-only attention.
  296. GGML_ASSERT(encoder_output == nullptr);
  297. return seqs;
  298. }
  299. // `encoder_output` must not be `None` for encoder-decoder attention.
  300. GGML_ASSERT(encoder_output != nullptr);
  301. residual = seqs;
  302. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  303. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  304. seqs = MultiheadAttention_forward(
  305. model,
  306. prefix + ".encoder_decoder_attn",
  307. seqs,
  308. encoder_output,
  309. encoder_output,
  310. /*attention masks=*/encoder_padding_mask
  311. );
  312. seqs = ggml_add(ctx, seqs, residual);
  313. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  314. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  315. // _forward_ffn(seqs)
  316. residual = seqs;
  317. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  318. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  319. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  320. // TODO:
  321. // if self.residual_scale is not None:
  322. // residual = self.residual_scale * residual
  323. seqs = ggml_add(ctx, seqs, residual);
  324. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  325. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  326. return seqs;
  327. }
  328. ggml_tensor* causal_mask_cache = nullptr;
  329. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  330. auto seq_len = seqs->ne[0];
  331. auto mask = causal_mask_cache;
  332. // TODO: this cache only works as long as we don't change the size/device too often
  333. // TODO: allow other ggml_type
  334. if (mask == nullptr || mask->backend != seqs->backend || mask->ne[0] < seq_len) {
  335. printf("new causal_mask (%ld, %ld) created\n", seq_len, seq_len);
  336. mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  337. char* data = (char*)mask->data;
  338. // tensor([[0., -inf, -inf, -inf],
  339. // [0., 0., -inf, -inf],
  340. // [0., 0., 0., -inf],
  341. // [0., 0., 0., 0.]])
  342. for (int i = 0; i < seq_len; ++i) {
  343. char* row = data + i * mask->nb[1];
  344. for (int j = 0; j <= i; ++j) {*(float*)(row + j * mask->nb[0]) = 0;}
  345. for (int j = i + 1; j < seq_len; ++j) {*(float*)(row + j * mask->nb[0]) = -INFINITY;}
  346. }
  347. causal_mask_cache = mask;
  348. }
  349. return ggml_view_2d(ctx, mask, seq_len, seq_len, mask->nb[1], 0);
  350. }
  351. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  352. fairseq2_model& model,
  353. const std::string& prefix,
  354. ggml_tensor* seqs,
  355. ggml_tensor* padding_mask,
  356. ggml_tensor* encoder_output,
  357. ggml_tensor* encoder_padding_mask
  358. ) {
  359. int layer_idx = 0;
  360. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  361. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  362. while (has_layer(model, layer_name)) {
  363. seqs = StandardTransformerDecoderLayer_forward(
  364. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  365. );
  366. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  367. layer_idx += 1;
  368. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  369. }
  370. if (has_layer(model, prefix + ".layer_norm"))
  371. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  372. return seqs;
  373. }
  374. using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
  375. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  376. auto opts = job.opts;
  377. int max_seq_len = -1;
  378. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  379. max_seq_len = opts.hard_max_seq_len;
  380. } else {
  381. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len + opts.soft_max_seq_len_b));
  382. }
  383. if (opts.min_seq_len > max_seq_len) {
  384. printf(
  385. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  386. opts.min_seq_len,
  387. max_seq_len
  388. );
  389. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  390. }
  391. int prefix_seq_len = job.prefix_seq->ne[0];
  392. if (prefix_seq_len >= max_seq_len) {
  393. printf(
  394. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  395. prefix_seq_len,
  396. max_seq_len
  397. );
  398. GGML_ASSERT(prefix_seq_len < max_seq_len);
  399. }
  400. return max_seq_len;
  401. }
  402. void _fan_out_encoder_output(
  403. ggml_context* ctx,
  404. ggml_tensor** encoder_output_out,
  405. ggml_tensor** encoder_padding_mask_out,
  406. int beam_size
  407. ) {
  408. // (S_enc, M)
  409. ggml_tensor* encoder_output = *encoder_output_out;
  410. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  411. // (B, S_enc, M)
  412. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  413. // (S_enc, M) -> (B, S_enc, M)
  414. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  415. // (S_enc) -> (B, S_enc)
  416. ggml_tensor* shape_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], beam_size);
  417. if (encoder_padding_mask != nullptr) {
  418. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  419. }
  420. }
  421. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  422. // TODO: this isn't the smartest way of doing this
  423. return ggml_log(ctx, ggml_soft_max(ctx, logits));
  424. }
  425. void _bootstrap_seqs_and_scores(
  426. fairseq2_model& model,
  427. const SequenceGeneratorJob& job,
  428. ggml_tensor* seqs,
  429. ggml_tensor* scores,
  430. ggml_tensor* encoder_output,
  431. ggml_tensor* encoder_padding_mask,
  432. IncrementalStateBag state_bag
  433. ) {
  434. int prefix_seq_len = job.prefix_seq->ne[0];
  435. int max_seq_len = scores->ne[0];
  436. int beam_size = scores->ne[1];
  437. GGML_ASSERT(prefix_seq_len > 0);
  438. if (prefix_seq_len == 1)
  439. return;
  440. ggml_context* ctx = model.ctx;
  441. // seqs[:, : prefix_seq_len] = job.prefix_seq;
  442. ggml_cpy(ctx, job.prefix_seq, ggml_view_2d(ctx, seqs, 0, prefix_seq_len, seqs->nb[1], 0));
  443. // We have to bootstrap the model with the already fanned-out encoder
  444. // output to correctly initialize its incremental state. This causes some
  445. // redundancy as we have to expand `decoder_input` to match the shape of
  446. // `encoder_output`.
  447. // (S_pfx) -> (N x B, S_pfx - 1)
  448. // prefix_seq[:-1].expand(encoder_output.size(0), -1)
  449. ggml_tensor* decoder_input = ggml_repeat(ctx, ggml_view_1d(ctx, job.prefix_seq, prefix_seq_len - 1, 0), encoder_output);
  450. // Bootstrap the model state with prefix sequence.
  451. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  452. model,
  453. "text_decoder",
  454. seqs,
  455. /*padding_mask*/ nullptr,
  456. encoder_output,
  457. encoder_padding_mask
  458. // TODO: state_bag
  459. );
  460. // TODO state_bag.increment_step(prefix_seq_len - 1)
  461. // logits, lprobs: (N, S_pfx - 1, V)
  462. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  463. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_view_3d(ctx, logits, logits->ne[0], logits->ne[1], 1, 0, 0, 0));
  464. int vocab_size = logits->ne[0];
  465. ggml_cgraph gf = ggml_build_forward(lprobs);
  466. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  467. // Fetch scores of next steps from "lprobs"
  468. float p_score = 0;
  469. for (int i = 0; i < prefix_seq_len; ++i) {
  470. int p = ggml_get_i32_1d(job.prefix_seq, i);
  471. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  472. for (int b = 0; b < beam_size; ++b) {
  473. // scores: (N, S)
  474. // Note: First step (e.g. BOS)'s score is always 0.
  475. ggml_set_f32_1d(scores, b * max_seq_len + i + 1, p_score);
  476. }
  477. }
  478. }
  479. /// Represents a hypothesis produced by a sequence generator.
  480. struct Hypothesis {
  481. /// The generated sequence.
  482. ggml_tensor* seq;
  483. /// The score of the hypothesis.
  484. float score;
  485. /// The score of each individual sequence step.
  486. ggml_tensor* step_scores;
  487. };
  488. /// Represents a standard beam search algoritm.
  489. int StandardBeamSearch_step(
  490. ggml_context* ctx,
  491. int step_nr,
  492. bool is_start_step,
  493. ggml_tensor* lprobs, // (N, S, V)
  494. ggml_tensor* scores, // (N, S)
  495. ggml_tensor* candidate_indices
  496. ) {
  497. int vocab_size = lprobs->ne[0];
  498. int sent_len = lprobs->ne[1];
  499. int beam_size = lprobs->ne[2];
  500. GGML_ASSERT(scores->ne[0] == sent_len);
  501. GGML_ASSERT(scores->ne[1] == beam_size);
  502. // should this be done by the caller ?
  503. ggml_tensor* last_scores = ggml_view_2d(ctx, scores, beam_size, 1, 0, step_nr);
  504. if (is_start_step) {
  505. // At the initial step, all hypotheses are equally likely, so we use
  506. // only the first beam.
  507. lprobs = ggml_view_3d(ctx, lprobs, vocab_size, sent_len, 1, 0, 0, 0);
  508. lprobs = ggml_cont(ctx, lprobs);
  509. // The first step always indicates the beginning of the sequence and
  510. // has no score.
  511. if (step_nr > 0) {
  512. lprobs = ggml_add(ctx, lprobs, last_scores);
  513. }
  514. } else {
  515. // Make probabilities contain cumulative scores for each hypothesis.
  516. lprobs = ggml_add(ctx, lprobs, last_scores);
  517. }
  518. ggml_cgraph gf = ggml_build_forward(lprobs);
  519. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  520. // Take the best 2 x `beam_size` predictions. We'll choose the first
  521. // `beam_size` of these which don't predict EOS to continue with.
  522. // (N, 2 x B)
  523. // `vocab_size` - 1 to never select PAD.
  524. int topk = std::min(2 * beam_size, vocab_size - 1);
  525. auto comp = [scores](std::int32_t a, std::int32_t b) {
  526. return ggml_get_f32_1d(scores, a) < ggml_get_f32_1d(scores, b);
  527. };
  528. auto cand = (std::int32_t*)candidate_indices->data;
  529. std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
  530. return topk;
  531. }
  532. bool _finalize_hypothesis(
  533. const SequenceGeneratorJob& job,
  534. ggml_context* ctx,
  535. int step_nr,
  536. std::int32_t candidate,
  537. ggml_tensor* seqs, // (beam_size, seq_len)
  538. ggml_tensor* scores, // (beam_size, seq_len)
  539. std::vector<Hypothesis>& hypotheses
  540. ) {
  541. int vocab_size = scores->ne[0];
  542. std::int32_t beam = candidate / vocab_size;
  543. std::int32_t token = candidate % vocab_size;
  544. float tok_score = ggml_get_f32_1d(scores, candidate);
  545. // Detect beams that reached the minimum length and that end with an EOS.
  546. bool eos = token == job.eos_idx;
  547. eos &= tok_score != -INFINITY;
  548. // TODO ignored_beam_mask ?
  549. // eos &= ggml_get_i32_1d(ignored_beam_mask, beam);
  550. // ggml_set_i32_1d(eos_mask, beam, eos);
  551. if (!eos) return false;
  552. // If the candidate beam is "finished", let's copy the score and sequence
  553. ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  554. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  555. auto tok = (std::int32_t*)tokens->data;
  556. auto sc = (float*)step_scores->data;
  557. ggml_set_f32_1d(scores, scores->ne[0] * beam + step_nr + 1, tok_score);
  558. for (int i = 0; i < step_nr + 1; ++i) {
  559. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  560. }
  561. tok[step_nr + 1] = token;
  562. float last_score = tok_score;
  563. for (int i = step_nr; i >= 0; --i) {
  564. // Convert from cumulative to per-step scores.
  565. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i + 0);
  566. sc[i] = last_score - sc0;
  567. last_score = sc0;
  568. }
  569. // Skip first EOS since it is always 0 and skews normalization.
  570. if (job.opts.normalize_scores)
  571. tok_score /= std::pow((step_nr + 1), job.opts.len_penalty);
  572. hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
  573. return true;
  574. }
  575. /// Generates a translation for a single sequence
  576. // TODO: finish this for beam_size=1
  577. // * implement the lprobs tweaking
  578. // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
  579. // TODO: support beam_size > 1:
  580. // * most layers assume un-batched input, but we want to handle several beams at once
  581. // * need to port "reorder_state_dict"
  582. // * once beam are selected with topk, we need to update seqs and scores tensors
  583. extern "C" float generate_sequence(
  584. fairseq2_model& model,
  585. const SequenceGeneratorJob& job,
  586. ggml_tensor* encoder_output,
  587. ggml_tensor* encoder_padding_mask,
  588. ggml_tensor* output_seq
  589. ) {
  590. int vocab_size = encoder_output->ne[0];
  591. int beam_size = job.opts.beam_size;
  592. int source_seq_len = encoder_output->ne[1];
  593. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  594. ggml_context* ctx = model.ctx;
  595. // (S_enc, M) -> (B, S_enc, M)
  596. _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
  597. std::vector<Hypothesis> finished_searches(beam_size);
  598. // Initialize buffers. (B, S)
  599. ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  600. ggml_set_i32(seqs, 0);
  601. ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  602. ggml_set_f32(scores, 0.0);
  603. IncrementalStateBag state_bag = {};
  604. _bootstrap_seqs_and_scores(
  605. model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
  606. );
  607. int prefix_seq_len = job.prefix_seq->ne[0];
  608. int start_step = prefix_seq_len - 1;
  609. // Holds the indices of beams (a beam can occur more than once) that we
  610. // should continue with in the next step.
  611. ggml_tensor* beam_indices = nullptr;
  612. // Indices of next token
  613. ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
  614. for (int i = 0; i < vocab_size * beam_size; ++i) ggml_set_i32_1d(candidate_indices, i, i);
  615. // Holds the indices of searches that we should continue with in the next
  616. // step. If not `None`, it means we finalized one or more searches in the
  617. // last step.
  618. ggml_tensor* search_indices = nullptr;
  619. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  620. // if (beam_indices != nullptr) {
  621. // // If not `None`, it means in the last step we finalized one or
  622. // // more searches. We should ensure that we adjust `beam_indices`
  623. // // before reordering `decoder`'s incremental state.
  624. // if (search_indices != nullptr) {
  625. // num_searches = search_indices->ne[0];
  626. // // (N)
  627. // delta = search_indices - torch.arange(num_searches, device=device)
  628. // // (N) -> (N, 1)
  629. // delta.unsqueeze_(-1)
  630. // // Adjust indices to take into account removed searches.
  631. // beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
  632. // }
  633. // // state_bag.reorder(beam_indices)
  634. // }
  635. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  636. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  637. model,
  638. "text_decoder",
  639. // seqs[:, step_nr : step_nr + 1]
  640. ggml_view_2d(ctx, seqs, 1, beam_size, step_nr * seqs->nb[0], 0),
  641. nullptr, // We never generate PAD.
  642. encoder_output,
  643. encoder_padding_mask
  644. // state_bag=state_bag,
  645. );
  646. // state_bag.increment_step()
  647. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  648. ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
  649. // // Do not allow EOS before reaching the minimum sequence length.
  650. // if step_nr < self.opts.min_seq_len:
  651. // lprobs[:, :, self.eos_idx] = -torch.inf
  652. // // If we have reached the maximum length, force the last step to be
  653. // // EOS.
  654. // if step_nr == max_seq_len - 2:
  655. // lprobs[:, :, : self.eos_idx] = -torch.inf
  656. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  657. // // Never allow PAD.
  658. // lprobs[:, :, self.pad_idx] = -torch.inf
  659. // // Apply UNK penalty.
  660. // if self.unk_idx is not None:
  661. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  662. // Determine candidates for the next step.
  663. // (N, 2 x B)
  664. int topk = StandardBeamSearch_step(
  665. ctx,
  666. step_nr,
  667. step_nr == start_step,
  668. lprobs,
  669. // TODO only pass scores for new tokens
  670. ggml_view_2d(ctx, scores, step_nr + 1, beam_size, 0, 0),
  671. candidate_indices
  672. );
  673. int ongoing_beams = 0;
  674. for (std::int32_t c = 0; c < topk; ++c) {
  675. bool finished = _finalize_hypothesis(job, ctx, step_nr, c, seqs, scores, finished_searches);
  676. if (!finished) ongoing_beams += 1;
  677. if (ongoing_beams >= beam_size) break;
  678. }
  679. if (finished_searches.size() == beam_size) break;
  680. // TODO: recreate scores and seqs with the best beams
  681. // Remove finished searches (ones for which `beam_size` finalized
  682. // beams have been generated) from the batch.
  683. ggml_tensor* search_indices = nullptr;
  684. // if (newly_finished_searches) {
  685. // new_num_searches = num_searches - len(newly_finished_searches)
  686. // // Construct `search_indices` which holds indices of searches
  687. // // to keep for the next step.
  688. // search_mask = torch.full((num_searches,), True, device=device)
  689. // search_mask[newly_finished_searches] = False
  690. // search_indices = torch.arange(num_searches, device=device)
  691. // search_indices = search_indices.masked_select(search_mask)
  692. // // Filter out removed batches from state variables.
  693. // // (N, B) -> (N - F, B)
  694. // ignored_beam_mask = ignored_beam_mask[search_indices]
  695. // // (N, 2 x B) -> (N - F, 2 x B)
  696. // cand_scores = cand_scores [search_indices]
  697. // cand_indices = cand_indices [search_indices]
  698. // cand_beam_indices = cand_beam_indices[search_indices]
  699. // // (N) -> (N - F)
  700. // search_offsets.resize_(new_num_searches, 1)
  701. // // (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
  702. // global_cand_beam_indices = cand_beam_indices + search_offsets
  703. // // (N, 2 x B) -> (N - F, 2 x B)
  704. // eos_mask = eos_mask[search_indices]
  705. // // (N x B, S) -> (N, B, S)
  706. // seqs = seqs .view(num_searches, -1)
  707. // scores = scores.view(num_searches, -1)
  708. // // (N, B, S + 1) -> ((N - F) x B, S)
  709. // seqs = seqs [search_indices].view(new_num_searches * beam_size, -1)
  710. // scores = scores[search_indices].view(new_num_searches * beam_size, -1)
  711. // // (N x B, S_enc, M) -> (N, B, S_enc, M)
  712. // encoder_output = encoder_output.unflatten(0, (num_searches, -1))
  713. // // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
  714. // encoder_output = encoder_output[search_indices].flatten(0, 1)
  715. // if encoder_padding_mask is not None:
  716. // // (N x B, S_enc, M) -> (N, B, S_enc, M)
  717. // padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
  718. // // (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
  719. // encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
  720. // num_searches = new_num_searches
  721. // }
  722. // eos_mask[:, :beam_size][ignored_beam_mask] = True
  723. // // Set `beam_weights` so that values greater than or equal to 2 x
  724. // // `beam_size` indicate finished beams (i.e. end with EOS) and values
  725. // // less than 2 x `beam_size` indicate active beams.
  726. // // (N, 2 x B)
  727. // beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
  728. // // Get the top `beam_size` active beams, which are the beams with the
  729. // // smallest weights in `active_beam_weights`.
  730. // // (N, B)
  731. // active_beam_weights, active_beams = torch.topk(
  732. // beam_weights, k=beam_size, dim=1, largest=False
  733. // )
  734. // // Update to ignore finalized beams in the next step.
  735. // // (N, B)
  736. // ignored_beam_mask = active_beam_weights >= 2 * beam_size
  737. // // We should always have at least one active beam in each search.
  738. // assert (~ignored_beam_mask).any(dim=1).all()
  739. // // Denotes which beams are continued for each new hypothesis (a beam
  740. // // can be selected more than once).
  741. // // (N, B)
  742. // beam_indices = torch.gather(
  743. // global_cand_beam_indices, dim=1, index=active_beams
  744. // )
  745. // // (N, B) -> (N x B)
  746. // beam_indices = beam_indices.view(-1)
  747. // // Reorder beams in the `seq` and `score` buffers. The same beam can
  748. // // be selected more than once.
  749. // if (step_nr > start_step) {
  750. // // seqs [:, : step_nr + 1] = torch.index_select(
  751. // // seqs [:, : step_nr + 1], dim=0, index=beam_indices
  752. // // )
  753. // // scores[:, : step_nr + 1] = torch.index_select(
  754. // // scores[:, : step_nr + 1], dim=0, index=beam_indices
  755. // // )
  756. // }
  757. // // (N x B, S) -> (N, B, S)
  758. // // seqs_view = seqs .view(num_searches, beam_size, -1)
  759. // // scores_view = scores.view(num_searches, beam_size, -1)
  760. // // seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
  761. // // scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams)
  762. }
  763. // Ensure that hypotheses are sorted by their scores before returning.
  764. // for batch in finished_searches:
  765. // batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value]
  766. // return SequenceGeneratorOutput(
  767. // results=finished_searches, device=device, pad_idx=self.pad_idx
  768. // )
  769. return 0.0f;
  770. }