fairseq2.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. #include <unordered_map>
  5. #include <algorithm>
  6. /// allocate the fairseq2 model and hyperparameters
  7. extern "C" fairseq2_model* fairseq2_model_alloc() {
  8. // pre-allocate some memory to write hyperparameters and tensors pointers
  9. auto* model = new fairseq2_model;
  10. model->hparams = new std::uint8_t[8 * 1024];
  11. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  12. model->tensors_ctx = nullptr;
  13. return model;
  14. };
  15. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  16. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  17. delete (std::uint64_t*)(model->arch);
  18. delete (std::uint8_t*)model->hparams;
  19. delete model;
  20. };
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  22. model->ctx = ctx;
  23. }
  24. extern "C" std::string* std_string_alloc(char* c_str) {
  25. return new std::string(c_str);
  26. }
  27. extern "C" void std_string_free(std::string* str) {
  28. delete str;
  29. }
  30. bool has_layer(fairseq2_model& model, const std::string& name) {
  31. return model.tensors.find(name) != model.tensors.end();
  32. }
  33. extern "C" ggml_tensor* Linear_forward(
  34. fairseq2_model& model,
  35. const std::string &prefix,
  36. ggml_tensor* input // (d_in)
  37. ) {
  38. // Note: for now we assumed un-batched input
  39. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  40. GGML_ASSERT(weight != nullptr);
  41. ggml_tensor* out = ggml_mul_mat(model.ctx, weight, input); // (d_out)
  42. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  43. if (bias == nullptr) return out;
  44. return ggml_add_inplace(model.ctx, out, bias);
  45. }
  46. extern "C" ggml_tensor* LayerNorm_forward(
  47. fairseq2_model& model,
  48. const std::string &prefix,
  49. ggml_tensor* input
  50. ) {
  51. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  52. GGML_ASSERT(weight != nullptr);
  53. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  54. GGML_ASSERT(bias != nullptr);
  55. auto ctx = model.ctx;
  56. // TODO: should `eps` be part of unity hparams ?
  57. input = ggml_norm(ctx, input, /*eps*/1e-5);
  58. return ggml_add_inplace(
  59. ctx,
  60. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  61. ggml_repeat(ctx, bias, input)
  62. );
  63. }
  64. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  65. fairseq2_model& model,
  66. const std::string& prefix,
  67. ggml_tensor* seqs
  68. ) {
  69. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  70. // inner_activation = ReLu // TODO: allow other activation
  71. seqs = ggml_relu_inplace(model.ctx, seqs);
  72. if (has_layer(model, prefix + ".inner_layer_norm")) {
  73. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  74. }
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
  79. int slen = x->ne[1];
  80. int model_dim = x->ne[0];
  81. // (S, dim) -> (S, H, H_dim)
  82. x = ggml_reshape_3d(ctx, x, model_dim / num_heads, num_heads, slen);
  83. // (S, H, H_dim) -> (H, S, H_dim)
  84. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  85. return x;
  86. }
  87. // TODO: flash_attn doesn't seem to work for cross attention because it assumes Q <= K
  88. # define UNITY_FLASH_ATTN 0
  89. extern "C" ggml_tensor* MultiheadAttention_forward(
  90. fairseq2_model& model,
  91. const std::string &prefix,
  92. ggml_tensor* queries, // (slen, d_in)
  93. ggml_tensor* keys, // (klen, d_in)
  94. ggml_tensor* values, // (klen, d_out)
  95. ggml_tensor* mask // (klen, slen)
  96. ) {
  97. int slen = queries->ne[1];
  98. int slenk = keys->ne[1];
  99. int num_heads = 16;
  100. int head_dim = queries->ne[0] / num_heads;
  101. ggml_context* ctx = model.ctx;
  102. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  103. q = reshape_num_head(ctx, q, num_heads); // (H, S, H_dim)
  104. ggml_set_name(q, "q");
  105. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  106. k = reshape_num_head(ctx, k, num_heads); // (H, Sk, H_dim)
  107. ggml_set_name(k, "k");
  108. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  109. v = ggml_reshape_3d(ctx, v, head_dim, num_heads, slenk); // (Sk, H, H_dim)
  110. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (H, H_dim, Sk)
  111. v = ggml_cont(ctx, v);
  112. ggml_set_name(v, "v");
  113. #if UNITY_FLASH_ATTN
  114. // For flash_attn, we assume either no masks, or triangular masks.
  115. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  116. ggml_set_name(attn, "attn");
  117. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
  118. attn = ggml_cont(ctx, attn);
  119. attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
  120. #else
  121. // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
  122. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  123. ggml_set_name(qk, "qk");
  124. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  125. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  126. qk = ggml_scale(ctx, qk, qk_scale);
  127. ggml_set_name(qk, "qk_scaled");
  128. if (mask) qk = ggml_add(ctx, qk, mask);
  129. // TODO: upgrade qk to float32 if needed
  130. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (H, Sk, S)
  131. ggml_set_name(attn_weights, "attn_weights");
  132. // (H, S, Sk) x (H, H_dim, Sk) -> (H, H_dim, S)
  133. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  134. ggml_set_name(attn, "attn");
  135. attn = ggml_reshape_2d(ctx, attn, slen, num_heads * head_dim); // (H * H_dim, S)
  136. attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
  137. // // I'm not sure why this one is needed ...
  138. attn = ggml_cont(ctx, attn);
  139. #endif // UNITY_FLASH_ATTN
  140. // out -> (S, d_out)
  141. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  142. ggml_set_name(out, "out");
  143. return out;
  144. }
  145. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  146. fairseq2_model& model,
  147. const std::string& prefix,
  148. ggml_tensor* seqs,
  149. ggml_tensor* padding_mask
  150. ) {
  151. ggml_context* ctx = model.ctx;
  152. // TODO: read norm_order from model
  153. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  154. // _forward_self_attn(seqs, padding_mask)
  155. auto residual = seqs;
  156. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  157. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  158. // TODO: add padding_mask to MultiheadAttention_forward
  159. GGML_ASSERT(padding_mask == nullptr);
  160. seqs = MultiheadAttention_forward(
  161. model,
  162. prefix + ".self_attn",
  163. seqs,
  164. seqs,
  165. seqs,
  166. /*attention masks=*/nullptr
  167. );
  168. if (has_layer(model, prefix + ".self_attn_norm"))
  169. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  170. seqs = ggml_add(ctx, seqs, residual);
  171. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  172. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  173. // _forward_ffn(seqs)
  174. residual = seqs;
  175. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  176. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  177. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  178. // TODO: if self.residual_scale is not None:
  179. // residual = self.residual_scale * residual
  180. seqs = ggml_add(ctx, seqs, residual);
  181. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  182. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  183. return seqs;
  184. }
  185. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  186. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  187. struct ggml_tensor * ggml_slice(
  188. struct ggml_context * ctx,
  189. struct ggml_tensor * a,
  190. int axis,
  191. int64_t start,
  192. int64_t end
  193. ) {
  194. int64_t ne[4];
  195. std::copy(a->ne, a->ne + 4, ne);
  196. if (axis < 0) axis = a->n_dims + axis;
  197. if (start < 0) start = ne[axis] + start;
  198. if (end < 0) end = ne[axis] + end;
  199. GGML_ASSERT(0 <= start);
  200. GGML_ASSERT(start <= end);
  201. GGML_ASSERT(end <= ne[axis]);
  202. ne[axis] = end - start;
  203. size_t offset = a->nb[axis] * start;
  204. size_t* nb = a->nb;
  205. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  206. result->n_dims = a->n_dims;
  207. return result;
  208. }
  209. extern "C" ggml_tensor* PositionalEmbedding_forward(
  210. fairseq2_model& model,
  211. const std::string& prefix,
  212. ggml_tensor* embeds
  213. ) {
  214. // This only work with the simple pos encoders
  215. int seq_len = embeds->ne[1];
  216. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  217. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
  218. return ggml_add(model.ctx, embeds, pos_embeds);
  219. }
  220. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  221. fairseq2_model& model,
  222. const std::string& prefix,
  223. ggml_tensor* seqs
  224. // TODO: state_bag
  225. ) {
  226. ggml_context* ctx = model.ctx;
  227. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  228. GGML_ASSERT(embed_weights != nullptr);
  229. ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
  230. // padding mask ?
  231. // padding_mask = to_padding_mask(embeds, seq_lens)
  232. if (has_layer(model, prefix + ".pos_encoder")) {
  233. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  234. }
  235. if (has_layer(model, prefix + ".layer_norm")) {
  236. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  237. }
  238. return embeds;
  239. }
  240. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  241. fairseq2_model& model,
  242. const std::string& prefix,
  243. ggml_tensor* seqs,
  244. ggml_tensor* padding_mask
  245. ) {
  246. int layer_idx = 0;
  247. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  248. while (has_layer(model, layer_name)) {
  249. seqs = StandardTransformerEncoderLayer_forward(
  250. model, layer_name, seqs, padding_mask
  251. );
  252. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  253. layer_idx += 1;
  254. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  255. }
  256. if (has_layer(model, prefix + ".layer_norm"))
  257. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  258. return seqs;
  259. }
  260. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  261. fairseq2_model& model,
  262. const std::string& prefix,
  263. ggml_tensor* seqs,
  264. ggml_tensor* self_attn_mask,
  265. ggml_tensor* encoder_output,
  266. ggml_tensor* encoder_padding_mask
  267. ) {
  268. ggml_context* ctx = model.ctx;
  269. // TODO: read norm_order from model
  270. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  271. // _forward_self_attn(seqs, padding_mask)
  272. auto residual = seqs;
  273. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  274. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  275. seqs = MultiheadAttention_forward(
  276. model,
  277. prefix + ".self_attn",
  278. seqs,
  279. seqs,
  280. seqs,
  281. /*attention masks=*/self_attn_mask
  282. );
  283. if (has_layer(model, prefix + ".self_attn_norm"))
  284. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  285. seqs = ggml_add(ctx, seqs, residual);
  286. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  287. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  288. // _forward_encoder_decoder_attn
  289. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  290. // `encoder_output` must be `None` for decoder-only attention.
  291. GGML_ASSERT(encoder_output == nullptr);
  292. return seqs;
  293. }
  294. // `encoder_output` must not be `None` for encoder-decoder attention.
  295. GGML_ASSERT(encoder_output != nullptr);
  296. residual = seqs;
  297. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  298. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  299. seqs = MultiheadAttention_forward(
  300. model,
  301. prefix + ".encoder_decoder_attn",
  302. seqs,
  303. encoder_output,
  304. encoder_output,
  305. /*attention masks=*/encoder_padding_mask
  306. );
  307. seqs = ggml_add(ctx, seqs, residual);
  308. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  309. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  310. // _forward_ffn(seqs)
  311. residual = seqs;
  312. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  313. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  314. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  315. // TODO:
  316. // if self.residual_scale is not None:
  317. // residual = self.residual_scale * residual
  318. seqs = ggml_add(ctx, seqs, residual);
  319. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  320. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  321. return seqs;
  322. }
  323. ggml_tensor* causal_mask_cache = nullptr;
  324. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  325. auto seq_len = seqs->ne[1];
  326. auto mask = causal_mask_cache;
  327. // TODO: this cache only works as long as we don't change the size/device too often
  328. // TODO: allow other ggml_type
  329. if (mask == nullptr || mask->backend != seqs->backend || mask->ne[0] < seq_len) {
  330. printf("new causal_mask (%ld, %ld) created\n", seq_len, seq_len);
  331. mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  332. char* data = (char*)mask->data;
  333. // tensor([[0., -inf, -inf, -inf],
  334. // [0., 0., -inf, -inf],
  335. // [0., 0., 0., -inf],
  336. // [0., 0., 0., 0.]])
  337. for (int i = 0; i < seq_len; ++i) {
  338. char* row = data + i * mask->nb[1];
  339. for (int j = 0; j <= i; ++j) {*(float*)(row + j * mask->nb[0]) = 0;}
  340. for (int j = i + 1; j < seq_len; ++j) {*(float*)(row + j * mask->nb[0]) = -INFINITY;}
  341. }
  342. causal_mask_cache = mask;
  343. }
  344. return ggml_view_2d(ctx, mask, seq_len, seq_len, mask->nb[1], 0);
  345. }
  346. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  347. fairseq2_model& model,
  348. const std::string& prefix,
  349. ggml_tensor* seqs,
  350. ggml_tensor* padding_mask,
  351. ggml_tensor* encoder_output,
  352. ggml_tensor* encoder_padding_mask
  353. ) {
  354. int layer_idx = 0;
  355. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  356. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  357. while (has_layer(model, layer_name)) {
  358. seqs = StandardTransformerDecoderLayer_forward(
  359. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  360. );
  361. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  362. layer_idx += 1;
  363. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  364. }
  365. if (has_layer(model, prefix + ".layer_norm"))
  366. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  367. return seqs;
  368. }
  369. using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
  370. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  371. auto opts = job.opts;
  372. int max_seq_len = -1;
  373. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  374. max_seq_len = opts.hard_max_seq_len;
  375. } else {
  376. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len + opts.soft_max_seq_len_b));
  377. }
  378. if (opts.min_seq_len > max_seq_len) {
  379. printf(
  380. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  381. opts.min_seq_len,
  382. max_seq_len
  383. );
  384. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  385. }
  386. int prefix_seq_len = job.prefix_seq->ne[0];
  387. if (prefix_seq_len >= max_seq_len) {
  388. printf(
  389. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  390. prefix_seq_len,
  391. max_seq_len
  392. );
  393. GGML_ASSERT(prefix_seq_len < max_seq_len);
  394. }
  395. return max_seq_len;
  396. }
  397. void _fan_out_encoder_output(
  398. ggml_context* ctx,
  399. ggml_tensor** encoder_output_out,
  400. ggml_tensor** encoder_padding_mask_out,
  401. int beam_size
  402. ) {
  403. // (S_enc, M)
  404. ggml_tensor* encoder_output = *encoder_output_out;
  405. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  406. // (B, S_enc, M)
  407. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  408. // (S_enc, M) -> (B, S_enc, M)
  409. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  410. // (S_enc) -> (B, S_enc)
  411. ggml_tensor* shape_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], beam_size);
  412. if (encoder_padding_mask != nullptr) {
  413. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  414. }
  415. }
  416. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  417. // TODO: this isn't the smartest way of doing this
  418. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  419. }
  420. void _bootstrap_seqs_and_scores(
  421. fairseq2_model& model,
  422. const SequenceGeneratorJob& job,
  423. ggml_tensor* seqs,
  424. ggml_tensor* scores,
  425. ggml_tensor* encoder_output,
  426. ggml_tensor* encoder_padding_mask,
  427. IncrementalStateBag state_bag
  428. ) {
  429. int prefix_seq_len = job.prefix_seq->ne[0];
  430. int max_seq_len = scores->ne[0];
  431. int beam_size = scores->ne[1];
  432. GGML_ASSERT(prefix_seq_len > 0);
  433. if (prefix_seq_len == 1)
  434. return;
  435. ggml_context* ctx = model.ctx;
  436. // seqs[:, : prefix_seq_len] = job.prefix_seq;
  437. ggml_cpy(ctx, job.prefix_seq, ggml_view_2d(ctx, seqs, 0, prefix_seq_len, seqs->nb[1], 0));
  438. // We have to bootstrap the model with the already fanned-out encoder
  439. // output to correctly initialize its incremental state. This causes some
  440. // redundancy as we have to expand `decoder_input` to match the shape of
  441. // `encoder_output`.
  442. // (S_pfx) -> (N x B, S_pfx - 1)
  443. // prefix_seq[:-1].expand(encoder_output.size(0), -1)
  444. ggml_tensor* decoder_input = ggml_repeat(ctx, ggml_view_1d(ctx, job.prefix_seq, prefix_seq_len - 1, 0), encoder_output);
  445. // Bootstrap the model state with prefix sequence.
  446. decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
  447. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  448. model,
  449. "text_decoder",
  450. decoder_input,
  451. /*padding_mask*/ nullptr,
  452. encoder_output,
  453. encoder_padding_mask
  454. // TODO: state_bag
  455. );
  456. // TODO state_bag.increment_step(prefix_seq_len - 1)
  457. // logits, lprobs: (N, S_pfx - 1, V)
  458. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  459. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_view_3d(ctx, logits, logits->ne[0], logits->ne[1], 1, 0, 0, 0));
  460. int vocab_size = logits->ne[0];
  461. ggml_cgraph gf = ggml_build_forward(lprobs);
  462. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  463. // Fetch scores of next steps from "lprobs"
  464. float p_score = 0;
  465. for (int i = 0; i < prefix_seq_len; ++i) {
  466. int p = ggml_get_i32_1d(job.prefix_seq, i);
  467. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  468. for (int b = 0; b < beam_size; ++b) {
  469. // scores: (N, S)
  470. // Note: First step (e.g. BOS)'s score is always 0.
  471. ggml_set_f32_1d(scores, b * max_seq_len + i + 1, p_score);
  472. }
  473. }
  474. }
  475. /// Represents a hypothesis produced by a sequence generator.
  476. struct Hypothesis {
  477. /// The generated sequence.
  478. ggml_tensor* seq;
  479. /// The score of the hypothesis.
  480. float score;
  481. /// The score of each individual sequence step.
  482. ggml_tensor* step_scores;
  483. };
  484. /// Represents a standard beam search algoritm.
  485. int StandardBeamSearch_step(
  486. ggml_context* ctx,
  487. int step_nr,
  488. bool is_start_step,
  489. ggml_tensor* lprobs, // (B, V)
  490. ggml_tensor* last_scores, // (B)
  491. ggml_tensor* candidate_indices
  492. ) {
  493. GGML_ASSERT(lprobs->n_dims == 2);
  494. int vocab_size = lprobs->ne[0];
  495. int beam_size = lprobs->ne[1];
  496. GGML_ASSERT(last_scores->n_dims == 2);
  497. GGML_ASSERT(last_scores->ne[0] == 1);
  498. GGML_ASSERT(last_scores->ne[1] == beam_size);
  499. GGML_ASSERT(candidate_indices->ne[0] == beam_size * vocab_size);
  500. // should this be done by the caller ?
  501. if (is_start_step) {
  502. // At the initial step, all hypotheses are equally likely, so we use
  503. // only the first beam.
  504. lprobs = ggml_slice(ctx, lprobs, 1, 0, 1);
  505. lprobs = ggml_cont(ctx, lprobs);
  506. // The first step always indicates the beginning of the sequence and
  507. // has no score.
  508. if (step_nr > 0) {
  509. lprobs = ggml_add_inplace(ctx, lprobs, last_scores);
  510. }
  511. } else {
  512. // Make probabilities contain cumulative scores for each hypothesis.
  513. lprobs = ggml_add_inplace(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  514. }
  515. // Note this is where we will actually do the model inference.
  516. ggml_cgraph gf = ggml_build_forward(lprobs);
  517. printf("StandardBeamSearch_step.graph.n_nodes: %d\n", gf.n_nodes);
  518. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  519. // Take the best 2 x `beam_size` predictions. We'll choose the first
  520. // `beam_size` of these which don't predict EOS to continue with.
  521. // (N, 2 x B)
  522. // `vocab_size` - 1 to never select PAD.
  523. int topk = std::min(2 * beam_size, vocab_size - 1);
  524. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  525. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  526. };
  527. auto cand = (std::int32_t*)candidate_indices->data;
  528. std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
  529. return topk;
  530. }
  531. int _finalize_hypothesis(
  532. const SequenceGeneratorJob& job,
  533. ggml_context* ctx,
  534. int step_nr,
  535. int vocab_size,
  536. std::int32_t candidate,
  537. float tok_score,
  538. ggml_tensor* seqs, // (beam_size, seq_len)
  539. ggml_tensor* scores, // (beam_size, seq_len)
  540. std::vector<Hypothesis>& hypotheses
  541. ) {
  542. std::int32_t beam = candidate / vocab_size;
  543. std::int32_t token = candidate % vocab_size;
  544. // Detect beams that reached the minimum length and that end with an EOS.
  545. bool eos = token == job.eos_idx;
  546. eos &= tok_score != -INFINITY;
  547. // TODO ignored_beam_mask ?
  548. // eos &= ggml_get_i32_1d(ignored_beam_mask, beam);
  549. // ggml_set_i32_1d(eos_mask, beam, eos);
  550. if (!eos) return 0;
  551. // If the candidate beam is "finished", let's copy the score and sequence
  552. ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  553. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  554. auto tok = (std::int32_t*)tokens->data;
  555. for (int i = 0; i < step_nr + 1; ++i) {
  556. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  557. }
  558. tok[step_nr + 1] = token;
  559. // Convert from cumulative to per-step scores.
  560. auto sc = (float*)step_scores->data;
  561. float last_score = tok_score;
  562. for (int i = step_nr; i >= 0; --i) {
  563. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  564. sc[i] = last_score - sc0;
  565. last_score = sc0;
  566. }
  567. if (job.opts.normalize_scores)
  568. // Skip first EOS since it is always 0 and skews normalization.
  569. tok_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  570. hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
  571. return 1;
  572. }
  573. /// Generates a translation for a single sequence
  574. // TODO: finish this for beam_size=1
  575. // * implement the lprobs tweaking
  576. // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
  577. // TODO: support beam_size > 1:
  578. // * most layers assume un-batched input, but we want to handle several beams at once
  579. // * need to port "reorder_state_dict"
  580. // * once beam are selected with topk, we need to update seqs and scores tensors
  581. extern "C" float generate_sequence(
  582. fairseq2_model& model,
  583. const SequenceGeneratorJob& job,
  584. ggml_tensor* encoder_output,
  585. ggml_tensor* encoder_padding_mask,
  586. ggml_tensor* output_seq
  587. ) {
  588. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  589. int vocab_size = embed->ne[1];
  590. std::size_t beam_size = job.opts.beam_size;
  591. int source_seq_len = encoder_output->ne[1];
  592. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  593. ggml_context* ctx = model.ctx;
  594. // (S_enc, M) -> (B, S_enc, M)
  595. _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
  596. std::vector<Hypothesis> finished_searches;
  597. finished_searches.reserve(beam_size);
  598. // Initialize buffers. (B, S)
  599. ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  600. ggml_set_i32(seqs, 0);
  601. ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  602. ggml_set_f32(scores, 0.0);
  603. IncrementalStateBag state_bag = {};
  604. _bootstrap_seqs_and_scores(
  605. model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
  606. );
  607. int prefix_seq_len = job.prefix_seq->ne[0];
  608. int start_step = prefix_seq_len - 1;
  609. // Holds the indices of beams (a beam can occur more than once) that we
  610. // should continue with in the next step.
  611. ggml_tensor* beam_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  612. ggml_tensor* next_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  613. ggml_tensor* next_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, beam_size);
  614. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  615. ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
  616. for (std::size_t i = 0; i < vocab_size * beam_size; ++i) ggml_set_i32_1d(candidate_indices, i, i);
  617. // TODO: memory management
  618. // there should be a per-step ggml_context for intermediary results
  619. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  620. // if (beam_indices != nullptr) {
  621. // // If not `None`, it means in the last step we finalized one or
  622. // // more searches. We should ensure that we adjust `beam_indices`
  623. // // before reordering `decoder`'s incremental state.
  624. // if (search_indices != nullptr) {
  625. // num_searches = search_indices->ne[0];
  626. // // (N)
  627. // delta = search_indices - torch.arange(num_searches, device=device)
  628. // // (N) -> (N, 1)
  629. // delta.unsqueeze_(-1)
  630. // // Adjust indices to take into account removed searches.
  631. // beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
  632. // }
  633. // // state_bag.reorder(beam_indices)
  634. // }
  635. // because of no IncrementalStateBag we pass input from the start
  636. // decoder_input = seqs[:, 0 : step_nr + 1]
  637. ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
  638. decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
  639. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  640. model,
  641. "text_decoder",
  642. decoder_input,
  643. nullptr, // We never generate PAD.
  644. encoder_output,
  645. encoder_padding_mask
  646. // state_bag=state_bag,
  647. );
  648. // state_bag.increment_step()
  649. // Because of no IncrementalStateBag decoder_output here is of shape (B, S, D)
  650. // Just look at the last token.
  651. decoder_output = ggml_slice(ctx, decoder_output, 1, step_nr, step_nr+1);
  652. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  653. ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
  654. // // Do not allow EOS before reaching the minimum sequence length.
  655. // if step_nr < self.opts.min_seq_len:
  656. // lprobs[:, :, self.eos_idx] = -torch.inf
  657. // // If we have reached the maximum length, force the last step to be
  658. // // EOS.
  659. // if step_nr == max_seq_len - 2:
  660. // lprobs[:, :, : self.eos_idx] = -torch.inf
  661. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  662. // // Never allow PAD.
  663. // lprobs[:, :, self.pad_idx] = -torch.inf
  664. // // Apply UNK penalty.
  665. // if self.unk_idx is not None:
  666. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  667. // Determine candidates for the next step.
  668. // (N, 2 x B)
  669. int topk = StandardBeamSearch_step(
  670. ctx,
  671. step_nr,
  672. step_nr == start_step,
  673. lprobs,
  674. ggml_slice(ctx, scores, 0, step_nr, step_nr+1),
  675. candidate_indices
  676. );
  677. std::size_t ongoing_beams = 0;
  678. int new_num_searches = 0;
  679. for (std::int32_t i = 0; i < topk; ++i) {
  680. int c = ggml_get_f32_1d(candidate_indices, i);
  681. float tok_score = ggml_get_f32_1d(lprobs, c);
  682. int finished = _finalize_hypothesis(job, ctx, step_nr, vocab_size, c, tok_score, seqs, scores, finished_searches);
  683. new_num_searches += finished;
  684. if (!finished){
  685. std::int32_t beam = c / vocab_size;
  686. std::int32_t token = c % vocab_size;
  687. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  688. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  689. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  690. ongoing_beams += 1 - finished;
  691. }
  692. if (ongoing_beams >= beam_size) break;
  693. if (finished_searches.size() >= beam_size) break;
  694. }
  695. if (finished_searches.size() >= beam_size) break;
  696. // Reorder beams in the `seq` and `score` buffers. The same beam can
  697. // be selected more than once.
  698. ggml_tensor* new_seqs = seqs;
  699. ggml_tensor* new_scores = scores;
  700. if (step_nr > start_step) {
  701. // (B, S), (B) -> (B, S)
  702. // ggml_get_rows only work with floats ...
  703. new_seqs->type = GGML_TYPE_F32;
  704. new_seqs = ggml_get_rows(ctx, seqs, beam_indices);
  705. new_scores = ggml_get_rows(ctx, new_scores, beam_indices);
  706. }
  707. // new_seqs[:, step_nr + 1] = next_tokens
  708. ggml_set_1d_inplace(ctx, new_seqs, next_tokens, new_seqs->nb[0] * (step_nr + 1));
  709. ggml_set_1d_inplace(ctx, new_scores, next_scores, new_scores->nb[0] * (step_nr + 1));
  710. ggml_cgraph gf = ggml_build_forward(new_seqs);
  711. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  712. new_seqs->type = GGML_TYPE_I32;
  713. gf = ggml_build_forward(new_scores);
  714. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  715. // TODO the old seqs and score buffers could be reused for next step
  716. seqs = new_seqs;
  717. scores = new_scores;
  718. }
  719. // Ensure that hypotheses are sorted by decreasing scores before returning.
  720. std::sort(
  721. finished_searches.begin(),
  722. finished_searches.end(),
  723. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  724. );
  725. // For now just return the best sequence
  726. // TODO: return structured output
  727. *output_seq = *(finished_searches[0].seq);
  728. return 0.0f;
  729. }