fairseq2.cpp 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. #include "ggml-alloc.h"
  12. #include <numeric>
  13. ggml_tensor* ggml_detach(ggml_tensor* a) {
  14. a->op = GGML_OP_NONE;
  15. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  16. return a;
  17. }
  18. // generate_sequence uses ggml_context and ggml_allocr to reuse memory buffers across steps.
  19. // This can lead to dangling pointers, which don't segfault, but instead read garbage data.
  20. // Enabling this flag allows to explictly reset memory buffers, making it more explicit
  21. // when we read garbage data.
  22. // It also prints memory usage information, which is useful to
  23. #define DEBUG_MEM_USAGE DEBUG
  24. size_t MB = 1024 * 1024;
  25. void printf_mem_usage(ggml_context* ctx, std::string name) {
  26. #if DEBUG_MEM_USAGE
  27. double mb = 1024.0 * 1024.0;
  28. printf(
  29. "%s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  30. name.c_str(),
  31. ggml_used_mem(ctx) / mb,
  32. ggml_get_mem_size(ctx) / mb
  33. );
  34. #endif
  35. }
  36. #define SWAP(x, y) \
  37. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  38. #define GGML_ASSERT_SHAPE(x, ne0, ne1, ne2, ne3) \
  39. GGML_ASSERT((ne0 == -1 || x->ne[0] == ne0) && (ne1 == -1 || x->ne[1] == ne1) && (ne2 == -1 || x->ne[2] == ne2) && (ne3 == -1 || x->ne[3] == ne3));
  40. /// allocate the fairseq2 model and hyperparameters
  41. extern "C" fairseq2_model* fairseq2_model_alloc() {
  42. // pre-allocate some memory to write hyperparameters and tensors pointers
  43. auto* model = new fairseq2_model;
  44. model->tensors_ctx = nullptr;
  45. return model;
  46. }
  47. extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_cache_ctx, int beam_size, int max_seq_len) {
  48. // Note: we only allocate the masks, proper kv cache allocation is delayed.
  49. GGML_ASSERT(kv_cache_ctx);
  50. GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx)); // We need to be able to alloc the kv_cache buffers
  51. model.kv_cache_ctx = kv_cache_ctx;
  52. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  53. FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
  54. self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
  55. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  56. for (auto named_tensor : model.tensors) {
  57. const std::string& name = named_tensor.first;
  58. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  59. continue;
  60. // create a cache entry without the ".k_proj.weight" suffix
  61. const std::string& shortname = name.substr(0, name.size() - 14);
  62. KeyValueTensor& kv = model.kv_cache[shortname];
  63. kv.step_nr = 0;
  64. kv.full_k = nullptr;
  65. kv.full_v = nullptr;
  66. kv.self_attn_mask = self_attn_mask;
  67. }
  68. }
  69. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  70. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  71. model.kv_cache.clear();
  72. }
  73. bool has_kv_cache(const fairseq2_model& model) {
  74. return model.kv_cache.size() > 0;
  75. }
  76. inline ggml_tensor* ggml_squeeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  77. int n_dims = x->n_dims;
  78. GGML_ASSERT(dim >= 0);
  79. GGML_ASSERT(dim < n_dims);
  80. GGML_ASSERT(x->ne[dim] == 1);
  81. return ggml_flatten_1d(ctx, x, dim);
  82. }
  83. inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  84. return ggml_unflatten_1d(ctx, x, dim, 1);
  85. }
  86. // copy k and v to kv cache
  87. // kv.full_k[step_nr] = k;
  88. // kv.full_v[step_nr] = v;
  89. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  90. KeyValueTensor& kv = model.kv_cache[prefix];
  91. int step_nr = kv.step_nr;
  92. ggml_context* ctx = model.kv_cache_ctx ? model.kv_cache_ctx : model.ctx;
  93. // We need to force allocation here, otherwise the kv_cache buffers can be reused
  94. bool no_alloc_save = ggml_get_no_alloc(ctx);
  95. ggml_set_no_alloc(ctx, false);
  96. int n_steps = (*k)->ne[1];
  97. int k_proj, batch_size;
  98. if (kv.full_k != nullptr) {
  99. // (N, S_kv, K_proj)
  100. k_proj = kv.full_k->ne[0];
  101. batch_size = kv.full_k->ne[2];
  102. ggml_detach(kv.full_k);
  103. ggml_detach(kv.full_v);
  104. kv.full_k = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_k, 1), ggml_unsqueeze(ctx, *k, 1)), 1);
  105. kv.full_v = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_v, 1), ggml_unsqueeze(ctx, *v, 1)), 1);
  106. } else {
  107. GGML_ASSERT(step_nr == 0);
  108. k_proj = (*k)->ne[0];
  109. batch_size = (*v)->ne[2];
  110. kv.full_k = ggml_dup(ctx, *k);
  111. kv.full_v = ggml_dup(ctx, *v);
  112. }
  113. *k = kv.full_k;
  114. *v = kv.full_v;
  115. ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
  116. ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
  117. step_nr += n_steps;
  118. GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
  119. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  120. // we return the Sq slice of the (Sq, Sk) attention mask
  121. if (self_attn_mask != nullptr) {
  122. *self_attn_mask = ggml_slice(
  123. ctx, ggml_slice(ctx, kv.self_attn_mask, 0, 0, step_nr),
  124. 1, step_nr - 1, step_nr
  125. );
  126. }
  127. kv.step_nr = step_nr;
  128. ggml_set_no_alloc(ctx, no_alloc_save);
  129. }
  130. // variant of ggml_get_rows that allows for a with more than 2 dims.
  131. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  132. int flattened = 0;
  133. GGML_ASSERT(a->n_dims <= 3);
  134. if (a->n_dims == 3) {
  135. flattened = a->ne[0];
  136. a = ggml_flatten_1d(ctx, a, 0);
  137. }
  138. a = ggml_get_rows(ctx, a, b);
  139. if (flattened) {
  140. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  141. }
  142. return a;
  143. }
  144. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  145. // GGML_ASSERT(ctx == kv.full_k->con);
  146. if (kv.full_k != nullptr) {
  147. ggml_detach(kv.full_k);
  148. const char* name = kv.full_k->name;
  149. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  150. ggml_build_forward_expand(gf, kv.full_k);
  151. ggml_format_name(kv.full_k, "%s (sorted)", name);
  152. }
  153. if (kv.full_v != nullptr) {
  154. ggml_detach(kv.full_v);
  155. const char* name = kv.full_v->name;
  156. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  157. ggml_build_forward_expand(gf, kv.full_v);
  158. ggml_format_name(kv.full_v, "%s (sorted)", name);
  159. }
  160. }
  161. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  162. auto self_attn_glob = "*.self_attn";
  163. for (auto& named_kv : model.kv_cache) {
  164. if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH)
  165. continue;
  166. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  167. }
  168. }
  169. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  170. const std::int64_t* data = &model.layer_config.at(name);
  171. double val = *(const double*)data;
  172. return val;
  173. }
  174. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  175. return model_layer_config_d(model, std::string(name));
  176. }
  177. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  178. return model.layer_config.at(std::string(name));
  179. }
  180. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  181. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  182. // delete model;
  183. }
  184. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  185. model->ctx = ctx;
  186. }
  187. extern "C" std::string* std_string_alloc(char* c_str) {
  188. return new std::string(c_str);
  189. }
  190. extern "C" void std_string_free(std::string* str) {
  191. delete str;
  192. }
  193. bool has_layer(fairseq2_model& model, const std::string& name) {
  194. return model.tensors.find(name) != model.tensors.end();
  195. }
  196. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  197. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  198. // `b` has shape (B, 1, D).
  199. // if `a` is (D_out, D), then we do one matmul for the full batch.
  200. b = ggml_flatten_1d(ctx, b, 1);
  201. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  202. }
  203. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  204. // not sure what's the best way to compute this with BLAS
  205. return ggml_mul_mat(ctx, a, b); // (d_out)
  206. }
  207. extern "C" ggml_tensor* Linear_forward(
  208. fairseq2_model& model,
  209. const std::string &prefix,
  210. ggml_tensor* input // (d_in)
  211. ) {
  212. // Note: for now we assumed un-batched input
  213. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  214. GGML_ASSERT(weight != nullptr);
  215. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  216. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  217. if (bias == nullptr) return out;
  218. return ggml_add(model.ctx, out, bias);
  219. }
  220. extern "C" ggml_tensor* LayerNorm_forward(
  221. fairseq2_model& model,
  222. const std::string &prefix,
  223. ggml_tensor* input
  224. ) {
  225. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  226. GGML_ASSERT(weight != nullptr);
  227. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  228. GGML_ASSERT(bias != nullptr);
  229. auto ctx = model.ctx;
  230. double eps = model_layer_config_d(model, prefix + ".eps");
  231. input = ggml_norm(ctx, input, /*eps*/eps);
  232. return ggml_add_inplace(
  233. ctx,
  234. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  235. ggml_repeat(ctx, bias, input)
  236. );
  237. }
  238. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  239. fairseq2_model& model,
  240. const std::string& prefix,
  241. ggml_tensor* seqs
  242. ) {
  243. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  244. // inner_activation = ReLu // TODO: allow other activation
  245. seqs = ggml_relu_inplace(model.ctx, seqs);
  246. if (has_layer(model, prefix + ".inner_layer_norm")) {
  247. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  248. }
  249. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  250. return seqs;
  251. }
  252. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  253. fairseq2_model& model,
  254. const std::string& prefix,
  255. ggml_tensor* seqs
  256. ) {
  257. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  258. seqs = ggml_silu(model.ctx, seqs);
  259. if (has_layer(model, prefix + ".inner_layer_norm")) {
  260. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  261. }
  262. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  263. return seqs;
  264. }
  265. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  266. int n_dims = x->n_dims;
  267. GGML_ASSERT(dim >= 0);
  268. GGML_ASSERT(dim < n_dims);
  269. GGML_ASSERT(ggml_is_contiguous(x));
  270. // Nothing to do
  271. if (dim == n_dims - 1) return x;
  272. if (n_dims == 2) {
  273. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  274. } else if (n_dims == 3) {
  275. if (dim == 0) {
  276. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  277. } else { // dim == 1
  278. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  279. }
  280. } else { // n_dims == 4
  281. if (dim == 0) {
  282. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  283. } else if (dim == 1) {
  284. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  285. } else { // dim == 2
  286. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  287. }
  288. }
  289. }
  290. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  291. int n_dims = x->n_dims;
  292. GGML_ASSERT(dim >= 0);
  293. GGML_ASSERT(dim < n_dims);
  294. GGML_ASSERT(n_dims < 4);
  295. GGML_ASSERT(x->ne[dim] % num_el == 0);
  296. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  297. if (n_dims == 1) {
  298. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  299. } else if (n_dims == 2) {
  300. if (dim == 0) {
  301. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  302. } else { // dim == 1
  303. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  304. }
  305. } else { // (n_dims == 3)
  306. if (dim == 0) {
  307. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  308. } else if (dim == 1) {
  309. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  310. } else { // dim == 2
  311. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  312. }
  313. }
  314. }
  315. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  316. // (B, S, dim) -> (B, S, H, H_dim)
  317. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  318. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  319. x = ggml_cont(ctx, x);
  320. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  321. return x;
  322. }
  323. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  324. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  325. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  326. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  327. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  328. v = ggml_cont(ctx, v);
  329. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  330. return v;
  331. }
  332. // flash_attn doesn't work for cross attention because it assumes Q <= K
  333. // and it seems to yield slightly different scores than expected, and thus a different beam search
  334. # define UNITY_FLASH_ATTN 0
  335. extern "C" ggml_tensor* MultiheadAttention_forward(
  336. fairseq2_model& model,
  337. const std::string &prefix,
  338. ggml_tensor* queries, // (slen, d_in)
  339. ggml_tensor* keys, // (klen, d_in)
  340. ggml_tensor* values, // (klen, d_out)
  341. ggml_tensor* attn_mask // (klen, slen)
  342. ) {
  343. int model_dim = queries->ne[0];
  344. int num_heads = model.layer_config.at(prefix + ".num_heads");
  345. int head_dim = model_dim / num_heads;
  346. GGML_ASSERT(model_dim % num_heads == 0);
  347. ggml_context* ctx = model.ctx;
  348. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  349. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  350. ggml_set_name(q, "q");
  351. ggml_tensor *k, *v;
  352. if (!has_kv_cache(model)) {
  353. k = Linear_forward(model, prefix + ".k_proj", keys);
  354. ggml_set_name(k, "k");
  355. v = Linear_forward(model, prefix + ".v_proj", values);
  356. ggml_set_name(v, "v");
  357. } else {
  358. bool encoder_decoder_attn = keys == values && keys != queries;
  359. if (encoder_decoder_attn) {
  360. // The K and V tensors of an encoder-decoder attention (i.e. the
  361. // projected encoder outputs) remain static during evaluation.
  362. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  363. if (kv_cache.step_nr == 0) {
  364. // If possible we use the ctx dedicated to kv_cache here,
  365. // because the enc dec attention is typically long lived.
  366. if (model.kv_cache_ctx) model.ctx = model.kv_cache_ctx;
  367. k = Linear_forward(model, prefix + ".k_proj", keys);
  368. ggml_set_name(k, "k");
  369. v = Linear_forward(model, prefix + ".v_proj", values);
  370. ggml_set_name(v, "v");
  371. // Note we are only storing a pointer to the buffer, not the full graph
  372. kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
  373. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  374. kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
  375. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  376. kv_cache.step_nr = keys->ne[1];
  377. model.ctx = ctx;
  378. } else {
  379. k = kv_cache.full_k;
  380. v = kv_cache.full_v;
  381. GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
  382. GGML_ASSERT(values->ne[1] == v->ne[1]); // cache content doesn't match the input sequence
  383. }
  384. } else { // self attention
  385. // (1, K) -> (N, 1, K_proj)
  386. k = Linear_forward(model, prefix + ".k_proj", keys);
  387. ggml_set_name(k, "k");
  388. // (1, V) -> (N, 1, V_proj)
  389. v = Linear_forward(model, prefix + ".v_proj", values);
  390. ggml_set_name(v, "v");
  391. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  392. }
  393. }
  394. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  395. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  396. v = ggml_cont(ctx, v);
  397. #if UNITY_FLASH_ATTN
  398. // For flash_attn, we assume either no masks, or triangular masks.
  399. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  400. ggml_set_name(attn, "attn");
  401. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  402. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  403. #else
  404. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  405. ggml_tensor* qk = mul_mat(ctx, k, q);
  406. ggml_set_name(qk, "qk");
  407. FORCE_ALLOC(qk_scale, ctx, ggml_new_tensor_1d(ctx, qk->type, 1));
  408. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  409. qk = ggml_scale(ctx, qk, qk_scale);
  410. ggml_set_name(qk, "qk_scaled");
  411. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  412. // TODO: upgrade qk to float32 if needed
  413. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  414. ggml_set_name(attn_weights, "attn_weights");
  415. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  416. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  417. ggml_set_name(attn, "attn");
  418. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  419. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  420. #endif // UNITY_FLASH_ATTN
  421. attn = ggml_cont(ctx, attn);
  422. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  423. // out -> (B, S, d_out)
  424. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  425. ggml_set_name(out, "out");
  426. return out;
  427. }
  428. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  429. fairseq2_model& model,
  430. const std::string& prefix,
  431. ggml_tensor* seqs,
  432. ggml_tensor* padding_mask
  433. ) {
  434. ggml_context* ctx = model.ctx;
  435. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  436. // _forward_self_attn(seqs, padding_mask)
  437. auto residual = seqs;
  438. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  439. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  440. // TODO: add padding_mask to MultiheadAttention_forward
  441. GGML_ASSERT(padding_mask == nullptr);
  442. seqs = MultiheadAttention_forward(
  443. model,
  444. prefix + ".self_attn",
  445. seqs,
  446. seqs,
  447. seqs,
  448. /*attn_mask=*/nullptr
  449. );
  450. if (has_layer(model, prefix + ".self_attn_norm"))
  451. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  452. seqs = ggml_add_inplace(ctx, seqs, residual);
  453. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  454. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  455. // _forward_ffn(seqs)
  456. residual = seqs;
  457. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  458. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  459. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  460. // TODO: if self.residual_scale is not None:
  461. // residual = self.residual_scale * residual
  462. seqs = ggml_add_inplace(ctx, seqs, residual);
  463. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  464. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  465. return seqs;
  466. }
  467. extern "C" ggml_tensor* WaveformToFbank_forward(
  468. fairseq2_model& model,
  469. const std::string &prefix,
  470. ggml_tensor* waveform
  471. ) {
  472. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  473. ggml_context* ctx = model.ctx;
  474. knf::MelBanksOptions mel_opts{};
  475. mel_opts.num_bins = 80;
  476. knf::FrameExtractionOptions frame_opts{};
  477. frame_opts.samp_freq = 16000;
  478. knf::FbankOptions opts{};
  479. opts.frame_opts = frame_opts;
  480. opts.mel_opts = mel_opts;
  481. std::vector<float_t> signal_frame{};
  482. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  483. FORCE_ALLOC(output, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames));
  484. knf::FbankComputer native_(opts);
  485. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  486. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  487. signal_frame.resize(0);
  488. // Extract the frame from the waveform tensor.
  489. knf::ExtractWindow(
  490. /*sample_offset=*/0,
  491. (float *)(waveform->data),
  492. waveform->ne[0],
  493. frame_nr,
  494. frame_opts,
  495. window_fn_,
  496. &signal_frame);
  497. native_.Compute(
  498. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  499. }
  500. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  501. output = ggml_norm(ctx, output, 1e-5);
  502. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  503. if (output->ne[1] % 2 == 1) {
  504. output = ggml_dup(ctx, ggml_slice(ctx, output, 1, 0, output->ne[1]-1));
  505. }
  506. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  507. return output;
  508. }
  509. // TODO: Check if it's possible to merge with standard MHA
  510. extern "C" ggml_tensor* RelativePositionMHA_forward(
  511. fairseq2_model& model,
  512. const std::string& prefix,
  513. ggml_tensor* seqs
  514. ) {
  515. ggml_context* ctx = model.ctx;
  516. ggml_tensor* residual = seqs;
  517. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  518. // self_attn: qkv
  519. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  520. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  521. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  522. // self_attn: rel_pos SDPA
  523. int32_t S = seqs->ne[1];
  524. int32_t H = 16; // TODO: Make this configurable
  525. int32_t n_ctx = 4096;
  526. int32_t K_h = seqs->ne[0] / H;
  527. int32_t start_index = n_ctx - S;
  528. int32_t end_index = n_ctx + S - 1;
  529. int num_indices = end_index - start_index;
  530. FORCE_ALLOC(rows, ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices));
  531. for (int i = 0; i < num_indices; i++) {
  532. ((int32_t *)rows->data)[i] = start_index + i;
  533. }
  534. // self_attn: load pos_enc weights & compute_r
  535. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  536. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  537. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  538. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  539. r = ggml_dup(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, r, 0, K_h), 0, 2, 1, 3));
  540. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  541. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  542. // self_attn: Permute QKV
  543. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  544. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Qcur, 0, K_h), 0, 2, 1, 3));
  545. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  546. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Kcur, 0, K_h), 0, 2, 1, 3));
  547. // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  548. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Vcur, 0, K_h), 1, 2, 0, 3));
  549. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  550. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  551. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  552. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  553. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  554. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  555. FORCE_ALLOC(pad, ctx, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1));
  556. pad = ggml_set_f32(pad, 0.0);
  557. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  558. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  559. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  560. // discard the first set of positive positions
  561. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  562. // shifts each row by an extra step
  563. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  564. // Discard positions used for shift.
  565. bd = ggml_slice(ctx, bd, 0, 0, S);
  566. // self_attn: compute attn / weights
  567. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  568. FORCE_ALLOC(attn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  569. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  570. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  571. attn_weights = ggml_soft_max(ctx, attn_weights);
  572. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  573. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  574. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  575. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  576. attn_out = ggml_add_inplace(
  577. ctx,
  578. attn_out,
  579. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  580. );
  581. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  582. return attn_out;
  583. }
  584. extern "C" ggml_tensor* ConvModule_forward(
  585. fairseq2_model& model,
  586. const std::string& prefix,
  587. ggml_tensor* seqs
  588. ) {
  589. ggml_context* ctx = model.ctx;
  590. ggml_tensor* residual = seqs;
  591. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  592. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  593. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  594. // conv: GLU
  595. seqs = ggml_glu(ctx, seqs);
  596. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  597. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  598. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, 15, 1);
  599. // conv: Custom implementation of batch norm
  600. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  601. // conv: SiLU actvation
  602. seqs = ggml_silu_inplace(ctx, seqs);
  603. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  604. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  605. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  606. // conv: + residual
  607. seqs = ggml_add_inplace(ctx, seqs, residual);
  608. return seqs;
  609. }
  610. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  611. fairseq2_model& model,
  612. const std::string& prefix,
  613. ggml_tensor* seqs,
  614. ggml_tensor* padding_mask
  615. ) {
  616. ggml_context* ctx = model.ctx;
  617. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  618. ggml_set_f32(ffn_scale, 0.5f);
  619. ggml_tensor* residual = seqs;
  620. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  621. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  622. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  623. seqs = ggml_add_inplace(ctx, seqs, residual);
  624. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  625. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  626. residual = seqs;
  627. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  628. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  629. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  630. seqs = ggml_add_inplace(ctx, seqs, residual);
  631. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  632. return seqs;
  633. }
  634. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  635. fairseq2_model& model,
  636. const std::string& prefix,
  637. ggml_tensor* seqs,
  638. ggml_tensor* padding_mask
  639. ) {
  640. ggml_context* ctx = model.ctx;
  641. seqs = WaveformToFbank_forward(model, prefix, seqs);
  642. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  643. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  644. int layer_idx = 0;
  645. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  646. while (has_layer(model, layer_name)) {
  647. seqs = StandardConformerEncoderLayer_forward(
  648. model, layer_name, seqs, padding_mask
  649. );
  650. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  651. layer_idx += 1;
  652. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  653. }
  654. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  655. ggml_tensor* residual = seqs;
  656. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  657. seqs = ggml_relu_inplace(ctx, seqs);
  658. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  659. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  660. ggml_set_f32(ffn_scale, 0.5f);
  661. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  662. seqs = ggml_add_inplace(ctx, seqs, residual);
  663. layer_idx = 0;
  664. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  665. while (has_layer(model, layer_name)) {
  666. seqs = StandardConformerEncoderAdaptorLayer_forward(
  667. model, layer_name, seqs, padding_mask
  668. );
  669. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  670. layer_idx += 1;
  671. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  672. }
  673. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  674. return seqs;
  675. }
  676. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  677. fairseq2_model& model,
  678. const std::string& prefix,
  679. ggml_tensor* seqs,
  680. ggml_tensor* padding_mask
  681. ) {
  682. ggml_context* ctx = model.ctx;
  683. ggml_tensor* residual = seqs;
  684. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  685. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  686. residual = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1);
  687. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  688. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  689. residual = ggml_glu(ctx, residual);
  690. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  691. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  692. seqs = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1);
  693. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  694. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  695. seqs = ggml_glu(ctx, seqs);
  696. seqs = MultiheadAttention_forward(
  697. model,
  698. prefix + ".self_attn",
  699. seqs,
  700. seqs,
  701. seqs,
  702. /*attention masks=*/nullptr
  703. );
  704. seqs = ggml_add_inplace(ctx, seqs, residual);
  705. residual = seqs;
  706. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  707. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  708. seqs = ggml_add_inplace(ctx, seqs, residual);
  709. return seqs;
  710. }
  711. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  712. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  713. ggml_tensor* ggml_slice(
  714. struct ggml_context * ctx,
  715. struct ggml_tensor * a,
  716. int axis,
  717. int64_t start,
  718. int64_t end
  719. ) {
  720. int64_t ne[4];
  721. std::copy(a->ne, a->ne + 4, ne);
  722. if (axis < 0) axis = a->n_dims + axis;
  723. if (start < 0) start = ne[axis] + start;
  724. if (end <= 0) end = ne[axis] + end;
  725. GGML_ASSERT(0 <= start);
  726. GGML_ASSERT(start < end);
  727. GGML_ASSERT(end <= ne[axis]);
  728. ne[axis] = end - start;
  729. size_t offset = a->nb[axis] * start;
  730. size_t* nb = a->nb;
  731. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  732. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  733. result->n_dims = a->n_dims;
  734. return result;
  735. }
  736. ggml_tensor* ggml_select(
  737. struct ggml_context * ctx,
  738. struct ggml_tensor * a,
  739. int axis,
  740. int64_t index
  741. ) {
  742. int64_t ne[GGML_MAX_DIMS];
  743. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  744. if (axis < 0) axis = a->n_dims + axis;
  745. if (index < 0) index = ne[axis] + index;
  746. GGML_ASSERT(0 <= index);
  747. GGML_ASSERT(index < ne[axis]);
  748. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  749. size_t offset = a->nb[axis] * index;
  750. size_t* nb = a->nb;
  751. GGML_ASSERT(GGML_MAX_DIMS == 4);
  752. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  753. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  754. result->n_dims = a->n_dims - 1;
  755. return result;
  756. }
  757. // Inplace computation of PositionalEmbedding
  758. extern "C" ggml_tensor* PositionalEmbedding_forward(
  759. fairseq2_model& model,
  760. const std::string& prefix,
  761. ggml_tensor* embeds
  762. ) {
  763. // This only work with the simple pos encoders
  764. int seq_len = embeds->ne[1];
  765. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  766. int start_step = 0;
  767. if (has_kv_cache(model)) {
  768. start_step = model.kv_cache[prefix].step_nr++;
  769. }
  770. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  771. return ggml_add(model.ctx, embeds, pos_embeds);
  772. }
  773. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  774. fairseq2_model& model,
  775. const std::string& prefix,
  776. ggml_tensor* seqs
  777. ) {
  778. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  779. ggml_context* ctx = model.ctx;
  780. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  781. GGML_ASSERT(embed_weights != nullptr);
  782. ggml_tensor* embeds;
  783. if (seqs->n_dims == 1) {
  784. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  785. } else {
  786. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  787. ggml_tensor* flat_seqs = seqs;
  788. if (!ggml_is_contiguous(seqs)) {
  789. flat_seqs = ggml_cont(ctx, flat_seqs);
  790. }
  791. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  792. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  793. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  794. embeds->n_dims = seqs->n_dims + 1;
  795. }
  796. // padding mask ?
  797. // padding_mask = to_padding_mask(embeds, seq_lens)
  798. if (has_layer(model, prefix + ".pos_encoder")) {
  799. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  800. }
  801. if (has_layer(model, prefix + ".layer_norm")) {
  802. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  803. }
  804. return embeds;
  805. }
  806. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  807. fairseq2_model& model,
  808. const std::string& prefix,
  809. ggml_tensor* seqs,
  810. ggml_tensor* padding_mask
  811. ) {
  812. int layer_idx = 0;
  813. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  814. while (has_layer(model, layer_name)) {
  815. seqs = StandardTransformerEncoderLayer_forward(
  816. model, layer_name, seqs, padding_mask
  817. );
  818. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  819. layer_idx += 1;
  820. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  821. }
  822. if (has_layer(model, prefix + ".layer_norm"))
  823. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  824. return seqs;
  825. }
  826. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  827. fairseq2_model& model,
  828. const std::string& prefix,
  829. ggml_tensor* seqs,
  830. ggml_tensor* self_attn_mask,
  831. ggml_tensor* encoder_output,
  832. ggml_tensor* encoder_padding_mask
  833. ) {
  834. ggml_context* ctx = model.ctx;
  835. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  836. // _forward_self_attn(seqs, padding_mask)
  837. auto residual = seqs;
  838. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  839. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  840. seqs = MultiheadAttention_forward(
  841. model,
  842. prefix + ".self_attn",
  843. seqs,
  844. seqs,
  845. seqs,
  846. /*attn_mask=*/self_attn_mask
  847. );
  848. if (has_layer(model, prefix + ".self_attn_norm"))
  849. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  850. seqs = ggml_add_inplace(ctx, seqs, residual);
  851. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  852. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  853. // _forward_encoder_decoder_attn
  854. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  855. // `encoder_output` must be `None` for decoder-only attention.
  856. GGML_ASSERT(encoder_output == nullptr);
  857. return seqs;
  858. }
  859. // `encoder_output` must not be `None` for encoder-decoder attention.
  860. GGML_ASSERT(encoder_output != nullptr);
  861. residual = seqs;
  862. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  863. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  864. seqs = MultiheadAttention_forward(
  865. model,
  866. prefix + ".encoder_decoder_attn",
  867. seqs,
  868. encoder_output,
  869. encoder_output,
  870. /*attention masks=*/encoder_padding_mask
  871. );
  872. seqs = ggml_add_inplace(ctx, seqs, residual);
  873. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  874. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  875. // _forward_ffn(seqs)
  876. residual = seqs;
  877. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  878. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  879. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  880. // TODO:
  881. // if self.residual_scale is not None:
  882. // residual = self.residual_scale * residual
  883. seqs = ggml_add_inplace(ctx, seqs, residual);
  884. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  885. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  886. return seqs;
  887. }
  888. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  889. auto seq_len = seqs->ne[1];
  890. // TODO: allow other ggml_type
  891. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  892. return ggml_diag_mask_inf(ctx, mask, 0);
  893. }
  894. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  895. fairseq2_model& model,
  896. const std::string& prefix,
  897. ggml_tensor* seqs,
  898. ggml_tensor* padding_mask,
  899. ggml_tensor* encoder_output,
  900. ggml_tensor* encoder_padding_mask
  901. ) {
  902. int layer_idx = 0;
  903. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  904. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  905. while (has_layer(model, layer_name)) {
  906. seqs = StandardTransformerDecoderLayer_forward(
  907. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  908. );
  909. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  910. layer_idx += 1;
  911. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  912. }
  913. if (has_layer(model, prefix + ".layer_norm"))
  914. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  915. return seqs;
  916. }
  917. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  918. auto opts = job.opts;
  919. int max_seq_len = -1;
  920. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  921. max_seq_len = opts.hard_max_seq_len;
  922. } else {
  923. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  924. }
  925. if (opts.min_seq_len > max_seq_len) {
  926. printf(
  927. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  928. opts.min_seq_len,
  929. max_seq_len
  930. );
  931. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  932. }
  933. int prefix_seq_len = job.prefix_seq->ne[0];
  934. if (prefix_seq_len >= max_seq_len) {
  935. printf(
  936. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  937. prefix_seq_len,
  938. max_seq_len
  939. );
  940. GGML_ASSERT(prefix_seq_len < max_seq_len);
  941. }
  942. return max_seq_len;
  943. }
  944. void _fan_out_encoder_output(
  945. ggml_context* ctx,
  946. ggml_tensor** encoder_output_out,
  947. ggml_tensor** encoder_padding_mask_out,
  948. int beam_size
  949. ) {
  950. // (S_enc, M)
  951. ggml_tensor* encoder_output = *encoder_output_out;
  952. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  953. // (B, S_enc, M)
  954. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  955. // (S_enc, M) -> (B, S_enc, M)
  956. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  957. // (S_enc) -> (B, S_enc)
  958. if (encoder_padding_mask != nullptr) {
  959. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  960. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  961. }
  962. }
  963. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  964. // TODO: this isn't the most precise way of doing this
  965. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  966. }
  967. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  968. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  969. ggml_type true_type = x->type;
  970. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  971. y->type = true_type;
  972. return y;
  973. }
  974. void _bootstrap_seqs_and_scores(
  975. fairseq2_model& model,
  976. const SequenceGeneratorJob& job,
  977. ggml_tensor* full_seqs,
  978. ggml_tensor* scores,
  979. ggml_tensor* encoder_output,
  980. ggml_tensor* encoder_padding_mask,
  981. ggml_tensor* lid_scores,
  982. int n_threads,
  983. const std::vector<int>& lang_ids
  984. ) {
  985. // Returns LID score map
  986. int prefix_seq_len = job.prefix_seq->ne[0];
  987. int max_seq_len = scores->ne[0];
  988. int beam_size = scores->ne[1];
  989. GGML_ASSERT(prefix_seq_len > 0);
  990. if (prefix_seq_len == 1)
  991. return;
  992. ggml_context* ctx = model.ctx;
  993. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  994. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  995. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  996. // We have to bootstrap the model with the already fanned-out encoder
  997. // output to correctly initialize its incremental state.
  998. // Note: we don't start decoding the last prefix token just yet.
  999. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  1000. // Bootstrap the model state with prefix sequence.
  1001. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  1002. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1003. model,
  1004. "text_decoder",
  1005. seqs,
  1006. /*padding_mask*/ nullptr,
  1007. encoder_output,
  1008. encoder_padding_mask
  1009. );
  1010. // logits, lprobs: (N, S_pfx - 1, V)
  1011. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  1012. int vocab_size = logits->ne[0];
  1013. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  1014. struct ggml_cgraph * gf = ggml_new_graph(ctx);
  1015. ggml_build_forward_expand(gf, lprobs);
  1016. ggml_graph_compute_with_ctx(ctx, gf, n_threads);
  1017. full_seqs->type = GGML_TYPE_I32;
  1018. job.prefix_seq->type = GGML_TYPE_I32;
  1019. // For LID
  1020. for (size_t i = 0; i < lang_ids.size(); ++i) {
  1021. ggml_set_f32_1d(lid_scores, i, std::exp(ggml_get_f32_1d(lprobs, lang_ids[i])));
  1022. }
  1023. // Fetch scores of next steps from "lprobs"
  1024. float p_score = 0;
  1025. for (int i = 1; i < prefix_seq_len; ++i) {
  1026. int p;
  1027. if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
  1028. // If tgt_lang is unk, use the most probable lang tag predicted by model
  1029. int max_value = std::numeric_limits<float>::min();
  1030. for (int j = 0; j < lang_ids.size(); j++) {
  1031. if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
  1032. max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
  1033. p = lang_ids[j];
  1034. }
  1035. }
  1036. } else {
  1037. p = ggml_get_i32_1d(job.prefix_seq, i);
  1038. }
  1039. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1040. for (int b = 0; b < beam_size; ++b) {
  1041. // scores: (N, S)
  1042. // Note: First step (e.g. BOS)'s score is always 0.
  1043. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1044. }
  1045. }
  1046. }
  1047. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1048. int topk(
  1049. ggml_tensor* lprobs, // (B, V)
  1050. std::int64_t k,
  1051. ggml_tensor* candidate_indices
  1052. ) {
  1053. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1054. // `beam_size` of these which don't predict EOS to continue with.
  1055. // (N, 2 x B)
  1056. // `vocab_size` - 1 to never select PAD.
  1057. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1058. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1059. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1060. };
  1061. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1062. auto cand = (std::int32_t*)candidate_indices->data;
  1063. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1064. return K;
  1065. }
  1066. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1067. std::size_t beam_size = job.opts.beam_size;
  1068. std::size_t eos_idx = job.eos_idx;
  1069. // Do not allow EOS before reaching the minimum sequence length.
  1070. if (step_nr < job.opts.min_seq_len) {
  1071. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1072. for (size_t i = 0; i < beam_size; ++i)
  1073. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1074. }
  1075. // If we have reached the maximum length, force the last step to be EOS.
  1076. if (step_nr == max_seq_len - 2) {
  1077. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1078. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1079. for (size_t b = 0; b < beam_size; ++b) {
  1080. size_t t = 0;
  1081. for (t = 0; t < eos_idx; ++t)
  1082. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1083. for (t = eos_idx + 1; t < vocab_size; ++t)
  1084. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1085. }
  1086. }
  1087. // Never allow PAD.
  1088. std::size_t pad_idx = job.pad_idx;
  1089. for (size_t i = 0; i < beam_size; ++i)
  1090. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1091. // Apply UNK penalty.
  1092. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1093. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1094. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1095. for (size_t i = 0; i < beam_size; ++i)
  1096. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1097. }
  1098. }
  1099. /// Copies the sequence and scores of a given candidate beam.
  1100. void _finalize_hypothesis(
  1101. const SequenceGeneratorJob& job,
  1102. ggml_context* ctx,
  1103. int step_nr,
  1104. std::int32_t beam,
  1105. std::int32_t token,
  1106. float eos_score,
  1107. ggml_tensor* seqs, // (beam_size, seq_len)
  1108. ggml_tensor* scores, // (beam_size, seq_len)
  1109. ggml_tensor* lid_scores,
  1110. Hypothesis* hypothesis
  1111. ) {
  1112. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1113. hypothesis->seq = seq;
  1114. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1115. hypothesis->step_scores = step_scores;
  1116. auto tok = (std::int32_t*)seq->data;
  1117. for (int i = 0; i < step_nr + 1; ++i) {
  1118. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1119. }
  1120. tok[step_nr + 1] = token;
  1121. // Convert from cumulative to per-step scores.
  1122. auto sc = (float*)step_scores->data;
  1123. float last_score = eos_score;
  1124. for (int i = step_nr; i >= 0; --i) {
  1125. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1126. sc[i + 1] = last_score - sc0;
  1127. last_score = sc0;
  1128. }
  1129. sc[0] = 0;
  1130. if (job.opts.normalize_scores)
  1131. // Skip first EOS since it is always 0 and skews normalization.
  1132. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1133. hypothesis->score = eos_score;
  1134. hypothesis->lid_scores = lid_scores;
  1135. }
  1136. // Uses ggml_context to store any object.
  1137. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1138. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1139. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1140. return ggml_init({
  1141. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1142. /*.mem_buffer =*/ buffer.data(),
  1143. /*.no_alloc =*/ false,
  1144. });
  1145. }
  1146. ggml_allocr* new_arena_allocr(std::vector<uint8_t>& buffer) {
  1147. return ggml_allocr_new(buffer.data(), buffer.capacity(), 8);
  1148. }
  1149. /// Generates a translation for a single sequence
  1150. /// The results Hypothesis are written inside `result_ctx`.
  1151. extern "C" Hypothesis* generate_sequence(
  1152. fairseq2_model& model,
  1153. const SequenceGeneratorJob& job,
  1154. ggml_tensor* encoder_output,
  1155. ggml_tensor* encoder_padding_mask,
  1156. ggml_context* result_ctx,
  1157. int n_threads
  1158. ) {
  1159. // Pre allocate memory buffers.
  1160. // * step_ctx: contains metadata for the model graph, as well as some explicit
  1161. // buffers for the lprobs tweaking.
  1162. // * prev_step_ctx: is an additional buffer because we need some results from previous steps,
  1163. // to compute next step. Notably self attention kv cache.
  1164. // * search_ctx contains tensors that should live for the full search,
  1165. // like encoder kv cache.
  1166. // * step_alloc contains buffer for the forward pass of the model.
  1167. // Split mem_mb into the different context we need to use.
  1168. int mem_mb = job.opts.mem_mb;
  1169. std::vector<uint8_t> local_bufs[4] = {
  1170. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // step_ctx
  1171. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // prev_step_ctx
  1172. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // search_ctx
  1173. std::vector<uint8_t>(mem_mb * MB * 1 / 10), // step_alloc
  1174. };
  1175. ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
  1176. std::vector<int> lang_ids;
  1177. for (const auto& kv : model.vocab.token_to_id) {
  1178. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  1179. lang_ids.push_back(kv.second);
  1180. }
  1181. }
  1182. std::sort(lang_ids.begin(), lang_ids.end());
  1183. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1184. size_t vocab_size = embed->ne[1];
  1185. std::size_t beam_size = job.opts.beam_size;
  1186. ggml_detach(encoder_output);
  1187. int source_seq_len = encoder_output->ne[1];
  1188. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1189. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1190. ggml_context* original_ctx = model.ctx;
  1191. fairseq2_kv_cache_alloc(model, search_ctx, beam_size, max_seq_len);
  1192. // (S_enc, M) -> (B, S_enc, M)
  1193. model.ctx = search_ctx;
  1194. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1195. // Allocate results in the context provided by the caller.
  1196. ggml_set_no_alloc(result_ctx, false);
  1197. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1198. Hypothesis* finished_searches = finished_searches_begin;
  1199. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1200. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1201. // Initialize buffers. (B, S)
  1202. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1203. ggml_set_i32(seqs, 0);
  1204. ggml_set_name(seqs, "seqs_0");
  1205. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1206. ggml_set_name(scores, "scores_0");
  1207. ggml_set_f32(scores, 0.0);
  1208. int prefix_seq_len = job.prefix_seq->ne[0];
  1209. int start_step = prefix_seq_len - 1;
  1210. ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step - 1) % 2]);
  1211. ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
  1212. GGML_ASSERT(step_ctx != search_ctx);
  1213. GGML_ASSERT(prev_step_ctx != step_ctx);
  1214. model.ctx = prev_step_ctx;
  1215. // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
  1216. model.kv_cache_ctx = search_ctx;
  1217. ggml_tensor* lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
  1218. _bootstrap_seqs_and_scores(
  1219. model, job, seqs, scores, encoder_output, encoder_padding_mask, lid_scores, n_threads, lang_ids
  1220. );
  1221. // Now we will only add self_attn.k_cache and those need to be resorted and copied at every step.
  1222. model.kv_cache_ctx = nullptr;
  1223. // Holds the indices of beams (a beam can occur more than once) that we
  1224. // should continue with in the next step.
  1225. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1226. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1227. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1228. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1229. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1230. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1231. ((int32_t *)(candidate_indices->data))[i] = i;
  1232. printf_mem_usage(search_ctx, "search_ctx");
  1233. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1234. model.ctx = step_ctx;
  1235. ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
  1236. float max_lprob;
  1237. int p;
  1238. if (step_nr == start_step) {
  1239. // Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
  1240. if (ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
  1241. float max_lprob = std::numeric_limits<float>::min();
  1242. for(int j = 0; j < lang_ids.size(); j++) {
  1243. auto val = ggml_get_f32_1d(lid_scores, j);
  1244. if (val > max_lprob) {
  1245. max_lprob = val;
  1246. p = lang_ids[j];
  1247. }
  1248. }
  1249. for (int k = 0; k < beam_size; k++) {
  1250. ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
  1251. }
  1252. }
  1253. }
  1254. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1255. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1256. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1257. model,
  1258. "text_decoder",
  1259. decoder_input,
  1260. nullptr, // We never generate PAD.
  1261. encoder_output,
  1262. encoder_padding_mask
  1263. ); // (B, 1, D)
  1264. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1265. // Force logits to be allocated in step_ctx, not in step_alloc.
  1266. ggml_set_no_alloc(step_ctx, false);
  1267. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1268. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1269. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1270. // TODO: use ggml properly compute the tweaks
  1271. struct ggml_cgraph * gf = ggml_new_graph(step_ctx);
  1272. ggml_build_forward_expand(gf, lprobs);
  1273. size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, gf);
  1274. GGML_UNUSED(fwd_mem);
  1275. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1276. ggml_detach(lprobs);
  1277. ggml_allocr_reset(step_alloc);
  1278. #if DEBUG_MEM_USAGE
  1279. printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf.n_nodes);
  1280. printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
  1281. std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
  1282. #endif
  1283. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1284. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1285. if (step_nr == start_step) {
  1286. // At the initial step, all hypotheses are equally likely, so we use
  1287. // only the first beam.
  1288. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1289. lprobs = ggml_cont(step_ctx, lprobs);
  1290. // The first step always indicates the beginning of the sequence and has no score.
  1291. if (step_nr > 0) {
  1292. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1293. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1294. }
  1295. } else {
  1296. // Make probabilities contain cumulative scores for each hypothesis.
  1297. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1298. }
  1299. ggml_build_forward_expand(gf, lprobs);
  1300. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1301. // Determine (beam, token) candidates for the next step.
  1302. // (N, 2 x B)
  1303. std::int64_t K = topk(
  1304. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1305. );
  1306. std::size_t ongoing_beams = 0;
  1307. for (std::int32_t i = 0; i < K; ++i) {
  1308. int c = ggml_get_f32_1d(candidate_indices, i);
  1309. std::int32_t beam = c / vocab_size;
  1310. std::int32_t token = c % vocab_size;
  1311. float tok_score = ggml_get_f32_1d(lprobs, c);
  1312. // Detect beams that reached the minimum length and that end with an EOS.
  1313. bool eos = token == job.eos_idx;
  1314. eos &= tok_score != -INFINITY;
  1315. if (eos) {
  1316. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, lid_scores, finished_searches++);
  1317. if (finished_searches == finished_searches_end)
  1318. goto end_of_beam_search;
  1319. continue;
  1320. }
  1321. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1322. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1323. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1324. ongoing_beams += 1;
  1325. if (ongoing_beams >= beam_size) break;
  1326. }
  1327. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1328. // be selected more than once.
  1329. // (B, S), (B) -> (B, S)
  1330. // don't use allocr API, cause it might reuse a kv cache buffer several time.
  1331. ggml_set_no_alloc(step_ctx, false);
  1332. ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
  1333. ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
  1334. struct ggml_cgraph * gf_reorder = ggml_new_graph(step_ctx);
  1335. ggml_build_forward_expand(gf_reorder, new_seqs);
  1336. ggml_build_forward_expand(gf_reorder, new_scores);
  1337. reorder_kv_cache(model, step_ctx, gf_reorder, beam_indices);
  1338. ggml_graph_compute_with_ctx(step_ctx, gf_reorder, n_threads);
  1339. seqs = ggml_detach(new_seqs);
  1340. scores = ggml_detach(new_scores);
  1341. // seqs[:, step_nr + 1] = next_tokens
  1342. // scores[:, step_nr + 1] = next_scores
  1343. for (std::size_t i = 0; i < beam_size; ++i) {
  1344. ((std::int32_t*)seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1345. ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1346. }
  1347. printf_mem_usage(step_ctx, "step_ctx");
  1348. ggml_free(prev_step_ctx);
  1349. prev_step_ctx = step_ctx;
  1350. #if DEBUG_MEM_USAGE
  1351. std::fill(local_bufs[(step_nr + 1) % 2].begin(), local_bufs[(step_nr + 1) % 2].end(), 0xAA);
  1352. #endif
  1353. step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1354. }
  1355. end_of_beam_search:
  1356. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1357. std::sort(
  1358. finished_searches_begin,
  1359. finished_searches_end,
  1360. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1361. );
  1362. printf_mem_usage(search_ctx, "search_ctx");
  1363. fairseq2_kv_cache_reset(model);
  1364. model.ctx = original_ctx;
  1365. return finished_searches_begin;
  1366. }
  1367. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1368. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1369. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1370. ggml_set_i32_1d(result[0].seq, 0, 314);
  1371. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1372. ggml_set_i32_1d(result[1].seq, 0, 421);
  1373. return result;
  1374. }
  1375. // SPM tokenizer
  1376. // original implementation:
  1377. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1378. struct llm_symbol {
  1379. using index = int;
  1380. index prev;
  1381. index next;
  1382. const char * text;
  1383. size_t n;
  1384. llama_vocab::id id;
  1385. };
  1386. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1387. static size_t utf8_len(char src) {
  1388. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1389. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1390. return lookup[highbits];
  1391. }
  1392. struct llm_bigram_spm {
  1393. struct comparator {
  1394. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1395. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1396. }
  1397. };
  1398. using queue_storage = std::vector<llm_bigram_spm>;
  1399. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1400. llm_symbol::index left;
  1401. llm_symbol::index right;
  1402. float score;
  1403. size_t size;
  1404. llama_vocab::id id;
  1405. };
  1406. struct llm_tokenizer_spm {
  1407. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1408. void tokenize(const std::string& input_text, ggml_tensor* output) {
  1409. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1410. // split string into utf8 chars
  1411. int index = 0;
  1412. size_t offs = 0;
  1413. // This is kind of annoying, but needed because with SPM,
  1414. // characters following a space have a special meaning.
  1415. // And the algorithm rely on substrings to do the lookups.
  1416. std::string text = input_text;
  1417. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1418. if (need_extra_space) text = " " + text;
  1419. while (offs < text.size()) {
  1420. size_t len = utf8_len(text[offs]);
  1421. size_t n = std::min(len, text.size() - offs);
  1422. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1423. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1424. llm_symbol sym = {
  1425. /*prev*/ index - 1,
  1426. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1427. /*text*/ text.c_str() + offs,
  1428. /*n*/ n,
  1429. /*id*/ id
  1430. };
  1431. offs += n;
  1432. index++;
  1433. symbols.emplace_back(sym);
  1434. }
  1435. // seed the work queue with all possible 2-character tokens.
  1436. for (size_t i = 1; i < symbols.size(); ++i) {
  1437. try_add_bigram(i - 1, i);
  1438. }
  1439. // keep substituting the highest frequency pairs for as long as we can.
  1440. while (!work_queue.empty()) {
  1441. auto bigram = work_queue.top();
  1442. work_queue.pop();
  1443. auto & left_sym = symbols[bigram.left];
  1444. auto & right_sym = symbols[bigram.right];
  1445. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1446. // if one of the symbols already got merged, skip it.
  1447. if (
  1448. left_sym.n == 0
  1449. || right_sym.n == 0
  1450. || left_sym.n + right_sym.n != bigram.size
  1451. ) continue;
  1452. // merge the right sym into the left one
  1453. left_sym.n += right_sym.n;
  1454. left_sym.id = bigram.id;
  1455. right_sym.n = 0;
  1456. // remove the right sym from the chain
  1457. left_sym.next = right_sym.next;
  1458. if (right_sym.next >= 0) {
  1459. symbols[right_sym.next].prev = bigram.left;
  1460. }
  1461. // find more substitutions
  1462. try_add_bigram(left_sym.prev, bigram.left);
  1463. try_add_bigram(bigram.left, left_sym.next);
  1464. }
  1465. llama_vocab::id* out = (llama_vocab::id*)output->data;
  1466. int out_step = sizeof(llama_vocab::id) / output->nb[0];
  1467. int num_tokens = 0;
  1468. for (int i = 0; i > -1; i = symbols[i].next) {
  1469. llm_symbol& symbol = symbols[i];
  1470. *(out + num_tokens * out_step) = symbol.id;
  1471. num_tokens += 1;
  1472. }
  1473. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1474. num_tokens += 1;
  1475. output->ne[0] = num_tokens;
  1476. }
  1477. private:
  1478. void try_add_bigram(int left, int right) {
  1479. if (left == -1 || right == -1) {
  1480. return;
  1481. }
  1482. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1483. auto token = vocab.token_to_id.find(text);
  1484. if (token == vocab.token_to_id.end()) {
  1485. return;
  1486. }
  1487. llama_vocab::id id = token->second;
  1488. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1489. return;
  1490. }
  1491. const auto& tok_data = vocab.id_to_token[id];
  1492. llm_bigram_spm bigram = {
  1493. /*left */ left,
  1494. /*right*/ right,
  1495. /*score*/ tok_data.score,
  1496. /*size */ text.size(),
  1497. /*id */ id
  1498. };
  1499. work_queue.push(bigram);
  1500. }
  1501. const llama_vocab& vocab;
  1502. std::vector<llm_symbol> symbols;
  1503. llm_bigram_spm::queue work_queue;
  1504. };
  1505. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
  1506. llm_tokenizer_spm spm = {model->vocab};
  1507. spm.tokenize(std::string(text), out);
  1508. }
  1509. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1510. int eos_idx = model->vocab.token_to_id["</s>"];
  1511. int sent_len = tokens->ne[0];
  1512. std::size_t written = 0;
  1513. for (int i = 0; i < sent_len; ++i) {
  1514. int id = ggml_get_i32_1d(tokens, i);
  1515. // Don't print the EOS token but only if it appear at the end.
  1516. if (i == sent_len - 1 && eos_idx == id) break;
  1517. std::string token = model->vocab.id_to_token.at(id).text;
  1518. // Skip the first space outputted.
  1519. auto begin = token.begin();
  1520. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1521. std::copy(begin, token.end(), out);
  1522. std::size_t n = token.end() - begin;
  1523. written += n;
  1524. out += n;
  1525. }
  1526. *out = '0';
  1527. return written;
  1528. }
  1529. // TODO: Unify with the above?
  1530. std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out) {
  1531. int eos_idx = model->vocab.token_to_id["</s>"];
  1532. int sent_len = tokens->ne[0];
  1533. std::size_t written = 0;
  1534. std::vector<float> word_scores;
  1535. std::vector<float> subword_scores;
  1536. std::vector<std::string> result_text;
  1537. std::string curr_token = "";
  1538. for (int i = 0; i < sent_len; ++i) {
  1539. int id = ggml_get_i32_1d(tokens, i);
  1540. // Don't print the EOS token but only if it appear at the end.
  1541. if (i == sent_len - 1 && eos_idx == id) break;
  1542. std::string token = model->vocab.id_to_token.at(id).text;
  1543. float score = ggml_get_f32_1d(scores, i+2); // 2 is prefix size
  1544. if(token[0] == ' ') {
  1545. // reset word score
  1546. if(subword_scores.size() > 0) {
  1547. float avg = std::accumulate(subword_scores.begin(), subword_scores.end(), 0.0f) / subword_scores.size();
  1548. word_scores.push_back(avg);
  1549. subword_scores.clear();
  1550. result_text.push_back(curr_token);
  1551. }
  1552. curr_token = token.substr(1);
  1553. } else {
  1554. curr_token += token;
  1555. }
  1556. subword_scores.push_back(score);
  1557. // Skip the first space outputted.
  1558. auto begin = token.begin();
  1559. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1560. std::copy(begin, token.end(), out);
  1561. std::size_t n = token.end() - begin;
  1562. written += n;
  1563. out += n;
  1564. }
  1565. if(subword_scores.size() > 0) {
  1566. word_scores.push_back(*std::min_element(subword_scores.begin(), subword_scores.end()));
  1567. subword_scores.clear();
  1568. result_text.push_back(curr_token);
  1569. }
  1570. *out = '0';
  1571. return std::make_pair(result_text, word_scores);
  1572. }