fairseq2.cpp 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "tracy/Tracy.hpp"
  10. #include "fairseq2.h"
  11. #include "ggml.h"
  12. ggml_tensor* ggml_detach(ggml_tensor* a) {
  13. a->op = GGML_OP_NONE;
  14. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  15. return a;
  16. }
  17. #define DEBUG_MEM_USAGE 1
  18. void printf_mem_usage(ggml_context* ctx, std::string name) {
  19. #if DEBUG_MEM_USAGE
  20. double mb = 1024.0 * 1024.0;
  21. printf(
  22. "ctx %s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  23. name.c_str(),
  24. ggml_used_mem(ctx) / mb,
  25. ggml_get_mem_size(ctx) / mb
  26. );
  27. #endif
  28. }
  29. #define SWAP(x, y) \
  30. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  31. /// allocate the fairseq2 model and hyperparameters
  32. extern "C" fairseq2_model* fairseq2_model_alloc() {
  33. // pre-allocate some memory to write hyperparameters and tensors pointers
  34. auto* model = new fairseq2_model;
  35. model->tensors_ctx = nullptr;
  36. return model;
  37. }
  38. extern "C" void fairseq2_kv_cache_alloc(const fairseq2_model& model, int beam_size, int max_seq_len) {
  39. // Note: we only allocate the cache for the decoder attention.
  40. // For encoder attention since we compute it all at once,
  41. // the allocation is delayed to the first forward pass, to not over allocate.
  42. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  43. auto self_attn_glob = "text_decoder.*self_attn.k_proj.weight";
  44. ggml_tensor* self_attn_mask = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, max_seq_len, max_seq_len);
  45. self_attn_mask = ggml_diag_mask_inf_inplace(model.ctx, self_attn_mask, 0);
  46. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  47. for (auto named_tensor : model.tensors) {
  48. const std::string& name = named_tensor.first;
  49. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  50. continue;
  51. // create a cache entry without the ".k_proj.weight" suffix
  52. const std::string& shortname = name.substr(0, name.size() - 14);
  53. KeyValueTensor& kv = model.kv_cache[shortname];
  54. kv.step_nr = 0;
  55. if (::fnmatch(self_attn_glob, name.c_str(), 0) == FNM_NOMATCH) {
  56. // enc_dec_attn
  57. // the tensors will be allocated during the first forward
  58. continue;
  59. }
  60. // self_attn
  61. ggml_tensor* k_proj = named_tensor.second;
  62. int model_dim = k_proj->ne[0];
  63. kv.full_k = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
  64. kv.full_v = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
  65. kv.self_attn_mask = self_attn_mask;
  66. ggml_format_name(kv.full_k, "%s.k_cache", shortname.c_str());
  67. ggml_format_name(kv.full_v, "%s.v_cache", shortname.c_str());
  68. }
  69. }
  70. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  71. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  72. model.kv_cache.clear();
  73. }
  74. bool has_kv_cache(const fairseq2_model& model) {
  75. return model.kv_cache.size() > 0;
  76. }
  77. // copy k and v to kv cache
  78. // kv.full_k[step_nr] = k;
  79. // kv.full_v[step_nr] = v;
  80. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  81. KeyValueTensor& kv = model.kv_cache[prefix];
  82. GGML_ASSERT(kv.full_k != nullptr); // key not found !
  83. int step_nr = kv.step_nr;
  84. ggml_tensor* full_k = kv.full_k;
  85. ggml_tensor* full_v = kv.full_v;
  86. // (N, S_kv, K_proj)
  87. GGML_ASSERT((*k)->ne[1] == 1); // TODO I think we could handle adding a full prefix sequence
  88. ggml_tensor* updated_k = ggml_set_2d_inplace(model.ctx, full_k, *k, full_k->nb[2], full_k->nb[1] * step_nr);
  89. ggml_tensor* updated_v = ggml_set_2d_inplace(model.ctx, full_v, *v, full_v->nb[2], full_v->nb[1] * step_nr);
  90. *k = ggml_slice(model.ctx, updated_k, 1, 0, step_nr + 1);
  91. *v = ggml_slice(model.ctx, updated_v, 1, 0, step_nr + 1);
  92. ggml_format_name(*k, "%s (step=%d)", full_k->name, step_nr);
  93. ggml_format_name(*v, "%s (step=%d)", full_v->name, step_nr);
  94. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  95. // we return the Sq slice of the (Sq, Sk) attention mask
  96. *self_attn_mask = ggml_slice(
  97. model.ctx,
  98. ggml_slice(model.ctx, kv.self_attn_mask, 0, 0, step_nr + 1),
  99. 1,
  100. step_nr,
  101. step_nr + 1
  102. );
  103. kv.step_nr = step_nr + 1;
  104. }
  105. // variant of ggml_get_rows that allows for a with more than 2 dims.
  106. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  107. int flattened = 0;
  108. GGML_ASSERT(a->n_dims <= 3);
  109. if (a->n_dims == 3) {
  110. flattened = a->ne[0];
  111. a = ggml_flatten_1d(ctx, a, 0);
  112. }
  113. a = ggml_get_rows(ctx, a, b);
  114. if (flattened) {
  115. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  116. }
  117. return a;
  118. }
  119. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  120. if (kv.full_k != nullptr) {
  121. ggml_detach(kv.full_k);
  122. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  123. ggml_build_forward_expand(gf, kv.full_k);
  124. }
  125. if (kv.full_v != nullptr) {
  126. ggml_detach(kv.full_v);
  127. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  128. ggml_build_forward_expand(gf, kv.full_v);
  129. }
  130. }
  131. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  132. for (auto& named_kv : model.kv_cache) {
  133. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  134. }
  135. }
  136. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  137. const std::int64_t* data = &model.layer_config.at(name);
  138. double val = *(const double*)data;
  139. return val;
  140. }
  141. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  142. return model_layer_config_d(model, std::string(name));
  143. }
  144. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  145. return model.layer_config.at(std::string(name));
  146. }
  147. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  148. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  149. delete model;
  150. }
  151. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  152. model->ctx = ctx;
  153. }
  154. extern "C" std::string* std_string_alloc(char* c_str) {
  155. return new std::string(c_str);
  156. }
  157. extern "C" void std_string_free(std::string* str) {
  158. delete str;
  159. }
  160. bool has_layer(fairseq2_model& model, const std::string& name) {
  161. return model.tensors.find(name) != model.tensors.end();
  162. }
  163. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  164. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  165. // `b` has shape (B, 1, D).
  166. // if `a` is (D_out, D), then we do one matmul for the full batch.
  167. b = ggml_flatten_1d(ctx, b, 1);
  168. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  169. }
  170. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  171. // not sure what's the best way to compute this with BLAS
  172. return ggml_mul_mat(ctx, a, b); // (d_out)
  173. }
  174. extern "C" ggml_tensor* Linear_forward(
  175. fairseq2_model& model,
  176. const std::string &prefix,
  177. ggml_tensor* input // (d_in)
  178. ) {
  179. // Note: for now we assumed un-batched input
  180. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  181. GGML_ASSERT(weight != nullptr);
  182. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  183. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  184. if (bias == nullptr) return out;
  185. return ggml_add_inplace(model.ctx, out, bias);
  186. }
  187. extern "C" ggml_tensor* LayerNorm_forward(
  188. fairseq2_model& model,
  189. const std::string &prefix,
  190. ggml_tensor* input
  191. ) {
  192. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  193. GGML_ASSERT(weight != nullptr);
  194. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  195. GGML_ASSERT(bias != nullptr);
  196. auto ctx = model.ctx;
  197. double eps = model_layer_config_d(model, prefix + ".eps");
  198. input = ggml_norm(ctx, input, /*eps*/eps);
  199. return ggml_add_inplace(
  200. ctx,
  201. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  202. ggml_repeat(ctx, bias, input)
  203. );
  204. }
  205. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  206. fairseq2_model& model,
  207. const std::string& prefix,
  208. ggml_tensor* seqs
  209. ) {
  210. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  211. // inner_activation = ReLu // TODO: allow other activation
  212. seqs = ggml_relu_inplace(model.ctx, seqs);
  213. if (has_layer(model, prefix + ".inner_layer_norm")) {
  214. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  215. }
  216. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  217. return seqs;
  218. }
  219. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  220. fairseq2_model& model,
  221. const std::string& prefix,
  222. ggml_tensor* seqs
  223. ) {
  224. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  225. seqs = ggml_silu(model.ctx, seqs);
  226. if (has_layer(model, prefix + ".inner_layer_norm")) {
  227. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  228. }
  229. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  230. return seqs;
  231. }
  232. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  233. int n_dims = x->n_dims;
  234. GGML_ASSERT(dim >= 0);
  235. GGML_ASSERT(dim < n_dims);
  236. GGML_ASSERT(ggml_is_contiguous(x));
  237. // Nothing to do
  238. if (dim == n_dims - 1) return x;
  239. if (n_dims == 2) {
  240. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  241. } else if (n_dims == 3) {
  242. if (dim == 0) {
  243. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  244. } else { // dim == 1
  245. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  246. }
  247. } else { // n_dims == 4
  248. if (dim == 0) {
  249. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  250. } else if (dim == 1) {
  251. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  252. } else { // dim == 2
  253. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  254. }
  255. }
  256. }
  257. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  258. int n_dims = x->n_dims;
  259. GGML_ASSERT(dim >= 0);
  260. GGML_ASSERT(dim < n_dims);
  261. GGML_ASSERT(n_dims < 4);
  262. GGML_ASSERT(x->ne[dim] % num_el == 0);
  263. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  264. if (n_dims == 1) {
  265. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  266. } else if (n_dims == 2) {
  267. if (dim == 0) {
  268. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  269. } else { // dim == 1
  270. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  271. }
  272. } else { // (n_dims == 3)
  273. if (dim == 0) {
  274. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  275. } else if (dim == 1) {
  276. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  277. } else { // dim == 2
  278. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  279. }
  280. }
  281. }
  282. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  283. // (B, S, dim) -> (B, S, H, H_dim)
  284. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  285. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  286. x = ggml_cont(ctx, x);
  287. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  288. return x;
  289. }
  290. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  291. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  292. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  293. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  294. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  295. v = ggml_cont(ctx, v);
  296. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  297. return v;
  298. }
  299. // flash_attn doesn't work for cross attention because it assumes Q <= K
  300. // and it seems to yield slightly different scores than expected, and thus a different beam search
  301. # define UNITY_FLASH_ATTN 0
  302. extern "C" ggml_tensor* MultiheadAttention_forward(
  303. fairseq2_model& model,
  304. const std::string &prefix,
  305. ggml_tensor* queries, // (slen, d_in)
  306. ggml_tensor* keys, // (klen, d_in)
  307. ggml_tensor* values, // (klen, d_out)
  308. ggml_tensor* attn_mask // (klen, slen)
  309. ) {
  310. int model_dim = queries->ne[0];
  311. int num_heads = model.layer_config.at(prefix + ".num_heads");
  312. int head_dim = model_dim / num_heads;
  313. GGML_ASSERT(model_dim % num_heads == 0);
  314. ggml_context* ctx = model.ctx;
  315. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  316. ggml_set_name(q, "q");
  317. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  318. ggml_tensor *k, *v;
  319. if (!has_kv_cache(model)) {
  320. k = Linear_forward(model, prefix + ".k_proj", keys);
  321. ggml_set_name(k, "k");
  322. v = Linear_forward(model, prefix + ".v_proj", values);
  323. ggml_set_name(v, "v");
  324. } else {
  325. bool encoder_decoder_attn = keys == values && keys != queries;
  326. if (encoder_decoder_attn) {
  327. // The K and V tensors of an encoder-decoder attention (i.e. the
  328. // projected encoder outputs) remain static during evaluation.
  329. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  330. if (kv_cache.step_nr == 0) {
  331. k = Linear_forward(model, prefix + ".k_proj", keys);
  332. v = Linear_forward(model, prefix + ".v_proj", values);
  333. // TODO: encoder_padding_mask
  334. // Note we are only storing a pointer to the buffer, not the full graph
  335. kv_cache.full_k = ggml_detach(ggml_dup_inplace(ctx, k));
  336. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  337. kv_cache.full_v = ggml_detach(ggml_dup_inplace(ctx, v));
  338. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  339. kv_cache.step_nr = keys->ne[1];
  340. } else {
  341. k = kv_cache.full_k;
  342. v = kv_cache.full_v;
  343. // This is a cache collision. TODO: fairseq2_kv_cache_reset
  344. GGML_ASSERT(keys->ne[1] == k->ne[1]);
  345. GGML_ASSERT(values->ne[1] == v->ne[1]);
  346. }
  347. } else { // self attention
  348. // (1, K) -> (N, 1, K_proj)
  349. k = Linear_forward(model, prefix + ".k_proj", keys);
  350. ggml_set_name(k, "k");
  351. // (1, V) -> (N, 1, V_proj)
  352. v = Linear_forward(model, prefix + ".v_proj", values);
  353. ggml_set_name(v, "v");
  354. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  355. }
  356. }
  357. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  358. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  359. v = ggml_cont(ctx, v);
  360. #if UNITY_FLASH_ATTN
  361. // For flash_attn, we assume either no masks, or triangular masks.
  362. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  363. ggml_set_name(attn, "attn");
  364. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  365. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  366. #else
  367. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  368. ggml_tensor* qk = mul_mat(ctx, k, q);
  369. ggml_set_name(qk, "qk");
  370. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  371. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  372. qk = ggml_scale_inplace(ctx, qk, qk_scale);
  373. ggml_set_name(qk, "qk_scaled");
  374. // TODO: Should we replace this by ggml_diag_mask_inf ?
  375. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  376. // TODO: upgrade qk to float32 if needed
  377. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  378. ggml_set_name(attn_weights, "attn_weights");
  379. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  380. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  381. ggml_set_name(attn, "attn");
  382. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  383. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  384. #endif // UNITY_FLASH_ATTN
  385. attn = ggml_cont(ctx, attn);
  386. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  387. // out -> (B, S, d_out)
  388. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  389. ggml_set_name(out, "out");
  390. return out;
  391. }
  392. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  393. fairseq2_model& model,
  394. const std::string& prefix,
  395. ggml_tensor* seqs,
  396. ggml_tensor* padding_mask
  397. ) {
  398. ggml_context* ctx = model.ctx;
  399. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  400. // _forward_self_attn(seqs, padding_mask)
  401. auto residual = seqs;
  402. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  403. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  404. // TODO: add padding_mask to MultiheadAttention_forward
  405. GGML_ASSERT(padding_mask == nullptr);
  406. seqs = MultiheadAttention_forward(
  407. model,
  408. prefix + ".self_attn",
  409. seqs,
  410. seqs,
  411. seqs,
  412. /*attn_mask=*/nullptr
  413. );
  414. if (has_layer(model, prefix + ".self_attn_norm"))
  415. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  416. seqs = ggml_add_inplace(ctx, seqs, residual);
  417. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  418. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  419. // _forward_ffn(seqs)
  420. residual = seqs;
  421. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  422. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  423. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  424. // TODO: if self.residual_scale is not None:
  425. // residual = self.residual_scale * residual
  426. seqs = ggml_add_inplace(ctx, seqs, residual);
  427. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  428. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  429. return seqs;
  430. }
  431. extern "C" ggml_tensor* WaveformToFbank_forward(
  432. fairseq2_model& model,
  433. const std::string &prefix,
  434. ggml_tensor* waveform
  435. ) {
  436. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  437. ggml_context* ctx = model.ctx;
  438. knf::MelBanksOptions mel_opts{};
  439. mel_opts.num_bins = 80;
  440. knf::FrameExtractionOptions frame_opts{};
  441. frame_opts.samp_freq = 16000;
  442. knf::FbankOptions opts{};
  443. opts.frame_opts = frame_opts;
  444. opts.mel_opts = mel_opts;
  445. std::vector<float_t> signal_frame{};
  446. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  447. ggml_tensor* output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames);
  448. knf::FbankComputer native_(opts);
  449. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  450. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  451. signal_frame.resize(0);
  452. // Extract the frame from the waveform tensor.
  453. knf::ExtractWindow(
  454. /*sample_offset=*/0,
  455. (float *)(waveform->data),
  456. waveform->ne[0],
  457. frame_nr,
  458. frame_opts,
  459. window_fn_,
  460. &signal_frame);
  461. native_.Compute(
  462. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  463. }
  464. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  465. output = ggml_norm(ctx, output, 1e-5);
  466. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  467. if (output->ne[1] % 2 == 1) {
  468. ggml_tensor* remove_last = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, output->ne[1]-1);
  469. for (int i = 0; i < output->ne[1]-1; ++i) {
  470. ((int32_t *) remove_last->data)[i] = i;
  471. }
  472. output = ggml_get_rows(ctx, output, remove_last);
  473. }
  474. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  475. return output;
  476. }
  477. // TODO: Check if it's possible to merge with standard MHA
  478. extern "C" ggml_tensor* RelativePositionMHA_forward(
  479. fairseq2_model& model,
  480. const std::string& prefix,
  481. ggml_tensor* seqs
  482. ) {
  483. ggml_context* ctx = model.ctx;
  484. ggml_tensor* residual = seqs;
  485. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  486. // self_attn: qkv
  487. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  488. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  489. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  490. // self_attn: rel_pos SDPA
  491. int32_t S = seqs->ne[1];
  492. int32_t H = 16; // TODO: Make this configurable
  493. int32_t n_ctx = 4096;
  494. int32_t K_h = seqs->ne[0] / H;
  495. int32_t start_index = n_ctx - S;
  496. int32_t end_index = n_ctx + S - 1;
  497. int num_indices = end_index - start_index;
  498. ggml_tensor* rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
  499. for (int i = 0; i < num_indices; i++) {
  500. ((int32_t *)rows->data)[i] = start_index + i;
  501. }
  502. // self_attn: load pos_enc weights & compute_r
  503. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  504. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  505. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  506. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  507. r = ggml_dup(ctx, ggml_permute(ctx,
  508. ggml_cpy(ctx,
  509. r,
  510. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S*2-1)),
  511. 0, 2, 1, 3));
  512. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  513. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  514. // self_attn: Permute QKV
  515. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx,
  516. ggml_cpy(ctx,
  517. Qcur,
  518. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  519. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  520. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx,
  521. ggml_cpy(ctx,
  522. Kcur,
  523. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  524. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  525. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx,
  526. ggml_cpy(ctx,
  527. Vcur,
  528. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  529. 1, 2, 0, 3)); // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  530. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  531. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  532. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  533. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  534. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  535. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  536. ggml_tensor* pad = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1);
  537. pad = ggml_set_f32(pad, 0.0);
  538. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  539. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  540. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  541. // discard the first set of positive positions
  542. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  543. // shifts each row by an extra step
  544. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  545. // Discard positions used for shift.
  546. bd = ggml_slice(ctx, bd, 0, 0, S);
  547. // self_attn: compute attn / weights
  548. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  549. ggml_tensor* attn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  550. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  551. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  552. attn_weights = ggml_soft_max(ctx, attn_weights);
  553. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  554. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  555. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  556. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  557. attn_out = ggml_add_inplace(
  558. ctx,
  559. attn_out,
  560. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  561. );
  562. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  563. return attn_out;
  564. }
  565. extern "C" ggml_tensor* ConvModule_forward(
  566. fairseq2_model& model,
  567. const std::string& prefix,
  568. ggml_tensor* seqs
  569. ) {
  570. ggml_context* ctx = model.ctx;
  571. ggml_tensor* residual = seqs;
  572. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  573. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  574. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  575. // conv: GLU
  576. seqs = ggml_glu(ctx, seqs);
  577. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  578. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  579. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, 15, 1);
  580. // conv: Custom implementation of batch norm
  581. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  582. // conv: SiLU actvation
  583. seqs = ggml_silu_inplace(ctx, seqs);
  584. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  585. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  586. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  587. // conv: + residual
  588. seqs = ggml_add_inplace(ctx, seqs, residual);
  589. return seqs;
  590. }
  591. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  592. fairseq2_model& model,
  593. const std::string& prefix,
  594. ggml_tensor* seqs,
  595. ggml_tensor* padding_mask
  596. ) {
  597. ggml_context* ctx = model.ctx;
  598. ggml_tensor* ffn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  599. ggml_set_f32(ffn_scale, 0.5f);
  600. ggml_tensor* residual = seqs;
  601. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  602. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  603. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  604. seqs = ggml_add_inplace(ctx, seqs, residual);
  605. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  606. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  607. residual = seqs;
  608. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  609. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  610. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  611. seqs = ggml_add_inplace(ctx, seqs, residual);
  612. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  613. return seqs;
  614. }
  615. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  616. fairseq2_model& model,
  617. const std::string& prefix,
  618. ggml_tensor* seqs,
  619. ggml_tensor* padding_mask
  620. ) {
  621. ggml_context* ctx = model.ctx;
  622. seqs = WaveformToFbank_forward(model, prefix, seqs);
  623. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  624. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  625. int layer_idx = 0;
  626. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  627. while (has_layer(model, layer_name)) {
  628. seqs = StandardConformerEncoderLayer_forward(
  629. model, layer_name, seqs, padding_mask
  630. );
  631. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  632. layer_idx += 1;
  633. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  634. }
  635. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  636. ggml_tensor* residual = seqs;
  637. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  638. seqs = ggml_relu_inplace(ctx, seqs);
  639. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  640. ggml_tensor* ffn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  641. ggml_set_f32(ffn_scale, 0.5f);
  642. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  643. seqs = ggml_add_inplace(ctx, seqs, residual);
  644. layer_idx = 0;
  645. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  646. while (has_layer(model, layer_name)) {
  647. seqs = StandardConformerEncoderAdaptorLayer_forward(
  648. model, layer_name, seqs, padding_mask
  649. );
  650. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  651. layer_idx += 1;
  652. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  653. }
  654. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  655. return seqs;
  656. }
  657. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  658. fairseq2_model& model,
  659. const std::string& prefix,
  660. ggml_tensor* seqs,
  661. ggml_tensor* padding_mask
  662. ) {
  663. ggml_context* ctx = model.ctx;
  664. ggml_tensor* residual = seqs;
  665. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  666. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  667. residual = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1);
  668. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  669. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  670. residual = ggml_glu(ctx, residual);
  671. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  672. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  673. seqs = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1);
  674. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  675. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  676. seqs = ggml_glu(ctx, seqs);
  677. seqs = MultiheadAttention_forward(
  678. model,
  679. prefix + ".self_attn",
  680. seqs,
  681. seqs,
  682. seqs,
  683. /*attention masks=*/nullptr
  684. );
  685. seqs = ggml_add_inplace(ctx, seqs, residual);
  686. residual = seqs;
  687. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  688. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  689. seqs = ggml_add_inplace(ctx, seqs, residual);
  690. return seqs;
  691. }
  692. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  693. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  694. ggml_tensor* ggml_slice(
  695. struct ggml_context * ctx,
  696. struct ggml_tensor * a,
  697. int axis,
  698. int64_t start,
  699. int64_t end
  700. ) {
  701. int64_t ne[4];
  702. std::copy(a->ne, a->ne + 4, ne);
  703. if (axis < 0) axis = a->n_dims + axis;
  704. if (start < 0) start = ne[axis] + start;
  705. if (end <= 0) end = ne[axis] + end;
  706. GGML_ASSERT(0 <= start);
  707. GGML_ASSERT(start < end);
  708. GGML_ASSERT(end <= ne[axis]);
  709. ne[axis] = end - start;
  710. size_t offset = a->nb[axis] * start;
  711. size_t* nb = a->nb;
  712. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  713. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  714. result->n_dims = a->n_dims;
  715. return result;
  716. }
  717. ggml_tensor* ggml_select(
  718. struct ggml_context * ctx,
  719. struct ggml_tensor * a,
  720. int axis,
  721. int64_t index
  722. ) {
  723. int64_t ne[GGML_MAX_DIMS];
  724. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  725. if (axis < 0) axis = a->n_dims + axis;
  726. if (index < 0) index = ne[axis] + index;
  727. GGML_ASSERT(0 <= index);
  728. GGML_ASSERT(index < ne[axis]);
  729. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  730. size_t offset = a->nb[axis] * index;
  731. size_t* nb = a->nb;
  732. GGML_ASSERT(GGML_MAX_DIMS == 4);
  733. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  734. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  735. result->n_dims = a->n_dims - 1;
  736. return result;
  737. }
  738. // Inplace computation of PositionalEmbedding
  739. extern "C" ggml_tensor* PositionalEmbedding_forward(
  740. fairseq2_model& model,
  741. const std::string& prefix,
  742. ggml_tensor* embeds
  743. ) {
  744. // This only work with the simple pos encoders
  745. int seq_len = embeds->ne[1];
  746. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  747. int start_step = 0;
  748. if (has_kv_cache(model)) {
  749. start_step = model.kv_cache[prefix].step_nr++;
  750. }
  751. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  752. return ggml_add(model.ctx, embeds, pos_embeds);
  753. }
  754. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  755. fairseq2_model& model,
  756. const std::string& prefix,
  757. ggml_tensor* seqs
  758. ) {
  759. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  760. ggml_context* ctx = model.ctx;
  761. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  762. GGML_ASSERT(embed_weights != nullptr);
  763. ggml_tensor* embeds;
  764. if (seqs->n_dims == 1) {
  765. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  766. } else {
  767. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  768. ggml_tensor* flat_seqs = seqs;
  769. if (!ggml_is_contiguous(seqs)) {
  770. flat_seqs->type = GGML_TYPE_F32;
  771. flat_seqs = ggml_cont(ctx, flat_seqs);
  772. }
  773. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  774. flat_seqs->type = GGML_TYPE_I32;
  775. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  776. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  777. embeds->n_dims = seqs->n_dims + 1;
  778. }
  779. // padding mask ?
  780. // padding_mask = to_padding_mask(embeds, seq_lens)
  781. if (has_layer(model, prefix + ".pos_encoder")) {
  782. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  783. }
  784. if (has_layer(model, prefix + ".layer_norm")) {
  785. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  786. }
  787. return embeds;
  788. }
  789. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  790. fairseq2_model& model,
  791. const std::string& prefix,
  792. ggml_tensor* seqs,
  793. ggml_tensor* padding_mask
  794. ) {
  795. int layer_idx = 0;
  796. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  797. while (has_layer(model, layer_name)) {
  798. seqs = StandardTransformerEncoderLayer_forward(
  799. model, layer_name, seqs, padding_mask
  800. );
  801. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  802. layer_idx += 1;
  803. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  804. }
  805. if (has_layer(model, prefix + ".layer_norm"))
  806. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  807. return seqs;
  808. }
  809. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  810. fairseq2_model& model,
  811. const std::string& prefix,
  812. ggml_tensor* seqs,
  813. ggml_tensor* self_attn_mask,
  814. ggml_tensor* encoder_output,
  815. ggml_tensor* encoder_padding_mask
  816. ) {
  817. ggml_context* ctx = model.ctx;
  818. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  819. // _forward_self_attn(seqs, padding_mask)
  820. auto residual = seqs;
  821. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  822. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  823. seqs = MultiheadAttention_forward(
  824. model,
  825. prefix + ".self_attn",
  826. seqs,
  827. seqs,
  828. seqs,
  829. /*attn_mask=*/self_attn_mask
  830. );
  831. if (has_layer(model, prefix + ".self_attn_norm"))
  832. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  833. seqs = ggml_add_inplace(ctx, seqs, residual);
  834. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  835. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  836. // _forward_encoder_decoder_attn
  837. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  838. // `encoder_output` must be `None` for decoder-only attention.
  839. GGML_ASSERT(encoder_output == nullptr);
  840. return seqs;
  841. }
  842. // `encoder_output` must not be `None` for encoder-decoder attention.
  843. GGML_ASSERT(encoder_output != nullptr);
  844. residual = seqs;
  845. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  846. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  847. seqs = MultiheadAttention_forward(
  848. model,
  849. prefix + ".encoder_decoder_attn",
  850. seqs,
  851. encoder_output,
  852. encoder_output,
  853. /*attention masks=*/encoder_padding_mask
  854. );
  855. seqs = ggml_add_inplace(ctx, seqs, residual);
  856. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  857. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  858. // _forward_ffn(seqs)
  859. residual = seqs;
  860. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  861. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  862. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  863. // TODO:
  864. // if self.residual_scale is not None:
  865. // residual = self.residual_scale * residual
  866. seqs = ggml_add_inplace(ctx, seqs, residual);
  867. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  868. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  869. return seqs;
  870. }
  871. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  872. auto seq_len = seqs->ne[1];
  873. // TODO: allow other ggml_type
  874. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  875. return ggml_diag_mask_inf(ctx, mask, 0);
  876. }
  877. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  878. fairseq2_model& model,
  879. const std::string& prefix,
  880. ggml_tensor* seqs,
  881. ggml_tensor* padding_mask,
  882. ggml_tensor* encoder_output,
  883. ggml_tensor* encoder_padding_mask
  884. ) {
  885. int layer_idx = 0;
  886. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  887. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  888. while (has_layer(model, layer_name)) {
  889. seqs = StandardTransformerDecoderLayer_forward(
  890. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  891. );
  892. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  893. layer_idx += 1;
  894. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  895. }
  896. if (has_layer(model, prefix + ".layer_norm"))
  897. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  898. return seqs;
  899. }
  900. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  901. auto opts = job.opts;
  902. int max_seq_len = -1;
  903. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  904. max_seq_len = opts.hard_max_seq_len;
  905. } else {
  906. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  907. }
  908. if (opts.min_seq_len > max_seq_len) {
  909. printf(
  910. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  911. opts.min_seq_len,
  912. max_seq_len
  913. );
  914. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  915. }
  916. int prefix_seq_len = job.prefix_seq->ne[0];
  917. if (prefix_seq_len >= max_seq_len) {
  918. printf(
  919. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  920. prefix_seq_len,
  921. max_seq_len
  922. );
  923. GGML_ASSERT(prefix_seq_len < max_seq_len);
  924. }
  925. return max_seq_len;
  926. }
  927. void _fan_out_encoder_output(
  928. ggml_context* ctx,
  929. ggml_tensor** encoder_output_out,
  930. ggml_tensor** encoder_padding_mask_out,
  931. int beam_size
  932. ) {
  933. // (S_enc, M)
  934. ggml_tensor* encoder_output = *encoder_output_out;
  935. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  936. // (B, S_enc, M)
  937. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  938. // (S_enc, M) -> (B, S_enc, M)
  939. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  940. // (S_enc) -> (B, S_enc)
  941. if (encoder_padding_mask != nullptr) {
  942. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  943. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  944. }
  945. }
  946. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  947. // TODO: this isn't the most precise way of doing this
  948. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  949. }
  950. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  951. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  952. ggml_type true_type = x->type;
  953. x->type = GGML_TYPE_F32;
  954. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  955. y->type = true_type;
  956. return y;
  957. }
  958. extern "C" void _bootstrap_seqs_and_scores(
  959. fairseq2_model& model,
  960. const SequenceGeneratorJob& job,
  961. ggml_tensor* full_seqs,
  962. ggml_tensor* scores,
  963. ggml_tensor* encoder_output,
  964. ggml_tensor* encoder_padding_mask
  965. ) {
  966. ZoneScoped;
  967. int prefix_seq_len = job.prefix_seq->ne[0];
  968. int max_seq_len = scores->ne[0];
  969. int beam_size = scores->ne[1];
  970. GGML_ASSERT(prefix_seq_len > 0);
  971. if (prefix_seq_len == 1)
  972. return;
  973. ggml_context* ctx = model.ctx;
  974. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  975. full_seqs->type = GGML_TYPE_F32;
  976. job.prefix_seq->type = GGML_TYPE_F32;
  977. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  978. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  979. // We have to bootstrap the model with the already fanned-out encoder
  980. // output to correctly initialize its incremental state.
  981. // Note: we don't start decoding the last prefix token just yet.
  982. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  983. seqs->type = GGML_TYPE_I32;
  984. // Bootstrap the model state with prefix sequence.
  985. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  986. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  987. model,
  988. "text_decoder",
  989. seqs,
  990. /*padding_mask*/ nullptr,
  991. encoder_output,
  992. encoder_padding_mask
  993. );
  994. // TODO state_bag.increment_step(prefix_seq_len - 1)
  995. // logits, lprobs: (N, S_pfx - 1, V)
  996. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  997. int vocab_size = logits->ne[0];
  998. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  999. ggml_cgraph gf = ggml_build_forward(lprobs);
  1000. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  1001. full_seqs->type = GGML_TYPE_I32;
  1002. job.prefix_seq->type = GGML_TYPE_I32;
  1003. // Fetch scores of next steps from "lprobs"
  1004. float p_score = 0;
  1005. for (int i = 1; i < prefix_seq_len; ++i) {
  1006. int p = ggml_get_i32_1d(job.prefix_seq, i);
  1007. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1008. for (int b = 0; b < beam_size; ++b) {
  1009. // scores: (N, S)
  1010. // Note: First step (e.g. BOS)'s score is always 0.
  1011. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1012. }
  1013. }
  1014. }
  1015. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1016. int topk(
  1017. ggml_tensor* lprobs, // (B, V)
  1018. std::int64_t k,
  1019. ggml_tensor* candidate_indices
  1020. ) {
  1021. ZoneNamed(topk, true);
  1022. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1023. // `beam_size` of these which don't predict EOS to continue with.
  1024. // (N, 2 x B)
  1025. // `vocab_size` - 1 to never select PAD.
  1026. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1027. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1028. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1029. };
  1030. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1031. auto cand = (std::int32_t*)candidate_indices->data;
  1032. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1033. return K;
  1034. }
  1035. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1036. ZoneNamed(tweak_lprobs, true);
  1037. std::size_t beam_size = job.opts.beam_size;
  1038. std::size_t eos_idx = job.eos_idx;
  1039. // Do not allow EOS before reaching the minimum sequence length.
  1040. if (step_nr < job.opts.min_seq_len) {
  1041. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1042. for (size_t i = 0; i < beam_size; ++i)
  1043. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1044. }
  1045. // If we have reached the maximum length, force the last step to be EOS.
  1046. if (step_nr == max_seq_len - 2) {
  1047. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1048. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1049. for (size_t b = 0; b < beam_size; ++b) {
  1050. size_t t = 0;
  1051. for (t = 0; t < eos_idx; ++t)
  1052. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1053. for (t = eos_idx + 1; t < vocab_size; ++t)
  1054. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1055. }
  1056. }
  1057. // Never allow PAD.
  1058. std::size_t pad_idx = job.pad_idx;
  1059. for (size_t i = 0; i < beam_size; ++i)
  1060. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1061. // Apply UNK penalty.
  1062. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1063. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1064. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1065. for (size_t i = 0; i < beam_size; ++i)
  1066. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1067. }
  1068. }
  1069. /// Copies the sequence and scores of a given candidate beam.
  1070. void _finalize_hypothesis(
  1071. const SequenceGeneratorJob& job,
  1072. ggml_context* ctx,
  1073. int step_nr,
  1074. std::int32_t beam,
  1075. std::int32_t token,
  1076. float eos_score,
  1077. ggml_tensor* seqs, // (beam_size, seq_len)
  1078. ggml_tensor* scores, // (beam_size, seq_len)
  1079. Hypothesis* hypothesis
  1080. ) {
  1081. ZoneNamed(_finalize_hypothesis, true);
  1082. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1083. hypothesis->seq = seq;
  1084. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1085. hypothesis->step_scores = step_scores;
  1086. auto tok = (std::int32_t*)seq->data;
  1087. for (int i = 0; i < step_nr + 1; ++i) {
  1088. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1089. }
  1090. tok[step_nr + 1] = token;
  1091. // Convert from cumulative to per-step scores.
  1092. auto sc = (float*)step_scores->data;
  1093. float last_score = eos_score;
  1094. for (int i = step_nr; i >= 0; --i) {
  1095. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1096. sc[i + 1] = last_score - sc0;
  1097. last_score = sc0;
  1098. }
  1099. sc[0] = 0;
  1100. if (job.opts.normalize_scores)
  1101. // Skip first EOS since it is always 0 and skews normalization.
  1102. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1103. hypothesis->score = eos_score;
  1104. }
  1105. // Uses ggml_context to store any object.
  1106. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1107. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1108. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1109. return ggml_init({
  1110. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1111. /*.mem_buffer =*/ buffer.data(),
  1112. /*.no_alloc =*/ false,
  1113. });
  1114. }
  1115. /// Generates a translation for a single sequence
  1116. // TODO: clean ups
  1117. // * replace manual tensor tweaking with ggml_set_*d (a ggml_set_slice could be useful)
  1118. extern "C" Hypothesis* generate_sequence(
  1119. fairseq2_model& model,
  1120. const SequenceGeneratorJob& job,
  1121. ggml_tensor* encoder_output,
  1122. ggml_tensor* encoder_padding_mask,
  1123. ggml_context* result_ctx
  1124. ) {
  1125. ZoneScoped;
  1126. std::vector<uint8_t> local_bufs[3] = {
  1127. std::vector<uint8_t>(256 * 1024 * 1024), // step_ctx
  1128. std::vector<uint8_t>(256 * 1024 * 1024), // next_step_ctx
  1129. std::vector<uint8_t>(256 * 1024 * 1024) // search_ctx
  1130. };
  1131. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1132. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1133. size_t vocab_size = embed->ne[1];
  1134. std::size_t beam_size = job.opts.beam_size;
  1135. int source_seq_len = encoder_output->ne[1];
  1136. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1137. ggml_context* original_ctx = model.ctx;
  1138. model.ctx = search_ctx;
  1139. fairseq2_kv_cache_alloc(model, beam_size, max_seq_len);
  1140. // (S_enc, M) -> (B, S_enc, M)
  1141. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1142. // Allocate results in the context provided by the caller.
  1143. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1144. Hypothesis* finished_searches = finished_searches_begin;
  1145. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1146. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1147. // Initialize buffers. (B, S)
  1148. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1149. ggml_set_i32(seqs, 0);
  1150. ggml_set_name(seqs, "seqs_0");
  1151. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1152. ggml_set_name(scores, "scores_0");
  1153. ggml_set_f32(scores, 0.0);
  1154. _bootstrap_seqs_and_scores(
  1155. model, job, seqs, scores, encoder_output, encoder_padding_mask
  1156. );
  1157. int prefix_seq_len = job.prefix_seq->ne[0];
  1158. int start_step = prefix_seq_len - 1;
  1159. // Holds the indices of beams (a beam can occur more than once) that we
  1160. // should continue with in the next step.
  1161. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1162. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1163. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1164. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1165. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1166. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1167. ((int32_t *)(candidate_indices->data))[i] = i;
  1168. printf_mem_usage(search_ctx, "search_ctx");
  1169. ggml_context* step_ctx = ctx_from_buffer(local_bufs[0]);
  1170. ggml_context* next_step_ctx = nullptr;
  1171. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1172. model.ctx = step_ctx;
  1173. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1174. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1175. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1176. model,
  1177. "text_decoder",
  1178. decoder_input,
  1179. nullptr, // We never generate PAD.
  1180. encoder_output,
  1181. encoder_padding_mask
  1182. ); // (B, 1, D)
  1183. // Just look at the last token.
  1184. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1185. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1186. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1187. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1188. // TODO: use ggml properly compute the tweaks
  1189. ggml_cgraph gf = ggml_build_forward(lprobs);
  1190. // printf("beam search step %d. Graph.n_nodes: %d\n", step_nr, gf.n_nodes);
  1191. ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
  1192. ggml_detach(lprobs);
  1193. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1194. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1195. if (step_nr == start_step) {
  1196. // At the initial step, all hypotheses are equally likely, so we use
  1197. // only the first beam.
  1198. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1199. lprobs = ggml_cont(step_ctx, lprobs);
  1200. // The first step always indicates the beginning of the sequence and has no score.
  1201. if (step_nr > 0) {
  1202. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1203. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1204. }
  1205. } else {
  1206. // Make probabilities contain cumulative scores for each hypothesis.
  1207. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1208. }
  1209. gf = ggml_build_forward(lprobs);
  1210. ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
  1211. // Determine (beam, token) candidates for the next step.
  1212. // (N, 2 x B)
  1213. std::int64_t K = topk(
  1214. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1215. );
  1216. std::size_t ongoing_beams = 0;
  1217. for (std::int32_t i = 0; i < K; ++i) {
  1218. ZoneNamed(beam_search_step, true);
  1219. int c = ggml_get_f32_1d(candidate_indices, i);
  1220. std::int32_t beam = c / vocab_size;
  1221. std::int32_t token = c % vocab_size;
  1222. float tok_score = ggml_get_f32_1d(lprobs, c);
  1223. // Detect beams that reached the minimum length and that end with an EOS.
  1224. bool eos = token == job.eos_idx;
  1225. eos &= tok_score != -INFINITY;
  1226. if (eos) {
  1227. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches++);
  1228. if (finished_searches == finished_searches_end)
  1229. goto end_of_beam_search;
  1230. continue;
  1231. }
  1232. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1233. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1234. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1235. ongoing_beams += 1;
  1236. if (ongoing_beams >= beam_size) break;
  1237. }
  1238. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1239. // be selected more than once.
  1240. ggml_tensor* new_seqs = seqs;
  1241. ggml_tensor* new_scores = scores;
  1242. if (step_nr > start_step) {
  1243. // (B, S), (B) -> (B, S)
  1244. // ggml_get_rows and ggml_set only work with floats ...
  1245. new_seqs->type = GGML_TYPE_F32;
  1246. new_seqs = ggml_get_rows(search_ctx, seqs, beam_indices);
  1247. new_scores = ggml_get_rows(search_ctx, scores, beam_indices);
  1248. ggml_cgraph gf_reorder = ggml_build_forward(new_seqs);
  1249. ggml_build_forward_expand(&gf_reorder, new_scores);
  1250. next_step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1251. reorder_kv_cache(model, next_step_ctx, &gf_reorder, beam_indices);
  1252. ggml_graph_compute_with_ctx(next_step_ctx, &gf_reorder, 1);
  1253. ggml_detach(new_seqs);
  1254. ggml_detach(new_scores);
  1255. new_seqs->type = GGML_TYPE_I32;
  1256. printf_mem_usage(search_ctx, "search_ctx");
  1257. SWAP(step_ctx, next_step_ctx);
  1258. }
  1259. // new_seqs[:, step_nr + 1] = next_tokens
  1260. // new_scores[:, step_nr + 1] = next_scores
  1261. for (std::size_t i = 0; i < beam_size; ++i) {
  1262. ((std::int32_t*)new_seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1263. ((float*)new_scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1264. }
  1265. // TODO the old seqs and score buffers could be reused for next step
  1266. seqs = new_seqs;
  1267. scores = new_scores;
  1268. printf_mem_usage(step_ctx, "step_ctx");
  1269. }
  1270. end_of_beam_search:
  1271. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1272. std::sort(
  1273. finished_searches_begin,
  1274. finished_searches_end,
  1275. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1276. );
  1277. fairseq2_kv_cache_reset(model);
  1278. model.ctx = original_ctx;
  1279. return finished_searches_begin;
  1280. }
  1281. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1282. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1283. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1284. ggml_set_i32_1d(result[0].seq, 0, 314);
  1285. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1286. ggml_set_i32_1d(result[1].seq, 0, 421);
  1287. return result;
  1288. }
  1289. // SPM tokenizer
  1290. // original implementation:
  1291. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1292. struct llm_symbol {
  1293. using index = int;
  1294. index prev;
  1295. index next;
  1296. const char * text;
  1297. size_t n;
  1298. llama_vocab::id id;
  1299. };
  1300. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1301. static size_t utf8_len(char src) {
  1302. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1303. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1304. return lookup[highbits];
  1305. }
  1306. struct llm_bigram_spm {
  1307. struct comparator {
  1308. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1309. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1310. }
  1311. };
  1312. using queue_storage = std::vector<llm_bigram_spm>;
  1313. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1314. llm_symbol::index left;
  1315. llm_symbol::index right;
  1316. float score;
  1317. size_t size;
  1318. llama_vocab::id id;
  1319. };
  1320. struct llm_tokenizer_spm {
  1321. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1322. void tokenize(const std::string& input_text, ggml_tensor& output) {
  1323. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1324. // split string into utf8 chars
  1325. int index = 0;
  1326. size_t offs = 0;
  1327. // This is kind of annoying, but needed because with SPM,
  1328. // characters following a space have a special meaning.
  1329. // And the algorithm rely on substrings to do the lookups.
  1330. std::string text = input_text;
  1331. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1332. if (need_extra_space) text = " " + text;
  1333. while (offs < text.size()) {
  1334. size_t len = utf8_len(text[offs]);
  1335. size_t n = std::min(len, text.size() - offs);
  1336. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1337. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1338. llm_symbol sym = {
  1339. /*prev*/ index - 1,
  1340. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1341. /*text*/ text.c_str() + offs,
  1342. /*n*/ n,
  1343. /*id*/ id
  1344. };
  1345. offs += n;
  1346. index++;
  1347. symbols.emplace_back(sym);
  1348. }
  1349. // seed the work queue with all possible 2-character tokens.
  1350. for (size_t i = 1; i < symbols.size(); ++i) {
  1351. try_add_bigram(i - 1, i);
  1352. }
  1353. // keep substituting the highest frequency pairs for as long as we can.
  1354. while (!work_queue.empty()) {
  1355. auto bigram = work_queue.top();
  1356. work_queue.pop();
  1357. auto & left_sym = symbols[bigram.left];
  1358. auto & right_sym = symbols[bigram.right];
  1359. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1360. // if one of the symbols already got merged, skip it.
  1361. if (
  1362. left_sym.n == 0
  1363. || right_sym.n == 0
  1364. || left_sym.n + right_sym.n != bigram.size
  1365. ) continue;
  1366. // merge the right sym into the left one
  1367. left_sym.n += right_sym.n;
  1368. left_sym.id = bigram.id;
  1369. right_sym.n = 0;
  1370. // remove the right sym from the chain
  1371. left_sym.next = right_sym.next;
  1372. if (right_sym.next >= 0) {
  1373. symbols[right_sym.next].prev = bigram.left;
  1374. }
  1375. // find more substitutions
  1376. try_add_bigram(left_sym.prev, bigram.left);
  1377. try_add_bigram(bigram.left, left_sym.next);
  1378. }
  1379. llama_vocab::id* out = (llama_vocab::id*)output.data;
  1380. int out_step = sizeof(llama_vocab::id) / output.nb[0];
  1381. int num_tokens = 0;
  1382. for (int i = 0; i > -1; i = symbols[i].next) {
  1383. llm_symbol& symbol = symbols[i];
  1384. *(out + num_tokens * out_step) = symbol.id;
  1385. num_tokens += 1;
  1386. }
  1387. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1388. num_tokens += 1;
  1389. output.ne[0] = num_tokens;
  1390. }
  1391. private:
  1392. void try_add_bigram(int left, int right) {
  1393. if (left == -1 || right == -1) {
  1394. return;
  1395. }
  1396. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1397. auto token = vocab.token_to_id.find(text);
  1398. if (token == vocab.token_to_id.end()) {
  1399. return;
  1400. }
  1401. llama_vocab::id id = token->second;
  1402. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1403. return;
  1404. }
  1405. const auto& tok_data = vocab.id_to_token[id];
  1406. llm_bigram_spm bigram = {
  1407. /*left */ left,
  1408. /*right*/ right,
  1409. /*score*/ tok_data.score,
  1410. /*size */ text.size(),
  1411. /*id */ id
  1412. };
  1413. work_queue.push(bigram);
  1414. }
  1415. const llama_vocab& vocab;
  1416. std::vector<llm_symbol> symbols;
  1417. llm_bigram_spm::queue work_queue;
  1418. };
  1419. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
  1420. llm_tokenizer_spm spm = {model->vocab};
  1421. spm.tokenize(std::string(text), out);
  1422. }
  1423. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1424. int eos_idx = model->vocab.token_to_id["</s>"];
  1425. int sent_len = tokens->ne[0];
  1426. std::size_t written = 0;
  1427. for (int i = 0; i < sent_len; ++i) {
  1428. int id = ggml_get_i32_1d(tokens, i);
  1429. // Don't print the EOS token but only if it appear at the end.
  1430. if (i == sent_len - 1 && eos_idx == id) break;
  1431. std::string token = model->vocab.id_to_token.at(id).text;
  1432. // Skip the first space outputted.
  1433. auto begin = token.begin();
  1434. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1435. std::copy(begin, token.end(), out);
  1436. std::size_t n = token.end() - begin;
  1437. written += n;
  1438. out += n;
  1439. }
  1440. *out = '0';
  1441. return written;
  1442. }