fairseq2.cpp 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. #include "ggml-alloc.h"
  12. #include <numeric>
  13. ggml_tensor* ggml_detach(ggml_tensor* a) {
  14. a->op = GGML_OP_NONE;
  15. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  16. return a;
  17. }
  18. // generate_sequence uses ggml_context and ggml_allocr to reuse memory buffers across steps.
  19. // This can lead to dangling pointers, which don't segfault, but instead read garbage data.
  20. // Enabling this flag allows to explictly reset memory buffers, making it more explicit
  21. // when we read garbage data.
  22. // It also prints memory usage information, which is useful to
  23. #define DEBUG_MEM_USAGE 1
  24. size_t MB = 1024 * 1024;
  25. void printf_mem_usage(ggml_context* ctx, std::string name) {
  26. #if DEBUG_MEM_USAGE
  27. double mb = 1024.0 * 1024.0;
  28. printf(
  29. "%s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  30. name.c_str(),
  31. ggml_used_mem(ctx) / mb,
  32. ggml_get_mem_size(ctx) / mb
  33. );
  34. #endif
  35. }
  36. #define SWAP(x, y) \
  37. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  38. #define GGML_ASSERT_SHAPE(x, ne0, ne1, ne2, ne3) \
  39. GGML_ASSERT((ne0 == -1 || x->ne[0] == ne0) && (ne1 == -1 || x->ne[1] == ne1) && (ne2 == -1 || x->ne[2] == ne2) && (ne3 == -1 || x->ne[3] == ne3));
  40. /// allocate the fairseq2 model and hyperparameters
  41. extern "C" fairseq2_model* fairseq2_model_alloc() {
  42. // pre-allocate some memory to write hyperparameters and tensors pointers
  43. auto* model = new fairseq2_model;
  44. model->tensors_ctx = nullptr;
  45. return model;
  46. }
  47. extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_cache_ctx, int beam_size, int max_seq_len) {
  48. // Note: we only allocate the masks, proper kv cache allocation is delayed.
  49. GGML_ASSERT(kv_cache_ctx);
  50. GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx)); // We need to be able to alloc the kv_cache buffers
  51. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  52. FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
  53. self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
  54. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  55. for (auto named_tensor : model.tensors) {
  56. const std::string& name = named_tensor.first;
  57. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  58. continue;
  59. // create a cache entry without the ".k_proj.weight" suffix
  60. const std::string& shortname = name.substr(0, name.size() - 14);
  61. KeyValueTensor& kv = model.kv_cache[shortname];
  62. kv.step_nr = 0;
  63. kv.full_k = nullptr;
  64. kv.full_v = nullptr;
  65. kv.self_attn_mask = self_attn_mask;
  66. }
  67. }
  68. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  69. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  70. model.kv_cache.clear();
  71. }
  72. bool has_kv_cache(const fairseq2_model& model) {
  73. return model.kv_cache.size() > 0;
  74. }
  75. inline ggml_tensor* ggml_squeeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  76. int n_dims = x->n_dims;
  77. GGML_ASSERT(dim >= 0);
  78. GGML_ASSERT(dim < n_dims);
  79. GGML_ASSERT(x->ne[dim] == 1);
  80. return ggml_flatten_1d(ctx, x, dim);
  81. }
  82. inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  83. return ggml_unflatten_1d(ctx, x, dim, 1);
  84. }
  85. // copy k and v to kv cache
  86. // kv.full_k[step_nr] = k;
  87. // kv.full_v[step_nr] = v;
  88. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  89. KeyValueTensor& kv = model.kv_cache[prefix];
  90. int step_nr = kv.step_nr;
  91. ggml_context* ctx = model.ctx;
  92. // We need to force allocation here, otherwise the kv_cache buffers can be reused
  93. bool no_alloc_save = ggml_get_no_alloc(ctx);
  94. ggml_set_no_alloc(ctx, false);
  95. int n_steps = (*k)->ne[1];
  96. // printf("Prefix: %s n_steps: %d\n", prefix.c_str(), n_steps);
  97. int k_proj, batch_size;
  98. if (kv.full_k != nullptr) {
  99. // (N, S_kv, K_proj)
  100. k_proj = kv.full_k->ne[0];
  101. batch_size = kv.full_k->ne[2];
  102. ggml_detach(kv.full_k);
  103. ggml_detach(kv.full_v);
  104. kv.full_k = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_k, 1), ggml_unsqueeze(ctx, *k, 1)), 1);
  105. kv.full_v = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_v, 1), ggml_unsqueeze(ctx, *v, 1)), 1);
  106. } else {
  107. GGML_ASSERT(step_nr == 0);
  108. k_proj = (*k)->ne[0];
  109. batch_size = (*v)->ne[2];
  110. kv.full_k = ggml_dup(ctx, *k);
  111. kv.full_v = ggml_dup(ctx, *v);
  112. }
  113. *k = kv.full_k;
  114. *v = kv.full_v;
  115. ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
  116. ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
  117. step_nr += n_steps;
  118. // printf("Prefix: %s step_nr: %d\n", prefix.c_str(), step_nr);
  119. GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
  120. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  121. // we return the Sq slice of the (Sq, Sk) attention mask
  122. if (self_attn_mask != nullptr) {
  123. *self_attn_mask = ggml_slice(
  124. ctx, ggml_slice(ctx, kv.self_attn_mask, 0, 0, step_nr),
  125. 1, step_nr - 1, step_nr
  126. );
  127. }
  128. kv.step_nr = step_nr;
  129. ggml_set_no_alloc(ctx, no_alloc_save);
  130. }
  131. // variant of ggml_get_rows that allows for a with more than 2 dims.
  132. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  133. int flattened = 0;
  134. GGML_ASSERT(a->n_dims <= 3);
  135. if (a->n_dims == 3) {
  136. flattened = a->ne[0];
  137. a = ggml_flatten_1d(ctx, a, 0);
  138. }
  139. a = ggml_get_rows(ctx, a, b);
  140. if (flattened) {
  141. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  142. }
  143. return a;
  144. }
  145. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  146. // GGML_ASSERT(ctx == kv.full_k->con);
  147. if (kv.full_k != nullptr) {
  148. ggml_detach(kv.full_k);
  149. const char* name = kv.full_k->name;
  150. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  151. ggml_build_forward_expand(gf, kv.full_k);
  152. ggml_format_name(kv.full_k, "%s (sorted)", name);
  153. }
  154. if (kv.full_v != nullptr) {
  155. ggml_detach(kv.full_v);
  156. const char* name = kv.full_v->name;
  157. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  158. ggml_build_forward_expand(gf, kv.full_v);
  159. ggml_format_name(kv.full_v, "%s (sorted)", name);
  160. }
  161. }
  162. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  163. auto self_attn_glob = "*.self_attn";
  164. for (auto& named_kv : model.kv_cache) {
  165. if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH) {
  166. continue;
  167. }
  168. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  169. }
  170. }
  171. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  172. const std::int64_t* data = &model.layer_config.at(name);
  173. double val = *(const double*)data;
  174. return val;
  175. }
  176. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  177. return model_layer_config_d(model, std::string(name));
  178. }
  179. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  180. return model.layer_config.at(std::string(name));
  181. }
  182. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  183. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  184. // delete model;
  185. }
  186. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  187. model->ctx = ctx;
  188. }
  189. extern "C" std::string* std_string_alloc(char* c_str) {
  190. return new std::string(c_str);
  191. }
  192. extern "C" void std_string_free(std::string* str) {
  193. delete str;
  194. }
  195. bool has_layer(fairseq2_model& model, const std::string& name) {
  196. return model.tensors.find(name) != model.tensors.end();
  197. }
  198. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  199. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  200. // `b` has shape (B, 1, D).
  201. // if `a` is (D_out, D), then we do one matmul for the full batch.
  202. b = ggml_flatten_1d(ctx, b, 1);
  203. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  204. }
  205. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  206. // not sure what's the best way to compute this with BLAS
  207. return ggml_mul_mat(ctx, a, b); // (d_out)
  208. }
  209. extern "C" ggml_tensor* Linear_forward(
  210. fairseq2_model& model,
  211. const std::string &prefix,
  212. ggml_tensor* input // (d_in)
  213. ) {
  214. // Note: for now we assumed un-batched input
  215. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  216. GGML_ASSERT(weight != nullptr);
  217. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  218. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  219. if (bias == nullptr) return out;
  220. return ggml_add(model.ctx, out, bias);
  221. }
  222. extern "C" ggml_tensor* LayerNorm_forward(
  223. fairseq2_model& model,
  224. const std::string &prefix,
  225. ggml_tensor* input
  226. ) {
  227. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  228. GGML_ASSERT(weight != nullptr);
  229. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  230. GGML_ASSERT(bias != nullptr);
  231. auto ctx = model.ctx;
  232. double eps = model_layer_config_d(model, prefix + ".eps");
  233. input = ggml_norm(ctx, input, /*eps*/eps);
  234. return ggml_add_inplace(
  235. ctx,
  236. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  237. ggml_repeat(ctx, bias, input)
  238. );
  239. }
  240. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  241. fairseq2_model& model,
  242. const std::string& prefix,
  243. ggml_tensor* seqs
  244. ) {
  245. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  246. // inner_activation = ReLu // TODO: allow other activation
  247. seqs = ggml_relu_inplace(model.ctx, seqs);
  248. if (has_layer(model, prefix + ".inner_layer_norm")) {
  249. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  250. }
  251. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  252. return seqs;
  253. }
  254. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  255. fairseq2_model& model,
  256. const std::string& prefix,
  257. ggml_tensor* seqs
  258. ) {
  259. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  260. seqs = ggml_silu(model.ctx, seqs);
  261. if (has_layer(model, prefix + ".inner_layer_norm")) {
  262. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  263. }
  264. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  265. return seqs;
  266. }
  267. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  268. int n_dims = x->n_dims;
  269. GGML_ASSERT(dim >= 0);
  270. GGML_ASSERT(dim < n_dims);
  271. GGML_ASSERT(ggml_is_contiguous(x));
  272. // Nothing to do
  273. if (dim == n_dims - 1) return x;
  274. if (n_dims == 2) {
  275. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  276. } else if (n_dims == 3) {
  277. if (dim == 0) {
  278. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  279. } else { // dim == 1
  280. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  281. }
  282. } else { // n_dims == 4
  283. if (dim == 0) {
  284. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  285. } else if (dim == 1) {
  286. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  287. } else { // dim == 2
  288. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  289. }
  290. }
  291. }
  292. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  293. int n_dims = x->n_dims;
  294. GGML_ASSERT(dim >= 0);
  295. GGML_ASSERT(dim < n_dims);
  296. GGML_ASSERT(n_dims < 4);
  297. GGML_ASSERT(x->ne[dim] % num_el == 0);
  298. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  299. if (n_dims == 1) {
  300. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  301. } else if (n_dims == 2) {
  302. if (dim == 0) {
  303. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  304. } else { // dim == 1
  305. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  306. }
  307. } else { // (n_dims == 3)
  308. if (dim == 0) {
  309. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  310. } else if (dim == 1) {
  311. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  312. } else { // dim == 2
  313. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  314. }
  315. }
  316. }
  317. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  318. // (B, S, dim) -> (B, S, H, H_dim)
  319. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  320. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  321. x = ggml_cont(ctx, x);
  322. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  323. return x;
  324. }
  325. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  326. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  327. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  328. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  329. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  330. v = ggml_cont(ctx, v);
  331. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  332. return v;
  333. }
  334. // flash_attn doesn't work for cross attention because it assumes Q <= K
  335. // and it seems to yield slightly different scores than expected, and thus a different beam search
  336. # define UNITY_FLASH_ATTN 0
  337. extern "C" ggml_tensor* MultiheadAttention_forward(
  338. fairseq2_model& model,
  339. const std::string &prefix,
  340. ggml_tensor* queries, // (slen, d_in)
  341. ggml_tensor* keys, // (klen, d_in)
  342. ggml_tensor* values, // (klen, d_out)
  343. ggml_tensor* attn_mask // (klen, slen)
  344. ) {
  345. int model_dim = queries->ne[0];
  346. int num_heads = model.layer_config.at(prefix + ".num_heads");
  347. int head_dim = model_dim / num_heads;
  348. GGML_ASSERT(model_dim % num_heads == 0);
  349. ggml_context* ctx = model.ctx;
  350. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  351. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  352. ggml_set_name(q, "q");
  353. ggml_tensor *k, *v;
  354. if (!has_kv_cache(model)) {
  355. k = Linear_forward(model, prefix + ".k_proj", keys);
  356. ggml_set_name(k, "k");
  357. v = Linear_forward(model, prefix + ".v_proj", values);
  358. ggml_set_name(v, "v");
  359. } else {
  360. bool encoder_decoder_attn = keys == values && keys != queries;
  361. if (encoder_decoder_attn) {
  362. // The K and V tensors of an encoder-decoder attention (i.e. the
  363. // projected encoder outputs) remain static during evaluation.
  364. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  365. if (kv_cache.step_nr == 0) {
  366. // If possible we use the ctx dedicated to kv_cache here,
  367. // because the enc dec attention is typically long lived.
  368. if (model.enc_kv_cache_ctx) model.ctx = model.enc_kv_cache_ctx;
  369. k = Linear_forward(model, prefix + ".k_proj", keys);
  370. ggml_set_name(k, "k");
  371. v = Linear_forward(model, prefix + ".v_proj", values);
  372. ggml_set_name(v, "v");
  373. // Note we are only storing a pointer to the buffer, not the full graph
  374. kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
  375. printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
  376. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  377. kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
  378. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  379. kv_cache.step_nr = keys->ne[1];
  380. model.ctx = ctx;
  381. } else {
  382. printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
  383. k = kv_cache.full_k;
  384. v = kv_cache.full_v;
  385. GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
  386. GGML_ASSERT(values->ne[1] == v->ne[1]); // cache content doesn't match the input sequence
  387. }
  388. } else { // self attention
  389. // (1, K) -> (N, 1, K_proj)
  390. for (auto& named_kv : model.kv_cache) {
  391. auto enc_dec_attn_glob = "*.encoder_decoder_attn";
  392. if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
  393. printf("HERE BEFORE CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
  394. if(named_kv.second.full_k != nullptr)
  395. printf("HERE BEFORE CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
  396. }
  397. }
  398. k = Linear_forward(model, prefix + ".k_proj", keys);
  399. for (auto& named_kv : model.kv_cache) {
  400. auto enc_dec_attn_glob = "*.encoder_decoder_attn";
  401. if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
  402. printf("HERE AFTER CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
  403. if(named_kv.second.full_k != nullptr)
  404. printf("HERE AFTER CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
  405. }
  406. }
  407. ggml_set_name(k, "k");
  408. // (1, V) -> (N, 1, V_proj)
  409. v = Linear_forward(model, prefix + ".v_proj", values);
  410. ggml_set_name(v, "v");
  411. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  412. }
  413. }
  414. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  415. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  416. v = ggml_cont(ctx, v);
  417. #if UNITY_FLASH_ATTN
  418. // For flash_attn, we assume either no masks, or triangular masks.
  419. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  420. ggml_set_name(attn, "attn");
  421. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  422. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  423. #else
  424. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  425. ggml_tensor* qk = mul_mat(ctx, k, q);
  426. ggml_set_name(qk, "qk");
  427. FORCE_ALLOC(qk_scale, ctx, ggml_new_tensor_1d(ctx, qk->type, 1));
  428. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  429. qk = ggml_scale(ctx, qk, qk_scale);
  430. ggml_set_name(qk, "qk_scaled");
  431. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  432. // TODO: upgrade qk to float32 if needed
  433. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  434. ggml_set_name(attn_weights, "attn_weights");
  435. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  436. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  437. ggml_set_name(attn, "attn");
  438. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  439. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  440. #endif // UNITY_FLASH_ATTN
  441. attn = ggml_cont(ctx, attn);
  442. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  443. // out -> (B, S, d_out)
  444. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  445. ggml_set_name(out, "out");
  446. return out;
  447. }
  448. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  449. fairseq2_model& model,
  450. const std::string& prefix,
  451. ggml_tensor* seqs,
  452. ggml_tensor* padding_mask
  453. ) {
  454. ggml_context* ctx = model.ctx;
  455. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  456. // _forward_self_attn(seqs, padding_mask)
  457. auto residual = seqs;
  458. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  459. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  460. // TODO: add padding_mask to MultiheadAttention_forward
  461. GGML_ASSERT(padding_mask == nullptr);
  462. seqs = MultiheadAttention_forward(
  463. model,
  464. prefix + ".self_attn",
  465. seqs,
  466. seqs,
  467. seqs,
  468. /*attn_mask=*/nullptr
  469. );
  470. if (has_layer(model, prefix + ".self_attn_norm"))
  471. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  472. seqs = ggml_add_inplace(ctx, seqs, residual);
  473. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  474. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  475. // _forward_ffn(seqs)
  476. residual = seqs;
  477. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  478. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  479. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  480. // TODO: if self.residual_scale is not None:
  481. // residual = self.residual_scale * residual
  482. seqs = ggml_add_inplace(ctx, seqs, residual);
  483. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  484. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  485. return seqs;
  486. }
  487. extern "C" ggml_tensor* WaveformToFbank_forward(
  488. fairseq2_model& model,
  489. const std::string &prefix,
  490. ggml_tensor* waveform
  491. ) {
  492. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  493. ggml_context* ctx = model.ctx;
  494. knf::MelBanksOptions mel_opts{};
  495. mel_opts.num_bins = 80;
  496. knf::FrameExtractionOptions frame_opts{};
  497. frame_opts.samp_freq = 16000;
  498. knf::FbankOptions opts{};
  499. opts.frame_opts = frame_opts;
  500. opts.mel_opts = mel_opts;
  501. std::vector<float_t> signal_frame{};
  502. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  503. FORCE_ALLOC(output, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames));
  504. knf::FbankComputer native_(opts);
  505. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  506. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  507. signal_frame.resize(0);
  508. // Extract the frame from the waveform tensor.
  509. knf::ExtractWindow(
  510. /*sample_offset=*/0,
  511. (float *)(waveform->data),
  512. waveform->ne[0],
  513. frame_nr,
  514. frame_opts,
  515. window_fn_,
  516. &signal_frame);
  517. native_.Compute(
  518. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  519. }
  520. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  521. output = ggml_norm(ctx, output, 1e-5);
  522. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  523. if (output->ne[1] % 2 == 1) {
  524. output = ggml_dup(ctx, ggml_slice(ctx, output, 1, 0, output->ne[1]-1));
  525. }
  526. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  527. return output;
  528. }
  529. // TODO: Check if it's possible to merge with standard MHA
  530. extern "C" ggml_tensor* RelativePositionMHA_forward(
  531. fairseq2_model& model,
  532. const std::string& prefix,
  533. ggml_tensor* seqs
  534. ) {
  535. ggml_context* ctx = model.ctx;
  536. ggml_tensor* residual = seqs;
  537. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  538. // self_attn: qkv
  539. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  540. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  541. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  542. // self_attn: rel_pos SDPA
  543. int32_t S = seqs->ne[1];
  544. int32_t H = 16; // TODO: Make this configurable
  545. int32_t n_ctx = 4096;
  546. int32_t K_h = seqs->ne[0] / H;
  547. int32_t start_index = n_ctx - S;
  548. int32_t end_index = n_ctx + S - 1;
  549. int num_indices = end_index - start_index;
  550. FORCE_ALLOC(rows, ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices));
  551. for (int i = 0; i < num_indices; i++) {
  552. ((int32_t *)rows->data)[i] = start_index + i;
  553. }
  554. // self_attn: load pos_enc weights & compute_r
  555. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  556. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  557. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  558. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  559. r = ggml_dup(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, r, 0, K_h), 0, 2, 1, 3));
  560. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  561. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  562. // self_attn: Permute QKV
  563. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  564. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Qcur, 0, K_h), 0, 2, 1, 3));
  565. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  566. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Kcur, 0, K_h), 0, 2, 1, 3));
  567. // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  568. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Vcur, 0, K_h), 1, 2, 0, 3));
  569. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  570. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  571. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  572. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  573. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  574. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  575. FORCE_ALLOC(pad, ctx, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1));
  576. pad = ggml_set_f32(pad, 0.0);
  577. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  578. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  579. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  580. // discard the first set of positive positions
  581. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  582. // shifts each row by an extra step
  583. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  584. // Discard positions used for shift.
  585. bd = ggml_slice(ctx, bd, 0, 0, S);
  586. // self_attn: compute attn / weights
  587. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  588. FORCE_ALLOC(attn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  589. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  590. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  591. attn_weights = ggml_soft_max(ctx, attn_weights);
  592. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  593. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  594. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  595. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  596. attn_out = ggml_add_inplace(
  597. ctx,
  598. attn_out,
  599. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  600. );
  601. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  602. return attn_out;
  603. }
  604. extern "C" ggml_tensor* ConvModule_forward(
  605. fairseq2_model& model,
  606. const std::string& prefix,
  607. ggml_tensor* seqs
  608. ) {
  609. ggml_context* ctx = model.ctx;
  610. ggml_tensor* residual = seqs;
  611. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  612. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  613. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  614. // conv: GLU
  615. seqs = ggml_glu(ctx, seqs);
  616. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  617. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  618. int K = model.tensors[prefix + ".depthwise_conv.weight"]->ne[0];
  619. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, K / 2, 1, seqs->ne[1]);
  620. // conv: Custom implementation of batch norm
  621. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  622. // conv: SiLU actvation
  623. seqs = ggml_silu_inplace(ctx, seqs);
  624. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  625. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  626. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  627. // conv: + residual
  628. seqs = ggml_add_inplace(ctx, seqs, residual);
  629. return seqs;
  630. }
  631. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  632. fairseq2_model& model,
  633. const std::string& prefix,
  634. ggml_tensor* seqs,
  635. ggml_tensor* padding_mask
  636. ) {
  637. ggml_context* ctx = model.ctx;
  638. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  639. ggml_set_f32(ffn_scale, 0.5f);
  640. ggml_tensor* residual = seqs;
  641. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  642. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  643. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  644. seqs = ggml_add_inplace(ctx, seqs, residual);
  645. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  646. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  647. residual = seqs;
  648. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  649. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  650. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  651. seqs = ggml_add_inplace(ctx, seqs, residual);
  652. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  653. return seqs;
  654. }
  655. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  656. fairseq2_model& model,
  657. const std::string& prefix,
  658. ggml_tensor* seqs,
  659. ggml_tensor* padding_mask
  660. ) {
  661. ggml_context* ctx = model.ctx;
  662. seqs = WaveformToFbank_forward(model, prefix, seqs);
  663. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  664. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  665. int layer_idx = 0;
  666. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  667. while (has_layer(model, layer_name)) {
  668. seqs = StandardConformerEncoderLayer_forward(
  669. model, layer_name, seqs, padding_mask
  670. );
  671. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  672. layer_idx += 1;
  673. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  674. }
  675. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  676. ggml_tensor* residual = seqs;
  677. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  678. seqs = ggml_relu_inplace(ctx, seqs);
  679. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  680. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  681. ggml_set_f32(ffn_scale, 0.5f);
  682. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  683. seqs = ggml_add_inplace(ctx, seqs, residual);
  684. layer_idx = 0;
  685. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  686. while (has_layer(model, layer_name)) {
  687. seqs = StandardConformerEncoderAdaptorLayer_forward(
  688. model, layer_name, seqs, padding_mask
  689. );
  690. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  691. layer_idx += 1;
  692. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  693. }
  694. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  695. return seqs;
  696. }
  697. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  698. fairseq2_model& model,
  699. const std::string& prefix,
  700. ggml_tensor* seqs,
  701. ggml_tensor* padding_mask
  702. ) {
  703. ggml_context* ctx = model.ctx;
  704. ggml_tensor* residual = seqs;
  705. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  706. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  707. residual = ggml_conv_1d(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1, 1);
  708. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  709. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  710. residual = ggml_glu(ctx, residual);
  711. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  712. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  713. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1, 1);
  714. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  715. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  716. seqs = ggml_glu(ctx, seqs);
  717. seqs = MultiheadAttention_forward(
  718. model,
  719. prefix + ".self_attn",
  720. seqs,
  721. seqs,
  722. seqs,
  723. /*attention masks=*/nullptr
  724. );
  725. seqs = ggml_add_inplace(ctx, seqs, residual);
  726. residual = seqs;
  727. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  728. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  729. seqs = ggml_add_inplace(ctx, seqs, residual);
  730. return seqs;
  731. }
  732. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  733. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  734. ggml_tensor* ggml_slice(
  735. struct ggml_context * ctx,
  736. struct ggml_tensor * a,
  737. int axis,
  738. int64_t start,
  739. int64_t end
  740. ) {
  741. int64_t ne[4];
  742. std::copy(a->ne, a->ne + 4, ne);
  743. if (axis < 0) axis = a->n_dims + axis;
  744. if (start < 0) start = ne[axis] + start;
  745. if (end <= 0) end = ne[axis] + end;
  746. GGML_ASSERT(0 <= start);
  747. GGML_ASSERT(start < end);
  748. GGML_ASSERT(end <= ne[axis]);
  749. ne[axis] = end - start;
  750. size_t offset = a->nb[axis] * start;
  751. size_t* nb = a->nb;
  752. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  753. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  754. result->n_dims = a->n_dims;
  755. return result;
  756. }
  757. ggml_tensor* ggml_select(
  758. struct ggml_context * ctx,
  759. struct ggml_tensor * a,
  760. int axis,
  761. int64_t index
  762. ) {
  763. int64_t ne[GGML_MAX_DIMS];
  764. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  765. if (axis < 0) axis = a->n_dims + axis;
  766. if (index < 0) index = ne[axis] + index;
  767. GGML_ASSERT(0 <= index);
  768. GGML_ASSERT(index < ne[axis]);
  769. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  770. size_t offset = a->nb[axis] * index;
  771. size_t* nb = a->nb;
  772. GGML_ASSERT(GGML_MAX_DIMS == 4);
  773. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  774. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  775. result->n_dims = a->n_dims - 1;
  776. return result;
  777. }
  778. // Inplace computation of PositionalEmbedding
  779. extern "C" ggml_tensor* PositionalEmbedding_forward(
  780. fairseq2_model& model,
  781. const std::string& prefix,
  782. ggml_tensor* embeds
  783. ) {
  784. // This only work with the simple pos encoders
  785. int seq_len = embeds->ne[1];
  786. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  787. int start_step = 0;
  788. if (has_kv_cache(model)) {
  789. start_step = model.kv_cache[prefix].step_nr++;
  790. }
  791. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  792. return ggml_add(model.ctx, embeds, pos_embeds);
  793. }
  794. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  795. fairseq2_model& model,
  796. const std::string& prefix,
  797. ggml_tensor* seqs
  798. ) {
  799. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  800. ggml_context* ctx = model.ctx;
  801. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  802. GGML_ASSERT(embed_weights != nullptr);
  803. ggml_tensor* embeds;
  804. if (seqs->n_dims == 1) {
  805. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  806. } else {
  807. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  808. ggml_tensor* flat_seqs = seqs;
  809. if (!ggml_is_contiguous(seqs)) {
  810. flat_seqs = ggml_cont(ctx, flat_seqs);
  811. }
  812. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  813. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  814. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  815. embeds->n_dims = seqs->n_dims + 1;
  816. }
  817. // padding mask ?
  818. // padding_mask = to_padding_mask(embeds, seq_lens)
  819. if (has_layer(model, prefix + ".pos_encoder")) {
  820. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  821. }
  822. if (has_layer(model, prefix + ".layer_norm")) {
  823. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  824. }
  825. return embeds;
  826. }
  827. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  828. fairseq2_model& model,
  829. const std::string& prefix,
  830. ggml_tensor* seqs,
  831. ggml_tensor* padding_mask
  832. ) {
  833. int layer_idx = 0;
  834. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  835. while (has_layer(model, layer_name)) {
  836. seqs = StandardTransformerEncoderLayer_forward(
  837. model, layer_name, seqs, padding_mask
  838. );
  839. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  840. layer_idx += 1;
  841. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  842. }
  843. if (has_layer(model, prefix + ".layer_norm"))
  844. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  845. return seqs;
  846. }
  847. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  848. fairseq2_model& model,
  849. const std::string& prefix,
  850. ggml_tensor* seqs,
  851. ggml_tensor* self_attn_mask,
  852. ggml_tensor* encoder_output,
  853. ggml_tensor* encoder_padding_mask
  854. ) {
  855. ggml_context* ctx = model.ctx;
  856. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  857. // _forward_self_attn(seqs, padding_mask)
  858. auto residual = seqs;
  859. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  860. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  861. seqs = MultiheadAttention_forward(
  862. model,
  863. prefix + ".self_attn",
  864. seqs,
  865. seqs,
  866. seqs,
  867. /*attn_mask=*/self_attn_mask
  868. );
  869. if (has_layer(model, prefix + ".self_attn_norm"))
  870. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  871. seqs = ggml_add_inplace(ctx, seqs, residual);
  872. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  873. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  874. // _forward_encoder_decoder_attn
  875. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  876. // `encoder_output` must be `None` for decoder-only attention.
  877. GGML_ASSERT(encoder_output == nullptr);
  878. return seqs;
  879. }
  880. // `encoder_output` must not be `None` for encoder-decoder attention.
  881. GGML_ASSERT(encoder_output != nullptr);
  882. residual = seqs;
  883. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  884. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  885. seqs = MultiheadAttention_forward(
  886. model,
  887. prefix + ".encoder_decoder_attn",
  888. seqs,
  889. encoder_output,
  890. encoder_output,
  891. /*attention masks=*/encoder_padding_mask
  892. );
  893. seqs = ggml_add_inplace(ctx, seqs, residual);
  894. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  895. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  896. // _forward_ffn(seqs)
  897. residual = seqs;
  898. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  899. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  900. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  901. // TODO:
  902. // if self.residual_scale is not None:
  903. // residual = self.residual_scale * residual
  904. seqs = ggml_add_inplace(ctx, seqs, residual);
  905. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  906. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  907. return seqs;
  908. }
  909. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  910. auto seq_len = seqs->ne[1];
  911. // TODO: allow other ggml_type
  912. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  913. return ggml_diag_mask_inf(ctx, mask, 0);
  914. }
  915. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  916. fairseq2_model& model,
  917. const std::string& prefix,
  918. ggml_tensor* seqs,
  919. ggml_tensor* padding_mask,
  920. ggml_tensor* encoder_output,
  921. ggml_tensor* encoder_padding_mask
  922. ) {
  923. int layer_idx = 0;
  924. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  925. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  926. while (has_layer(model, layer_name)) {
  927. seqs = StandardTransformerDecoderLayer_forward(
  928. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  929. );
  930. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  931. layer_idx += 1;
  932. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  933. }
  934. if (has_layer(model, prefix + ".layer_norm"))
  935. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  936. return seqs;
  937. }
  938. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  939. auto opts = job.opts;
  940. int max_seq_len = -1;
  941. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  942. max_seq_len = opts.hard_max_seq_len;
  943. } else {
  944. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  945. }
  946. if (opts.min_seq_len > max_seq_len) {
  947. printf(
  948. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  949. opts.min_seq_len,
  950. max_seq_len
  951. );
  952. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  953. }
  954. int prefix_seq_len = job.prefix_seq->ne[0];
  955. if (prefix_seq_len >= max_seq_len) {
  956. printf(
  957. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  958. prefix_seq_len,
  959. max_seq_len
  960. );
  961. GGML_ASSERT(prefix_seq_len < max_seq_len);
  962. }
  963. return max_seq_len;
  964. }
  965. void _fan_out_encoder_output(
  966. ggml_context* ctx,
  967. ggml_tensor** encoder_output_out,
  968. ggml_tensor** encoder_padding_mask_out,
  969. int beam_size
  970. ) {
  971. // (S_enc, M)
  972. ggml_tensor* encoder_output = *encoder_output_out;
  973. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  974. // (B, S_enc, M)
  975. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  976. // (S_enc, M) -> (B, S_enc, M)
  977. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  978. // (S_enc) -> (B, S_enc)
  979. if (encoder_padding_mask != nullptr) {
  980. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  981. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  982. }
  983. }
  984. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  985. // TODO: this isn't the most precise way of doing this
  986. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  987. }
  988. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  989. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  990. ggml_type true_type = x->type;
  991. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  992. y->type = true_type;
  993. return y;
  994. }
  995. void _bootstrap_seqs_and_scores(
  996. fairseq2_model& model,
  997. const SequenceGeneratorJob& job,
  998. ggml_tensor* full_seqs,
  999. ggml_tensor* scores,
  1000. ggml_tensor* encoder_output,
  1001. ggml_tensor* encoder_padding_mask,
  1002. ggml_tensor* lid_scores,
  1003. int n_threads,
  1004. const std::vector<int>& lang_ids
  1005. ) {
  1006. // Returns LID score map
  1007. int prefix_seq_len = job.prefix_seq->ne[0];
  1008. int max_seq_len = scores->ne[0];
  1009. int beam_size = scores->ne[1];
  1010. GGML_ASSERT(prefix_seq_len > 0);
  1011. ggml_context* ctx = model.ctx;
  1012. if (prefix_seq_len == 1) {
  1013. // We only have one token in prefix, we won't compute decoding scores,
  1014. // we just need to copy the token to seqs.
  1015. // Note: it also means the enc_kv_cache will be populated later.
  1016. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  1017. ggml_set_i32(seqs, ggml_get_i32_1d(job.prefix_seq, 0));
  1018. return;
  1019. }
  1020. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  1021. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  1022. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  1023. // We have to bootstrap the model with the already fanned-out encoder
  1024. // output to correctly initialize its incremental state.
  1025. // Note: we don't start decoding the last prefix token just yet.
  1026. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  1027. // Bootstrap the model state with prefix sequence.
  1028. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  1029. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1030. model,
  1031. "text_decoder",
  1032. seqs,
  1033. /*padding_mask*/ nullptr,
  1034. encoder_output,
  1035. encoder_padding_mask
  1036. );
  1037. // logits, lprobs: (N, S_pfx - 1, V)
  1038. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  1039. int vocab_size = logits->ne[0];
  1040. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  1041. struct ggml_cgraph * gf = ggml_new_graph(ctx);
  1042. ggml_build_forward_expand(gf, lprobs);
  1043. ggml_graph_compute_with_ctx(ctx, gf, n_threads);
  1044. full_seqs->type = GGML_TYPE_I32;
  1045. job.prefix_seq->type = GGML_TYPE_I32;
  1046. // For LID
  1047. for (size_t i = 0; i < lang_ids.size(); ++i) {
  1048. ggml_set_f32_1d(lid_scores, i, std::exp(ggml_get_f32_1d(lprobs, lang_ids[i])));
  1049. }
  1050. // Fetch scores of next steps from "lprobs"
  1051. float p_score = 0;
  1052. for (int i = 1; i < prefix_seq_len; ++i) {
  1053. int p;
  1054. if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
  1055. // If tgt_lang is unk, use the most probable lang tag predicted by model
  1056. int max_value = std::numeric_limits<float>::min();
  1057. for (int j = 0; j < lang_ids.size(); j++) {
  1058. if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
  1059. max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
  1060. p = lang_ids[j];
  1061. }
  1062. }
  1063. } else {
  1064. p = ggml_get_i32_1d(job.prefix_seq, i);
  1065. }
  1066. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1067. for (int b = 0; b < beam_size; ++b) {
  1068. // scores: (N, S)
  1069. // Note: First step (e.g. BOS)'s score is always 0.
  1070. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1071. }
  1072. }
  1073. }
  1074. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1075. int topk(
  1076. ggml_tensor* lprobs, // (B, V)
  1077. std::int64_t k,
  1078. ggml_tensor* candidate_indices
  1079. ) {
  1080. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1081. // `beam_size` of these which don't predict EOS to continue with.
  1082. // (N, 2 x B)
  1083. // `vocab_size` - 1 to never select PAD.
  1084. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1085. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1086. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1087. };
  1088. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1089. auto cand = (std::int32_t*)candidate_indices->data;
  1090. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1091. return K;
  1092. }
  1093. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1094. std::size_t beam_size = job.opts.beam_size;
  1095. std::size_t eos_idx = job.eos_idx;
  1096. // Do not allow EOS before reaching the minimum sequence length.
  1097. if (step_nr < job.opts.min_seq_len) {
  1098. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1099. for (size_t i = 0; i < beam_size; ++i)
  1100. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1101. }
  1102. // If we have reached the maximum length, force the last step to be EOS.
  1103. if (step_nr == max_seq_len - 2) {
  1104. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1105. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1106. for (size_t b = 0; b < beam_size; ++b) {
  1107. size_t t = 0;
  1108. for (t = 0; t < eos_idx; ++t)
  1109. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1110. for (t = eos_idx + 1; t < vocab_size; ++t)
  1111. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1112. }
  1113. }
  1114. // Never allow PAD.
  1115. std::size_t pad_idx = job.pad_idx;
  1116. for (size_t i = 0; i < beam_size; ++i)
  1117. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1118. // Apply UNK penalty.
  1119. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1120. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1121. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1122. for (size_t i = 0; i < beam_size; ++i)
  1123. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1124. }
  1125. }
  1126. /// Copies the sequence and scores of a given candidate beam.
  1127. void _finalize_hypothesis(
  1128. const SequenceGeneratorJob& job,
  1129. ggml_context* ctx,
  1130. int step_nr,
  1131. std::int32_t beam,
  1132. std::int32_t token,
  1133. float eos_score,
  1134. ggml_tensor* seqs, // (beam_size, seq_len)
  1135. ggml_tensor* scores, // (beam_size, seq_len)
  1136. ggml_tensor* lid_scores,
  1137. Hypothesis* hypothesis
  1138. ) {
  1139. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1140. hypothesis->seq = seq;
  1141. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1142. hypothesis->step_scores = step_scores;
  1143. auto tok = (std::int32_t*)seq->data;
  1144. for (int i = 0; i < step_nr + 1; ++i) {
  1145. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1146. }
  1147. tok[step_nr + 1] = token;
  1148. // Convert from cumulative to per-step scores.
  1149. auto sc = (float*)step_scores->data;
  1150. float last_score = eos_score;
  1151. for (int i = step_nr; i >= 0; --i) {
  1152. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1153. sc[i + 1] = last_score - sc0;
  1154. last_score = sc0;
  1155. }
  1156. sc[0] = 0;
  1157. if (job.opts.normalize_scores)
  1158. // Skip first EOS since it is always 0 and skews normalization.
  1159. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1160. hypothesis->score = eos_score;
  1161. hypothesis->lid_scores = lid_scores;
  1162. }
  1163. // Uses ggml_context to store any object.
  1164. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1165. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1166. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1167. return ggml_init({
  1168. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1169. /*.mem_buffer =*/ buffer.data(),
  1170. /*.no_alloc =*/ false,
  1171. });
  1172. }
  1173. ggml_allocr* new_arena_allocr(std::vector<uint8_t>& buffer) {
  1174. return ggml_allocr_new(buffer.data(), buffer.capacity(), 8);
  1175. }
  1176. /// Generates a translation for a single sequence
  1177. /// The results Hypothesis are written inside `result_ctx`.
  1178. extern "C" Hypothesis* generate_sequence(
  1179. fairseq2_model& model,
  1180. const SequenceGeneratorJob& job,
  1181. ggml_tensor* encoder_output,
  1182. ggml_tensor* encoder_padding_mask,
  1183. ggml_context* result_ctx,
  1184. int n_threads
  1185. ) {
  1186. // Pre allocate memory buffers.
  1187. // * step_ctx: contains metadata for the model graph, as well as some explicit
  1188. // buffers for the lprobs tweaking.
  1189. // * prev_step_ctx: is an additional buffer because we need some results from previous steps,
  1190. // to compute next step. Notably self attention kv cache.
  1191. // * search_ctx contains tensors that should live for the full search,
  1192. // like encoder kv cache.
  1193. // * step_alloc contains buffer for the forward pass of the model.
  1194. // Split mem_mb into the different context we need to use.
  1195. int mem_mb = job.opts.mem_mb;
  1196. std::vector<uint8_t> local_bufs[4] = {
  1197. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // step_ctx
  1198. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // prev_step_ctx
  1199. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // search_ctx
  1200. std::vector<uint8_t>(mem_mb * MB * 1 / 10), // step_alloc
  1201. };
  1202. ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
  1203. std::vector<int> lang_ids;
  1204. if (model.hparams["multilingual"] != 0) {
  1205. for (const auto& kv : model.vocab.token_to_id) {
  1206. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  1207. lang_ids.push_back(kv.second);
  1208. }
  1209. }
  1210. std::sort(lang_ids.begin(), lang_ids.end());
  1211. }
  1212. std::cout << "model multilinguality: " << model.hparams["multilingual"] << " (langs)" << std::endl;
  1213. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1214. size_t vocab_size = embed->ne[1];
  1215. std::size_t beam_size = job.opts.beam_size;
  1216. ggml_detach(encoder_output);
  1217. int source_seq_len = encoder_output->ne[1];
  1218. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1219. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1220. ggml_context* original_ctx = model.ctx;
  1221. fairseq2_kv_cache_alloc(model, search_ctx, beam_size, max_seq_len);
  1222. // (S_enc, M) -> (B, S_enc, M)
  1223. model.ctx = search_ctx;
  1224. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1225. // Allocate results in the context provided by the caller.
  1226. ggml_set_no_alloc(result_ctx, false);
  1227. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1228. Hypothesis* finished_searches = finished_searches_begin;
  1229. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1230. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1231. // Initialize buffers. (B, S)
  1232. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1233. printf("Seqs dim: [%d %d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1234. ggml_set_i32(seqs, 0);
  1235. ggml_set_name(seqs, "seqs_0");
  1236. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1237. ggml_set_name(scores, "scores_0");
  1238. ggml_set_f32(scores, 0.0);
  1239. int prefix_seq_len = job.prefix_seq->ne[0];
  1240. int start_step = prefix_seq_len - 1;
  1241. ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step + 1) % 2]);
  1242. ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
  1243. GGML_ASSERT(step_ctx != search_ctx);
  1244. GGML_ASSERT(prev_step_ctx != step_ctx);
  1245. model.ctx = prev_step_ctx;
  1246. // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
  1247. model.enc_kv_cache_ctx = search_ctx;
  1248. ggml_tensor* lid_scores;
  1249. if (lang_ids.size()) {
  1250. lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
  1251. }
  1252. // Multilingual models: Bootstrap LID scores
  1253. _bootstrap_seqs_and_scores(
  1254. model, job, seqs, scores, encoder_output, encoder_padding_mask, lid_scores, n_threads, lang_ids
  1255. );
  1256. printf("Seqs dim after bootstrapping: [%d %d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1257. // Holds the indices of beams (a beam can occur more than once) that we
  1258. // should continue with in the next step.
  1259. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1260. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1261. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1262. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1263. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1264. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1265. ((int32_t *)(candidate_indices->data))[i] = i;
  1266. printf_mem_usage(search_ctx, "search_ctx");
  1267. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1268. model.ctx = step_ctx;
  1269. ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
  1270. float max_lprob;
  1271. int p;
  1272. if (step_nr == start_step) {
  1273. // Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
  1274. if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
  1275. float max_lprob = std::numeric_limits<float>::min();
  1276. for(int j = 0; j < lang_ids.size(); j++) {
  1277. auto val = ggml_get_f32_1d(lid_scores, j);
  1278. if (val > max_lprob) {
  1279. max_lprob = val;
  1280. p = lang_ids[j];
  1281. }
  1282. }
  1283. for (int k = 0; k < beam_size; k++) {
  1284. ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
  1285. }
  1286. }
  1287. }
  1288. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1289. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1290. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1291. model,
  1292. "text_decoder",
  1293. decoder_input,
  1294. nullptr, // We never generate PAD.
  1295. encoder_output,
  1296. encoder_padding_mask
  1297. ); // (B, 1, D)
  1298. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1299. // Force logits to be allocated in step_ctx, not in step_alloc.
  1300. ggml_set_no_alloc(step_ctx, false);
  1301. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1302. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1303. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1304. // TODO: use ggml properly compute the tweaks
  1305. struct ggml_cgraph * gf = ggml_new_graph(step_ctx);
  1306. ggml_build_forward_expand(gf, lprobs);
  1307. size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, gf);
  1308. GGML_UNUSED(fwd_mem);
  1309. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1310. ggml_detach(lprobs);
  1311. ggml_allocr_reset(step_alloc);
  1312. #if DEBUG_MEM_USAGE
  1313. printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf->n_nodes);
  1314. printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
  1315. std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
  1316. #endif
  1317. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1318. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1319. if (step_nr == start_step) {
  1320. // At the initial step, all hypotheses are equally likely, so we use
  1321. // only the first beam.
  1322. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1323. lprobs = ggml_cont(step_ctx, lprobs);
  1324. // The first step always indicates the beginning of the sequence and has no score.
  1325. if (step_nr > 0) {
  1326. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1327. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1328. }
  1329. } else {
  1330. // Make probabilities contain cumulative scores for each hypothesis.
  1331. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1332. }
  1333. ggml_build_forward_expand(gf, lprobs);
  1334. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1335. // Determine (beam, token) candidates for the next step.
  1336. // (N, 2 x B)
  1337. std::int64_t K = topk(
  1338. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1339. );
  1340. std::size_t ongoing_beams = 0;
  1341. for (std::int32_t i = 0; i < K; ++i) {
  1342. int c = ggml_get_f32_1d(candidate_indices, i);
  1343. std::int32_t beam = c / vocab_size;
  1344. std::int32_t token = c % vocab_size;
  1345. float tok_score = ggml_get_f32_1d(lprobs, c);
  1346. // Detect beams that reached the minimum length and that end with an EOS.
  1347. bool eos = token == job.eos_idx;
  1348. eos &= tok_score != -INFINITY;
  1349. if (eos) {
  1350. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, lid_scores, finished_searches++);
  1351. if (finished_searches == finished_searches_end)
  1352. goto end_of_beam_search;
  1353. continue;
  1354. }
  1355. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1356. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1357. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1358. if (model.hparams["multilingual"] == 0) {
  1359. printf("Token at top%d: %d (%s)\n", i, token, model.tgt_vocab.id_to_token.at(token).text.c_str());
  1360. } else {
  1361. printf("Token at top%d: %d (%s)\n", i, token, model.vocab.id_to_token.at(token).text.c_str());
  1362. }
  1363. // printf("Seqs dim: [%d %d %d], beam_indices: [%d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2], beam_indices->ne[0], beam_indices->ne[1]);
  1364. ongoing_beams += 1;
  1365. if (ongoing_beams >= beam_size) break;
  1366. }
  1367. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1368. // be selected more than once.
  1369. // (B, S), (B) -> (B, S)
  1370. // don't use allocr API, cause it might reuse a kv cache buffer several time.
  1371. ggml_set_no_alloc(step_ctx, false);
  1372. printf("Seqs dim before getting rows step %d: [%d %d %d]\n", step_nr, seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1373. ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
  1374. ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
  1375. struct ggml_cgraph * gf_reorder = ggml_new_graph(step_ctx);
  1376. ggml_build_forward_expand(gf_reorder, new_seqs);
  1377. ggml_build_forward_expand(gf_reorder, new_scores);
  1378. reorder_kv_cache(model, step_ctx, gf_reorder, beam_indices);
  1379. ggml_graph_compute_with_ctx(step_ctx, gf_reorder, n_threads);
  1380. seqs = ggml_detach(new_seqs);
  1381. printf("Seqs dim after detach step %d: [%d %d %d]\n", step_nr, seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1382. scores = ggml_detach(new_scores);
  1383. // seqs[:, step_nr + 1] = next_tokens
  1384. // scores[:, step_nr + 1] = next_scores
  1385. for (std::size_t i = 0; i < beam_size; ++i) {
  1386. ((std::int32_t*)seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1387. ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1388. }
  1389. printf_mem_usage(step_ctx, "step_ctx");
  1390. ggml_free(prev_step_ctx);
  1391. prev_step_ctx = step_ctx;
  1392. #if DEBUG_MEM_USAGE
  1393. std::fill(local_bufs[(step_nr + 1) % 2].begin(), local_bufs[(step_nr + 1) % 2].end(), 0xAA);
  1394. #endif
  1395. step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1396. }
  1397. end_of_beam_search:
  1398. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1399. std::sort(
  1400. finished_searches_begin,
  1401. finished_searches_end,
  1402. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1403. );
  1404. printf_mem_usage(search_ctx, "search_ctx");
  1405. // fairseq2_kv_cache_reset(model);
  1406. model.ctx = original_ctx;
  1407. return finished_searches_begin;
  1408. }
  1409. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1410. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1411. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1412. ggml_set_i32_1d(result[0].seq, 0, 314);
  1413. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1414. ggml_set_i32_1d(result[1].seq, 0, 421);
  1415. return result;
  1416. }
  1417. // SPM tokenizer
  1418. // original implementation:
  1419. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1420. struct llm_symbol {
  1421. using index = int;
  1422. index prev;
  1423. index next;
  1424. const char * text;
  1425. size_t n;
  1426. llama_vocab::id id;
  1427. };
  1428. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1429. static size_t utf8_len(char src) {
  1430. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1431. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1432. return lookup[highbits];
  1433. }
  1434. struct llm_bigram_spm {
  1435. struct comparator {
  1436. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1437. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1438. }
  1439. };
  1440. using queue_storage = std::vector<llm_bigram_spm>;
  1441. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1442. llm_symbol::index left;
  1443. llm_symbol::index right;
  1444. float score;
  1445. size_t size;
  1446. llama_vocab::id id;
  1447. };
  1448. struct llm_tokenizer_spm {
  1449. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1450. void tokenize(const std::string& input_text, ggml_tensor* output) {
  1451. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1452. // split string into utf8 chars
  1453. int index = 0;
  1454. size_t offs = 0;
  1455. // This is kind of annoying, but needed because with SPM,
  1456. // characters following a space have a special meaning.
  1457. // And the algorithm rely on substrings to do the lookups.
  1458. std::string text = input_text;
  1459. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1460. if (need_extra_space) text = " " + text;
  1461. while (offs < text.size()) {
  1462. size_t len = utf8_len(text[offs]);
  1463. size_t n = std::min(len, text.size() - offs);
  1464. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1465. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1466. llm_symbol sym = {
  1467. /*prev*/ index - 1,
  1468. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1469. /*text*/ text.c_str() + offs,
  1470. /*n*/ n,
  1471. /*id*/ id
  1472. };
  1473. offs += n;
  1474. index++;
  1475. symbols.emplace_back(sym);
  1476. }
  1477. // seed the work queue with all possible 2-character tokens.
  1478. for (size_t i = 1; i < symbols.size(); ++i) {
  1479. try_add_bigram(i - 1, i);
  1480. }
  1481. // keep substituting the highest frequency pairs for as long as we can.
  1482. while (!work_queue.empty()) {
  1483. auto bigram = work_queue.top();
  1484. work_queue.pop();
  1485. auto & left_sym = symbols[bigram.left];
  1486. auto & right_sym = symbols[bigram.right];
  1487. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1488. // if one of the symbols already got merged, skip it.
  1489. if (
  1490. left_sym.n == 0
  1491. || right_sym.n == 0
  1492. || left_sym.n + right_sym.n != bigram.size
  1493. ) continue;
  1494. // merge the right sym into the left one
  1495. left_sym.n += right_sym.n;
  1496. left_sym.id = bigram.id;
  1497. right_sym.n = 0;
  1498. // remove the right sym from the chain
  1499. left_sym.next = right_sym.next;
  1500. if (right_sym.next >= 0) {
  1501. symbols[right_sym.next].prev = bigram.left;
  1502. }
  1503. // find more substitutions
  1504. try_add_bigram(left_sym.prev, bigram.left);
  1505. try_add_bigram(bigram.left, left_sym.next);
  1506. }
  1507. llama_vocab::id* out = (llama_vocab::id*)output->data;
  1508. int out_step = sizeof(llama_vocab::id) / output->nb[0];
  1509. int num_tokens = 0;
  1510. for (int i = 0; i > -1; i = symbols[i].next) {
  1511. llm_symbol& symbol = symbols[i];
  1512. *(out + num_tokens * out_step) = symbol.id;
  1513. num_tokens += 1;
  1514. }
  1515. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1516. num_tokens += 1;
  1517. output->ne[0] = num_tokens;
  1518. }
  1519. private:
  1520. void try_add_bigram(int left, int right) {
  1521. if (left == -1 || right == -1) {
  1522. return;
  1523. }
  1524. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1525. auto token = vocab.token_to_id.find(text);
  1526. if (token == vocab.token_to_id.end()) {
  1527. return;
  1528. }
  1529. llama_vocab::id id = token->second;
  1530. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1531. return;
  1532. }
  1533. const auto& tok_data = vocab.id_to_token[id];
  1534. llm_bigram_spm bigram = {
  1535. /*left */ left,
  1536. /*right*/ right,
  1537. /*score*/ tok_data.score,
  1538. /*size */ text.size(),
  1539. /*id */ id
  1540. };
  1541. work_queue.push(bigram);
  1542. }
  1543. const llama_vocab& vocab;
  1544. std::vector<llm_symbol> symbols;
  1545. llm_bigram_spm::queue work_queue;
  1546. };
  1547. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
  1548. llm_tokenizer_spm spm = {model->vocab};
  1549. spm.tokenize(std::string(text), out);
  1550. }
  1551. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1552. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1553. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1554. int sent_len = tokens->ne[0];
  1555. std::size_t written = 0;
  1556. std::vector<llama_vocab::token_data> id_to_token = no_tgt_vocab ? model->vocab.id_to_token : model->tgt_vocab.id_to_token;
  1557. for (int i = 0; i < sent_len; ++i) {
  1558. int id = ggml_get_i32_1d(tokens, i);
  1559. // Don't print the EOS token but only if it appear at the end.
  1560. if (i == sent_len - 1 && eos_idx == id) break;
  1561. std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
  1562. // Skip the first space outputted.
  1563. auto begin = token.begin();
  1564. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1565. std::copy(begin, token.end(), out);
  1566. std::size_t n = token.end() - begin;
  1567. written += n;
  1568. out += n;
  1569. }
  1570. *out = '0';
  1571. return written;
  1572. }
  1573. // TODO: Unify with the above?
  1574. std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(
  1575. fairseq2_model* model,
  1576. ggml_tensor* tokens,
  1577. ggml_tensor* scores,
  1578. char* out) {
  1579. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1580. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1581. int sent_len = tokens->ne[0];
  1582. std::size_t written = 0;
  1583. std::vector<float> word_scores;
  1584. std::vector<float> subword_scores;
  1585. std::vector<std::string> result_text;
  1586. std::string curr_token = "";
  1587. for (int i = 0; i < sent_len; ++i) {
  1588. int id = ggml_get_i32_1d(tokens, i);
  1589. // Don't print the EOS token but only if it appear at the end.
  1590. if (i == sent_len - 1 && eos_idx == id) break;
  1591. std::string token = model->vocab.id_to_token.at(id).text;
  1592. float score = ggml_get_f32_1d(scores, i+2); // 2 is prefix size
  1593. if(token[0] == ' ') {
  1594. // reset word score
  1595. if(subword_scores.size() > 0) {
  1596. float avg = std::accumulate(subword_scores.begin(), subword_scores.end(), 0.0f) / subword_scores.size();
  1597. word_scores.push_back(avg);
  1598. subword_scores.clear();
  1599. result_text.push_back(curr_token);
  1600. }
  1601. curr_token = token.substr(1);
  1602. } else {
  1603. curr_token += token;
  1604. }
  1605. subword_scores.push_back(score);
  1606. // Skip the first space outputted.
  1607. auto begin = token.begin();
  1608. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1609. std::copy(begin, token.end(), out);
  1610. std::size_t n = token.end() - begin;
  1611. written += n;
  1612. out += n;
  1613. }
  1614. if(subword_scores.size() > 0) {
  1615. word_scores.push_back(*std::min_element(subword_scores.begin(), subword_scores.end()));
  1616. subword_scores.clear();
  1617. result_text.push_back(curr_token);
  1618. }
  1619. *out = '0';
  1620. return std::make_pair(result_text, word_scores);
  1621. }