fairseq2.cpp 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. #include "ggml-alloc.h"
  12. #include <numeric>
  13. ggml_tensor* ggml_detach(ggml_tensor* a) {
  14. a->op = GGML_OP_NONE;
  15. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  16. return a;
  17. }
  18. // generate_sequence uses ggml_context and ggml_allocr to reuse memory buffers across steps.
  19. // This can lead to dangling pointers, which don't segfault, but instead read garbage data.
  20. // Enabling this flag allows to explictly reset memory buffers, making it more explicit
  21. // when we read garbage data.
  22. // It also prints memory usage information, which is useful to
  23. #define DEBUG_MEM_USAGE 1
  24. size_t MB = 1024 * 1024;
  25. void printf_mem_usage(ggml_context* ctx, std::string name) {
  26. #if DEBUG_MEM_USAGE
  27. double mb = 1024.0 * 1024.0;
  28. printf(
  29. "%s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  30. name.c_str(),
  31. ggml_used_mem(ctx) / mb,
  32. ggml_get_mem_size(ctx) / mb
  33. );
  34. #endif
  35. }
  36. #define SWAP(x, y) \
  37. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  38. #define GGML_ASSERT_SHAPE(x, ne0, ne1, ne2, ne3) \
  39. GGML_ASSERT((ne0 == -1 || x->ne[0] == ne0) && (ne1 == -1 || x->ne[1] == ne1) && (ne2 == -1 || x->ne[2] == ne2) && (ne3 == -1 || x->ne[3] == ne3));
  40. /// allocate the fairseq2 model and hyperparameters
  41. extern "C" fairseq2_model* fairseq2_model_alloc() {
  42. // pre-allocate some memory to write hyperparameters and tensors pointers
  43. auto* model = new fairseq2_model;
  44. model->tensors_ctx = nullptr;
  45. return model;
  46. }
  47. extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_cache_ctx, int beam_size, int max_seq_len) {
  48. // Note: we only allocate the masks, proper kv cache allocation is delayed.
  49. GGML_ASSERT(kv_cache_ctx);
  50. GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx)); // We need to be able to alloc the kv_cache buffers
  51. model.kv_cache_ctx = kv_cache_ctx;
  52. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  53. FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
  54. self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
  55. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  56. for (auto named_tensor : model.tensors) {
  57. const std::string& name = named_tensor.first;
  58. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  59. continue;
  60. // create a cache entry without the ".k_proj.weight" suffix
  61. const std::string& shortname = name.substr(0, name.size() - 14);
  62. KeyValueTensor& kv = model.kv_cache[shortname];
  63. kv.step_nr = 0;
  64. kv.full_k = nullptr;
  65. kv.full_v = nullptr;
  66. kv.self_attn_mask = self_attn_mask;
  67. }
  68. }
  69. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  70. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  71. model.kv_cache.clear();
  72. }
  73. bool has_kv_cache(const fairseq2_model& model) {
  74. return model.kv_cache.size() > 0;
  75. }
  76. inline ggml_tensor* ggml_squeeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  77. int n_dims = x->n_dims;
  78. GGML_ASSERT(dim >= 0);
  79. GGML_ASSERT(dim < n_dims);
  80. GGML_ASSERT(x->ne[dim] == 1);
  81. return ggml_flatten_1d(ctx, x, dim);
  82. }
  83. inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  84. return ggml_unflatten_1d(ctx, x, dim, 1);
  85. }
  86. // copy k and v to kv cache
  87. // kv.full_k[step_nr] = k;
  88. // kv.full_v[step_nr] = v;
  89. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  90. KeyValueTensor& kv = model.kv_cache[prefix];
  91. int step_nr = kv.step_nr;
  92. ggml_context* ctx = model.kv_cache_ctx ? model.kv_cache_ctx : model.ctx;
  93. // We need to force allocation here, otherwise the kv_cache buffers can be reused
  94. bool no_alloc_save = ggml_get_no_alloc(ctx);
  95. ggml_set_no_alloc(ctx, false);
  96. int n_steps = (*k)->ne[1];
  97. // printf("Prefix: %s n_steps: %d\n", prefix.c_str(), n_steps);
  98. int k_proj, batch_size;
  99. if (kv.full_k != nullptr) {
  100. // (N, S_kv, K_proj)
  101. k_proj = kv.full_k->ne[0];
  102. batch_size = kv.full_k->ne[2];
  103. ggml_detach(kv.full_k);
  104. ggml_detach(kv.full_v);
  105. kv.full_k = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_k, 1), ggml_unsqueeze(ctx, *k, 1)), 1);
  106. kv.full_v = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_v, 1), ggml_unsqueeze(ctx, *v, 1)), 1);
  107. } else {
  108. GGML_ASSERT(step_nr == 0);
  109. k_proj = (*k)->ne[0];
  110. batch_size = (*v)->ne[2];
  111. kv.full_k = ggml_dup(ctx, *k);
  112. kv.full_v = ggml_dup(ctx, *v);
  113. }
  114. *k = kv.full_k;
  115. *v = kv.full_v;
  116. ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
  117. ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
  118. step_nr += n_steps;
  119. // printf("Prefix: %s step_nr: %d\n", prefix.c_str(), step_nr);
  120. GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
  121. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  122. // we return the Sq slice of the (Sq, Sk) attention mask
  123. if (self_attn_mask != nullptr) {
  124. *self_attn_mask = ggml_slice(
  125. ctx, ggml_slice(ctx, kv.self_attn_mask, 0, 0, step_nr),
  126. 1, step_nr - 1, step_nr
  127. );
  128. }
  129. kv.step_nr = step_nr;
  130. ggml_set_no_alloc(ctx, no_alloc_save);
  131. }
  132. // variant of ggml_get_rows that allows for a with more than 2 dims.
  133. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  134. int flattened = 0;
  135. GGML_ASSERT(a->n_dims <= 3);
  136. if (a->n_dims == 3) {
  137. flattened = a->ne[0];
  138. a = ggml_flatten_1d(ctx, a, 0);
  139. }
  140. a = ggml_get_rows(ctx, a, b);
  141. if (flattened) {
  142. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  143. }
  144. return a;
  145. }
  146. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  147. // GGML_ASSERT(ctx == kv.full_k->con);
  148. if (kv.full_k != nullptr) {
  149. ggml_detach(kv.full_k);
  150. const char* name = kv.full_k->name;
  151. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  152. ggml_build_forward_expand(gf, kv.full_k);
  153. ggml_format_name(kv.full_k, "%s (sorted)", name);
  154. }
  155. if (kv.full_v != nullptr) {
  156. ggml_detach(kv.full_v);
  157. const char* name = kv.full_v->name;
  158. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  159. ggml_build_forward_expand(gf, kv.full_v);
  160. ggml_format_name(kv.full_v, "%s (sorted)", name);
  161. }
  162. }
  163. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  164. auto self_attn_glob = "*.self_attn";
  165. for (auto& named_kv : model.kv_cache) {
  166. if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH) {
  167. continue;
  168. }
  169. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  170. }
  171. }
  172. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  173. const std::int64_t* data = &model.layer_config.at(name);
  174. double val = *(const double*)data;
  175. return val;
  176. }
  177. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  178. return model_layer_config_d(model, std::string(name));
  179. }
  180. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  181. return model.layer_config.at(std::string(name));
  182. }
  183. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  184. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  185. // delete model;
  186. }
  187. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  188. model->ctx = ctx;
  189. }
  190. extern "C" std::string* std_string_alloc(char* c_str) {
  191. return new std::string(c_str);
  192. }
  193. extern "C" void std_string_free(std::string* str) {
  194. delete str;
  195. }
  196. bool has_layer(fairseq2_model& model, const std::string& name) {
  197. return model.tensors.find(name) != model.tensors.end();
  198. }
  199. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  200. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  201. // `b` has shape (B, 1, D).
  202. // if `a` is (D_out, D), then we do one matmul for the full batch.
  203. b = ggml_flatten_1d(ctx, b, 1);
  204. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  205. }
  206. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  207. // not sure what's the best way to compute this with BLAS
  208. return ggml_mul_mat(ctx, a, b); // (d_out)
  209. }
  210. extern "C" ggml_tensor* Linear_forward(
  211. fairseq2_model& model,
  212. const std::string &prefix,
  213. ggml_tensor* input // (d_in)
  214. ) {
  215. // Note: for now we assumed un-batched input
  216. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  217. GGML_ASSERT(weight != nullptr);
  218. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  219. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  220. if (bias == nullptr) return out;
  221. return ggml_add(model.ctx, out, bias);
  222. }
  223. extern "C" ggml_tensor* LayerNorm_forward(
  224. fairseq2_model& model,
  225. const std::string &prefix,
  226. ggml_tensor* input
  227. ) {
  228. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  229. GGML_ASSERT(weight != nullptr);
  230. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  231. GGML_ASSERT(bias != nullptr);
  232. auto ctx = model.ctx;
  233. double eps = model_layer_config_d(model, prefix + ".eps");
  234. input = ggml_norm(ctx, input, /*eps*/eps);
  235. return ggml_add_inplace(
  236. ctx,
  237. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  238. ggml_repeat(ctx, bias, input)
  239. );
  240. }
  241. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  242. fairseq2_model& model,
  243. const std::string& prefix,
  244. ggml_tensor* seqs
  245. ) {
  246. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  247. // inner_activation = ReLu // TODO: allow other activation
  248. seqs = ggml_relu_inplace(model.ctx, seqs);
  249. if (has_layer(model, prefix + ".inner_layer_norm")) {
  250. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  251. }
  252. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  253. return seqs;
  254. }
  255. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  256. fairseq2_model& model,
  257. const std::string& prefix,
  258. ggml_tensor* seqs
  259. ) {
  260. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  261. seqs = ggml_silu(model.ctx, seqs);
  262. if (has_layer(model, prefix + ".inner_layer_norm")) {
  263. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  264. }
  265. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  266. return seqs;
  267. }
  268. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  269. int n_dims = x->n_dims;
  270. GGML_ASSERT(dim >= 0);
  271. GGML_ASSERT(dim < n_dims);
  272. GGML_ASSERT(ggml_is_contiguous(x));
  273. // Nothing to do
  274. if (dim == n_dims - 1) return x;
  275. if (n_dims == 2) {
  276. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  277. } else if (n_dims == 3) {
  278. if (dim == 0) {
  279. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  280. } else { // dim == 1
  281. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  282. }
  283. } else { // n_dims == 4
  284. if (dim == 0) {
  285. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  286. } else if (dim == 1) {
  287. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  288. } else { // dim == 2
  289. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  290. }
  291. }
  292. }
  293. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  294. int n_dims = x->n_dims;
  295. GGML_ASSERT(dim >= 0);
  296. GGML_ASSERT(dim < n_dims);
  297. GGML_ASSERT(n_dims < 4);
  298. GGML_ASSERT(x->ne[dim] % num_el == 0);
  299. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  300. if (n_dims == 1) {
  301. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  302. } else if (n_dims == 2) {
  303. if (dim == 0) {
  304. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  305. } else { // dim == 1
  306. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  307. }
  308. } else { // (n_dims == 3)
  309. if (dim == 0) {
  310. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  311. } else if (dim == 1) {
  312. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  313. } else { // dim == 2
  314. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  315. }
  316. }
  317. }
  318. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  319. // (B, S, dim) -> (B, S, H, H_dim)
  320. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  321. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  322. x = ggml_cont(ctx, x);
  323. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  324. return x;
  325. }
  326. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  327. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  328. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  329. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  330. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  331. v = ggml_cont(ctx, v);
  332. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  333. return v;
  334. }
  335. // flash_attn doesn't work for cross attention because it assumes Q <= K
  336. // and it seems to yield slightly different scores than expected, and thus a different beam search
  337. # define UNITY_FLASH_ATTN 0
  338. extern "C" ggml_tensor* MultiheadAttention_forward(
  339. fairseq2_model& model,
  340. const std::string &prefix,
  341. ggml_tensor* queries, // (slen, d_in)
  342. ggml_tensor* keys, // (klen, d_in)
  343. ggml_tensor* values, // (klen, d_out)
  344. ggml_tensor* attn_mask // (klen, slen)
  345. ) {
  346. int model_dim = queries->ne[0];
  347. int num_heads = model.layer_config.at(prefix + ".num_heads");
  348. int head_dim = model_dim / num_heads;
  349. GGML_ASSERT(model_dim % num_heads == 0);
  350. ggml_context* ctx = model.ctx;
  351. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  352. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  353. ggml_set_name(q, "q");
  354. ggml_tensor *k, *v;
  355. if (!has_kv_cache(model)) {
  356. k = Linear_forward(model, prefix + ".k_proj", keys);
  357. ggml_set_name(k, "k");
  358. v = Linear_forward(model, prefix + ".v_proj", values);
  359. ggml_set_name(v, "v");
  360. } else {
  361. bool encoder_decoder_attn = keys == values && keys != queries;
  362. if (encoder_decoder_attn) {
  363. // The K and V tensors of an encoder-decoder attention (i.e. the
  364. // projected encoder outputs) remain static during evaluation.
  365. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  366. if (kv_cache.step_nr == 0) {
  367. // If possible we use the ctx dedicated to kv_cache here,
  368. // because the enc dec attention is typically long lived.
  369. if (model.kv_cache_ctx) model.ctx = model.kv_cache_ctx;
  370. k = Linear_forward(model, prefix + ".k_proj", keys);
  371. ggml_set_name(k, "k");
  372. v = Linear_forward(model, prefix + ".v_proj", values);
  373. ggml_set_name(v, "v");
  374. // Note we are only storing a pointer to the buffer, not the full graph
  375. kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
  376. printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
  377. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  378. kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
  379. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  380. kv_cache.step_nr = keys->ne[1];
  381. model.ctx = ctx;
  382. } else {
  383. printf("prefix: %s, k: %d %d %d\n", prefix.c_str(), kv_cache.full_k->ne[0], kv_cache.full_k->ne[1], kv_cache.full_k->ne[2]);
  384. k = kv_cache.full_k;
  385. v = kv_cache.full_v;
  386. GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
  387. GGML_ASSERT(values->ne[1] == v->ne[1]); // cache content doesn't match the input sequence
  388. }
  389. } else { // self attention
  390. // (1, K) -> (N, 1, K_proj)
  391. for (auto& named_kv : model.kv_cache) {
  392. auto enc_dec_attn_glob = "*.encoder_decoder_attn";
  393. if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
  394. printf("HERE BEFORE CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
  395. if(named_kv.second.full_k != nullptr)
  396. printf("HERE BEFORE CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
  397. }
  398. }
  399. k = Linear_forward(model, prefix + ".k_proj", keys);
  400. for (auto& named_kv : model.kv_cache) {
  401. auto enc_dec_attn_glob = "*.encoder_decoder_attn";
  402. if (::fnmatch(enc_dec_attn_glob, named_kv.first.c_str(), 0) != FNM_NOMATCH) {
  403. printf("HERE AFTER CULPRIT LINE prefix: %s\n", named_kv.first.c_str());
  404. if(named_kv.second.full_k != nullptr)
  405. printf("HERE AFTER CULPRIT LINE k: %d\n", named_kv.second.full_k->ne[0]);
  406. }
  407. }
  408. ggml_set_name(k, "k");
  409. // (1, V) -> (N, 1, V_proj)
  410. v = Linear_forward(model, prefix + ".v_proj", values);
  411. ggml_set_name(v, "v");
  412. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  413. }
  414. }
  415. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  416. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  417. v = ggml_cont(ctx, v);
  418. #if UNITY_FLASH_ATTN
  419. // For flash_attn, we assume either no masks, or triangular masks.
  420. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  421. ggml_set_name(attn, "attn");
  422. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  423. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  424. #else
  425. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  426. ggml_tensor* qk = mul_mat(ctx, k, q);
  427. ggml_set_name(qk, "qk");
  428. FORCE_ALLOC(qk_scale, ctx, ggml_new_tensor_1d(ctx, qk->type, 1));
  429. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  430. qk = ggml_scale(ctx, qk, qk_scale);
  431. ggml_set_name(qk, "qk_scaled");
  432. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  433. // TODO: upgrade qk to float32 if needed
  434. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  435. ggml_set_name(attn_weights, "attn_weights");
  436. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  437. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  438. ggml_set_name(attn, "attn");
  439. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  440. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  441. #endif // UNITY_FLASH_ATTN
  442. attn = ggml_cont(ctx, attn);
  443. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  444. // out -> (B, S, d_out)
  445. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  446. ggml_set_name(out, "out");
  447. return out;
  448. }
  449. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  450. fairseq2_model& model,
  451. const std::string& prefix,
  452. ggml_tensor* seqs,
  453. ggml_tensor* padding_mask
  454. ) {
  455. ggml_context* ctx = model.ctx;
  456. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  457. // _forward_self_attn(seqs, padding_mask)
  458. auto residual = seqs;
  459. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  460. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  461. // TODO: add padding_mask to MultiheadAttention_forward
  462. GGML_ASSERT(padding_mask == nullptr);
  463. seqs = MultiheadAttention_forward(
  464. model,
  465. prefix + ".self_attn",
  466. seqs,
  467. seqs,
  468. seqs,
  469. /*attn_mask=*/nullptr
  470. );
  471. if (has_layer(model, prefix + ".self_attn_norm"))
  472. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  473. seqs = ggml_add_inplace(ctx, seqs, residual);
  474. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  475. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  476. // _forward_ffn(seqs)
  477. residual = seqs;
  478. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  479. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  480. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  481. // TODO: if self.residual_scale is not None:
  482. // residual = self.residual_scale * residual
  483. seqs = ggml_add_inplace(ctx, seqs, residual);
  484. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  485. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  486. return seqs;
  487. }
  488. extern "C" ggml_tensor* WaveformToFbank_forward(
  489. fairseq2_model& model,
  490. const std::string &prefix,
  491. ggml_tensor* waveform
  492. ) {
  493. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  494. ggml_context* ctx = model.ctx;
  495. knf::MelBanksOptions mel_opts{};
  496. mel_opts.num_bins = 80;
  497. knf::FrameExtractionOptions frame_opts{};
  498. frame_opts.samp_freq = 16000;
  499. knf::FbankOptions opts{};
  500. opts.frame_opts = frame_opts;
  501. opts.mel_opts = mel_opts;
  502. std::vector<float_t> signal_frame{};
  503. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  504. FORCE_ALLOC(output, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames));
  505. knf::FbankComputer native_(opts);
  506. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  507. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  508. signal_frame.resize(0);
  509. // Extract the frame from the waveform tensor.
  510. knf::ExtractWindow(
  511. /*sample_offset=*/0,
  512. (float *)(waveform->data),
  513. waveform->ne[0],
  514. frame_nr,
  515. frame_opts,
  516. window_fn_,
  517. &signal_frame);
  518. native_.Compute(
  519. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  520. }
  521. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  522. output = ggml_norm(ctx, output, 1e-5);
  523. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  524. if (output->ne[1] % 2 == 1) {
  525. output = ggml_dup(ctx, ggml_slice(ctx, output, 1, 0, output->ne[1]-1));
  526. }
  527. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  528. return output;
  529. }
  530. // TODO: Check if it's possible to merge with standard MHA
  531. extern "C" ggml_tensor* RelativePositionMHA_forward(
  532. fairseq2_model& model,
  533. const std::string& prefix,
  534. ggml_tensor* seqs
  535. ) {
  536. ggml_context* ctx = model.ctx;
  537. ggml_tensor* residual = seqs;
  538. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  539. // self_attn: qkv
  540. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  541. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  542. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  543. // self_attn: rel_pos SDPA
  544. int32_t S = seqs->ne[1];
  545. int32_t H = 16; // TODO: Make this configurable
  546. int32_t n_ctx = 4096;
  547. int32_t K_h = seqs->ne[0] / H;
  548. int32_t start_index = n_ctx - S;
  549. int32_t end_index = n_ctx + S - 1;
  550. int num_indices = end_index - start_index;
  551. FORCE_ALLOC(rows, ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices));
  552. for (int i = 0; i < num_indices; i++) {
  553. ((int32_t *)rows->data)[i] = start_index + i;
  554. }
  555. // self_attn: load pos_enc weights & compute_r
  556. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  557. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  558. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  559. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  560. r = ggml_dup(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, r, 0, K_h), 0, 2, 1, 3));
  561. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  562. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  563. // self_attn: Permute QKV
  564. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  565. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Qcur, 0, K_h), 0, 2, 1, 3));
  566. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  567. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Kcur, 0, K_h), 0, 2, 1, 3));
  568. // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  569. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Vcur, 0, K_h), 1, 2, 0, 3));
  570. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  571. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  572. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  573. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  574. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  575. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  576. FORCE_ALLOC(pad, ctx, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1));
  577. pad = ggml_set_f32(pad, 0.0);
  578. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  579. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  580. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  581. // discard the first set of positive positions
  582. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  583. // shifts each row by an extra step
  584. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  585. // Discard positions used for shift.
  586. bd = ggml_slice(ctx, bd, 0, 0, S);
  587. // self_attn: compute attn / weights
  588. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  589. FORCE_ALLOC(attn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  590. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  591. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  592. attn_weights = ggml_soft_max(ctx, attn_weights);
  593. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  594. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  595. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  596. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  597. attn_out = ggml_add_inplace(
  598. ctx,
  599. attn_out,
  600. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  601. );
  602. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  603. return attn_out;
  604. }
  605. extern "C" ggml_tensor* ConvModule_forward(
  606. fairseq2_model& model,
  607. const std::string& prefix,
  608. ggml_tensor* seqs
  609. ) {
  610. ggml_context* ctx = model.ctx;
  611. ggml_tensor* residual = seqs;
  612. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  613. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  614. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  615. // conv: GLU
  616. seqs = ggml_glu(ctx, seqs);
  617. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  618. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  619. int K = model.tensors[prefix + ".depthwise_conv.weight"]->ne[0];
  620. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, K / 2, 1, seqs->ne[1]);
  621. // conv: Custom implementation of batch norm
  622. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  623. // conv: SiLU actvation
  624. seqs = ggml_silu_inplace(ctx, seqs);
  625. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  626. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  627. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  628. // conv: + residual
  629. seqs = ggml_add_inplace(ctx, seqs, residual);
  630. return seqs;
  631. }
  632. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  633. fairseq2_model& model,
  634. const std::string& prefix,
  635. ggml_tensor* seqs,
  636. ggml_tensor* padding_mask
  637. ) {
  638. ggml_context* ctx = model.ctx;
  639. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  640. ggml_set_f32(ffn_scale, 0.5f);
  641. ggml_tensor* residual = seqs;
  642. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  643. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  644. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  645. seqs = ggml_add_inplace(ctx, seqs, residual);
  646. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  647. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  648. residual = seqs;
  649. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  650. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  651. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  652. seqs = ggml_add_inplace(ctx, seqs, residual);
  653. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  654. return seqs;
  655. }
  656. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  657. fairseq2_model& model,
  658. const std::string& prefix,
  659. ggml_tensor* seqs,
  660. ggml_tensor* padding_mask
  661. ) {
  662. ggml_context* ctx = model.ctx;
  663. seqs = WaveformToFbank_forward(model, prefix, seqs);
  664. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  665. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  666. int layer_idx = 0;
  667. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  668. while (has_layer(model, layer_name)) {
  669. seqs = StandardConformerEncoderLayer_forward(
  670. model, layer_name, seqs, padding_mask
  671. );
  672. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  673. layer_idx += 1;
  674. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  675. }
  676. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  677. ggml_tensor* residual = seqs;
  678. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  679. seqs = ggml_relu_inplace(ctx, seqs);
  680. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  681. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  682. ggml_set_f32(ffn_scale, 0.5f);
  683. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  684. seqs = ggml_add_inplace(ctx, seqs, residual);
  685. layer_idx = 0;
  686. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  687. while (has_layer(model, layer_name)) {
  688. seqs = StandardConformerEncoderAdaptorLayer_forward(
  689. model, layer_name, seqs, padding_mask
  690. );
  691. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  692. layer_idx += 1;
  693. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  694. }
  695. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  696. return seqs;
  697. }
  698. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  699. fairseq2_model& model,
  700. const std::string& prefix,
  701. ggml_tensor* seqs,
  702. ggml_tensor* padding_mask
  703. ) {
  704. ggml_context* ctx = model.ctx;
  705. ggml_tensor* residual = seqs;
  706. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  707. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  708. residual = ggml_conv_1d(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1, 1);
  709. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  710. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  711. residual = ggml_glu(ctx, residual);
  712. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  713. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  714. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1, 1);
  715. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  716. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  717. seqs = ggml_glu(ctx, seqs);
  718. seqs = MultiheadAttention_forward(
  719. model,
  720. prefix + ".self_attn",
  721. seqs,
  722. seqs,
  723. seqs,
  724. /*attention masks=*/nullptr
  725. );
  726. seqs = ggml_add_inplace(ctx, seqs, residual);
  727. residual = seqs;
  728. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  729. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  730. seqs = ggml_add_inplace(ctx, seqs, residual);
  731. return seqs;
  732. }
  733. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  734. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  735. ggml_tensor* ggml_slice(
  736. struct ggml_context * ctx,
  737. struct ggml_tensor * a,
  738. int axis,
  739. int64_t start,
  740. int64_t end
  741. ) {
  742. int64_t ne[4];
  743. std::copy(a->ne, a->ne + 4, ne);
  744. if (axis < 0) axis = a->n_dims + axis;
  745. if (start < 0) start = ne[axis] + start;
  746. if (end <= 0) end = ne[axis] + end;
  747. GGML_ASSERT(0 <= start);
  748. GGML_ASSERT(start < end);
  749. GGML_ASSERT(end <= ne[axis]);
  750. ne[axis] = end - start;
  751. size_t offset = a->nb[axis] * start;
  752. size_t* nb = a->nb;
  753. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  754. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  755. result->n_dims = a->n_dims;
  756. return result;
  757. }
  758. ggml_tensor* ggml_select(
  759. struct ggml_context * ctx,
  760. struct ggml_tensor * a,
  761. int axis,
  762. int64_t index
  763. ) {
  764. int64_t ne[GGML_MAX_DIMS];
  765. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  766. if (axis < 0) axis = a->n_dims + axis;
  767. if (index < 0) index = ne[axis] + index;
  768. GGML_ASSERT(0 <= index);
  769. GGML_ASSERT(index < ne[axis]);
  770. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  771. size_t offset = a->nb[axis] * index;
  772. size_t* nb = a->nb;
  773. GGML_ASSERT(GGML_MAX_DIMS == 4);
  774. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  775. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  776. result->n_dims = a->n_dims - 1;
  777. return result;
  778. }
  779. // Inplace computation of PositionalEmbedding
  780. extern "C" ggml_tensor* PositionalEmbedding_forward(
  781. fairseq2_model& model,
  782. const std::string& prefix,
  783. ggml_tensor* embeds
  784. ) {
  785. // This only work with the simple pos encoders
  786. int seq_len = embeds->ne[1];
  787. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  788. int start_step = 0;
  789. if (has_kv_cache(model)) {
  790. start_step = model.kv_cache[prefix].step_nr++;
  791. }
  792. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  793. return ggml_add(model.ctx, embeds, pos_embeds);
  794. }
  795. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  796. fairseq2_model& model,
  797. const std::string& prefix,
  798. ggml_tensor* seqs
  799. ) {
  800. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  801. ggml_context* ctx = model.ctx;
  802. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  803. GGML_ASSERT(embed_weights != nullptr);
  804. ggml_tensor* embeds;
  805. if (seqs->n_dims == 1) {
  806. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  807. } else {
  808. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  809. ggml_tensor* flat_seqs = seqs;
  810. if (!ggml_is_contiguous(seqs)) {
  811. flat_seqs = ggml_cont(ctx, flat_seqs);
  812. }
  813. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  814. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  815. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  816. embeds->n_dims = seqs->n_dims + 1;
  817. }
  818. // padding mask ?
  819. // padding_mask = to_padding_mask(embeds, seq_lens)
  820. if (has_layer(model, prefix + ".pos_encoder")) {
  821. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  822. }
  823. if (has_layer(model, prefix + ".layer_norm")) {
  824. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  825. }
  826. return embeds;
  827. }
  828. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  829. fairseq2_model& model,
  830. const std::string& prefix,
  831. ggml_tensor* seqs,
  832. ggml_tensor* padding_mask
  833. ) {
  834. int layer_idx = 0;
  835. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  836. while (has_layer(model, layer_name)) {
  837. seqs = StandardTransformerEncoderLayer_forward(
  838. model, layer_name, seqs, padding_mask
  839. );
  840. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  841. layer_idx += 1;
  842. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  843. }
  844. if (has_layer(model, prefix + ".layer_norm"))
  845. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  846. return seqs;
  847. }
  848. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  849. fairseq2_model& model,
  850. const std::string& prefix,
  851. ggml_tensor* seqs,
  852. ggml_tensor* self_attn_mask,
  853. ggml_tensor* encoder_output,
  854. ggml_tensor* encoder_padding_mask
  855. ) {
  856. ggml_context* ctx = model.ctx;
  857. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  858. // _forward_self_attn(seqs, padding_mask)
  859. auto residual = seqs;
  860. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  861. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  862. seqs = MultiheadAttention_forward(
  863. model,
  864. prefix + ".self_attn",
  865. seqs,
  866. seqs,
  867. seqs,
  868. /*attn_mask=*/self_attn_mask
  869. );
  870. if (has_layer(model, prefix + ".self_attn_norm"))
  871. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  872. seqs = ggml_add_inplace(ctx, seqs, residual);
  873. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  874. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  875. // _forward_encoder_decoder_attn
  876. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  877. // `encoder_output` must be `None` for decoder-only attention.
  878. GGML_ASSERT(encoder_output == nullptr);
  879. return seqs;
  880. }
  881. // `encoder_output` must not be `None` for encoder-decoder attention.
  882. GGML_ASSERT(encoder_output != nullptr);
  883. residual = seqs;
  884. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  885. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  886. seqs = MultiheadAttention_forward(
  887. model,
  888. prefix + ".encoder_decoder_attn",
  889. seqs,
  890. encoder_output,
  891. encoder_output,
  892. /*attention masks=*/encoder_padding_mask
  893. );
  894. seqs = ggml_add_inplace(ctx, seqs, residual);
  895. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  896. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  897. // _forward_ffn(seqs)
  898. residual = seqs;
  899. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  900. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  901. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  902. // TODO:
  903. // if self.residual_scale is not None:
  904. // residual = self.residual_scale * residual
  905. seqs = ggml_add_inplace(ctx, seqs, residual);
  906. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  907. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  908. return seqs;
  909. }
  910. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  911. auto seq_len = seqs->ne[1];
  912. // TODO: allow other ggml_type
  913. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  914. return ggml_diag_mask_inf(ctx, mask, 0);
  915. }
  916. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  917. fairseq2_model& model,
  918. const std::string& prefix,
  919. ggml_tensor* seqs,
  920. ggml_tensor* padding_mask,
  921. ggml_tensor* encoder_output,
  922. ggml_tensor* encoder_padding_mask
  923. ) {
  924. int layer_idx = 0;
  925. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  926. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  927. while (has_layer(model, layer_name)) {
  928. seqs = StandardTransformerDecoderLayer_forward(
  929. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  930. );
  931. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  932. layer_idx += 1;
  933. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  934. }
  935. if (has_layer(model, prefix + ".layer_norm"))
  936. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  937. return seqs;
  938. }
  939. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  940. auto opts = job.opts;
  941. int max_seq_len = -1;
  942. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  943. max_seq_len = opts.hard_max_seq_len;
  944. } else {
  945. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  946. }
  947. if (opts.min_seq_len > max_seq_len) {
  948. printf(
  949. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  950. opts.min_seq_len,
  951. max_seq_len
  952. );
  953. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  954. }
  955. int prefix_seq_len = job.prefix_seq->ne[0];
  956. if (prefix_seq_len >= max_seq_len) {
  957. printf(
  958. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  959. prefix_seq_len,
  960. max_seq_len
  961. );
  962. GGML_ASSERT(prefix_seq_len < max_seq_len);
  963. }
  964. return max_seq_len;
  965. }
  966. void _fan_out_encoder_output(
  967. ggml_context* ctx,
  968. ggml_tensor** encoder_output_out,
  969. ggml_tensor** encoder_padding_mask_out,
  970. int beam_size
  971. ) {
  972. // (S_enc, M)
  973. ggml_tensor* encoder_output = *encoder_output_out;
  974. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  975. // (B, S_enc, M)
  976. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  977. // (S_enc, M) -> (B, S_enc, M)
  978. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  979. // (S_enc) -> (B, S_enc)
  980. if (encoder_padding_mask != nullptr) {
  981. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  982. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  983. }
  984. }
  985. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  986. // TODO: this isn't the most precise way of doing this
  987. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  988. }
  989. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  990. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  991. ggml_type true_type = x->type;
  992. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  993. y->type = true_type;
  994. return y;
  995. }
  996. void _bootstrap_seqs_and_scores(
  997. fairseq2_model& model,
  998. const SequenceGeneratorJob& job,
  999. ggml_tensor* full_seqs,
  1000. ggml_tensor* scores,
  1001. ggml_tensor* encoder_output,
  1002. ggml_tensor* encoder_padding_mask,
  1003. ggml_tensor* lid_scores,
  1004. int n_threads,
  1005. const std::vector<int>& lang_ids
  1006. ) {
  1007. // Returns LID score map
  1008. int prefix_seq_len = job.prefix_seq->ne[0];
  1009. int max_seq_len = scores->ne[0];
  1010. int beam_size = scores->ne[1];
  1011. GGML_ASSERT(prefix_seq_len > 0);
  1012. if (prefix_seq_len == 1) {
  1013. // bootstrap all beams in full_seqs with EOS
  1014. // This is equivalent to:
  1015. // // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  1016. // because in normal case: prefix_seq[0] = EOS
  1017. //
  1018. int eos_id = model.vocab.token_to_id["</s>"];
  1019. if (model.tgt_vocab.id_to_token.size()) {
  1020. eos_id = model.tgt_vocab.token_to_id["</s>"];
  1021. }
  1022. size_t vocab_size = model.tensors["text_decoder_frontend.embed.weight"]->ne[1];
  1023. for (int k = 0; k < beam_size; k++) {
  1024. ggml_set_i32_1d(full_seqs, k * vocab_size, eos_id);
  1025. }
  1026. return;
  1027. }
  1028. ggml_context* ctx = model.ctx;
  1029. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  1030. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  1031. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  1032. // We have to bootstrap the model with the already fanned-out encoder
  1033. // output to correctly initialize its incremental state.
  1034. // Note: we don't start decoding the last prefix token just yet.
  1035. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  1036. // Bootstrap the model state with prefix sequence.
  1037. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  1038. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1039. model,
  1040. "text_decoder",
  1041. seqs,
  1042. /*padding_mask*/ nullptr,
  1043. encoder_output,
  1044. encoder_padding_mask
  1045. );
  1046. // logits, lprobs: (N, S_pfx - 1, V)
  1047. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  1048. int vocab_size = logits->ne[0];
  1049. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  1050. struct ggml_cgraph * gf = ggml_new_graph(ctx);
  1051. ggml_build_forward_expand(gf, lprobs);
  1052. ggml_graph_compute_with_ctx(ctx, gf, n_threads);
  1053. full_seqs->type = GGML_TYPE_I32;
  1054. job.prefix_seq->type = GGML_TYPE_I32;
  1055. // For LID
  1056. for (size_t i = 0; i < lang_ids.size(); ++i) {
  1057. ggml_set_f32_1d(lid_scores, i, std::exp(ggml_get_f32_1d(lprobs, lang_ids[i])));
  1058. }
  1059. // Fetch scores of next steps from "lprobs"
  1060. float p_score = 0;
  1061. for (int i = 1; i < prefix_seq_len; ++i) {
  1062. int p;
  1063. if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
  1064. // If tgt_lang is unk, use the most probable lang tag predicted by model
  1065. int max_value = std::numeric_limits<float>::min();
  1066. for (int j = 0; j < lang_ids.size(); j++) {
  1067. if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
  1068. max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
  1069. p = lang_ids[j];
  1070. }
  1071. }
  1072. } else {
  1073. p = ggml_get_i32_1d(job.prefix_seq, i);
  1074. }
  1075. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1076. for (int b = 0; b < beam_size; ++b) {
  1077. // scores: (N, S)
  1078. // Note: First step (e.g. BOS)'s score is always 0.
  1079. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1080. }
  1081. }
  1082. }
  1083. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1084. int topk(
  1085. ggml_tensor* lprobs, // (B, V)
  1086. std::int64_t k,
  1087. ggml_tensor* candidate_indices
  1088. ) {
  1089. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1090. // `beam_size` of these which don't predict EOS to continue with.
  1091. // (N, 2 x B)
  1092. // `vocab_size` - 1 to never select PAD.
  1093. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1094. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1095. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1096. };
  1097. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1098. auto cand = (std::int32_t*)candidate_indices->data;
  1099. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1100. return K;
  1101. }
  1102. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1103. std::size_t beam_size = job.opts.beam_size;
  1104. std::size_t eos_idx = job.eos_idx;
  1105. // Do not allow EOS before reaching the minimum sequence length.
  1106. if (step_nr < job.opts.min_seq_len) {
  1107. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1108. for (size_t i = 0; i < beam_size; ++i)
  1109. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1110. }
  1111. // If we have reached the maximum length, force the last step to be EOS.
  1112. if (step_nr == max_seq_len - 2) {
  1113. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1114. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1115. for (size_t b = 0; b < beam_size; ++b) {
  1116. size_t t = 0;
  1117. for (t = 0; t < eos_idx; ++t)
  1118. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1119. for (t = eos_idx + 1; t < vocab_size; ++t)
  1120. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1121. }
  1122. }
  1123. // Never allow PAD.
  1124. std::size_t pad_idx = job.pad_idx;
  1125. for (size_t i = 0; i < beam_size; ++i)
  1126. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1127. // Apply UNK penalty.
  1128. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1129. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1130. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1131. for (size_t i = 0; i < beam_size; ++i)
  1132. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1133. }
  1134. }
  1135. /// Copies the sequence and scores of a given candidate beam.
  1136. void _finalize_hypothesis(
  1137. const SequenceGeneratorJob& job,
  1138. ggml_context* ctx,
  1139. int step_nr,
  1140. std::int32_t beam,
  1141. std::int32_t token,
  1142. float eos_score,
  1143. ggml_tensor* seqs, // (beam_size, seq_len)
  1144. ggml_tensor* scores, // (beam_size, seq_len)
  1145. ggml_tensor* lid_scores,
  1146. Hypothesis* hypothesis
  1147. ) {
  1148. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1149. hypothesis->seq = seq;
  1150. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1151. hypothesis->step_scores = step_scores;
  1152. auto tok = (std::int32_t*)seq->data;
  1153. for (int i = 0; i < step_nr + 1; ++i) {
  1154. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1155. }
  1156. tok[step_nr + 1] = token;
  1157. // Convert from cumulative to per-step scores.
  1158. auto sc = (float*)step_scores->data;
  1159. float last_score = eos_score;
  1160. for (int i = step_nr; i >= 0; --i) {
  1161. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1162. sc[i + 1] = last_score - sc0;
  1163. last_score = sc0;
  1164. }
  1165. sc[0] = 0;
  1166. if (job.opts.normalize_scores)
  1167. // Skip first EOS since it is always 0 and skews normalization.
  1168. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1169. hypothesis->score = eos_score;
  1170. hypothesis->lid_scores = lid_scores;
  1171. }
  1172. // Uses ggml_context to store any object.
  1173. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1174. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1175. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1176. return ggml_init({
  1177. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1178. /*.mem_buffer =*/ buffer.data(),
  1179. /*.no_alloc =*/ false,
  1180. });
  1181. }
  1182. ggml_allocr* new_arena_allocr(std::vector<uint8_t>& buffer) {
  1183. return ggml_allocr_new(buffer.data(), buffer.capacity(), 8);
  1184. }
  1185. /// Generates a translation for a single sequence
  1186. /// The results Hypothesis are written inside `result_ctx`.
  1187. extern "C" Hypothesis* generate_sequence(
  1188. fairseq2_model& model,
  1189. const SequenceGeneratorJob& job,
  1190. ggml_tensor* encoder_output,
  1191. ggml_tensor* encoder_padding_mask,
  1192. ggml_context* result_ctx,
  1193. int n_threads
  1194. ) {
  1195. // Pre allocate memory buffers.
  1196. // * step_ctx: contains metadata for the model graph, as well as some explicit
  1197. // buffers for the lprobs tweaking.
  1198. // * prev_step_ctx: is an additional buffer because we need some results from previous steps,
  1199. // to compute next step. Notably self attention kv cache.
  1200. // * search_ctx contains tensors that should live for the full search,
  1201. // like encoder kv cache.
  1202. // * step_alloc contains buffer for the forward pass of the model.
  1203. // Split mem_mb into the different context we need to use.
  1204. int mem_mb = job.opts.mem_mb;
  1205. std::vector<uint8_t> local_bufs[4] = {
  1206. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // step_ctx
  1207. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // prev_step_ctx
  1208. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // search_ctx
  1209. std::vector<uint8_t>(mem_mb * MB * 1 / 10), // step_alloc
  1210. };
  1211. ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
  1212. std::vector<int> lang_ids;
  1213. if (model.hparams["multilingual"] != 0) {
  1214. for (const auto& kv : model.vocab.token_to_id) {
  1215. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  1216. lang_ids.push_back(kv.second);
  1217. }
  1218. }
  1219. std::sort(lang_ids.begin(), lang_ids.end());
  1220. }
  1221. std::cout << "model multilinguality: " << model.hparams["multilingual"] << " (langs)" << std::endl;
  1222. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1223. size_t vocab_size = embed->ne[1];
  1224. std::size_t beam_size = job.opts.beam_size;
  1225. ggml_detach(encoder_output);
  1226. int source_seq_len = encoder_output->ne[1];
  1227. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1228. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1229. ggml_context* original_ctx = model.ctx;
  1230. fairseq2_kv_cache_alloc(model, search_ctx, beam_size, max_seq_len);
  1231. // (S_enc, M) -> (B, S_enc, M)
  1232. model.ctx = search_ctx;
  1233. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1234. // Allocate results in the context provided by the caller.
  1235. ggml_set_no_alloc(result_ctx, false);
  1236. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1237. Hypothesis* finished_searches = finished_searches_begin;
  1238. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1239. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1240. // Initialize buffers. (B, S)
  1241. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1242. printf("Seqs dim: [%d %d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1243. ggml_set_i32(seqs, 0);
  1244. ggml_set_name(seqs, "seqs_0");
  1245. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1246. ggml_set_name(scores, "scores_0");
  1247. ggml_set_f32(scores, 0.0);
  1248. int prefix_seq_len = job.prefix_seq->ne[0];
  1249. int start_step = prefix_seq_len - 1;
  1250. ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step - 1) % 2]);
  1251. ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
  1252. GGML_ASSERT(step_ctx != search_ctx);
  1253. GGML_ASSERT(prev_step_ctx != step_ctx);
  1254. model.ctx = prev_step_ctx;
  1255. // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
  1256. model.kv_cache_ctx = search_ctx;
  1257. ggml_tensor* lid_scores;
  1258. if (lang_ids.size()) {
  1259. lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
  1260. }
  1261. // Multilingual models: Bootstrap LID scores
  1262. _bootstrap_seqs_and_scores(
  1263. model, job, seqs, scores, encoder_output, encoder_padding_mask, lid_scores, n_threads, lang_ids
  1264. );
  1265. printf("Seqs dim after bootstrapping: [%d %d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1266. // Now we will only add self_attn.k_cache and those need to be resorted and copied at every step.
  1267. model.kv_cache_ctx = nullptr;
  1268. // Holds the indices of beams (a beam can occur more than once) that we
  1269. // should continue with in the next step.
  1270. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1271. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1272. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1273. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1274. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1275. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1276. ((int32_t *)(candidate_indices->data))[i] = i;
  1277. printf_mem_usage(search_ctx, "search_ctx");
  1278. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1279. model.ctx = step_ctx;
  1280. ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
  1281. float max_lprob;
  1282. int p;
  1283. if (step_nr == start_step) {
  1284. // Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
  1285. if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
  1286. float max_lprob = std::numeric_limits<float>::min();
  1287. for(int j = 0; j < lang_ids.size(); j++) {
  1288. auto val = ggml_get_f32_1d(lid_scores, j);
  1289. if (val > max_lprob) {
  1290. max_lprob = val;
  1291. p = lang_ids[j];
  1292. }
  1293. }
  1294. for (int k = 0; k < beam_size; k++) {
  1295. ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
  1296. }
  1297. }
  1298. }
  1299. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1300. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1301. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1302. model,
  1303. "text_decoder",
  1304. decoder_input,
  1305. nullptr, // We never generate PAD.
  1306. encoder_output,
  1307. encoder_padding_mask
  1308. ); // (B, 1, D)
  1309. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1310. // Force logits to be allocated in step_ctx, not in step_alloc.
  1311. ggml_set_no_alloc(step_ctx, false);
  1312. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1313. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1314. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1315. // TODO: use ggml properly compute the tweaks
  1316. struct ggml_cgraph * gf = ggml_new_graph(step_ctx);
  1317. ggml_build_forward_expand(gf, lprobs);
  1318. size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, gf);
  1319. GGML_UNUSED(fwd_mem);
  1320. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1321. ggml_detach(lprobs);
  1322. ggml_allocr_reset(step_alloc);
  1323. #if DEBUG_MEM_USAGE
  1324. printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf->n_nodes);
  1325. printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
  1326. std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
  1327. #endif
  1328. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1329. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1330. if (step_nr == start_step) {
  1331. // At the initial step, all hypotheses are equally likely, so we use
  1332. // only the first beam.
  1333. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1334. lprobs = ggml_cont(step_ctx, lprobs);
  1335. // The first step always indicates the beginning of the sequence and has no score.
  1336. if (step_nr > 0) {
  1337. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1338. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1339. }
  1340. } else {
  1341. // Make probabilities contain cumulative scores for each hypothesis.
  1342. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1343. }
  1344. ggml_build_forward_expand(gf, lprobs);
  1345. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1346. // Determine (beam, token) candidates for the next step.
  1347. // (N, 2 x B)
  1348. std::int64_t K = topk(
  1349. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1350. );
  1351. std::size_t ongoing_beams = 0;
  1352. for (std::int32_t i = 0; i < K; ++i) {
  1353. int c = ggml_get_f32_1d(candidate_indices, i);
  1354. std::int32_t beam = c / vocab_size;
  1355. std::int32_t token = c % vocab_size;
  1356. float tok_score = ggml_get_f32_1d(lprobs, c);
  1357. // Detect beams that reached the minimum length and that end with an EOS.
  1358. bool eos = token == job.eos_idx;
  1359. eos &= tok_score != -INFINITY;
  1360. if (eos) {
  1361. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, lid_scores, finished_searches++);
  1362. if (finished_searches == finished_searches_end)
  1363. goto end_of_beam_search;
  1364. continue;
  1365. }
  1366. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1367. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1368. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1369. if (model.hparams["multilingual"] == 0) {
  1370. printf("Token at top%d: %d (%s)\n", i, token, model.tgt_vocab.id_to_token.at(token).text.c_str());
  1371. } else {
  1372. printf("Token at top%d: %d (%s)\n", i, token, model.vocab.id_to_token.at(token).text.c_str());
  1373. }
  1374. // printf("Seqs dim: [%d %d %d], beam_indices: [%d %d]\n", seqs->ne[0], seqs->ne[1], seqs->ne[2], beam_indices->ne[0], beam_indices->ne[1]);
  1375. ongoing_beams += 1;
  1376. if (ongoing_beams >= beam_size) break;
  1377. }
  1378. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1379. // be selected more than once.
  1380. // (B, S), (B) -> (B, S)
  1381. // don't use allocr API, cause it might reuse a kv cache buffer several time.
  1382. ggml_set_no_alloc(step_ctx, false);
  1383. printf("Seqs dim before getting rows step %d: [%d %d %d]\n", step_nr, seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1384. ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
  1385. ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
  1386. struct ggml_cgraph * gf_reorder = ggml_new_graph(step_ctx);
  1387. ggml_build_forward_expand(gf_reorder, new_seqs);
  1388. ggml_build_forward_expand(gf_reorder, new_scores);
  1389. reorder_kv_cache(model, step_ctx, gf_reorder, beam_indices);
  1390. ggml_graph_compute_with_ctx(step_ctx, gf_reorder, n_threads);
  1391. seqs = ggml_detach(new_seqs);
  1392. printf("Seqs dim after detach step %d: [%d %d %d]\n", step_nr, seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  1393. scores = ggml_detach(new_scores);
  1394. // seqs[:, step_nr + 1] = next_tokens
  1395. // scores[:, step_nr + 1] = next_scores
  1396. for (std::size_t i = 0; i < beam_size; ++i) {
  1397. ((std::int32_t*)seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1398. ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1399. }
  1400. printf_mem_usage(step_ctx, "step_ctx");
  1401. ggml_free(prev_step_ctx);
  1402. prev_step_ctx = step_ctx;
  1403. #if DEBUG_MEM_USAGE
  1404. std::fill(local_bufs[(step_nr + 1) % 2].begin(), local_bufs[(step_nr + 1) % 2].end(), 0xAA);
  1405. #endif
  1406. step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1407. }
  1408. end_of_beam_search:
  1409. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1410. std::sort(
  1411. finished_searches_begin,
  1412. finished_searches_end,
  1413. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1414. );
  1415. printf_mem_usage(search_ctx, "search_ctx");
  1416. // fairseq2_kv_cache_reset(model);
  1417. model.ctx = original_ctx;
  1418. return finished_searches_begin;
  1419. }
  1420. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1421. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1422. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1423. ggml_set_i32_1d(result[0].seq, 0, 314);
  1424. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1425. ggml_set_i32_1d(result[1].seq, 0, 421);
  1426. return result;
  1427. }
  1428. // SPM tokenizer
  1429. // original implementation:
  1430. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1431. struct llm_symbol {
  1432. using index = int;
  1433. index prev;
  1434. index next;
  1435. const char * text;
  1436. size_t n;
  1437. llama_vocab::id id;
  1438. };
  1439. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1440. static size_t utf8_len(char src) {
  1441. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1442. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1443. return lookup[highbits];
  1444. }
  1445. struct llm_bigram_spm {
  1446. struct comparator {
  1447. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1448. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1449. }
  1450. };
  1451. using queue_storage = std::vector<llm_bigram_spm>;
  1452. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1453. llm_symbol::index left;
  1454. llm_symbol::index right;
  1455. float score;
  1456. size_t size;
  1457. llama_vocab::id id;
  1458. };
  1459. struct llm_tokenizer_spm {
  1460. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1461. void tokenize(const std::string& input_text, ggml_tensor* output) {
  1462. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1463. // split string into utf8 chars
  1464. int index = 0;
  1465. size_t offs = 0;
  1466. // This is kind of annoying, but needed because with SPM,
  1467. // characters following a space have a special meaning.
  1468. // And the algorithm rely on substrings to do the lookups.
  1469. std::string text = input_text;
  1470. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1471. if (need_extra_space) text = " " + text;
  1472. while (offs < text.size()) {
  1473. size_t len = utf8_len(text[offs]);
  1474. size_t n = std::min(len, text.size() - offs);
  1475. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1476. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1477. llm_symbol sym = {
  1478. /*prev*/ index - 1,
  1479. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1480. /*text*/ text.c_str() + offs,
  1481. /*n*/ n,
  1482. /*id*/ id
  1483. };
  1484. offs += n;
  1485. index++;
  1486. symbols.emplace_back(sym);
  1487. }
  1488. // seed the work queue with all possible 2-character tokens.
  1489. for (size_t i = 1; i < symbols.size(); ++i) {
  1490. try_add_bigram(i - 1, i);
  1491. }
  1492. // keep substituting the highest frequency pairs for as long as we can.
  1493. while (!work_queue.empty()) {
  1494. auto bigram = work_queue.top();
  1495. work_queue.pop();
  1496. auto & left_sym = symbols[bigram.left];
  1497. auto & right_sym = symbols[bigram.right];
  1498. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1499. // if one of the symbols already got merged, skip it.
  1500. if (
  1501. left_sym.n == 0
  1502. || right_sym.n == 0
  1503. || left_sym.n + right_sym.n != bigram.size
  1504. ) continue;
  1505. // merge the right sym into the left one
  1506. left_sym.n += right_sym.n;
  1507. left_sym.id = bigram.id;
  1508. right_sym.n = 0;
  1509. // remove the right sym from the chain
  1510. left_sym.next = right_sym.next;
  1511. if (right_sym.next >= 0) {
  1512. symbols[right_sym.next].prev = bigram.left;
  1513. }
  1514. // find more substitutions
  1515. try_add_bigram(left_sym.prev, bigram.left);
  1516. try_add_bigram(bigram.left, left_sym.next);
  1517. }
  1518. llama_vocab::id* out = (llama_vocab::id*)output->data;
  1519. int out_step = sizeof(llama_vocab::id) / output->nb[0];
  1520. int num_tokens = 0;
  1521. for (int i = 0; i > -1; i = symbols[i].next) {
  1522. llm_symbol& symbol = symbols[i];
  1523. *(out + num_tokens * out_step) = symbol.id;
  1524. num_tokens += 1;
  1525. }
  1526. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1527. num_tokens += 1;
  1528. output->ne[0] = num_tokens;
  1529. }
  1530. private:
  1531. void try_add_bigram(int left, int right) {
  1532. if (left == -1 || right == -1) {
  1533. return;
  1534. }
  1535. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1536. auto token = vocab.token_to_id.find(text);
  1537. if (token == vocab.token_to_id.end()) {
  1538. return;
  1539. }
  1540. llama_vocab::id id = token->second;
  1541. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1542. return;
  1543. }
  1544. const auto& tok_data = vocab.id_to_token[id];
  1545. llm_bigram_spm bigram = {
  1546. /*left */ left,
  1547. /*right*/ right,
  1548. /*score*/ tok_data.score,
  1549. /*size */ text.size(),
  1550. /*id */ id
  1551. };
  1552. work_queue.push(bigram);
  1553. }
  1554. const llama_vocab& vocab;
  1555. std::vector<llm_symbol> symbols;
  1556. llm_bigram_spm::queue work_queue;
  1557. };
  1558. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
  1559. llm_tokenizer_spm spm = {model->vocab};
  1560. spm.tokenize(std::string(text), out);
  1561. }
  1562. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1563. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1564. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1565. int sent_len = tokens->ne[0];
  1566. std::size_t written = 0;
  1567. std::vector<llama_vocab::token_data> id_to_token = no_tgt_vocab ? model->vocab.id_to_token : model->tgt_vocab.id_to_token;
  1568. for (int i = 0; i < sent_len; ++i) {
  1569. int id = ggml_get_i32_1d(tokens, i);
  1570. // Don't print the EOS token but only if it appear at the end.
  1571. if (i == sent_len - 1 && eos_idx == id) break;
  1572. std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
  1573. // Skip the first space outputted.
  1574. auto begin = token.begin();
  1575. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1576. std::copy(begin, token.end(), out);
  1577. std::size_t n = token.end() - begin;
  1578. written += n;
  1579. out += n;
  1580. }
  1581. *out = '0';
  1582. return written;
  1583. }
  1584. // TODO: Unify with the above?
  1585. std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(
  1586. fairseq2_model* model,
  1587. ggml_tensor* tokens,
  1588. ggml_tensor* scores,
  1589. char* out) {
  1590. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1591. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1592. int sent_len = tokens->ne[0];
  1593. std::size_t written = 0;
  1594. std::vector<float> word_scores;
  1595. std::vector<float> subword_scores;
  1596. std::vector<std::string> result_text;
  1597. std::string curr_token = "";
  1598. for (int i = 0; i < sent_len; ++i) {
  1599. int id = ggml_get_i32_1d(tokens, i);
  1600. // Don't print the EOS token but only if it appear at the end.
  1601. if (i == sent_len - 1 && eos_idx == id) break;
  1602. std::string token = model->vocab.id_to_token.at(id).text;
  1603. float score = ggml_get_f32_1d(scores, i+2); // 2 is prefix size
  1604. if(token[0] == ' ') {
  1605. // reset word score
  1606. if(subword_scores.size() > 0) {
  1607. float avg = std::accumulate(subword_scores.begin(), subword_scores.end(), 0.0f) / subword_scores.size();
  1608. word_scores.push_back(avg);
  1609. subword_scores.clear();
  1610. result_text.push_back(curr_token);
  1611. }
  1612. curr_token = token.substr(1);
  1613. } else {
  1614. curr_token += token;
  1615. }
  1616. subword_scores.push_back(score);
  1617. // Skip the first space outputted.
  1618. auto begin = token.begin();
  1619. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1620. std::copy(begin, token.end(), out);
  1621. std::size_t n = token.end() - begin;
  1622. written += n;
  1623. out += n;
  1624. }
  1625. if(subword_scores.size() > 0) {
  1626. word_scores.push_back(*std::min_element(subword_scores.begin(), subword_scores.end()));
  1627. subword_scores.clear();
  1628. result_text.push_back(curr_token);
  1629. }
  1630. *out = '0';
  1631. return std::make_pair(result_text, word_scores);
  1632. }