fairseq2.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. #include <math.h>
  2. #include "ggml.h"
  3. #include "fairseq2.h"
  4. #include <unordered_map>
  5. #include <algorithm>
  6. /// allocate the fairseq2 model and hyperparameters
  7. extern "C" fairseq2_model* fairseq2_model_alloc() {
  8. // pre-allocate some memory to write hyperparameters and tensors pointers
  9. auto* model = new fairseq2_model;
  10. model->hparams = new std::uint8_t[8 * 1024];
  11. model->arch = new std::uint64_t[16 * 1024]; // max tensors allowed
  12. model->tensors_ctx = nullptr;
  13. return model;
  14. };
  15. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  16. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  17. delete (std::uint64_t*)(model->arch);
  18. delete (std::uint8_t*)model->hparams;
  19. delete model;
  20. };
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  22. model->ctx = ctx;
  23. }
  24. extern "C" std::string* std_string_alloc(char* c_str) {
  25. return new std::string(c_str);
  26. }
  27. extern "C" void std_string_free(std::string* str) {
  28. delete str;
  29. }
  30. bool has_layer(fairseq2_model& model, const std::string& name) {
  31. return model.tensors.find(name) != model.tensors.end();
  32. }
  33. extern "C" ggml_tensor* Linear_forward(
  34. fairseq2_model& model,
  35. const std::string &prefix,
  36. ggml_tensor* input // (d_in)
  37. ) {
  38. // Note: for now we assumed un-batched input
  39. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  40. GGML_ASSERT(weight != nullptr);
  41. ggml_tensor* out = ggml_mul_mat(model.ctx, weight, input); // (d_out)
  42. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  43. if (bias == nullptr) return out;
  44. return ggml_add_inplace(model.ctx, out, bias);
  45. }
  46. extern "C" ggml_tensor* LayerNorm_forward(
  47. fairseq2_model& model,
  48. const std::string &prefix,
  49. ggml_tensor* input
  50. ) {
  51. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  52. GGML_ASSERT(weight != nullptr);
  53. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  54. GGML_ASSERT(bias != nullptr);
  55. auto ctx = model.ctx;
  56. // TODO: should `eps` be part of unity hparams ?
  57. input = ggml_norm(ctx, input, /*eps*/1e-5);
  58. return ggml_add_inplace(
  59. ctx,
  60. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  61. ggml_repeat(ctx, bias, input)
  62. );
  63. }
  64. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  65. fairseq2_model& model,
  66. const std::string& prefix,
  67. ggml_tensor* seqs
  68. ) {
  69. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  70. // inner_activation = ReLu // TODO: allow other activation
  71. seqs = ggml_relu_inplace(model.ctx, seqs);
  72. if (has_layer(model, prefix + ".inner_layer_norm")) {
  73. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  74. }
  75. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  76. return seqs;
  77. }
  78. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  79. int n_dims = x->n_dims;
  80. GGML_ASSERT(dim >= 0);
  81. GGML_ASSERT(dim < n_dims);
  82. GGML_ASSERT(ggml_is_contiguous(x));
  83. // Nothing to do
  84. if (dim == n_dims - 1) return x;
  85. if (n_dims == 2) {
  86. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  87. } else if (n_dims == 3) {
  88. if (dim == 0) {
  89. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  90. } else { // dim == 1
  91. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  92. }
  93. } else { // n_dims == 4
  94. if (dim == 0) {
  95. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  96. } else if (dim == 1) {
  97. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  98. } else { // dim == 2
  99. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  100. }
  101. }
  102. }
  103. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  104. int n_dims = x->n_dims;
  105. GGML_ASSERT(dim >= 0);
  106. GGML_ASSERT(dim < n_dims);
  107. GGML_ASSERT(n_dims < 4);
  108. if (n_dims == 1) {
  109. return ggml_reshape_2d(ctx, x, num_el, x->ne[0] / num_el);
  110. } else if (n_dims == 2) {
  111. if (dim == 0) {
  112. return ggml_reshape_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1]);
  113. } else { // dim == 1
  114. return ggml_reshape_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el);
  115. }
  116. } else { // (n_dims == 3)
  117. if (dim == 0) {
  118. return ggml_reshape_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2]);
  119. } else if (dim == 1) {
  120. return ggml_reshape_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2]);
  121. } else { // dim == 2
  122. return ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el);
  123. }
  124. }
  125. }
  126. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  127. // (B, S, dim) -> (B, S, H, H_dim)
  128. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  129. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  130. x = ggml_cont(ctx, x);
  131. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  132. return x;
  133. }
  134. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  135. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  136. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  137. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  138. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  139. v = ggml_cont(ctx, v);
  140. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  141. return v;
  142. }
  143. // flash_attn doesn't work for cross attention because it assumes Q <= K
  144. // TODO: enable flash_attn only for the encoder
  145. # define UNITY_FLASH_ATTN 0
  146. extern "C" ggml_tensor* MultiheadAttention_forward(
  147. fairseq2_model& model,
  148. const std::string &prefix,
  149. ggml_tensor* queries, // (slen, d_in)
  150. ggml_tensor* keys, // (klen, d_in)
  151. ggml_tensor* values, // (klen, d_out)
  152. ggml_tensor* mask // (klen, slen)
  153. ) {
  154. int model_dim = queries->ne[0];
  155. int num_heads = 16; // TODO: read from hparams
  156. int head_dim = model_dim / num_heads;
  157. GGML_ASSERT(model_dim % num_heads == 0);
  158. ggml_context* ctx = model.ctx;
  159. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  160. ggml_set_name(q, "q");
  161. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  162. ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
  163. ggml_set_name(k, "k");
  164. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  165. ggml_tensor* v = Linear_forward(model, prefix + ".v_proj", values);
  166. ggml_set_name(v, "v");
  167. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  168. v = ggml_cont(ctx, v);
  169. #if UNITY_FLASH_ATTN
  170. // For flash_attn, we assume either no masks, or triangular masks.
  171. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr); // (B * H, S, H_dim)
  172. ggml_set_name(attn, "attn");
  173. // TODO test !
  174. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  175. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  176. #else
  177. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  178. ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
  179. ggml_set_name(qk, "qk");
  180. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  181. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  182. qk = ggml_scale(ctx, qk, qk_scale);
  183. ggml_set_name(qk, "qk_scaled");
  184. // TODO: Should we replace this by ggml_diag_mask_inf ?
  185. if (mask) qk = ggml_add(ctx, qk, mask);
  186. // TODO: upgrade qk to float32 if needed
  187. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  188. ggml_set_name(attn_weights, "attn_weights");
  189. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  190. ggml_tensor* attn = ggml_mul_mat(ctx, attn_weights, v);
  191. ggml_set_name(attn, "attn");
  192. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  193. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  194. #endif // UNITY_FLASH_ATTN
  195. attn = ggml_cont(ctx, attn);
  196. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  197. // out -> (B, S, d_out)
  198. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  199. ggml_set_name(out, "out");
  200. return out;
  201. }
  202. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  203. fairseq2_model& model,
  204. const std::string& prefix,
  205. ggml_tensor* seqs,
  206. ggml_tensor* padding_mask
  207. ) {
  208. ggml_context* ctx = model.ctx;
  209. // TODO: read norm_order from model
  210. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  211. // _forward_self_attn(seqs, padding_mask)
  212. auto residual = seqs;
  213. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  214. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  215. // TODO: add padding_mask to MultiheadAttention_forward
  216. GGML_ASSERT(padding_mask == nullptr);
  217. seqs = MultiheadAttention_forward(
  218. model,
  219. prefix + ".self_attn",
  220. seqs,
  221. seqs,
  222. seqs,
  223. /*attention masks=*/nullptr
  224. );
  225. if (has_layer(model, prefix + ".self_attn_norm"))
  226. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  227. seqs = ggml_add(ctx, seqs, residual);
  228. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  229. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  230. // _forward_ffn(seqs)
  231. residual = seqs;
  232. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  233. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  234. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  235. // TODO: if self.residual_scale is not None:
  236. // residual = self.residual_scale * residual
  237. seqs = ggml_add(ctx, seqs, residual);
  238. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  239. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  240. return seqs;
  241. }
  242. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  243. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  244. struct ggml_tensor * ggml_slice(
  245. struct ggml_context * ctx,
  246. struct ggml_tensor * a,
  247. int axis,
  248. int64_t start,
  249. int64_t end
  250. ) {
  251. int64_t ne[4];
  252. std::copy(a->ne, a->ne + 4, ne);
  253. if (axis < 0) axis = a->n_dims + axis;
  254. if (start < 0) start = ne[axis] + start;
  255. if (end < 0) end = ne[axis] + end;
  256. GGML_ASSERT(0 <= start);
  257. GGML_ASSERT(start <= end);
  258. GGML_ASSERT(end <= ne[axis]);
  259. ne[axis] = end - start;
  260. size_t offset = a->nb[axis] * start;
  261. size_t* nb = a->nb;
  262. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  263. result->n_dims = a->n_dims;
  264. return result;
  265. }
  266. extern "C" ggml_tensor* PositionalEmbedding_forward(
  267. fairseq2_model& model,
  268. const std::string& prefix,
  269. ggml_tensor* embeds
  270. ) {
  271. // This only work with the simple pos encoders
  272. int seq_len = embeds->ne[1];
  273. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  274. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, 0, seq_len);
  275. return ggml_add(model.ctx, embeds, pos_embeds);
  276. }
  277. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  278. fairseq2_model& model,
  279. const std::string& prefix,
  280. ggml_tensor* seqs
  281. // TODO: state_bag
  282. ) {
  283. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  284. ggml_context* ctx = model.ctx;
  285. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  286. GGML_ASSERT(embed_weights != nullptr);
  287. ggml_tensor* embeds;
  288. if (seqs->n_dims == 1) {
  289. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  290. } else {
  291. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  292. embeds = ggml_get_rows(ctx, embed_weights, ggml_reshape_1d(ctx, seqs, ggml_nelements(seqs)));
  293. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  294. embeds->n_dims = seqs->n_dims + 1;
  295. }
  296. // padding mask ?
  297. // padding_mask = to_padding_mask(embeds, seq_lens)
  298. if (has_layer(model, prefix + ".pos_encoder")) {
  299. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  300. }
  301. if (has_layer(model, prefix + ".layer_norm")) {
  302. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  303. }
  304. return embeds;
  305. }
  306. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  307. fairseq2_model& model,
  308. const std::string& prefix,
  309. ggml_tensor* seqs,
  310. ggml_tensor* padding_mask
  311. ) {
  312. int layer_idx = 0;
  313. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  314. while (has_layer(model, layer_name)) {
  315. seqs = StandardTransformerEncoderLayer_forward(
  316. model, layer_name, seqs, padding_mask
  317. );
  318. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  319. layer_idx += 1;
  320. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  321. }
  322. if (has_layer(model, prefix + ".layer_norm"))
  323. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  324. return seqs;
  325. }
  326. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  327. fairseq2_model& model,
  328. const std::string& prefix,
  329. ggml_tensor* seqs,
  330. ggml_tensor* self_attn_mask,
  331. ggml_tensor* encoder_output,
  332. ggml_tensor* encoder_padding_mask
  333. ) {
  334. ggml_context* ctx = model.ctx;
  335. // TODO: read norm_order from model
  336. auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
  337. // _forward_self_attn(seqs, padding_mask)
  338. auto residual = seqs;
  339. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  340. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  341. seqs = MultiheadAttention_forward(
  342. model,
  343. prefix + ".self_attn",
  344. seqs,
  345. seqs,
  346. seqs,
  347. /*attention masks=*/self_attn_mask
  348. );
  349. if (has_layer(model, prefix + ".self_attn_norm"))
  350. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  351. seqs = ggml_add(ctx, seqs, residual);
  352. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  353. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  354. // _forward_encoder_decoder_attn
  355. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  356. // `encoder_output` must be `None` for decoder-only attention.
  357. GGML_ASSERT(encoder_output == nullptr);
  358. return seqs;
  359. }
  360. // `encoder_output` must not be `None` for encoder-decoder attention.
  361. GGML_ASSERT(encoder_output != nullptr);
  362. residual = seqs;
  363. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  364. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  365. seqs = MultiheadAttention_forward(
  366. model,
  367. prefix + ".encoder_decoder_attn",
  368. seqs,
  369. encoder_output,
  370. encoder_output,
  371. /*attention masks=*/encoder_padding_mask
  372. );
  373. seqs = ggml_add(ctx, seqs, residual);
  374. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  375. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  376. // _forward_ffn(seqs)
  377. residual = seqs;
  378. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  379. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  380. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  381. // TODO:
  382. // if self.residual_scale is not None:
  383. // residual = self.residual_scale * residual
  384. seqs = ggml_add(ctx, seqs, residual);
  385. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  386. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  387. return seqs;
  388. }
  389. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  390. auto seq_len = seqs->ne[1];
  391. // TODO: allow other ggml_type
  392. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  393. return ggml_diag_mask_inf(ctx, mask, 0);
  394. }
  395. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  396. fairseq2_model& model,
  397. const std::string& prefix,
  398. ggml_tensor* seqs,
  399. ggml_tensor* padding_mask,
  400. ggml_tensor* encoder_output,
  401. ggml_tensor* encoder_padding_mask
  402. ) {
  403. int layer_idx = 0;
  404. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  405. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  406. while (has_layer(model, layer_name)) {
  407. seqs = StandardTransformerDecoderLayer_forward(
  408. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  409. );
  410. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  411. layer_idx += 1;
  412. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  413. }
  414. if (has_layer(model, prefix + ".layer_norm"))
  415. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  416. return seqs;
  417. }
  418. using IncrementalStateBag = std::unordered_map<ggml_tensor*, ggml_tensor*>*;
  419. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  420. auto opts = job.opts;
  421. int max_seq_len = -1;
  422. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  423. max_seq_len = opts.hard_max_seq_len;
  424. } else {
  425. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  426. }
  427. if (opts.min_seq_len > max_seq_len) {
  428. printf(
  429. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  430. opts.min_seq_len,
  431. max_seq_len
  432. );
  433. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  434. }
  435. int prefix_seq_len = job.prefix_seq->ne[0];
  436. if (prefix_seq_len >= max_seq_len) {
  437. printf(
  438. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  439. prefix_seq_len,
  440. max_seq_len
  441. );
  442. GGML_ASSERT(prefix_seq_len < max_seq_len);
  443. }
  444. return max_seq_len;
  445. }
  446. void _fan_out_encoder_output(
  447. ggml_context* ctx,
  448. ggml_tensor** encoder_output_out,
  449. ggml_tensor** encoder_padding_mask_out,
  450. int beam_size
  451. ) {
  452. // (S_enc, M)
  453. ggml_tensor* encoder_output = *encoder_output_out;
  454. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  455. // (B, S_enc, M)
  456. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  457. // (S_enc, M) -> (B, S_enc, M)
  458. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  459. // (S_enc) -> (B, S_enc)
  460. ggml_tensor* shape_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], beam_size);
  461. if (encoder_padding_mask != nullptr) {
  462. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  463. }
  464. }
  465. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  466. // TODO: this isn't the most precise way of doing this
  467. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  468. }
  469. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  470. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  471. ggml_type true_type = x->type;
  472. x->type = GGML_TYPE_F32;
  473. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  474. y->type = true_type;
  475. return y;
  476. }
  477. void _bootstrap_seqs_and_scores(
  478. fairseq2_model& model,
  479. const SequenceGeneratorJob& job,
  480. ggml_tensor* full_seqs,
  481. ggml_tensor* scores,
  482. ggml_tensor* encoder_output,
  483. ggml_tensor* encoder_padding_mask,
  484. IncrementalStateBag state_bag
  485. ) {
  486. int prefix_seq_len = job.prefix_seq->ne[0];
  487. int max_seq_len = scores->ne[0];
  488. int beam_size = scores->ne[1];
  489. GGML_ASSERT(prefix_seq_len > 0);
  490. if (prefix_seq_len == 1)
  491. return;
  492. ggml_context* ctx = model.ctx;
  493. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  494. full_seqs->type = GGML_TYPE_F32;
  495. job.prefix_seq->type = GGML_TYPE_F32;
  496. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  497. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  498. // We have to bootstrap the model with the already fanned-out encoder
  499. // output to correctly initialize its incremental state.
  500. // Note: we don't start decoding the last prefix token just yet.
  501. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  502. // Bootstrap the model state with prefix sequence.
  503. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  504. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  505. model,
  506. "text_decoder",
  507. seqs,
  508. /*padding_mask*/ nullptr,
  509. encoder_output,
  510. encoder_padding_mask
  511. // TODO: state_bag
  512. );
  513. // TODO state_bag.increment_step(prefix_seq_len - 1)
  514. // logits, lprobs: (N, S_pfx - 1, V)
  515. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  516. int vocab_size = logits->ne[0];
  517. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  518. ggml_cgraph gf = ggml_build_forward(lprobs);
  519. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  520. full_seqs->type = GGML_TYPE_I32;
  521. job.prefix_seq->type = GGML_TYPE_I32;
  522. // Fetch scores of next steps from "lprobs"
  523. float p_score = 0;
  524. for (int i = 0; i < prefix_seq_len; ++i) {
  525. int p = ggml_get_i32_1d(job.prefix_seq, i);
  526. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  527. for (int b = 0; b < beam_size; ++b) {
  528. // scores: (N, S)
  529. // Note: First step (e.g. BOS)'s score is always 0.
  530. ggml_set_f32_1d(scores, b * max_seq_len + i + 1, p_score);
  531. }
  532. }
  533. }
  534. /// Represents a hypothesis produced by a sequence generator.
  535. struct Hypothesis {
  536. /// The generated sequence.
  537. ggml_tensor* seq;
  538. /// The score of the hypothesis.
  539. float score;
  540. /// The score of each individual sequence step.
  541. ggml_tensor* step_scores;
  542. };
  543. /// Represents a standard beam search algoritm.
  544. int StandardBeamSearch_step(
  545. ggml_context* ctx,
  546. int step_nr,
  547. bool is_start_step,
  548. ggml_tensor* lprobs, // (B, V)
  549. ggml_tensor* last_scores, // (B)
  550. ggml_tensor* candidate_indices
  551. ) {
  552. GGML_ASSERT(lprobs->n_dims == 2);
  553. int vocab_size = lprobs->ne[0];
  554. int beam_size = lprobs->ne[1];
  555. GGML_ASSERT(last_scores->n_dims == 2);
  556. GGML_ASSERT(last_scores->ne[0] == 1);
  557. GGML_ASSERT(last_scores->ne[1] == beam_size);
  558. GGML_ASSERT(candidate_indices->ne[0] == beam_size * vocab_size);
  559. // should this be done by the caller ?
  560. if (is_start_step) {
  561. // At the initial step, all hypotheses are equally likely, so we use
  562. // only the first beam.
  563. lprobs = ggml_slice(ctx, lprobs, 1, 0, 1);
  564. lprobs = ggml_cont(ctx, lprobs);
  565. // The first step always indicates the beginning of the sequence and
  566. // has no score.
  567. if (step_nr > 0) {
  568. lprobs = ggml_add_inplace(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  569. }
  570. } else {
  571. // Make probabilities contain cumulative scores for each hypothesis.
  572. // TODO this seems incorrect
  573. lprobs = ggml_add(ctx, lprobs, ggml_repeat(ctx, last_scores, lprobs));
  574. }
  575. ggml_cgraph gf = ggml_build_forward(lprobs);
  576. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  577. // Take the best 2 x `beam_size` predictions. We'll choose the first
  578. // `beam_size` of these which don't predict EOS to continue with.
  579. // (N, 2 x B)
  580. // `vocab_size` - 1 to never select PAD.
  581. int topk = std::min(2 * beam_size, vocab_size - 1);
  582. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  583. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  584. };
  585. auto cand = (std::int32_t*)candidate_indices->data;
  586. std::partial_sort(cand, cand + topk, cand + (beam_size * vocab_size), comp);
  587. return topk;
  588. }
  589. void ggml_detach(ggml_tensor* a) {
  590. a->op = GGML_OP_NONE;
  591. a->src[0] = nullptr;
  592. }
  593. int _finalize_hypothesis(
  594. const SequenceGeneratorJob& job,
  595. ggml_context* ctx,
  596. int step_nr,
  597. int vocab_size,
  598. std::int32_t candidate,
  599. float tok_score,
  600. ggml_tensor* seqs, // (beam_size, seq_len)
  601. ggml_tensor* scores, // (beam_size, seq_len)
  602. std::vector<Hypothesis>& hypotheses
  603. ) {
  604. std::int32_t beam = candidate / vocab_size;
  605. std::int32_t token = candidate % vocab_size;
  606. // Detect beams that reached the minimum length and that end with an EOS.
  607. bool eos = token == job.eos_idx;
  608. eos &= tok_score != -INFINITY;
  609. if (!eos) return 0;
  610. // If the candidate beam is "finished", let's copy the score and sequence
  611. ggml_tensor* tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  612. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  613. auto tok = (std::int32_t*)tokens->data;
  614. for (int i = 0; i < step_nr + 1; ++i) {
  615. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  616. }
  617. tok[step_nr + 1] = token;
  618. // Convert from cumulative to per-step scores.
  619. auto sc = (float*)step_scores->data;
  620. float last_score = tok_score;
  621. for (int i = step_nr; i >= 0; --i) {
  622. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  623. sc[i] = last_score - sc0;
  624. last_score = sc0;
  625. }
  626. if (job.opts.normalize_scores)
  627. // Skip first EOS since it is always 0 and skews normalization.
  628. tok_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  629. // TODO the score computed here isn't the same than computed by fairseq2.
  630. hypotheses.emplace_back(Hypothesis{tokens, tok_score, step_scores});
  631. return 1;
  632. }
  633. /// Generates a translation for a single sequence
  634. // TODO: finish this for beam_size=1
  635. // * find out why score is different (seq is the same though)
  636. // TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
  637. // TODO: support beam_size > 1:
  638. // * most layers assume un-batched input, but we want to handle several beams at once
  639. // * need to port "reorder_state_dict"
  640. // TODO: clean up
  641. // * replace manual tensor tweaking with ggml_set_*d (ggml_set_slice could be useful)
  642. extern "C" float generate_sequence(
  643. fairseq2_model& model,
  644. const SequenceGeneratorJob& job,
  645. ggml_tensor* encoder_output,
  646. ggml_tensor* encoder_padding_mask,
  647. ggml_tensor* output_seq
  648. ) {
  649. ggml_context* ctx = model.ctx;
  650. size_t eos_idx = job.eos_idx;
  651. auto pad_idx = job.pad_idx;
  652. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  653. size_t vocab_size = embed->ne[1];
  654. std::size_t beam_size = job.opts.beam_size;
  655. int source_seq_len = encoder_output->ne[1];
  656. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  657. // (S_enc, M) -> (B, S_enc, M)
  658. _fan_out_encoder_output(ctx, &encoder_output, &encoder_padding_mask, beam_size);
  659. std::vector<Hypothesis> finished_searches;
  660. finished_searches.reserve(beam_size);
  661. // Initialize buffers. (B, S)
  662. ggml_tensor* seqs = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  663. ggml_set_i32(seqs, 0);
  664. ggml_tensor* scores = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  665. ggml_set_f32(scores, 0.0);
  666. IncrementalStateBag state_bag = {};
  667. _bootstrap_seqs_and_scores(
  668. model, job, seqs, scores, encoder_output, encoder_padding_mask, state_bag
  669. );
  670. int prefix_seq_len = job.prefix_seq->ne[0];
  671. int start_step = prefix_seq_len - 1;
  672. // Holds the indices of beams (a beam can occur more than once) that we
  673. // should continue with in the next step.
  674. ggml_tensor* beam_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  675. ggml_tensor* next_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, beam_size);
  676. ggml_tensor* next_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, beam_size);
  677. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  678. ggml_tensor* candidate_indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vocab_size * beam_size);
  679. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  680. ((int32_t *)(candidate_indices->data))[i] = i;
  681. // TODO: memory management
  682. // there should be a per-step ggml_context for intermediary results
  683. // start of beam search:
  684. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  685. // if (beam_indices != nullptr) {
  686. // // If not `None`, it means in the last step we finalized one or
  687. // // more searches. We should ensure that we adjust `beam_indices`
  688. // // before reordering `decoder`'s incremental state.
  689. // if (search_indices != nullptr) {
  690. // num_searches = search_indices->ne[0];
  691. // // (N)
  692. // delta = search_indices - torch.arange(num_searches, device=device)
  693. // // (N) -> (N, 1)
  694. // delta.unsqueeze_(-1)
  695. // // Adjust indices to take into account removed searches.
  696. // beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
  697. // }
  698. // // state_bag.reorder(beam_indices)
  699. // }
  700. // because of no IncrementalStateBag we pass input from the start
  701. // decoder_input = seqs[:, 0 : step_nr + 1]
  702. ggml_tensor* decoder_input = ggml_slice(ctx, seqs, 0, 0, step_nr + 1);
  703. decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", decoder_input);
  704. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  705. model,
  706. "text_decoder",
  707. decoder_input,
  708. nullptr, // We never generate PAD.
  709. encoder_output,
  710. encoder_padding_mask
  711. // state_bag=state_bag,
  712. ); // (B, S, D)
  713. // state_bag.increment_step()
  714. // Because of no IncrementalStateBag decoder_output here is of shape (B, S, D)
  715. // Just look at the last token.
  716. decoder_output = ggml_slice(ctx, decoder_output, 1, step_nr, step_nr+1);
  717. decoder_output = ggml_cont(ctx, decoder_output);
  718. decoder_output = ggml_flatten_1d(ctx, decoder_output, 0); // (B, model_dim)
  719. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  720. ggml_tensor* lprobs = ggml_log_softmax(ctx, logits);
  721. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  722. // TODO: use ggml properly compute the tweaks
  723. ggml_cgraph gf = ggml_build_forward(lprobs);
  724. printf("beam search step %d. Graph.n_nodes: %d\n", step_nr, gf.n_nodes);
  725. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  726. ggml_detach(lprobs);
  727. // // Do not allow EOS before reaching the minimum sequence length.
  728. if (step_nr < job.opts.min_seq_len) {
  729. // lprobs[:, :, self.eos_idx] = -INFINITY;
  730. for (size_t i = 0; i < beam_size; ++i)
  731. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  732. }
  733. // If we have reached the maximum length, force the last step to be EOS.
  734. // TODO: should this be done in an adhoc loop ? how often does that happen anyway ?
  735. if (step_nr == max_seq_len - 2) {
  736. // lprobs[:, :, : self.eos_idx] = -torch.inf
  737. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  738. for (size_t b = 0; b < beam_size; ++b) {
  739. size_t t = 0;
  740. for (t = 0; t < eos_idx; ++t)
  741. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  742. for (t = eos_idx + 1; t < vocab_size; ++t)
  743. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  744. }
  745. }
  746. // Never allow PAD.
  747. for (size_t i = 0; i < beam_size; ++i)
  748. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  749. // Apply UNK penalty.
  750. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  751. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  752. auto lprobs_raw = ggml_get_data_f32(lprobs);
  753. for (size_t i = 0; i < beam_size; ++i)
  754. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  755. }
  756. // Determine candidates for the next step.
  757. // (N, 2 x B)
  758. int topk = StandardBeamSearch_step(
  759. ctx,
  760. step_nr,
  761. step_nr == start_step,
  762. lprobs,
  763. ggml_slice(ctx, scores, 0, step_nr, step_nr+1),
  764. candidate_indices
  765. );
  766. std::size_t ongoing_beams = 0;
  767. int new_num_searches = 0;
  768. for (std::int32_t i = 0; i < topk; ++i) {
  769. int c = ggml_get_f32_1d(candidate_indices, i);
  770. float tok_score = ggml_get_f32_1d(lprobs, c);
  771. int finished = _finalize_hypothesis(job, ctx, step_nr, vocab_size, c, tok_score, seqs, scores, finished_searches);
  772. new_num_searches += finished;
  773. if (!finished){
  774. std::int32_t beam = c / vocab_size;
  775. std::int32_t token = c % vocab_size;
  776. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  777. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  778. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  779. ongoing_beams += 1 - finished;
  780. }
  781. if (ongoing_beams >= beam_size) break;
  782. if (finished_searches.size() >= beam_size)
  783. goto end_of_beam_search;
  784. }
  785. // Reorder beams in the `seq` and `score` buffers. The same beam can
  786. // be selected more than once.
  787. ggml_tensor* new_seqs = seqs;
  788. // ggml_get_rows and ggml_set only work with floats ...
  789. new_seqs->type = GGML_TYPE_F32;
  790. ggml_tensor* new_scores = scores;
  791. if (step_nr > start_step) {
  792. // (B, S), (B) -> (B, S)
  793. new_seqs = ggml_get_rows(ctx, seqs, beam_indices);
  794. new_scores = ggml_get_rows(ctx, new_scores, beam_indices);
  795. }
  796. // new_seqs[:, step_nr + 1] = next_tokens
  797. gf = ggml_build_forward(ggml_set_1d_inplace(ctx, new_seqs, next_tokens, new_seqs->nb[0] * (step_nr + 1)));
  798. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  799. ggml_detach(new_seqs);
  800. new_seqs->type = GGML_TYPE_I32;
  801. gf = ggml_build_forward(ggml_set_1d_inplace(ctx, new_scores, next_scores, new_scores->nb[0] * (step_nr + 1)));
  802. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  803. ggml_detach(new_scores);
  804. // TODO the old seqs and score buffers could be reused for next step
  805. seqs = new_seqs;
  806. scores = new_scores;
  807. }
  808. end_of_beam_search:
  809. // Ensure that hypotheses are sorted by decreasing scores before returning.
  810. std::sort(
  811. finished_searches.begin(),
  812. finished_searches.end(),
  813. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  814. );
  815. // For now just return the best sequence
  816. // TODO: return structured output
  817. *output_seq = *(finished_searches[0].seq);
  818. return finished_searches[0].score;
  819. }