fairseq2.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. #include <unordered_map>
  5. #include <algorithm>
  6. /// allocate the fairseq2 model and hyperparameters
  7. extern "C" fairseq2_model* fairseq2_model_alloc() {
  8. // pre-allocate some memory to write hyperparameters and tensors pointers
  9. auto* model = new fairseq2_model;
  10. model->hparams = new std::uint8_t[8 * 1024];
  11. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  12. model->tensors_ctx = nullptr;
  13. return model;
  14. };
  15. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  16. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  17. delete (std::uint64_t*)(model->arch);
  18. delete (std::uint8_t*)model->hparams;
  19. delete model;
  20. };
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  22. model->ctx = ctx;
  23. }
  24. extern "C" std::string* std_string_alloc(char* c_str) {
  25. return new std::string(c_str);
  26. }
  27. extern "C" void std_string_free(std::string* str) {
  28. delete str;
  29. }
  30. bool has_layer(fairseq2_model& model, const std::string& name) {
  31. return model.tensors.find(name) != model.tensors.end();
  32. }
  33. extern "C" ggml_tensor* Linear_forward(
  34. fairseq2_model& model,
  35. const std::string &prefix,
  36. ggml_tensor* input // (d_in)
  37. ) {
  38. // Note: for now we assumed un-batched input
  39. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  40. GGML_ASSERT(weight != nullptr);
  41. ggml_tensor* out = ggml_mul_mat(model.ctx, weight, input); // (d_out)
  42. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  43. if (bias == nullptr) return out;
  44. return ggml_add_inplace(model.ctx, out, bias);
  45. }
  46. extern "C" ggml_tensor* LayerNorm_forward(
  47. fairseq2_model& model,
  48. const std::string &prefix,
  49. ggml_tensor* input
  50. ) {
  51. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  52. GGML_ASSERT(weight != nullptr);
  53. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  54. GGML_ASSERT(bias != nullptr);
  55. auto ctx = model.ctx;
  56. // TODO: should `eps` be part of unity hparams ?
  57. input = ggml_norm(ctx, input, /*eps*/1e-5);
  58. return ggml_add_inplace(
  59. ctx,
  60. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  61. ggml_repeat(ctx, bias, input)
  62. );
  63. }
  64. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  65. fairseq2_model& model,
  66. const std::string& prefix,
  67. ggml_tensor* seqs
  68. ) {
  69. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  70. // inner_activation = ReLu // TODO: allow other activation
  71. seqs = ggml_relu_inplace(model.ctx, seqs);
  72. if (has_layer(model, prefix + ".inner_layer_norm")) {
  73. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  74. }
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. /// Merge the given dimension and the previous one in the tensor.
  79. /// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
  80. /// dim is the position of the resulting merged dimension
  81. /// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d)
  82. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  83. int n_dims = x->n_dims;
  84. GGML_ASSERT(dim >= 0);
  85. GGML_ASSERT(dim < n_dims);
  86. // Nothing to do
  87. if (dim == n_dims - 1) return x;
  88. if (n_dims == 2) {
  89. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  90. } else if (n_dims == 3) {
  91. if (dim == 0) {
  92. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  93. } else { // dim == 1
  94. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  95. }
  96. } else { // n_dims == 4
  97. if (dim == 0) {
  98. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  99. } else if (dim == 1) {
  100. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  101. } else { // dim == 2
  102. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  103. }
  104. }
  105. }
  106. /// Split the given dimension.
  107. /// (..., K * N, ...) -> (..., K, N, ...)
  108. /// dim is the position of the output dimension with the given number of element (N).
  109. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  110. int n_dims = x->n_dims;
  111. GGML_ASSERT(dim >= 0);
  112. GGML_ASSERT(dim < n_dims);
  113. GGML_ASSERT(n_dims < 4);
  114. if (n_dims == 1) {
  115. return ggml_reshape_2d(ctx, x, num_el, x->ne[0] / num_el);
  116. } else if (n_dims == 2) {
  117. if (dim == 0) {
  118. return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
  119. } else { // dim == 1
  120. return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
  121. }
  122. } else { // (n_dims == 3)
  123. if (dim == 0) {
  124. return ggml_reshape_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2]);
  125. } else if (dim == 1) {
  126. return ggml_reshape_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2]);
  127. } else { // dim == 2
  128. return ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el);
  129. }
  130. }
  131. }
  132. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  133. // (B, S, dim) -> (B, S, H, H_dim)
  134. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  135. // (B?, S, H, H_dim) -> (B?, H, S, H_dim)
  136. x = ggml_permute(ctx, x, 0, 2, 1, 3);
  137. return x;
  138. }
  139. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  140. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  141. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  142. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  143. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  144. return v;
  145. }
  146. // flash_attn doesn't work for cross attention because it assumes Q <= K
  147. // TODO: enable flash_attn only for the encoder
  148. # define UNITY_FLASH_ATTN 0
  149. extern "C" ggml_tensor* MultiheadAttention_forward(
  150. fairseq2_model& model,
  151. const std::string &prefix,
  152. ggml_tensor* queries, // (slen, d_in)
  153. ggml_tensor* keys, // (klen, d_in)
  154. ggml_tensor* values, // (klen, d_out)
  155. ggml_tensor* mask // (klen, slen)
  156. ) {
  157. int model_dim = queries->ne[0];
  158. int num_heads = 16; // TODO: read from hparams
  159. int head_dim = model_dim / num_heads;
  160. GGML_ASSERT(model_dim % num_heads == 0);
  161. ggml_context* ctx = model.ctx;
  162. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
  163. q = _reshape_num_head(ctx, q, head_dim); // (B, H, S, H_dim)
  164. ggml_set_name(q, "q");
  165. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  166. k = _reshape_num_head(ctx, k, head_dim); // (B, H, Sk, H_dim)
  167. ggml_set_name(k, "k");
  168. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  169. v = _reshape_num_head_values(ctx, v, head_dim); // (B, H, H_dim, Sk)
  170. v = ggml_cont(ctx, v);
  171. ggml_set_name(v, "v");
  172. #if UNITY_FLASH_ATTN
  173. // For flash_attn, we assume either no masks, or triangular masks.
  174. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (H, S, H_dim)
  175. ggml_set_name(attn, "attn");
  176. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  177. attn = ggml_cont(ctx, attn);
  178. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  179. #else
  180. // (B, H, Sk, H_dim) x (B, H, S, H_dim) -> (B, H, S, Sk)
  181. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  182. ggml_set_name(qk, "qk");
  183. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  184. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  185. qk = ggml_scale(ctx, qk, qk_scale);
  186. ggml_set_name(qk, "qk_scaled");
  187. // TODO: Should we replace this by ggml_diag_mask_inf ?
  188. if (mask) qk = ggml_add(ctx, qk, mask);
  189. // TODO: upgrade qk to float32 if needed
  190. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B, H, S, Sk)
  191. ggml_set_name(attn_weights, "attn_weights");
  192. // (B, H, S, Sk) x (B, H, H_dim, Sk) -> (B, H, H_dim, S)
  193. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  194. ggml_set_name(attn, "attn");
  195. attn = ggml_flatten_1d(ctx, attn, 1); // (B, H * H_dim, S)
  196. attn = ggml_transpose(ctx, attn); // (B, S, H * H_dim)
  197. // // I'm not sure why this one is needed ...
  198. attn = ggml_cont(ctx, attn);
  199. #endif // UNITY_FLASH_ATTN
  200. // out -> (B, S, d_out)
  201. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  202. ggml_set_name(out, "out");
  203. return out;
  204. }
  205. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  206. fairseq2_model& model,
  207. const std::string& prefix,
  208. ggml_tensor* seqs,
  209. ggml_tensor* padding_mask
  210. ) {
  211. ggml_context* ctx = model.ctx;
  212. // TODO: read norm_order from model
  213. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  214. // _forward_self_attn(seqs, padding_mask)
  215. auto residual = seqs;
  216. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  217. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  218. // TODO: add padding_mask to MultiheadAttention_forward
  219. GGML_ASSERT(padding_mask == nullptr);
  220. seqs = MultiheadAttention_forward(
  221. model,
  222. prefix + ".self_attn",
  223. seqs,
  224. seqs,
  225. seqs,
  226. /*attention masks=*/nullptr
  227. );
  228. if (has_layer(model, prefix + ".self_attn_norm"))
  229. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  230. seqs = ggml_add(ctx, seqs, residual);
  231. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  232. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  233. // _forward_ffn(seqs)
  234. residual = seqs;
  235. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  236. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  237. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  238. // TODO: if self.residual_scale is not None:
  239. // residual = self.residual_scale * residual
  240. seqs = ggml_add(ctx, seqs, residual);
  241. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  242. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  243. return seqs;
  244. }
  245. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  246. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  247. struct ggml_tensor * ggml_slice(
  248. struct ggml_context * ctx,
  249. struct ggml_tensor * a,
  250. int axis,
  251. int64_t start,
  252. int64_t end
  253. ) {
  254. int64_t ne[4];
  255. std::copy(a->ne, a->ne + 4, ne);
  256. if (axis < 0) axis = a->n_dims + axis;
  257. if (start < 0) start = ne[axis] + start;
  258. if (end < 0) end = ne[axis] + end;
  259. GGML_ASSERT(0 <= start);
  260. GGML_ASSERT(start <= end);
  261. GGML_ASSERT(end <= ne[axis]);
  262. ne[axis] = end - start;
  263. size_t offset = a->nb[axis] * start;
  264. size_t* nb = a->nb;
  265. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  266. result->n_dims = a->n_dims;
  267. return result;
  268. }
  269. extern "C" ggml_tensor* PositionalEmbedding_forward(
  270. fairseq2_model& model,
  271. const std::string& prefix,
  272. ggml_tensor* embeds
  273. ) {
  274. // This only work with the simple pos encoders
  275. int seq_len = embeds->ne[1];
  276. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  277. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
  278. return ggml_add(model.ctx, embeds, pos_embeds);
  279. }
  280. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  281. fairseq2_model& model,
  282. const std::string& prefix,
  283. ggml_tensor* seqs
  284. // TODO: state_bag
  285. ) {
  286. ggml_context* ctx = model.ctx;
  287. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  288. GGML_ASSERT(embed_weights != nullptr);
  289. ggml_tensor* embeds = ggml_get_rows(ctx, embed_weights, seqs);
  290. // padding mask ?
  291. // padding_mask = to_padding_mask(embeds, seq_lens)
  292. if (has_layer(model, prefix + ".pos_encoder")) {
  293. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  294. }
  295. if (has_layer(model, prefix + ".layer_norm")) {
  296. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  297. }
  298. return embeds;
  299. }
  300. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  301. fairseq2_model& model,
  302. const std::string& prefix,
  303. ggml_tensor* seqs,
  304. ggml_tensor* padding_mask
  305. ) {
  306. int layer_idx = 0;
  307. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  308. while (has_layer(model, layer_name)) {
  309. seqs = StandardTransformerEncoderLayer_forward(
  310. model, layer_name, seqs, padding_mask
  311. );
  312. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  313. layer_idx += 1;
  314. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  315. }
  316. if (has_layer(model, prefix + ".layer_norm"))
  317. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  318. return seqs;
  319. }
  320. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  321. fairseq2_model& model,
  322. const std::string& prefix,
  323. ggml_tensor* seqs,
  324. ggml_tensor* self_attn_mask,
  325. ggml_tensor* encoder_output,
  326. ggml_tensor* encoder_padding_mask
  327. ) {
  328. ggml_context* ctx = model.ctx;
  329. // TODO: read norm_order from model
  330. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  331. // _forward_self_attn(seqs, padding_mask)
  332. auto residual = seqs;
  333. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  334. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  335. seqs = MultiheadAttention_forward(
  336. model,
  337. prefix + ".self_attn",
  338. seqs,
  339. seqs,
  340. seqs,
  341. /*attention masks=*/self_attn_mask
  342. );
  343. if (has_layer(model, prefix + ".self_attn_norm"))
  344. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  345. seqs = ggml_add(ctx, seqs, residual);
  346. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  347. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  348. // _forward_encoder_decoder_attn
  349. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  350. // `encoder_output` must be `None` for decoder-only attention.
  351. GGML_ASSERT(encoder_output == nullptr);
  352. return seqs;
  353. }
  354. // `encoder_output` must not be `None` for encoder-decoder attention.
  355. GGML_ASSERT(encoder_output != nullptr);
  356. residual = seqs;
  357. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  358. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  359. seqs = MultiheadAttention_forward(
  360. model,
  361. prefix + ".encoder_decoder_attn",
  362. seqs,
  363. encoder_output,
  364. encoder_output,
  365. /*attention masks=*/encoder_padding_mask
  366. );
  367. seqs = ggml_add(ctx, seqs, residual);
  368. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  369. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  370. // _forward_ffn(seqs)
  371. residual = seqs;
  372. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  373. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  374. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  375. // TODO:
  376. // if self.residual_scale is not None:
  377. // residual = self.residual_scale * residual
  378. seqs = ggml_add(ctx, seqs, residual);
  379. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  380. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  381. return seqs;
  382. }
  383. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  384. auto seq_len = seqs->ne[1];
  385. // TODO: allow other ggml_type
  386. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  387. return ggml_diag_mask_inf(ctx, mask, 0);
  388. }
  389. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  390. fairseq2_model& model,
  391. const std::string& prefix,
  392. ggml_tensor* seqs,
  393. ggml_tensor* padding_mask,
  394. ggml_tensor* encoder_output,
  395. ggml_tensor* encoder_padding_mask
  396. ) {
  397. int layer_idx = 0;
  398. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  399. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  400. while (has_layer(model, layer_name)) {
  401. seqs = StandardTransformerDecoderLayer_forward(
  402. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  403. );
  404. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  405. layer_idx += 1;
  406. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  407. }
  408. if (has_layer(model, prefix + ".layer_norm"))
  409. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  410. return seqs;
  411. }
  412. using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
  413. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  414. auto opts = job.opts;
  415. int max_seq_len = -1;
  416. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  417. max_seq_len = opts.hard_max_seq_len;
  418. } else {
  419. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len + opts.soft_max_seq_len_b));
  420. }
  421. if (opts.min_seq_len > max_seq_len) {
  422. printf(
  423. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  424. opts.min_seq_len,
  425. max_seq_len
  426. );
  427. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  428. }
  429. int prefix_seq_len = job.prefix_seq->ne[0];
  430. if (prefix_seq_len >= max_seq_len) {
  431. printf(
  432. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  433. prefix_seq_len,
  434. max_seq_len
  435. );
  436. GGML_ASSERT(prefix_seq_len < max_seq_len);
  437. }
  438. return max_seq_len;
  439. }
  440. void _fan_out_encoder_output(
  441. ggml_context* ctx,
  442. ggml_tensor** encoder_output_out,
  443. ggml_tensor** encoder_padding_mask_out,
  444. int beam_size
  445. ) {
  446. // (S_enc, M)
  447. ggml_tensor* encoder_output = *encoder_output_out;
  448. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  449. // (B, S_enc, M)
  450. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  451. // (S_enc, M) -> (B, S_enc, M)
  452. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  453. // (S_enc) -> (B, S_enc)
  454. ggml_tensor* shape_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], beam_size);
  455. if (encoder_padding_mask != nullptr) {
  456. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  457. }
  458. }
  459. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  460. // TODO: this isn't the most precise way of doing this
  461. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  462. }
  463. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  464. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  465. ggml_type true_type = x->type;
  466. x->type = GGML_TYPE_F32;
  467. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  468. y->type = true_type;
  469. return y;
  470. }
  471. void _bootstrap_seqs_and_scores(
  472. fairseq2_model& model,
  473. const SequenceGeneratorJob& job,
  474. ggml_tensor* full_seqs,
  475. ggml_tensor* scores,
  476. ggml_tensor* encoder_output,
  477. ggml_tensor* encoder_padding_mask,
  478. IncrementalStateBag state_bag
  479. ) {
  480. int prefix_seq_len = job.prefix_seq->ne[0];
  481. int max_seq_len = scores->ne[0];
  482. int beam_size = scores->ne[1];
  483. GGML_ASSERT(prefix_seq_len > 0);
  484. if (prefix_seq_len == 1)
  485. return;
  486. ggml_context* ctx = model.ctx;
  487. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  488. full_seqs->type = GGML_TYPE_F32;
  489. job.prefix_seq->type = GGML_TYPE_F32;
  490. ggml_tensor* seqs = ggml_cpy(ctx, job.prefix_seq, ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len));
  491. // We have to bootstrap the model with the already fanned-out encoder
  492. // output to correctly initialize its incremental state.
  493. // (S_pfx) -> (N x B, S_pfx - 1)
  494. // prefix_seq[:-1].expand(beam_size, -1)
  495. seqs = ggml_expand_2d(ctx, ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1), prefix_seq_len - 1, beam_size);
  496. seqs->type = GGML_TYPE_I32;
  497. // Bootstrap the model state with prefix sequence.
  498. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  499. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  500. model,
  501. "text_decoder",
  502. seqs,
  503. /*padding_mask*/ nullptr,
  504. encoder_output,
  505. encoder_padding_mask
  506. // TODO: state_bag
  507. );
  508. // TODO state_bag.increment_step(prefix_seq_len - 1)
  509. // logits, lprobs: (N, S_pfx - 1, V)
  510. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  511. int vocab_size = logits->ne[0];
  512. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  513. ggml_cgraph gf = ggml_build_forward(lprobs);
  514. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  515. full_seqs->type = GGML_TYPE_I32;
  516. job.prefix_seq->type = GGML_TYPE_I32;
  517. // Fetch scores of next steps from "lprobs"
  518. float p_score = 0;
  519. for (int i = 0; i < prefix_seq_len; ++i) {
  520. int p = ggml_get_i32_1d(job.prefix_seq, i);
  521. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  522. for (int b = 0; b < beam_size; ++b) {
  523. // scores: (N, S)
  524. // Note: First step (e.g. BOS)'s score is always 0.
  525. ggml_set_f32_1d(scores, b * max_seq_len + i + 1, p_score);
  526. }
  527. }
  528. }
  529. /// Represents a hypothesis produced by a sequence generator.
  530. struct Hypothesis {
  531. /// The generated sequence.
  532. ggml_tensor* seq;
  533. /// The score of the hypothesis.
  534. float score;
  535. /// The score of each individual sequence step.
  536. ggml_tensor* step_scores;
  537. };
  538. /// Represents a standard beam search algoritm.
  539. int StandardBeamSearch_step(
  540. ggml_context* ctx,
  541. int step_nr,
  542. bool is_start_step,
  543. ggml_tensor* lprobs, // (B, V)
  544. ggml_tensor* last_scores, // (B)
  545. ggml_tensor* candidate_indices
  546. ) {
  547. GGML_ASSERT(lprobs->n_dims == 2);
  548. int vocab_size = lprobs->ne[0];
  549. int beam_size = lprobs->ne[1];
  550. GGML_ASSERT(last_scores->n_dims == 2);
  551. GGML_ASSERT(last_scores->ne[0] == 1);
  552. GGML_ASSERT(last_scores->ne[1] == beam_size);
  553. GGML_ASSERT(candidate_indices->ne[0] == beam_size * vocab_size);
  554. // should this be done by the caller ?
  555. if (is_start_step) {
  556. // At the initial step, all hypotheses are equally likely, so we use
  557. // only the first beam.
  558. lprobs = ggml_slice(ctx, lprobs, 1, 0, 1);
  559. lprobs = ggml_cont(ctx, lprobs);
  560. // The first step always indicates the beginning of the sequence and
  561. // has no score.
  562. if (step_nr > 0) {
  563. lprobs = ggml_add_inplace(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  564. }
  565. } else {
  566. // Make probabilities contain cumulative scores for each hypothesis.
  567. // TODO this seems incorrect
  568. lprobs = ggml_add(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  569. }
  570. ggml_cgraph gf = ggml_build_forward(lprobs);
  571. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  572. // Take the best 2 x `beam_size` predictions. We'll choose the first
  573. // `beam_size` of these which don't predict EOS to continue with.
  574. // (N, 2 x B)
  575. // `vocab_size` - 1 to never select PAD.
  576. int topk = std::min(2 * beam_size, vocab_size - 1);
  577. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  578. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  579. };
  580. auto cand = (std::int32_t*)candidate_indices->data;
  581. std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
  582. return topk;
  583. }
  584. void ggml_detach(ggml_tensor* a) {
  585. a->op = GGML_OP_NONE;
  586. a->src[0] = nullptr;
  587. }
  588. int _finalize_hypothesis(
  589. const SequenceGeneratorJob& job,
  590. ggml_context* ctx,
  591. int step_nr,
  592. int vocab_size,
  593. std::int32_t candidate,
  594. float tok_score,
  595. ggml_tensor* seqs, // (beam_size, seq_len)
  596. ggml_tensor* scores, // (beam_size, seq_len)
  597. std::vector<Hypothesis>& hypotheses
  598. ) {
  599. std::int32_t beam = candidate / vocab_size;
  600. std::int32_t token = candidate % vocab_size;
  601. // Detect beams that reached the minimum length and that end with an EOS.
  602. bool eos = token == job.eos_idx;
  603. eos &= tok_score != -INFINITY;
  604. if (!eos) return 0;
  605. // If the candidate beam is "finished", let's copy the score and sequence
  606. ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  607. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  608. auto tok = (std::int32_t*)tokens->data;
  609. for (int i = 0; i < step_nr + 1; ++i) {
  610. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  611. }
  612. tok[step_nr + 1] = token;
  613. // Convert from cumulative to per-step scores.
  614. auto sc = (float*)step_scores->data;
  615. float last_score = tok_score;
  616. for (int i = step_nr; i >= 0; --i) {
  617. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  618. sc[i] = last_score - sc0;
  619. last_score = sc0;
  620. }
  621. if (job.opts.normalize_scores)
  622. // Skip first EOS since it is always 0 and skews normalization.
  623. tok_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  624. // TODO the score computed here isn't the same than computed by fairseq2.
  625. hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
  626. return 1;
  627. }
  628. /// Generates a translation for a single sequence
  629. // TODO: finish this for beam_size=1
  630. // * find out why score is different (seq is the same though)
  631. // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
  632. // TODO: support beam_size > 1:
  633. // * most layers assume un-batched input, but we want to handle several beams at once
  634. // * need to port "reorder_state_dict"
  635. // TODO: clean up
  636. // * replace manual tensor tweaking with ggml_set_*d (ggml_set_slice could be useful)
  637. extern "C" float generate_sequence(
  638. fairseq2_model& model,
  639. const SequenceGeneratorJob& job,
  640. ggml_tensor* encoder_output,
  641. ggml_tensor* encoder_padding_mask,
  642. ggml_tensor* output_seq
  643. ) {
  644. ggml_context* ctx = model.ctx;
  645. size_t eos_idx = job.eos_idx;
  646. auto pad_idx = job.pad_idx;
  647. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  648. size_t vocab_size = embed->ne[1];
  649. std::size_t beam_size = job.opts.beam_size;
  650. int source_seq_len = encoder_output->ne[1];
  651. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  652. // (S_enc, M) -> (B, S_enc, M)
  653. _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
  654. std::vector<Hypothesis> finished_searches;
  655. finished_searches.reserve(beam_size);
  656. // Initialize buffers. (B, S)
  657. ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  658. ggml_set_i32(seqs, 0);
  659. ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  660. ggml_set_f32(scores, 0.0);
  661. IncrementalStateBag state_bag = {};
  662. _bootstrap_seqs_and_scores(
  663. model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
  664. );
  665. int prefix_seq_len = job.prefix_seq->ne[0];
  666. int start_step = prefix_seq_len - 1;
  667. // Holds the indices of beams (a beam can occur more than once) that we
  668. // should continue with in the next step.
  669. ggml_tensor* beam_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  670. ggml_tensor* next_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  671. ggml_tensor* next_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, beam_size);
  672. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  673. ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
  674. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  675. ((int32_t *)(candidate_indices->data))[i] = i;
  676. // TODO: memory management
  677. // there should be a per-step ggml_context for intermediary results
  678. // start of beam search:
  679. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  680. // if (beam_indices != nullptr) {
  681. // // If not `None`, it means in the last step we finalized one or
  682. // // more searches. We should ensure that we adjust `beam_indices`
  683. // // before reordering `decoder`'s incremental state.
  684. // if (search_indices != nullptr) {
  685. // num_searches = search_indices->ne[0];
  686. // // (N)
  687. // delta = search_indices - torch.arange(num_searches, device=device)
  688. // // (N) -> (N, 1)
  689. // delta.unsqueeze_(-1)
  690. // // Adjust indices to take into account removed searches.
  691. // beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
  692. // }
  693. // // state_bag.reorder(beam_indices)
  694. // }
  695. // because of no IncrementalStateBag we pass input from the start
  696. // decoder_input = seqs[:, 0 : step_nr + 1]
  697. ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
  698. decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
  699. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  700. model,
  701. "text_decoder",
  702. decoder_input,
  703. nullptr, // We never generate PAD.
  704. encoder_output,
  705. encoder_padding_mask
  706. // state_bag=state_bag,
  707. );
  708. // state_bag.increment_step()
  709. // Because of no IncrementalStateBag decoder_output here is of shape (B, S, D)
  710. // Just look at the last token.
  711. decoder_output = ggml_slice(ctx, decoder_output, 1, step_nr, step_nr+1);
  712. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  713. ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
  714. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  715. // TODO: use ggml properly compute the tweaks
  716. ggml_cgraph gf = ggml_build_forward(lprobs);
  717. printf("beam search step %d. Graph.n_nodes: %d\n", step_nr, gf.n_nodes);
  718. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  719. ggml_detach(lprobs);
  720. // // Do not allow EOS before reaching the minimum sequence length.
  721. if (step_nr < job.opts.min_seq_len) {
  722. // lprobs[:, :, self.eos_idx] = -INFINITY;
  723. for (size_t i = 0; i < beam_size; ++i)
  724. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  725. }
  726. // If we have reached the maximum length, force the last step to be EOS.
  727. // TODO: should this be done in an adhoc loop ? how often does that happen anyway ?
  728. if (step_nr == max_seq_len - 2) {
  729. // lprobs[:, :, : self.eos_idx] = -torch.inf
  730. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  731. for (size_t b = 0; b < beam_size; ++b) {
  732. size_t t = 0;
  733. for (t = 0; t < eos_idx; ++t)
  734. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  735. for (t = eos_idx + 1; t < vocab_size; ++t)
  736. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  737. }
  738. }
  739. // Never allow PAD.
  740. for (size_t i = 0; i < beam_size; ++i)
  741. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  742. // Apply UNK penalty.
  743. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  744. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  745. auto lprobs_raw = ggml_get_data_f32(lprobs);
  746. for (size_t i = 0; i < beam_size; ++i)
  747. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  748. }
  749. // Determine candidates for the next step.
  750. // (N, 2 x B)
  751. int topk = StandardBeamSearch_step(
  752. ctx,
  753. step_nr,
  754. step_nr == start_step,
  755. lprobs,
  756. ggml_slice(ctx, scores, 0, step_nr, step_nr+1),
  757. candidate_indices
  758. );
  759. std::size_t ongoing_beams = 0;
  760. int new_num_searches = 0;
  761. for (std::int32_t i = 0; i < topk; ++i) {
  762. int c = ggml_get_f32_1d(candidate_indices, i);
  763. float tok_score = ggml_get_f32_1d(lprobs, c);
  764. int finished = _finalize_hypothesis(job, ctx, step_nr, vocab_size, c, tok_score, seqs, scores, finished_searches);
  765. new_num_searches += finished;
  766. if (!finished){
  767. std::int32_t beam = c / vocab_size;
  768. std::int32_t token = c % vocab_size;
  769. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  770. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  771. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  772. ongoing_beams += 1 - finished;
  773. }
  774. if (ongoing_beams >= beam_size) break;
  775. if (finished_searches.size() >= beam_size)
  776. goto end_of_beam_search;
  777. }
  778. // Reorder beams in the `seq` and `score` buffers. The same beam can
  779. // be selected more than once.
  780. ggml_tensor* new_seqs = seqs;
  781. // ggml_get_rows and ggml_set only work with floats ...
  782. new_seqs->type = GGML_TYPE_F32;
  783. ggml_tensor* new_scores = scores;
  784. if (step_nr > start_step) {
  785. // (B, S), (B) -> (B, S)
  786. new_seqs = ggml_get_rows(ctx, seqs, beam_indices);
  787. new_scores = ggml_get_rows(ctx, new_scores, beam_indices);
  788. }
  789. // new_seqs[:, step_nr + 1] = next_tokens
  790. gf = ggml_build_forward(ggml_set_1d_inplace(ctx, new_seqs, next_tokens, new_seqs->nb[0] * (step_nr + 1)));
  791. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  792. ggml_detach(new_seqs);
  793. new_seqs->type = GGML_TYPE_I32;
  794. gf = ggml_build_forward(ggml_set_1d_inplace(ctx, new_scores, next_scores, new_scores->nb[0] * (step_nr + 1)));
  795. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  796. ggml_detach(new_scores);
  797. // TODO the old seqs and score buffers could be reused for next step
  798. seqs = new_seqs;
  799. scores = new_scores;
  800. }
  801. end_of_beam_search:
  802. // Ensure that hypotheses are sorted by decreasing scores before returning.
  803. std::sort(
  804. finished_searches.begin(),
  805. finished_searches.end(),
  806. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  807. );
  808. // For now just return the best sequence
  809. // TODO: return structured output
  810. *output_seq = *(finished_searches[0].seq);
  811. return finished_searches[0].score;
  812. }