fairseq2.cpp 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. #include "ggml-alloc.h"
  12. ggml_tensor* ggml_detach(ggml_tensor* a) {
  13. a->op = GGML_OP_NONE;
  14. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  15. return a;
  16. }
  17. // generate_sequence uses ggml_context and ggml_allocr to reuse memory buffers across steps.
  18. // This can lead to dangling pointers, which don't segfault, but instead read garbage data.
  19. // Enabling this flag allows to explictly reset memory buffers, making it more explicit
  20. // when we read garbage data.
  21. // It also prints memory usage information, which is useful to
  22. #define DEBUG_MEM_USAGE DEBUG
  23. void printf_mem_usage(ggml_context* ctx, std::string name) {
  24. #if DEBUG_MEM_USAGE
  25. double mb = 1024.0 * 1024.0;
  26. printf(
  27. "ctx %s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  28. name.c_str(),
  29. ggml_used_mem(ctx) / mb,
  30. ggml_get_mem_size(ctx) / mb
  31. );
  32. #endif
  33. }
  34. #define SWAP(x, y) \
  35. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  36. #define GGML_ASSERT_SHAPE(x, ne0, ne1, ne2, ne3) \
  37. GGML_ASSERT((ne0 == -1 || x->ne[0] == ne0) && (ne1 == -1 || x->ne[1] == ne1) && (ne2 == -1 || x->ne[2] == ne2) && (ne3 == -1 || x->ne[3] == ne3));
  38. /// allocate the fairseq2 model and hyperparameters
  39. extern "C" fairseq2_model* fairseq2_model_alloc() {
  40. // pre-allocate some memory to write hyperparameters and tensors pointers
  41. auto* model = new fairseq2_model;
  42. model->tensors_ctx = nullptr;
  43. return model;
  44. }
  45. extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_cache_ctx, int beam_size, int max_seq_len) {
  46. // Note: we only allocate the masks, proper kv cache allocation is delayed.
  47. GGML_ASSERT(kv_cache_ctx);
  48. GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx)); // We need to be able to alloc the kv_cache buffers
  49. model.kv_cache_ctx = kv_cache_ctx;
  50. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  51. FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
  52. self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
  53. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  54. for (auto named_tensor : model.tensors) {
  55. const std::string& name = named_tensor.first;
  56. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  57. continue;
  58. // create a cache entry without the ".k_proj.weight" suffix
  59. const std::string& shortname = name.substr(0, name.size() - 14);
  60. KeyValueTensor& kv = model.kv_cache[shortname];
  61. kv.step_nr = 0;
  62. kv.full_k = nullptr;
  63. kv.full_v = nullptr;
  64. kv.self_attn_mask = self_attn_mask;
  65. }
  66. }
  67. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  68. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  69. model.kv_cache.clear();
  70. }
  71. bool has_kv_cache(const fairseq2_model& model) {
  72. return model.kv_cache.size() > 0;
  73. }
  74. inline ggml_tensor* ggml_squeeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  75. int n_dims = x->n_dims;
  76. GGML_ASSERT(dim >= 0);
  77. GGML_ASSERT(dim < n_dims);
  78. GGML_ASSERT(x->ne[dim] == 1);
  79. return ggml_flatten_1d(ctx, x, dim);
  80. }
  81. inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  82. return ggml_unflatten_1d(ctx, x, dim, 1);
  83. }
  84. // copy k and v to kv cache
  85. // kv.full_k[step_nr] = k;
  86. // kv.full_v[step_nr] = v;
  87. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  88. KeyValueTensor& kv = model.kv_cache[prefix];
  89. int step_nr = kv.step_nr;
  90. ggml_context* ctx = model.kv_cache_ctx ? model.kv_cache_ctx : model.ctx;
  91. int n_steps = (*k)->ne[1];
  92. int k_proj, batch_size;
  93. if (kv.full_k != nullptr) {
  94. // (N, S_kv, K_proj)
  95. k_proj = kv.full_k->ne[0];
  96. batch_size = kv.full_k->ne[2];
  97. ggml_detach(kv.full_k);
  98. ggml_detach(kv.full_v);
  99. kv.full_k = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_k, 1), ggml_unsqueeze(ctx, *k, 1)), 1);
  100. kv.full_v = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_v, 1), ggml_unsqueeze(ctx, *v, 1)), 1);
  101. } else {
  102. GGML_ASSERT(step_nr == 0);
  103. k_proj = (*k)->ne[0];
  104. batch_size = (*v)->ne[2];
  105. kv.full_k = ggml_dup(ctx, *k);
  106. kv.full_v = ggml_dup(ctx, *v);
  107. }
  108. *k = kv.full_k;
  109. *v = kv.full_v;
  110. ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
  111. ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
  112. step_nr += n_steps;
  113. GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
  114. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  115. // we return the Sq slice of the (Sq, Sk) attention mask
  116. *self_attn_mask = ggml_slice(
  117. model.ctx,
  118. ggml_slice(model.ctx, kv.self_attn_mask, 0, 0, step_nr),
  119. 1,
  120. step_nr - 1,
  121. step_nr
  122. );
  123. kv.step_nr = step_nr;
  124. }
  125. // variant of ggml_get_rows that allows for a with more than 2 dims.
  126. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  127. int flattened = 0;
  128. GGML_ASSERT(a->n_dims <= 3);
  129. if (a->n_dims == 3) {
  130. flattened = a->ne[0];
  131. a = ggml_flatten_1d(ctx, a, 0);
  132. }
  133. a = ggml_get_rows(ctx, a, b);
  134. if (flattened) {
  135. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  136. }
  137. return a;
  138. }
  139. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  140. // GGML_ASSERT(ctx == kv.full_k->con);
  141. if (kv.full_k != nullptr) {
  142. ggml_detach(kv.full_k);
  143. const char* name = kv.full_k->name;
  144. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  145. ggml_build_forward_expand(gf, kv.full_k);
  146. ggml_format_name(kv.full_k, "%s (sorted)", name);
  147. }
  148. if (kv.full_v != nullptr) {
  149. ggml_detach(kv.full_v);
  150. const char* name = kv.full_v->name;
  151. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  152. ggml_build_forward_expand(gf, kv.full_v);
  153. ggml_format_name(kv.full_v, "%s (sorted)", name);
  154. }
  155. }
  156. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  157. auto self_attn_glob = "*.self_attn";
  158. for (auto& named_kv : model.kv_cache) {
  159. if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH)
  160. continue;
  161. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  162. }
  163. }
  164. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  165. const std::int64_t* data = &model.layer_config.at(name);
  166. double val = *(const double*)data;
  167. return val;
  168. }
  169. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  170. return model_layer_config_d(model, std::string(name));
  171. }
  172. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  173. return model.layer_config.at(std::string(name));
  174. }
  175. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  176. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  177. delete model;
  178. }
  179. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  180. model->ctx = ctx;
  181. }
  182. extern "C" std::string* std_string_alloc(char* c_str) {
  183. return new std::string(c_str);
  184. }
  185. extern "C" void std_string_free(std::string* str) {
  186. delete str;
  187. }
  188. bool has_layer(fairseq2_model& model, const std::string& name) {
  189. return model.tensors.find(name) != model.tensors.end();
  190. }
  191. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  192. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  193. // `b` has shape (B, 1, D).
  194. // if `a` is (D_out, D), then we do one matmul for the full batch.
  195. b = ggml_flatten_1d(ctx, b, 1);
  196. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  197. }
  198. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  199. // not sure what's the best way to compute this with BLAS
  200. return ggml_mul_mat(ctx, a, b); // (d_out)
  201. }
  202. extern "C" ggml_tensor* Linear_forward(
  203. fairseq2_model& model,
  204. const std::string &prefix,
  205. ggml_tensor* input // (d_in)
  206. ) {
  207. // Note: for now we assumed un-batched input
  208. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  209. GGML_ASSERT(weight != nullptr);
  210. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  211. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  212. if (bias == nullptr) return out;
  213. return ggml_add(model.ctx, out, bias);
  214. }
  215. extern "C" ggml_tensor* LayerNorm_forward(
  216. fairseq2_model& model,
  217. const std::string &prefix,
  218. ggml_tensor* input
  219. ) {
  220. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  221. GGML_ASSERT(weight != nullptr);
  222. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  223. GGML_ASSERT(bias != nullptr);
  224. auto ctx = model.ctx;
  225. double eps = model_layer_config_d(model, prefix + ".eps");
  226. input = ggml_norm(ctx, input, /*eps*/eps);
  227. return ggml_add_inplace(
  228. ctx,
  229. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  230. ggml_repeat(ctx, bias, input)
  231. );
  232. }
  233. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  234. fairseq2_model& model,
  235. const std::string& prefix,
  236. ggml_tensor* seqs
  237. ) {
  238. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  239. // inner_activation = ReLu // TODO: allow other activation
  240. seqs = ggml_relu_inplace(model.ctx, seqs);
  241. if (has_layer(model, prefix + ".inner_layer_norm")) {
  242. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  243. }
  244. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  245. return seqs;
  246. }
  247. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  248. fairseq2_model& model,
  249. const std::string& prefix,
  250. ggml_tensor* seqs
  251. ) {
  252. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  253. seqs = ggml_silu(model.ctx, seqs);
  254. if (has_layer(model, prefix + ".inner_layer_norm")) {
  255. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  256. }
  257. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  258. return seqs;
  259. }
  260. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  261. int n_dims = x->n_dims;
  262. GGML_ASSERT(dim >= 0);
  263. GGML_ASSERT(dim < n_dims);
  264. GGML_ASSERT(ggml_is_contiguous(x));
  265. // Nothing to do
  266. if (dim == n_dims - 1) return x;
  267. if (n_dims == 2) {
  268. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  269. } else if (n_dims == 3) {
  270. if (dim == 0) {
  271. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  272. } else { // dim == 1
  273. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  274. }
  275. } else { // n_dims == 4
  276. if (dim == 0) {
  277. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  278. } else if (dim == 1) {
  279. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  280. } else { // dim == 2
  281. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  282. }
  283. }
  284. }
  285. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  286. int n_dims = x->n_dims;
  287. GGML_ASSERT(dim >= 0);
  288. GGML_ASSERT(dim < n_dims);
  289. GGML_ASSERT(n_dims < 4);
  290. GGML_ASSERT(x->ne[dim] % num_el == 0);
  291. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  292. if (n_dims == 1) {
  293. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  294. } else if (n_dims == 2) {
  295. if (dim == 0) {
  296. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  297. } else { // dim == 1
  298. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  299. }
  300. } else { // (n_dims == 3)
  301. if (dim == 0) {
  302. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  303. } else if (dim == 1) {
  304. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  305. } else { // dim == 2
  306. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  307. }
  308. }
  309. }
  310. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  311. // (B, S, dim) -> (B, S, H, H_dim)
  312. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  313. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  314. x = ggml_cont(ctx, x);
  315. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  316. return x;
  317. }
  318. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  319. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  320. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  321. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  322. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  323. v = ggml_cont(ctx, v);
  324. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  325. return v;
  326. }
  327. // flash_attn doesn't work for cross attention because it assumes Q <= K
  328. // and it seems to yield slightly different scores than expected, and thus a different beam search
  329. # define UNITY_FLASH_ATTN 0
  330. extern "C" ggml_tensor* MultiheadAttention_forward(
  331. fairseq2_model& model,
  332. const std::string &prefix,
  333. ggml_tensor* queries, // (slen, d_in)
  334. ggml_tensor* keys, // (klen, d_in)
  335. ggml_tensor* values, // (klen, d_out)
  336. ggml_tensor* attn_mask // (klen, slen)
  337. ) {
  338. int model_dim = queries->ne[0];
  339. int num_heads = model.layer_config.at(prefix + ".num_heads");
  340. int head_dim = model_dim / num_heads;
  341. GGML_ASSERT(model_dim % num_heads == 0);
  342. ggml_context* ctx = model.ctx;
  343. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  344. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  345. ggml_set_name(q, "q");
  346. ggml_tensor *k, *v;
  347. if (!has_kv_cache(model)) {
  348. k = Linear_forward(model, prefix + ".k_proj", keys);
  349. ggml_set_name(k, "k");
  350. v = Linear_forward(model, prefix + ".v_proj", values);
  351. ggml_set_name(v, "v");
  352. } else {
  353. bool encoder_decoder_attn = keys == values && keys != queries;
  354. if (encoder_decoder_attn) {
  355. // The K and V tensors of an encoder-decoder attention (i.e. the
  356. // projected encoder outputs) remain static during evaluation.
  357. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  358. if (kv_cache.step_nr == 0) {
  359. // If possible we use the ctx dedicated to kv_cache here,
  360. // because the enc dec attention is typically long lived.
  361. if (model.kv_cache_ctx) model.ctx = model.kv_cache_ctx;
  362. k = Linear_forward(model, prefix + ".k_proj", keys);
  363. ggml_set_name(k, "k");
  364. v = Linear_forward(model, prefix + ".v_proj", values);
  365. ggml_set_name(v, "v");
  366. // Note we are only storing a pointer to the buffer, not the full graph
  367. kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
  368. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  369. kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
  370. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  371. kv_cache.step_nr = keys->ne[1];
  372. model.ctx = ctx;
  373. } else {
  374. k = kv_cache.full_k;
  375. v = kv_cache.full_v;
  376. GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
  377. GGML_ASSERT(values->ne[1] == v->ne[1]); // cache content doesn't match the input sequence
  378. }
  379. } else { // self attention
  380. // (1, K) -> (N, 1, K_proj)
  381. k = Linear_forward(model, prefix + ".k_proj", keys);
  382. ggml_set_name(k, "k");
  383. // (1, V) -> (N, 1, V_proj)
  384. v = Linear_forward(model, prefix + ".v_proj", values);
  385. ggml_set_name(v, "v");
  386. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  387. }
  388. }
  389. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  390. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  391. v = ggml_cont(ctx, v);
  392. #if UNITY_FLASH_ATTN
  393. // For flash_attn, we assume either no masks, or triangular masks.
  394. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  395. ggml_set_name(attn, "attn");
  396. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  397. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  398. #else
  399. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  400. ggml_tensor* qk = mul_mat(ctx, k, q);
  401. ggml_set_name(qk, "qk");
  402. FORCE_ALLOC(qk_scale, ctx, ggml_new_tensor_1d(ctx, qk->type, 1));
  403. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  404. qk = ggml_scale(ctx, qk, qk_scale);
  405. ggml_set_name(qk, "qk_scaled");
  406. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  407. // TODO: upgrade qk to float32 if needed
  408. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  409. ggml_set_name(attn_weights, "attn_weights");
  410. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  411. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  412. ggml_set_name(attn, "attn");
  413. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  414. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  415. #endif // UNITY_FLASH_ATTN
  416. attn = ggml_cont(ctx, attn);
  417. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  418. // out -> (B, S, d_out)
  419. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  420. ggml_set_name(out, "out");
  421. return out;
  422. }
  423. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  424. fairseq2_model& model,
  425. const std::string& prefix,
  426. ggml_tensor* seqs,
  427. ggml_tensor* padding_mask
  428. ) {
  429. ggml_context* ctx = model.ctx;
  430. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  431. // _forward_self_attn(seqs, padding_mask)
  432. auto residual = seqs;
  433. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  434. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  435. // TODO: add padding_mask to MultiheadAttention_forward
  436. GGML_ASSERT(padding_mask == nullptr);
  437. seqs = MultiheadAttention_forward(
  438. model,
  439. prefix + ".self_attn",
  440. seqs,
  441. seqs,
  442. seqs,
  443. /*attn_mask=*/nullptr
  444. );
  445. if (has_layer(model, prefix + ".self_attn_norm"))
  446. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  447. seqs = ggml_add_inplace(ctx, seqs, residual);
  448. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  449. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  450. // _forward_ffn(seqs)
  451. residual = seqs;
  452. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  453. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  454. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  455. // TODO: if self.residual_scale is not None:
  456. // residual = self.residual_scale * residual
  457. seqs = ggml_add_inplace(ctx, seqs, residual);
  458. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  459. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  460. return seqs;
  461. }
  462. extern "C" ggml_tensor* WaveformToFbank_forward(
  463. fairseq2_model& model,
  464. const std::string &prefix,
  465. ggml_tensor* waveform
  466. ) {
  467. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  468. ggml_context* ctx = model.ctx;
  469. knf::MelBanksOptions mel_opts{};
  470. mel_opts.num_bins = 80;
  471. knf::FrameExtractionOptions frame_opts{};
  472. frame_opts.samp_freq = 16000;
  473. knf::FbankOptions opts{};
  474. opts.frame_opts = frame_opts;
  475. opts.mel_opts = mel_opts;
  476. std::vector<float_t> signal_frame{};
  477. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  478. FORCE_ALLOC(output, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames));
  479. knf::FbankComputer native_(opts);
  480. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  481. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  482. signal_frame.resize(0);
  483. // Extract the frame from the waveform tensor.
  484. knf::ExtractWindow(
  485. /*sample_offset=*/0,
  486. (float *)(waveform->data),
  487. waveform->ne[0],
  488. frame_nr,
  489. frame_opts,
  490. window_fn_,
  491. &signal_frame);
  492. native_.Compute(
  493. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  494. }
  495. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  496. output = ggml_norm(ctx, output, 1e-5);
  497. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  498. if (output->ne[1] % 2 == 1) {
  499. ggml_tensor* remove_last = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, output->ne[1]-1);
  500. for (int i = 0; i < output->ne[1]-1; ++i) {
  501. ((int32_t *) remove_last->data)[i] = i;
  502. }
  503. output = ggml_get_rows(ctx, output, remove_last);
  504. }
  505. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  506. return output;
  507. }
  508. // TODO: Check if it's possible to merge with standard MHA
  509. extern "C" ggml_tensor* RelativePositionMHA_forward(
  510. fairseq2_model& model,
  511. const std::string& prefix,
  512. ggml_tensor* seqs
  513. ) {
  514. ggml_context* ctx = model.ctx;
  515. ggml_tensor* residual = seqs;
  516. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  517. // self_attn: qkv
  518. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  519. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  520. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  521. // self_attn: rel_pos SDPA
  522. int32_t S = seqs->ne[1];
  523. int32_t H = 16; // TODO: Make this configurable
  524. int32_t n_ctx = 4096;
  525. int32_t K_h = seqs->ne[0] / H;
  526. int32_t start_index = n_ctx - S;
  527. int32_t end_index = n_ctx + S - 1;
  528. int num_indices = end_index - start_index;
  529. FORCE_ALLOC(rows, ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices));
  530. for (int i = 0; i < num_indices; i++) {
  531. ((int32_t *)rows->data)[i] = start_index + i;
  532. }
  533. // self_attn: load pos_enc weights & compute_r
  534. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  535. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  536. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  537. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  538. r = ggml_dup(ctx, ggml_permute(ctx,
  539. ggml_cpy(ctx,
  540. r,
  541. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S*2-1)),
  542. 0, 2, 1, 3));
  543. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  544. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  545. // self_attn: Permute QKV
  546. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx,
  547. ggml_cpy(ctx,
  548. Qcur,
  549. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  550. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  551. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx,
  552. ggml_cpy(ctx,
  553. Kcur,
  554. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  555. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  556. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx,
  557. ggml_cpy(ctx,
  558. Vcur,
  559. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  560. 1, 2, 0, 3)); // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  561. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  562. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  563. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  564. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  565. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  566. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  567. FORCE_ALLOC(pad, ctx, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1));
  568. pad = ggml_set_f32(pad, 0.0);
  569. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  570. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  571. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  572. // discard the first set of positive positions
  573. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  574. // shifts each row by an extra step
  575. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  576. // Discard positions used for shift.
  577. bd = ggml_slice(ctx, bd, 0, 0, S);
  578. // self_attn: compute attn / weights
  579. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  580. FORCE_ALLOC(attn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  581. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  582. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  583. attn_weights = ggml_soft_max(ctx, attn_weights);
  584. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  585. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  586. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  587. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  588. attn_out = ggml_add_inplace(
  589. ctx,
  590. attn_out,
  591. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  592. );
  593. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  594. return attn_out;
  595. }
  596. extern "C" ggml_tensor* ConvModule_forward(
  597. fairseq2_model& model,
  598. const std::string& prefix,
  599. ggml_tensor* seqs
  600. ) {
  601. ggml_context* ctx = model.ctx;
  602. ggml_tensor* residual = seqs;
  603. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  604. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  605. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  606. // conv: GLU
  607. seqs = ggml_glu(ctx, seqs);
  608. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  609. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  610. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, 15, 1);
  611. // conv: Custom implementation of batch norm
  612. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  613. // conv: SiLU actvation
  614. seqs = ggml_silu_inplace(ctx, seqs);
  615. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  616. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  617. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  618. // conv: + residual
  619. seqs = ggml_add_inplace(ctx, seqs, residual);
  620. return seqs;
  621. }
  622. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  623. fairseq2_model& model,
  624. const std::string& prefix,
  625. ggml_tensor* seqs,
  626. ggml_tensor* padding_mask
  627. ) {
  628. ggml_context* ctx = model.ctx;
  629. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  630. ggml_set_f32(ffn_scale, 0.5f);
  631. ggml_tensor* residual = seqs;
  632. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  633. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  634. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  635. seqs = ggml_add_inplace(ctx, seqs, residual);
  636. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  637. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  638. residual = seqs;
  639. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  640. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  641. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  642. seqs = ggml_add_inplace(ctx, seqs, residual);
  643. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  644. return seqs;
  645. }
  646. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  647. fairseq2_model& model,
  648. const std::string& prefix,
  649. ggml_tensor* seqs,
  650. ggml_tensor* padding_mask
  651. ) {
  652. ggml_context* ctx = model.ctx;
  653. seqs = WaveformToFbank_forward(model, prefix, seqs);
  654. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  655. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  656. int layer_idx = 0;
  657. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  658. while (has_layer(model, layer_name)) {
  659. seqs = StandardConformerEncoderLayer_forward(
  660. model, layer_name, seqs, padding_mask
  661. );
  662. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  663. layer_idx += 1;
  664. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  665. }
  666. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  667. ggml_tensor* residual = seqs;
  668. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  669. seqs = ggml_relu_inplace(ctx, seqs);
  670. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  671. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  672. ggml_set_f32(ffn_scale, 0.5f);
  673. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  674. seqs = ggml_add_inplace(ctx, seqs, residual);
  675. layer_idx = 0;
  676. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  677. while (has_layer(model, layer_name)) {
  678. seqs = StandardConformerEncoderAdaptorLayer_forward(
  679. model, layer_name, seqs, padding_mask
  680. );
  681. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  682. layer_idx += 1;
  683. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  684. }
  685. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  686. return seqs;
  687. }
  688. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  689. fairseq2_model& model,
  690. const std::string& prefix,
  691. ggml_tensor* seqs,
  692. ggml_tensor* padding_mask
  693. ) {
  694. ggml_context* ctx = model.ctx;
  695. ggml_tensor* residual = seqs;
  696. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  697. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  698. residual = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1);
  699. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  700. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  701. residual = ggml_glu(ctx, residual);
  702. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  703. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  704. seqs = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1);
  705. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  706. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  707. seqs = ggml_glu(ctx, seqs);
  708. seqs = MultiheadAttention_forward(
  709. model,
  710. prefix + ".self_attn",
  711. seqs,
  712. seqs,
  713. seqs,
  714. /*attention masks=*/nullptr
  715. );
  716. seqs = ggml_add_inplace(ctx, seqs, residual);
  717. residual = seqs;
  718. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  719. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  720. seqs = ggml_add_inplace(ctx, seqs, residual);
  721. return seqs;
  722. }
  723. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  724. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  725. ggml_tensor* ggml_slice(
  726. struct ggml_context * ctx,
  727. struct ggml_tensor * a,
  728. int axis,
  729. int64_t start,
  730. int64_t end
  731. ) {
  732. int64_t ne[4];
  733. std::copy(a->ne, a->ne + 4, ne);
  734. if (axis < 0) axis = a->n_dims + axis;
  735. if (start < 0) start = ne[axis] + start;
  736. if (end <= 0) end = ne[axis] + end;
  737. GGML_ASSERT(0 <= start);
  738. GGML_ASSERT(start < end);
  739. GGML_ASSERT(end <= ne[axis]);
  740. ne[axis] = end - start;
  741. size_t offset = a->nb[axis] * start;
  742. size_t* nb = a->nb;
  743. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  744. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  745. result->n_dims = a->n_dims;
  746. return result;
  747. }
  748. ggml_tensor* ggml_select(
  749. struct ggml_context * ctx,
  750. struct ggml_tensor * a,
  751. int axis,
  752. int64_t index
  753. ) {
  754. int64_t ne[GGML_MAX_DIMS];
  755. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  756. if (axis < 0) axis = a->n_dims + axis;
  757. if (index < 0) index = ne[axis] + index;
  758. GGML_ASSERT(0 <= index);
  759. GGML_ASSERT(index < ne[axis]);
  760. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  761. size_t offset = a->nb[axis] * index;
  762. size_t* nb = a->nb;
  763. GGML_ASSERT(GGML_MAX_DIMS == 4);
  764. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  765. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  766. result->n_dims = a->n_dims - 1;
  767. return result;
  768. }
  769. // Inplace computation of PositionalEmbedding
  770. extern "C" ggml_tensor* PositionalEmbedding_forward(
  771. fairseq2_model& model,
  772. const std::string& prefix,
  773. ggml_tensor* embeds
  774. ) {
  775. // This only work with the simple pos encoders
  776. int seq_len = embeds->ne[1];
  777. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  778. int start_step = 0;
  779. if (has_kv_cache(model)) {
  780. start_step = model.kv_cache[prefix].step_nr++;
  781. }
  782. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  783. return ggml_add(model.ctx, embeds, pos_embeds);
  784. }
  785. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  786. fairseq2_model& model,
  787. const std::string& prefix,
  788. ggml_tensor* seqs
  789. ) {
  790. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  791. ggml_context* ctx = model.ctx;
  792. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  793. GGML_ASSERT(embed_weights != nullptr);
  794. ggml_tensor* embeds;
  795. if (seqs->n_dims == 1) {
  796. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  797. } else {
  798. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  799. ggml_tensor* flat_seqs = seqs;
  800. if (!ggml_is_contiguous(seqs)) {
  801. flat_seqs = ggml_cont(ctx, flat_seqs);
  802. }
  803. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  804. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  805. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  806. embeds->n_dims = seqs->n_dims + 1;
  807. }
  808. // padding mask ?
  809. // padding_mask = to_padding_mask(embeds, seq_lens)
  810. if (has_layer(model, prefix + ".pos_encoder")) {
  811. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  812. }
  813. if (has_layer(model, prefix + ".layer_norm")) {
  814. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  815. }
  816. return embeds;
  817. }
  818. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  819. fairseq2_model& model,
  820. const std::string& prefix,
  821. ggml_tensor* seqs,
  822. ggml_tensor* padding_mask
  823. ) {
  824. int layer_idx = 0;
  825. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  826. while (has_layer(model, layer_name)) {
  827. seqs = StandardTransformerEncoderLayer_forward(
  828. model, layer_name, seqs, padding_mask
  829. );
  830. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  831. layer_idx += 1;
  832. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  833. }
  834. if (has_layer(model, prefix + ".layer_norm"))
  835. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  836. return seqs;
  837. }
  838. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  839. fairseq2_model& model,
  840. const std::string& prefix,
  841. ggml_tensor* seqs,
  842. ggml_tensor* self_attn_mask,
  843. ggml_tensor* encoder_output,
  844. ggml_tensor* encoder_padding_mask
  845. ) {
  846. ggml_context* ctx = model.ctx;
  847. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  848. // _forward_self_attn(seqs, padding_mask)
  849. auto residual = seqs;
  850. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  851. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  852. seqs = MultiheadAttention_forward(
  853. model,
  854. prefix + ".self_attn",
  855. seqs,
  856. seqs,
  857. seqs,
  858. /*attn_mask=*/self_attn_mask
  859. );
  860. if (has_layer(model, prefix + ".self_attn_norm"))
  861. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  862. seqs = ggml_add_inplace(ctx, seqs, residual);
  863. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  864. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  865. // _forward_encoder_decoder_attn
  866. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  867. // `encoder_output` must be `None` for decoder-only attention.
  868. GGML_ASSERT(encoder_output == nullptr);
  869. return seqs;
  870. }
  871. // `encoder_output` must not be `None` for encoder-decoder attention.
  872. GGML_ASSERT(encoder_output != nullptr);
  873. residual = seqs;
  874. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  875. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  876. seqs = MultiheadAttention_forward(
  877. model,
  878. prefix + ".encoder_decoder_attn",
  879. seqs,
  880. encoder_output,
  881. encoder_output,
  882. /*attention masks=*/encoder_padding_mask
  883. );
  884. seqs = ggml_add_inplace(ctx, seqs, residual);
  885. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  886. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  887. // _forward_ffn(seqs)
  888. residual = seqs;
  889. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  890. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  891. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  892. // TODO:
  893. // if self.residual_scale is not None:
  894. // residual = self.residual_scale * residual
  895. seqs = ggml_add_inplace(ctx, seqs, residual);
  896. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  897. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  898. return seqs;
  899. }
  900. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  901. auto seq_len = seqs->ne[1];
  902. // TODO: allow other ggml_type
  903. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  904. return ggml_diag_mask_inf(ctx, mask, 0);
  905. }
  906. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  907. fairseq2_model& model,
  908. const std::string& prefix,
  909. ggml_tensor* seqs,
  910. ggml_tensor* padding_mask,
  911. ggml_tensor* encoder_output,
  912. ggml_tensor* encoder_padding_mask
  913. ) {
  914. int layer_idx = 0;
  915. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  916. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  917. while (has_layer(model, layer_name)) {
  918. seqs = StandardTransformerDecoderLayer_forward(
  919. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  920. );
  921. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  922. layer_idx += 1;
  923. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  924. }
  925. if (has_layer(model, prefix + ".layer_norm"))
  926. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  927. return seqs;
  928. }
  929. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  930. auto opts = job.opts;
  931. int max_seq_len = -1;
  932. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  933. max_seq_len = opts.hard_max_seq_len;
  934. } else {
  935. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  936. }
  937. if (opts.min_seq_len > max_seq_len) {
  938. printf(
  939. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  940. opts.min_seq_len,
  941. max_seq_len
  942. );
  943. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  944. }
  945. int prefix_seq_len = job.prefix_seq->ne[0];
  946. if (prefix_seq_len >= max_seq_len) {
  947. printf(
  948. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  949. prefix_seq_len,
  950. max_seq_len
  951. );
  952. GGML_ASSERT(prefix_seq_len < max_seq_len);
  953. }
  954. return max_seq_len;
  955. }
  956. void _fan_out_encoder_output(
  957. ggml_context* ctx,
  958. ggml_tensor** encoder_output_out,
  959. ggml_tensor** encoder_padding_mask_out,
  960. int beam_size
  961. ) {
  962. // (S_enc, M)
  963. ggml_tensor* encoder_output = *encoder_output_out;
  964. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  965. // (B, S_enc, M)
  966. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  967. // (S_enc, M) -> (B, S_enc, M)
  968. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  969. // (S_enc) -> (B, S_enc)
  970. if (encoder_padding_mask != nullptr) {
  971. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  972. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  973. }
  974. }
  975. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  976. // TODO: this isn't the most precise way of doing this
  977. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  978. }
  979. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  980. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  981. ggml_type true_type = x->type;
  982. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  983. y->type = true_type;
  984. return y;
  985. }
  986. extern "C" void _bootstrap_seqs_and_scores(
  987. fairseq2_model& model,
  988. const SequenceGeneratorJob& job,
  989. ggml_tensor* full_seqs,
  990. ggml_tensor* scores,
  991. ggml_tensor* encoder_output,
  992. ggml_tensor* encoder_padding_mask
  993. ) {
  994. int prefix_seq_len = job.prefix_seq->ne[0];
  995. int max_seq_len = scores->ne[0];
  996. int beam_size = scores->ne[1];
  997. GGML_ASSERT(prefix_seq_len > 0);
  998. if (prefix_seq_len == 1)
  999. return;
  1000. ggml_context* ctx = model.ctx;
  1001. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  1002. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  1003. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  1004. // We have to bootstrap the model with the already fanned-out encoder
  1005. // output to correctly initialize its incremental state.
  1006. // Note: we don't start decoding the last prefix token just yet.
  1007. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  1008. // Bootstrap the model state with prefix sequence.
  1009. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  1010. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1011. model,
  1012. "text_decoder",
  1013. seqs,
  1014. /*padding_mask*/ nullptr,
  1015. encoder_output,
  1016. encoder_padding_mask
  1017. );
  1018. // logits, lprobs: (N, S_pfx - 1, V)
  1019. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  1020. int vocab_size = logits->ne[0];
  1021. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  1022. ggml_cgraph gf = ggml_build_forward(lprobs);
  1023. ggml_graph_compute_with_ctx(ctx, &gf, 1);
  1024. // Fetch scores of next steps from "lprobs"
  1025. float p_score = 0;
  1026. for (int i = 1; i < prefix_seq_len; ++i) {
  1027. int p = ggml_get_i32_1d(job.prefix_seq, i);
  1028. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1029. for (int b = 0; b < beam_size; ++b) {
  1030. // scores: (N, S)
  1031. // Note: First step (e.g. BOS)'s score is always 0.
  1032. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1033. }
  1034. }
  1035. }
  1036. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1037. int topk(
  1038. ggml_tensor* lprobs, // (B, V)
  1039. std::int64_t k,
  1040. ggml_tensor* candidate_indices
  1041. ) {
  1042. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1043. // `beam_size` of these which don't predict EOS to continue with.
  1044. // (N, 2 x B)
  1045. // `vocab_size` - 1 to never select PAD.
  1046. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1047. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1048. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1049. };
  1050. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1051. auto cand = (std::int32_t*)candidate_indices->data;
  1052. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1053. return K;
  1054. }
  1055. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1056. std::size_t beam_size = job.opts.beam_size;
  1057. std::size_t eos_idx = job.eos_idx;
  1058. // Do not allow EOS before reaching the minimum sequence length.
  1059. if (step_nr < job.opts.min_seq_len) {
  1060. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1061. for (size_t i = 0; i < beam_size; ++i)
  1062. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1063. }
  1064. // If we have reached the maximum length, force the last step to be EOS.
  1065. if (step_nr == max_seq_len - 2) {
  1066. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1067. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1068. for (size_t b = 0; b < beam_size; ++b) {
  1069. size_t t = 0;
  1070. for (t = 0; t < eos_idx; ++t)
  1071. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1072. for (t = eos_idx + 1; t < vocab_size; ++t)
  1073. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1074. }
  1075. }
  1076. // Never allow PAD.
  1077. std::size_t pad_idx = job.pad_idx;
  1078. for (size_t i = 0; i < beam_size; ++i)
  1079. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1080. // Apply UNK penalty.
  1081. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1082. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1083. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1084. for (size_t i = 0; i < beam_size; ++i)
  1085. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1086. }
  1087. }
  1088. /// Copies the sequence and scores of a given candidate beam.
  1089. void _finalize_hypothesis(
  1090. const SequenceGeneratorJob& job,
  1091. ggml_context* ctx,
  1092. int step_nr,
  1093. std::int32_t beam,
  1094. std::int32_t token,
  1095. float eos_score,
  1096. ggml_tensor* seqs, // (beam_size, seq_len)
  1097. ggml_tensor* scores, // (beam_size, seq_len)
  1098. Hypothesis* hypothesis
  1099. ) {
  1100. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1101. hypothesis->seq = seq;
  1102. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1103. hypothesis->step_scores = step_scores;
  1104. auto tok = (std::int32_t*)seq->data;
  1105. for (int i = 0; i < step_nr + 1; ++i) {
  1106. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1107. }
  1108. tok[step_nr + 1] = token;
  1109. // Convert from cumulative to per-step scores.
  1110. auto sc = (float*)step_scores->data;
  1111. float last_score = eos_score;
  1112. for (int i = step_nr; i >= 0; --i) {
  1113. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1114. sc[i + 1] = last_score - sc0;
  1115. last_score = sc0;
  1116. }
  1117. sc[0] = 0;
  1118. if (job.opts.normalize_scores)
  1119. // Skip first EOS since it is always 0 and skews normalization.
  1120. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1121. hypothesis->score = eos_score;
  1122. }
  1123. // Uses ggml_context to store any object.
  1124. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1125. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1126. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1127. return ggml_init({
  1128. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1129. /*.mem_buffer =*/ buffer.data(),
  1130. /*.no_alloc =*/ false,
  1131. });
  1132. }
  1133. ggml_allocr* new_arena_allocr(std::vector<uint8_t>& buffer) {
  1134. return ggml_allocr_new(buffer.data(), buffer.capacity(), 8);
  1135. }
  1136. /// Generates a translation for a single sequence
  1137. /// The results Hypothesis are written inside `result_ctx`.
  1138. extern "C" Hypothesis* generate_sequence(
  1139. fairseq2_model& model,
  1140. const SequenceGeneratorJob& job,
  1141. ggml_tensor* encoder_output,
  1142. ggml_tensor* encoder_padding_mask,
  1143. ggml_context* result_ctx
  1144. ) {
  1145. // Pre allocate memory buffers.
  1146. // * step_ctx: contains metadata for the model graph, as well as some explicit
  1147. // buffers for the lprobs tweaking.
  1148. // * prev_step_ctx: is an additional buffer because we need some results from previous steps,
  1149. // to compute next step. Notably self attention kv cache.
  1150. // * search_ctx contains tensors that should live for the full search,
  1151. // like encoder kv cache.
  1152. // * step_alloc contains buffer for the forward pass of the model.
  1153. // TODO: the size allocated should depend on the input length and vocab size
  1154. std::vector<uint8_t> local_bufs[5] = {
  1155. std::vector<uint8_t>(128 * 1024 * 1024), // step_ctx
  1156. std::vector<uint8_t>(128 * 1024 * 1024), // prev_step_ctx
  1157. std::vector<uint8_t>(256 * 1024 * 1024), // search_ctx
  1158. std::vector<uint8_t>(256 * 1024 * 1024), // step_alloc
  1159. };
  1160. ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
  1161. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1162. size_t vocab_size = embed->ne[1];
  1163. std::size_t beam_size = job.opts.beam_size;
  1164. ggml_detach(encoder_output);
  1165. int source_seq_len = encoder_output->ne[1];
  1166. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1167. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1168. ggml_context* original_ctx = model.ctx;
  1169. fairseq2_kv_cache_alloc(model, search_ctx, beam_size, max_seq_len);
  1170. // (S_enc, M) -> (B, S_enc, M)
  1171. model.ctx = search_ctx;
  1172. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1173. // Allocate results in the context provided by the caller.
  1174. ggml_set_no_alloc(result_ctx, false);
  1175. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1176. Hypothesis* finished_searches = finished_searches_begin;
  1177. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1178. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1179. // Initialize buffers. (B, S)
  1180. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1181. ggml_set_i32(seqs, 0);
  1182. ggml_set_name(seqs, "seqs_0");
  1183. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1184. ggml_set_name(scores, "scores_0");
  1185. ggml_set_f32(scores, 0.0);
  1186. int prefix_seq_len = job.prefix_seq->ne[0];
  1187. int start_step = prefix_seq_len - 1;
  1188. ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step - 1) % 2]);
  1189. ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
  1190. GGML_ASSERT(step_ctx != search_ctx);
  1191. GGML_ASSERT(prev_step_ctx != step_ctx);
  1192. model.ctx = prev_step_ctx;
  1193. // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
  1194. model.kv_cache_ctx = search_ctx;
  1195. _bootstrap_seqs_and_scores(
  1196. model, job, seqs, scores, encoder_output, encoder_padding_mask
  1197. );
  1198. // Holds the indices of beams (a beam can occur more than once) that we
  1199. // should continue with in the next step.
  1200. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1201. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1202. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1203. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1204. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1205. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1206. ((int32_t *)(candidate_indices->data))[i] = i;
  1207. printf_mem_usage(search_ctx, "search_ctx");
  1208. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1209. model.ctx = step_ctx;
  1210. ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
  1211. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1212. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1213. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1214. model,
  1215. "text_decoder",
  1216. decoder_input,
  1217. nullptr, // We never generate PAD.
  1218. encoder_output,
  1219. encoder_padding_mask
  1220. ); // (B, 1, D)
  1221. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1222. // Force logits to be allocated in step_ctx, not in step_alloc.
  1223. ggml_set_no_alloc(step_ctx, false);
  1224. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1225. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1226. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1227. // TODO: use ggml properly compute the tweaks
  1228. ggml_cgraph gf = ggml_build_forward(lprobs);
  1229. size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, &gf);
  1230. GGML_UNUSED(fwd_mem);
  1231. ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
  1232. ggml_detach(lprobs);
  1233. ggml_allocr_reset(step_alloc);
  1234. #if DEBUG_MEM_USAGE
  1235. printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf.n_nodes);
  1236. printf(" Fwd mem: %.1fMB\n", fwd_mem/1024.0/1024.0);
  1237. std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
  1238. #endif
  1239. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1240. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1241. if (step_nr == start_step) {
  1242. // At the initial step, all hypotheses are equally likely, so we use
  1243. // only the first beam.
  1244. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1245. lprobs = ggml_cont(step_ctx, lprobs);
  1246. // The first step always indicates the beginning of the sequence and has no score.
  1247. if (step_nr > 0) {
  1248. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1249. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1250. }
  1251. } else {
  1252. // Make probabilities contain cumulative scores for each hypothesis.
  1253. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1254. }
  1255. gf = ggml_build_forward(lprobs);
  1256. ggml_graph_compute_with_ctx(step_ctx, &gf, 1);
  1257. // Determine (beam, token) candidates for the next step.
  1258. // (N, 2 x B)
  1259. std::int64_t K = topk(
  1260. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1261. );
  1262. std::size_t ongoing_beams = 0;
  1263. for (std::int32_t i = 0; i < K; ++i) {
  1264. int c = ggml_get_f32_1d(candidate_indices, i);
  1265. std::int32_t beam = c / vocab_size;
  1266. std::int32_t token = c % vocab_size;
  1267. float tok_score = ggml_get_f32_1d(lprobs, c);
  1268. // Detect beams that reached the minimum length and that end with an EOS.
  1269. bool eos = token == job.eos_idx;
  1270. eos &= tok_score != -INFINITY;
  1271. if (eos) {
  1272. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches++);
  1273. if (finished_searches == finished_searches_end)
  1274. goto end_of_beam_search;
  1275. continue;
  1276. }
  1277. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1278. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1279. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1280. ongoing_beams += 1;
  1281. if (ongoing_beams >= beam_size) break;
  1282. }
  1283. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1284. // be selected more than once.
  1285. // (B, S), (B) -> (B, S)
  1286. ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
  1287. ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
  1288. ggml_cgraph gf_reorder = ggml_build_forward(new_seqs);
  1289. ggml_build_forward_expand(&gf_reorder, new_scores);
  1290. reorder_kv_cache(model, step_ctx, &gf_reorder, beam_indices);
  1291. ggml_graph_compute_with_ctx(step_ctx, &gf_reorder, 1);
  1292. seqs = ggml_detach(new_seqs);
  1293. scores = ggml_detach(new_scores);
  1294. // seqs[:, step_nr + 1] = next_tokens
  1295. // scores[:, step_nr + 1] = next_scores
  1296. for (std::size_t i = 0; i < beam_size; ++i) {
  1297. ((std::int32_t*)seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1298. ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1299. }
  1300. printf_mem_usage(step_ctx, "step_ctx");
  1301. ggml_free(prev_step_ctx);
  1302. prev_step_ctx = step_ctx;
  1303. #if DEBUG_MEM_USAGE
  1304. std::fill(local_bufs[(step_nr + 1) % 2].begin(), local_bufs[(step_nr + 1) % 2].end(), 0xAA);
  1305. #endif
  1306. step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1307. }
  1308. end_of_beam_search:
  1309. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1310. std::sort(
  1311. finished_searches_begin,
  1312. finished_searches_end,
  1313. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1314. );
  1315. fairseq2_kv_cache_reset(model);
  1316. model.ctx = original_ctx;
  1317. return finished_searches_begin;
  1318. }
  1319. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1320. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1321. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1322. ggml_set_i32_1d(result[0].seq, 0, 314);
  1323. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1324. ggml_set_i32_1d(result[1].seq, 0, 421);
  1325. return result;
  1326. }
  1327. // SPM tokenizer
  1328. // original implementation:
  1329. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1330. struct llm_symbol {
  1331. using index = int;
  1332. index prev;
  1333. index next;
  1334. const char * text;
  1335. size_t n;
  1336. llama_vocab::id id;
  1337. };
  1338. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1339. static size_t utf8_len(char src) {
  1340. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1341. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1342. return lookup[highbits];
  1343. }
  1344. struct llm_bigram_spm {
  1345. struct comparator {
  1346. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1347. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1348. }
  1349. };
  1350. using queue_storage = std::vector<llm_bigram_spm>;
  1351. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1352. llm_symbol::index left;
  1353. llm_symbol::index right;
  1354. float score;
  1355. size_t size;
  1356. llama_vocab::id id;
  1357. };
  1358. struct llm_tokenizer_spm {
  1359. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1360. void tokenize(const std::string& input_text, ggml_tensor& output) {
  1361. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1362. // split string into utf8 chars
  1363. int index = 0;
  1364. size_t offs = 0;
  1365. // This is kind of annoying, but needed because with SPM,
  1366. // characters following a space have a special meaning.
  1367. // And the algorithm rely on substrings to do the lookups.
  1368. std::string text = input_text;
  1369. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1370. if (need_extra_space) text = " " + text;
  1371. while (offs < text.size()) {
  1372. size_t len = utf8_len(text[offs]);
  1373. size_t n = std::min(len, text.size() - offs);
  1374. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1375. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1376. llm_symbol sym = {
  1377. /*prev*/ index - 1,
  1378. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1379. /*text*/ text.c_str() + offs,
  1380. /*n*/ n,
  1381. /*id*/ id
  1382. };
  1383. offs += n;
  1384. index++;
  1385. symbols.emplace_back(sym);
  1386. }
  1387. // seed the work queue with all possible 2-character tokens.
  1388. for (size_t i = 1; i < symbols.size(); ++i) {
  1389. try_add_bigram(i - 1, i);
  1390. }
  1391. // keep substituting the highest frequency pairs for as long as we can.
  1392. while (!work_queue.empty()) {
  1393. auto bigram = work_queue.top();
  1394. work_queue.pop();
  1395. auto & left_sym = symbols[bigram.left];
  1396. auto & right_sym = symbols[bigram.right];
  1397. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1398. // if one of the symbols already got merged, skip it.
  1399. if (
  1400. left_sym.n == 0
  1401. || right_sym.n == 0
  1402. || left_sym.n + right_sym.n != bigram.size
  1403. ) continue;
  1404. // merge the right sym into the left one
  1405. left_sym.n += right_sym.n;
  1406. left_sym.id = bigram.id;
  1407. right_sym.n = 0;
  1408. // remove the right sym from the chain
  1409. left_sym.next = right_sym.next;
  1410. if (right_sym.next >= 0) {
  1411. symbols[right_sym.next].prev = bigram.left;
  1412. }
  1413. // find more substitutions
  1414. try_add_bigram(left_sym.prev, bigram.left);
  1415. try_add_bigram(bigram.left, left_sym.next);
  1416. }
  1417. llama_vocab::id* out = (llama_vocab::id*)output.data;
  1418. int out_step = sizeof(llama_vocab::id) / output.nb[0];
  1419. int num_tokens = 0;
  1420. for (int i = 0; i > -1; i = symbols[i].next) {
  1421. llm_symbol& symbol = symbols[i];
  1422. *(out + num_tokens * out_step) = symbol.id;
  1423. num_tokens += 1;
  1424. }
  1425. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1426. num_tokens += 1;
  1427. output.ne[0] = num_tokens;
  1428. }
  1429. private:
  1430. void try_add_bigram(int left, int right) {
  1431. if (left == -1 || right == -1) {
  1432. return;
  1433. }
  1434. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1435. auto token = vocab.token_to_id.find(text);
  1436. if (token == vocab.token_to_id.end()) {
  1437. return;
  1438. }
  1439. llama_vocab::id id = token->second;
  1440. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1441. return;
  1442. }
  1443. const auto& tok_data = vocab.id_to_token[id];
  1444. llm_bigram_spm bigram = {
  1445. /*left */ left,
  1446. /*right*/ right,
  1447. /*score*/ tok_data.score,
  1448. /*size */ text.size(),
  1449. /*id */ id
  1450. };
  1451. work_queue.push(bigram);
  1452. }
  1453. const llama_vocab& vocab;
  1454. std::vector<llm_symbol> symbols;
  1455. llm_bigram_spm::queue work_queue;
  1456. };
  1457. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
  1458. llm_tokenizer_spm spm = {model->vocab};
  1459. spm.tokenize(std::string(text), out);
  1460. }
  1461. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1462. int eos_idx = model->vocab.token_to_id["</s>"];
  1463. int sent_len = tokens->ne[0];
  1464. std::size_t written = 0;
  1465. for (int i = 0; i < sent_len; ++i) {
  1466. int id = ggml_get_i32_1d(tokens, i);
  1467. // Don't print the EOS token but only if it appear at the end.
  1468. if (i == sent_len - 1 && eos_idx == id) break;
  1469. std::string token = model->vocab.id_to_token.at(id).text;
  1470. // Skip the first space outputted.
  1471. auto begin = token.begin();
  1472. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1473. std::copy(begin, token.end(), out);
  1474. std::size_t n = token.end() - begin;
  1475. written += n;
  1476. out += n;
  1477. }
  1478. *out = '0';
  1479. return written;
  1480. }