fairseq2.cpp 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. #include "ggml-alloc.h"
  12. #include <numeric>
  13. ggml_tensor* ggml_detach(ggml_tensor* a) {
  14. a->op = GGML_OP_NONE;
  15. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  16. return a;
  17. }
  18. // generate_sequence uses ggml_context and ggml_allocr to reuse memory buffers across steps.
  19. // This can lead to dangling pointers, which don't segfault, but instead read garbage data.
  20. // Enabling this flag allows to explictly reset memory buffers, making it more explicit
  21. // when we read garbage data.
  22. // It also prints memory usage information, which is useful to
  23. #define DEBUG_MEM_USAGE DEBUG
  24. std::size_t MB = 1024 * 1024;
  25. void printf_mem_usage(ggml_context* ctx, std::string name) {
  26. #if DEBUG_MEM_USAGE
  27. double mb = 1024.0 * 1024.0;
  28. printf(
  29. "%s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  30. name.c_str(),
  31. ggml_used_mem(ctx) / mb,
  32. ggml_get_mem_size(ctx) / mb
  33. );
  34. #endif
  35. }
  36. #define SWAP(x, y) \
  37. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  38. #define GGML_ASSERT_SHAPE(x, ne0, ne1, ne2, ne3) \
  39. GGML_ASSERT((ne0 == -1 || x->ne[0] == ne0) && (ne1 == -1 || x->ne[1] == ne1) && (ne2 == -1 || x->ne[2] == ne2) && (ne3 == -1 || x->ne[3] == ne3));
  40. /// allocate the fairseq2 model and hyperparameters
  41. extern "C" fairseq2_model* fairseq2_model_alloc() {
  42. // pre-allocate some memory to write hyperparameters and tensors pointers
  43. auto* model = new fairseq2_model;
  44. model->tensors_ctx = nullptr;
  45. return model;
  46. }
  47. extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_cache_ctx, int beam_size, int max_seq_len) {
  48. // Note: we only allocate the masks, proper kv cache allocation is delayed.
  49. GGML_ASSERT(kv_cache_ctx);
  50. GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx)); // We need to be able to alloc the kv_cache buffers
  51. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  52. FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
  53. self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
  54. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  55. for (auto named_tensor : model.tensors) {
  56. const std::string& name = named_tensor.first;
  57. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  58. continue;
  59. // create a cache entry without the ".k_proj.weight" suffix
  60. const std::string& shortname = name.substr(0, name.size() - 14);
  61. KeyValueTensor& kv = model.kv_cache[shortname];
  62. kv.step_nr = 0;
  63. kv.full_k = nullptr;
  64. kv.full_v = nullptr;
  65. kv.self_attn_mask = self_attn_mask;
  66. }
  67. }
  68. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  69. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  70. model.kv_cache.clear();
  71. }
  72. bool has_kv_cache(const fairseq2_model& model) {
  73. return model.kv_cache.size() > 0;
  74. }
  75. inline ggml_tensor* ggml_squeeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  76. int n_dims = x->n_dims;
  77. GGML_ASSERT(dim >= 0);
  78. GGML_ASSERT(dim < n_dims);
  79. GGML_ASSERT(x->ne[dim] == 1);
  80. return ggml_flatten_1d(ctx, x, dim);
  81. }
  82. inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
  83. return ggml_unflatten_1d(ctx, x, dim, 1);
  84. }
  85. // copy k and v to kv cache
  86. // kv.full_k[step_nr] = k;
  87. // kv.full_v[step_nr] = v;
  88. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  89. KeyValueTensor& kv = model.kv_cache[prefix];
  90. int step_nr = kv.step_nr;
  91. ggml_context* ctx = model.ctx;
  92. // We need to force allocation here, otherwise the kv_cache buffers can be reused
  93. bool no_alloc_save = ggml_get_no_alloc(ctx);
  94. ggml_set_no_alloc(ctx, false);
  95. int n_steps = (*k)->ne[1];
  96. int k_proj, batch_size;
  97. if (kv.full_k != nullptr) {
  98. // (N, S_kv, K_proj)
  99. k_proj = kv.full_k->ne[0];
  100. batch_size = kv.full_k->ne[2];
  101. ggml_detach(kv.full_k);
  102. ggml_detach(kv.full_v);
  103. kv.full_k = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_k, 1), ggml_unsqueeze(ctx, *k, 1)), 1);
  104. kv.full_v = ggml_squeeze(ctx, ggml_concat(ctx, ggml_unsqueeze(ctx, kv.full_v, 1), ggml_unsqueeze(ctx, *v, 1)), 1);
  105. } else {
  106. GGML_ASSERT(step_nr == 0);
  107. k_proj = (*k)->ne[0];
  108. batch_size = (*v)->ne[2];
  109. kv.full_k = ggml_dup(ctx, *k);
  110. kv.full_v = ggml_dup(ctx, *v);
  111. }
  112. *k = kv.full_k;
  113. *v = kv.full_v;
  114. ggml_format_name(kv.full_k, "%s.k (step=%d)", prefix.c_str(), step_nr);
  115. ggml_format_name(kv.full_v, "%s.v (step=%d)", prefix.c_str(), step_nr);
  116. step_nr += n_steps;
  117. GGML_ASSERT_SHAPE(kv.full_k, k_proj, step_nr, batch_size, 1);
  118. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  119. // we return the Sq slice of the (Sq, Sk) attention mask
  120. if (self_attn_mask != nullptr) {
  121. *self_attn_mask = ggml_slice(
  122. ctx, ggml_slice(ctx, kv.self_attn_mask, 0, 0, step_nr),
  123. 1, step_nr - 1, step_nr
  124. );
  125. }
  126. kv.step_nr = step_nr;
  127. }
  128. // variant of ggml_get_rows that allows for a with more than 2 dims.
  129. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  130. int flattened = 0;
  131. GGML_ASSERT(a->n_dims <= 3);
  132. if (a->n_dims == 3) {
  133. flattened = a->ne[0];
  134. a = ggml_flatten_1d(ctx, a, 0);
  135. }
  136. a = ggml_get_rows(ctx, a, b);
  137. if (flattened) {
  138. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  139. }
  140. return a;
  141. }
  142. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  143. // GGML_ASSERT(ctx == kv.full_k->con);
  144. if (kv.full_k != nullptr) {
  145. ggml_detach(kv.full_k);
  146. const char* name = kv.full_k->name;
  147. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  148. ggml_build_forward_expand(gf, kv.full_k);
  149. ggml_format_name(kv.full_k, "%s (sorted)", name);
  150. }
  151. if (kv.full_v != nullptr) {
  152. ggml_detach(kv.full_v);
  153. const char* name = kv.full_v->name;
  154. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  155. ggml_build_forward_expand(gf, kv.full_v);
  156. ggml_format_name(kv.full_v, "%s (sorted)", name);
  157. }
  158. }
  159. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  160. auto self_attn_glob = "*.self_attn";
  161. for (auto& named_kv : model.kv_cache) {
  162. if (::fnmatch(self_attn_glob, named_kv.first.c_str(), 0) == FNM_NOMATCH)
  163. continue;
  164. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  165. }
  166. }
  167. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  168. const std::int64_t* data = &model.layer_config.at(name);
  169. double val = *(const double*)data;
  170. return val;
  171. }
  172. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  173. return model_layer_config_d(model, std::string(name));
  174. }
  175. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  176. return model.layer_config.at(std::string(name));
  177. }
  178. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  179. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  180. // delete model;
  181. }
  182. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  183. model->ctx = ctx;
  184. }
  185. extern "C" std::string* std_string_alloc(char* c_str) {
  186. return new std::string(c_str);
  187. }
  188. extern "C" void std_string_free(std::string* str) {
  189. delete str;
  190. }
  191. bool has_layer(fairseq2_model& model, const std::string& name) {
  192. return model.tensors.find(name) != model.tensors.end();
  193. }
  194. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  195. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  196. // `b` has shape (B, 1, D).
  197. // if `a` is (D_out, D), then we do one matmul for the full batch.
  198. b = ggml_flatten_1d(ctx, b, 1);
  199. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  200. }
  201. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  202. // not sure what's the best way to compute this with BLAS
  203. return ggml_mul_mat(ctx, a, b); // (d_out)
  204. }
  205. extern "C" ggml_tensor* Linear_forward(
  206. fairseq2_model& model,
  207. const std::string &prefix,
  208. ggml_tensor* input // (d_in)
  209. ) {
  210. // Note: for now we assumed un-batched input
  211. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  212. GGML_ASSERT(weight != nullptr);
  213. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  214. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  215. if (bias == nullptr) return out;
  216. return ggml_add(model.ctx, out, bias);
  217. }
  218. extern "C" ggml_tensor* LayerNorm_forward(
  219. fairseq2_model& model,
  220. const std::string &prefix,
  221. ggml_tensor* input
  222. ) {
  223. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  224. GGML_ASSERT(weight != nullptr);
  225. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  226. GGML_ASSERT(bias != nullptr);
  227. auto ctx = model.ctx;
  228. double eps = model_layer_config_d(model, prefix + ".eps");
  229. input = ggml_norm(ctx, input, /*eps*/eps);
  230. return ggml_add_inplace(
  231. ctx,
  232. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  233. ggml_repeat(ctx, bias, input)
  234. );
  235. }
  236. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  237. fairseq2_model& model,
  238. const std::string& prefix,
  239. ggml_tensor* seqs
  240. ) {
  241. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  242. // inner_activation = ReLu // TODO: allow other activation
  243. seqs = ggml_relu_inplace(model.ctx, seqs);
  244. if (has_layer(model, prefix + ".inner_layer_norm")) {
  245. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  246. }
  247. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  248. return seqs;
  249. }
  250. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  251. fairseq2_model& model,
  252. const std::string& prefix,
  253. ggml_tensor* seqs
  254. ) {
  255. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  256. seqs = ggml_silu(model.ctx, seqs);
  257. if (has_layer(model, prefix + ".inner_layer_norm")) {
  258. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  259. }
  260. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  261. return seqs;
  262. }
  263. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  264. int n_dims = x->n_dims;
  265. GGML_ASSERT(dim >= 0);
  266. GGML_ASSERT(dim < n_dims);
  267. GGML_ASSERT(ggml_is_contiguous(x));
  268. // Nothing to do
  269. if (dim == n_dims - 1) return x;
  270. if (n_dims == 2) {
  271. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  272. } else if (n_dims == 3) {
  273. if (dim == 0) {
  274. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  275. } else { // dim == 1
  276. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  277. }
  278. } else { // n_dims == 4
  279. if (dim == 0) {
  280. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  281. } else if (dim == 1) {
  282. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  283. } else { // dim == 2
  284. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  285. }
  286. }
  287. }
  288. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  289. int n_dims = x->n_dims;
  290. GGML_ASSERT(dim >= 0);
  291. GGML_ASSERT(dim < n_dims);
  292. GGML_ASSERT(n_dims < 4);
  293. GGML_ASSERT(x->ne[dim] % num_el == 0);
  294. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  295. if (n_dims == 1) {
  296. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  297. } else if (n_dims == 2) {
  298. if (dim == 0) {
  299. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  300. } else { // dim == 1
  301. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  302. }
  303. } else { // (n_dims == 3)
  304. if (dim == 0) {
  305. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  306. } else if (dim == 1) {
  307. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  308. } else { // dim == 2
  309. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  310. }
  311. }
  312. }
  313. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  314. // (B, S, dim) -> (B, S, H, H_dim)
  315. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  316. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  317. x = ggml_cont(ctx, x);
  318. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  319. return x;
  320. }
  321. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  322. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  323. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  324. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  325. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  326. v = ggml_cont(ctx, v);
  327. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  328. return v;
  329. }
  330. // flash_attn doesn't work for cross attention because it assumes Q <= K
  331. // and it seems to yield slightly different scores than expected, and thus a different beam search
  332. # define UNITY_FLASH_ATTN 0
  333. extern "C" ggml_tensor* MultiheadAttention_forward(
  334. fairseq2_model& model,
  335. const std::string &prefix,
  336. ggml_tensor* queries, // (slen, d_in)
  337. ggml_tensor* keys, // (klen, d_in)
  338. ggml_tensor* values, // (klen, d_out)
  339. ggml_tensor* attn_mask // (klen, slen)
  340. ) {
  341. int model_dim = queries->ne[0];
  342. int num_heads = model.layer_config.at(prefix + ".num_heads");
  343. int head_dim = model_dim / num_heads;
  344. GGML_ASSERT(model_dim % num_heads == 0);
  345. ggml_context* ctx = model.ctx;
  346. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  347. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  348. ggml_set_name(q, "q");
  349. ggml_tensor *k, *v;
  350. if (!has_kv_cache(model)) {
  351. k = Linear_forward(model, prefix + ".k_proj", keys);
  352. ggml_set_name(k, "k");
  353. v = Linear_forward(model, prefix + ".v_proj", values);
  354. ggml_set_name(v, "v");
  355. } else {
  356. bool encoder_decoder_attn = keys == values && keys != queries;
  357. if (encoder_decoder_attn) {
  358. // The K and V tensors of an encoder-decoder attention (i.e. the
  359. // projected encoder outputs) remain static during evaluation.
  360. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  361. if (kv_cache.step_nr == 0) {
  362. // If possible we use the ctx dedicated to kv_cache here,
  363. // because the enc dec attention is typically long lived.
  364. if (model.enc_kv_cache_ctx) model.ctx = model.enc_kv_cache_ctx;
  365. k = Linear_forward(model, prefix + ".k_proj", keys);
  366. ggml_set_name(k, "k");
  367. v = Linear_forward(model, prefix + ".v_proj", values);
  368. ggml_set_name(v, "v");
  369. // Note we are only storing a pointer to the buffer, not the full graph
  370. kv_cache.full_k = ggml_detach(ggml_dup_inplace(model.ctx, k));
  371. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  372. kv_cache.full_v = ggml_detach(ggml_dup_inplace(model.ctx, v));
  373. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  374. kv_cache.step_nr = keys->ne[1];
  375. model.ctx = ctx;
  376. } else {
  377. k = kv_cache.full_k;
  378. v = kv_cache.full_v;
  379. GGML_ASSERT(keys->ne[1] == k->ne[1]); // cache content doesn't match the input sequence
  380. GGML_ASSERT(values->ne[1] == v->ne[1]); // cache content doesn't match the input sequence
  381. }
  382. } else { // self attention
  383. // (1, K) -> (N, 1, K_proj)
  384. k = Linear_forward(model, prefix + ".k_proj", keys);
  385. ggml_set_name(k, "k");
  386. // (1, V) -> (N, 1, V_proj)
  387. v = Linear_forward(model, prefix + ".v_proj", values);
  388. ggml_set_name(v, "v");
  389. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  390. }
  391. }
  392. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  393. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  394. v = ggml_cont(ctx, v);
  395. #if UNITY_FLASH_ATTN
  396. // For flash_attn, we assume either no masks, or triangular masks.
  397. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  398. ggml_set_name(attn, "attn");
  399. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  400. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  401. #else
  402. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  403. ggml_tensor* qk = mul_mat(ctx, k, q);
  404. ggml_set_name(qk, "qk");
  405. FORCE_ALLOC(qk_scale, ctx, ggml_new_tensor_1d(ctx, qk->type, 1));
  406. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  407. qk = ggml_scale(ctx, qk, qk_scale);
  408. ggml_set_name(qk, "qk_scaled");
  409. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  410. // TODO: upgrade qk to float32 if needed
  411. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  412. ggml_set_name(attn_weights, "attn_weights");
  413. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  414. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  415. ggml_set_name(attn, "attn");
  416. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  417. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  418. #endif // UNITY_FLASH_ATTN
  419. attn = ggml_cont(ctx, attn);
  420. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  421. // out -> (B, S, d_out)
  422. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  423. ggml_set_name(out, "out");
  424. return out;
  425. }
  426. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  427. fairseq2_model& model,
  428. const std::string& prefix,
  429. ggml_tensor* seqs,
  430. ggml_tensor* padding_mask
  431. ) {
  432. ggml_context* ctx = model.ctx;
  433. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  434. // _forward_self_attn(seqs, padding_mask)
  435. auto residual = seqs;
  436. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  437. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  438. // TODO: add padding_mask to MultiheadAttention_forward
  439. GGML_ASSERT(padding_mask == nullptr);
  440. seqs = MultiheadAttention_forward(
  441. model,
  442. prefix + ".self_attn",
  443. seqs,
  444. seqs,
  445. seqs,
  446. /*attn_mask=*/nullptr
  447. );
  448. if (has_layer(model, prefix + ".self_attn_norm"))
  449. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  450. seqs = ggml_add_inplace(ctx, seqs, residual);
  451. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  452. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  453. // _forward_ffn(seqs)
  454. residual = seqs;
  455. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  456. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  457. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  458. // TODO: if self.residual_scale is not None:
  459. // residual = self.residual_scale * residual
  460. seqs = ggml_add_inplace(ctx, seqs, residual);
  461. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  462. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  463. return seqs;
  464. }
  465. extern "C" ggml_tensor* WaveformToFbank_forward(
  466. fairseq2_model& model,
  467. const std::string &prefix,
  468. ggml_tensor* waveform
  469. ) {
  470. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  471. ggml_context* ctx = model.ctx;
  472. knf::MelBanksOptions mel_opts{};
  473. mel_opts.num_bins = 80;
  474. knf::FrameExtractionOptions frame_opts{};
  475. frame_opts.samp_freq = 16000;
  476. knf::FbankOptions opts{};
  477. opts.frame_opts = frame_opts;
  478. opts.mel_opts = mel_opts;
  479. std::vector<float_t> signal_frame{};
  480. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  481. FORCE_ALLOC(output, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames));
  482. knf::FbankComputer native_(opts);
  483. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  484. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  485. signal_frame.resize(0);
  486. // Extract the frame from the waveform tensor.
  487. knf::ExtractWindow(
  488. /*sample_offset=*/0,
  489. (float *)(waveform->data),
  490. waveform->ne[0],
  491. frame_nr,
  492. frame_opts,
  493. window_fn_,
  494. &signal_frame);
  495. native_.Compute(
  496. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  497. }
  498. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  499. output = ggml_norm(ctx, output, 1e-5);
  500. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  501. if (output->ne[1] % 2 == 1) {
  502. output = ggml_dup(ctx, ggml_slice(ctx, output, 1, 0, output->ne[1]-1));
  503. }
  504. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  505. return output;
  506. }
  507. // TODO: Check if it's possible to merge with standard MHA
  508. extern "C" ggml_tensor* RelativePositionMHA_forward(
  509. fairseq2_model& model,
  510. const std::string& prefix,
  511. ggml_tensor* seqs
  512. ) {
  513. ggml_context* ctx = model.ctx;
  514. ggml_tensor* residual = seqs;
  515. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  516. // self_attn: qkv
  517. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  518. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  519. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  520. // self_attn: rel_pos SDPA
  521. int32_t S = seqs->ne[1];
  522. int32_t H = 16; // TODO: Make this configurable
  523. int32_t n_ctx = 4096;
  524. int32_t K_h = seqs->ne[0] / H;
  525. int32_t start_index = n_ctx - S;
  526. int32_t end_index = n_ctx + S - 1;
  527. int num_indices = end_index - start_index;
  528. FORCE_ALLOC(rows, ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices));
  529. for (int i = 0; i < num_indices; i++) {
  530. ((int32_t *)rows->data)[i] = start_index + i;
  531. }
  532. // self_attn: load pos_enc weights & compute_r
  533. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  534. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  535. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  536. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  537. r = ggml_dup(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, r, 0, K_h), 0, 2, 1, 3));
  538. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  539. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  540. // self_attn: Permute QKV
  541. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  542. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Qcur, 0, K_h), 0, 2, 1, 3));
  543. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  544. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Kcur, 0, K_h), 0, 2, 1, 3));
  545. // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  546. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx, ggml_unflatten_1d(ctx, Vcur, 0, K_h), 1, 2, 0, 3));
  547. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  548. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  549. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  550. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  551. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  552. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  553. FORCE_ALLOC(pad, ctx, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1));
  554. pad = ggml_set_f32(pad, 0.0);
  555. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  556. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  557. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  558. // discard the first set of positive positions
  559. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  560. // shifts each row by an extra step
  561. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  562. // Discard positions used for shift.
  563. bd = ggml_slice(ctx, bd, 0, 0, S);
  564. // self_attn: compute attn / weights
  565. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  566. FORCE_ALLOC(attn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  567. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  568. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  569. attn_weights = ggml_soft_max(ctx, attn_weights);
  570. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  571. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  572. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  573. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  574. attn_out = ggml_add_inplace(
  575. ctx,
  576. attn_out,
  577. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  578. );
  579. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  580. return attn_out;
  581. }
  582. extern "C" ggml_tensor* ConvModule_forward(
  583. fairseq2_model& model,
  584. const std::string& prefix,
  585. ggml_tensor* seqs
  586. ) {
  587. ggml_context* ctx = model.ctx;
  588. ggml_tensor* residual = seqs;
  589. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  590. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  591. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  592. // conv: GLU
  593. seqs = ggml_glu(ctx, seqs);
  594. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  595. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  596. int K = model.tensors[prefix + ".depthwise_conv.weight"]->ne[0];
  597. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, K / 2, 1, seqs->ne[1]);
  598. // conv: Custom implementation of batch norm
  599. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  600. // conv: SiLU actvation
  601. seqs = ggml_silu_inplace(ctx, seqs);
  602. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  603. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  604. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  605. // conv: + residual
  606. seqs = ggml_add_inplace(ctx, seqs, residual);
  607. return seqs;
  608. }
  609. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  610. fairseq2_model& model,
  611. const std::string& prefix,
  612. ggml_tensor* seqs,
  613. ggml_tensor* padding_mask
  614. ) {
  615. ggml_context* ctx = model.ctx;
  616. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  617. ggml_set_f32(ffn_scale, 0.5f);
  618. ggml_tensor* residual = seqs;
  619. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  620. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  621. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  622. seqs = ggml_add_inplace(ctx, seqs, residual);
  623. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  624. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  625. residual = seqs;
  626. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  627. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  628. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  629. seqs = ggml_add_inplace(ctx, seqs, residual);
  630. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  631. return seqs;
  632. }
  633. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  634. fairseq2_model& model,
  635. const std::string& prefix,
  636. ggml_tensor* seqs,
  637. ggml_tensor* padding_mask
  638. ) {
  639. ggml_context* ctx = model.ctx;
  640. seqs = WaveformToFbank_forward(model, prefix, seqs);
  641. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  642. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  643. int layer_idx = 0;
  644. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  645. while (has_layer(model, layer_name)) {
  646. seqs = StandardConformerEncoderLayer_forward(
  647. model, layer_name, seqs, padding_mask
  648. );
  649. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  650. layer_idx += 1;
  651. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  652. }
  653. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  654. ggml_tensor* residual = seqs;
  655. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  656. seqs = ggml_relu_inplace(ctx, seqs);
  657. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  658. FORCE_ALLOC(ffn_scale, ctx, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1));
  659. ggml_set_f32(ffn_scale, 0.5f);
  660. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  661. seqs = ggml_add_inplace(ctx, seqs, residual);
  662. layer_idx = 0;
  663. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  664. while (has_layer(model, layer_name)) {
  665. seqs = StandardConformerEncoderAdaptorLayer_forward(
  666. model, layer_name, seqs, padding_mask
  667. );
  668. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  669. layer_idx += 1;
  670. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  671. }
  672. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  673. return seqs;
  674. }
  675. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  676. fairseq2_model& model,
  677. const std::string& prefix,
  678. ggml_tensor* seqs,
  679. ggml_tensor* padding_mask
  680. ) {
  681. ggml_context* ctx = model.ctx;
  682. ggml_tensor* residual = seqs;
  683. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  684. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  685. residual = ggml_conv_1d(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1, 1);
  686. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  687. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  688. residual = ggml_glu(ctx, residual);
  689. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  690. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  691. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1, 1);
  692. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  693. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  694. seqs = ggml_glu(ctx, seqs);
  695. seqs = MultiheadAttention_forward(
  696. model,
  697. prefix + ".self_attn",
  698. seqs,
  699. seqs,
  700. seqs,
  701. /*attention masks=*/nullptr
  702. );
  703. seqs = ggml_add_inplace(ctx, seqs, residual);
  704. residual = seqs;
  705. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  706. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  707. seqs = ggml_add_inplace(ctx, seqs, residual);
  708. return seqs;
  709. }
  710. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  711. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  712. ggml_tensor* ggml_slice(
  713. struct ggml_context * ctx,
  714. struct ggml_tensor * a,
  715. int axis,
  716. int64_t start,
  717. int64_t end
  718. ) {
  719. int64_t ne[4];
  720. std::copy(a->ne, a->ne + 4, ne);
  721. if (axis < 0) axis = a->n_dims + axis;
  722. if (start < 0) start = ne[axis] + start;
  723. if (end <= 0) end = ne[axis] + end;
  724. GGML_ASSERT(0 <= start);
  725. GGML_ASSERT(start < end);
  726. GGML_ASSERT(end <= ne[axis]);
  727. ne[axis] = end - start;
  728. std::size_t offset = a->nb[axis] * start;
  729. std::size_t* nb = a->nb;
  730. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  731. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  732. result->n_dims = a->n_dims;
  733. return result;
  734. }
  735. ggml_tensor* ggml_select(
  736. struct ggml_context * ctx,
  737. struct ggml_tensor * a,
  738. int axis,
  739. int64_t index
  740. ) {
  741. int64_t ne[GGML_MAX_DIMS];
  742. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  743. if (axis < 0) axis = a->n_dims + axis;
  744. if (index < 0) index = ne[axis] + index;
  745. GGML_ASSERT(0 <= index);
  746. GGML_ASSERT(index < ne[axis]);
  747. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  748. std::size_t offset = a->nb[axis] * index;
  749. std::size_t* nb = a->nb;
  750. GGML_ASSERT(GGML_MAX_DIMS == 4);
  751. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  752. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  753. result->n_dims = a->n_dims - 1;
  754. return result;
  755. }
  756. // Inplace computation of PositionalEmbedding
  757. extern "C" ggml_tensor* PositionalEmbedding_forward(
  758. fairseq2_model& model,
  759. const std::string& prefix,
  760. ggml_tensor* embeds
  761. ) {
  762. // This only work with the simple pos encoders
  763. int seq_len = embeds->ne[1];
  764. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  765. int start_step = 0;
  766. if (has_kv_cache(model)) {
  767. start_step = model.kv_cache[prefix].step_nr++;
  768. }
  769. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  770. return ggml_add(model.ctx, embeds, pos_embeds);
  771. }
  772. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  773. fairseq2_model& model,
  774. const std::string& prefix,
  775. ggml_tensor* seqs
  776. ) {
  777. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  778. ggml_context* ctx = model.ctx;
  779. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  780. GGML_ASSERT(embed_weights != nullptr);
  781. ggml_tensor* embeds;
  782. if (seqs->n_dims == 1) {
  783. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  784. } else {
  785. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  786. ggml_tensor* flat_seqs = seqs;
  787. if (!ggml_is_contiguous(seqs)) {
  788. flat_seqs = ggml_cont(ctx, flat_seqs);
  789. }
  790. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  791. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  792. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  793. embeds->n_dims = seqs->n_dims + 1;
  794. }
  795. // padding mask ?
  796. // padding_mask = to_padding_mask(embeds, seq_lens)
  797. if (has_layer(model, prefix + ".pos_encoder")) {
  798. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  799. }
  800. if (has_layer(model, prefix + ".layer_norm")) {
  801. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  802. }
  803. return embeds;
  804. }
  805. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  806. fairseq2_model& model,
  807. const std::string& prefix,
  808. ggml_tensor* seqs,
  809. ggml_tensor* padding_mask
  810. ) {
  811. int layer_idx = 0;
  812. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  813. while (has_layer(model, layer_name)) {
  814. seqs = StandardTransformerEncoderLayer_forward(
  815. model, layer_name, seqs, padding_mask
  816. );
  817. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  818. layer_idx += 1;
  819. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  820. }
  821. if (has_layer(model, prefix + ".layer_norm"))
  822. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  823. return seqs;
  824. }
  825. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  826. fairseq2_model& model,
  827. const std::string& prefix,
  828. ggml_tensor* seqs,
  829. ggml_tensor* self_attn_mask,
  830. ggml_tensor* encoder_output,
  831. ggml_tensor* encoder_padding_mask
  832. ) {
  833. ggml_context* ctx = model.ctx;
  834. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  835. // _forward_self_attn(seqs, padding_mask)
  836. auto residual = seqs;
  837. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  838. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  839. seqs = MultiheadAttention_forward(
  840. model,
  841. prefix + ".self_attn",
  842. seqs,
  843. seqs,
  844. seqs,
  845. /*attn_mask=*/self_attn_mask
  846. );
  847. if (has_layer(model, prefix + ".self_attn_norm"))
  848. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  849. seqs = ggml_add_inplace(ctx, seqs, residual);
  850. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  851. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  852. // _forward_encoder_decoder_attn
  853. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  854. // `encoder_output` must be `None` for decoder-only attention.
  855. GGML_ASSERT(encoder_output == nullptr);
  856. return seqs;
  857. }
  858. // `encoder_output` must not be `None` for encoder-decoder attention.
  859. GGML_ASSERT(encoder_output != nullptr);
  860. residual = seqs;
  861. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  862. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  863. seqs = MultiheadAttention_forward(
  864. model,
  865. prefix + ".encoder_decoder_attn",
  866. seqs,
  867. encoder_output,
  868. encoder_output,
  869. /*attention masks=*/encoder_padding_mask
  870. );
  871. seqs = ggml_add_inplace(ctx, seqs, residual);
  872. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  873. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  874. // _forward_ffn(seqs)
  875. residual = seqs;
  876. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  877. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  878. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  879. // TODO:
  880. // if self.residual_scale is not None:
  881. // residual = self.residual_scale * residual
  882. seqs = ggml_add_inplace(ctx, seqs, residual);
  883. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  884. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  885. return seqs;
  886. }
  887. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  888. auto seq_len = seqs->ne[1];
  889. // TODO: allow other ggml_type
  890. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  891. return ggml_diag_mask_inf(ctx, mask, 0);
  892. }
  893. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  894. fairseq2_model& model,
  895. const std::string& prefix,
  896. ggml_tensor* seqs,
  897. ggml_tensor* padding_mask,
  898. ggml_tensor* encoder_output,
  899. ggml_tensor* encoder_padding_mask
  900. ) {
  901. int layer_idx = 0;
  902. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  903. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  904. while (has_layer(model, layer_name)) {
  905. seqs = StandardTransformerDecoderLayer_forward(
  906. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  907. );
  908. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  909. layer_idx += 1;
  910. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  911. }
  912. if (has_layer(model, prefix + ".layer_norm"))
  913. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  914. return seqs;
  915. }
  916. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  917. auto opts = job.opts;
  918. int max_seq_len = -1;
  919. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  920. max_seq_len = opts.hard_max_seq_len;
  921. } else {
  922. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  923. }
  924. if (opts.min_seq_len > max_seq_len) {
  925. printf(
  926. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  927. opts.min_seq_len,
  928. max_seq_len
  929. );
  930. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  931. }
  932. int prefix_seq_len = job.prefix_seq->ne[0];
  933. if (prefix_seq_len >= max_seq_len) {
  934. printf(
  935. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  936. prefix_seq_len,
  937. max_seq_len
  938. );
  939. GGML_ASSERT(prefix_seq_len < max_seq_len);
  940. }
  941. return max_seq_len;
  942. }
  943. void _fan_out_encoder_output(
  944. ggml_context* ctx,
  945. ggml_tensor** encoder_output_out,
  946. ggml_tensor** encoder_padding_mask_out,
  947. int beam_size
  948. ) {
  949. // (S_enc, M)
  950. ggml_tensor* encoder_output = *encoder_output_out;
  951. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  952. // (B, S_enc, M)
  953. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  954. // (S_enc, M) -> (B, S_enc, M)
  955. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  956. // (S_enc) -> (B, S_enc)
  957. if (encoder_padding_mask != nullptr) {
  958. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  959. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  960. }
  961. }
  962. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  963. // TODO: this isn't the most precise way of doing this
  964. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  965. }
  966. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  967. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  968. ggml_type true_type = x->type;
  969. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  970. y->type = true_type;
  971. return y;
  972. }
  973. void _bootstrap_seqs_and_scores(
  974. fairseq2_model& model,
  975. const SequenceGeneratorJob& job,
  976. ggml_tensor* full_seqs,
  977. ggml_tensor* scores,
  978. ggml_tensor* encoder_output,
  979. ggml_tensor* encoder_padding_mask,
  980. ggml_tensor* lid_scores,
  981. int n_threads,
  982. const std::vector<int>& lang_ids
  983. ) {
  984. // Returns LID score map
  985. int prefix_seq_len = job.prefix_seq->ne[0];
  986. int max_seq_len = scores->ne[0];
  987. int beam_size = scores->ne[1];
  988. GGML_ASSERT(prefix_seq_len > 0);
  989. ggml_context* ctx = model.ctx;
  990. if (prefix_seq_len == 1) {
  991. // We only have one token in prefix, we won't compute decoding scores,
  992. // we just need to copy the token to seqs.
  993. // Note: it also means the enc_kv_cache will be populated later.
  994. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  995. ggml_set_i32(seqs, ggml_get_i32_1d(job.prefix_seq, 0));
  996. return;
  997. }
  998. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  999. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  1000. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  1001. // We have to bootstrap the model with the already fanned-out encoder
  1002. // output to correctly initialize its incremental state.
  1003. // Note: we don't start decoding the last prefix token just yet.
  1004. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  1005. // Bootstrap the model state with prefix sequence.
  1006. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  1007. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1008. model,
  1009. "text_decoder",
  1010. seqs,
  1011. /*padding_mask*/ nullptr,
  1012. encoder_output,
  1013. encoder_padding_mask
  1014. );
  1015. // logits, lprobs: (N, S_pfx - 1, V)
  1016. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  1017. int vocab_size = logits->ne[0];
  1018. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  1019. struct ggml_cgraph * gf = ggml_new_graph(ctx);
  1020. ggml_build_forward_expand(gf, lprobs);
  1021. ggml_graph_compute_with_ctx(ctx, gf, n_threads);
  1022. full_seqs->type = GGML_TYPE_I32;
  1023. job.prefix_seq->type = GGML_TYPE_I32;
  1024. // For LID
  1025. for (std::size_t i = 0; i < lang_ids.size(); ++i) {
  1026. ggml_set_f32_1d(lid_scores, i, std::exp(ggml_get_f32_1d(lprobs, lang_ids[i])));
  1027. }
  1028. // Fetch scores of next steps from "lprobs"
  1029. float p_score = 0;
  1030. for (int i = 1; i < prefix_seq_len; ++i) {
  1031. int p = 0;
  1032. if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
  1033. // If tgt_lang is unk, use the most probable lang tag predicted by model
  1034. int max_value = std::numeric_limits<float>::min();
  1035. for (std::size_t j = 0; j < lang_ids.size(); j++) {
  1036. if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
  1037. max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
  1038. p = lang_ids[j];
  1039. }
  1040. }
  1041. } else {
  1042. p = ggml_get_i32_1d(job.prefix_seq, i);
  1043. }
  1044. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1045. for (int b = 0; b < beam_size; ++b) {
  1046. // scores: (N, S)
  1047. // Note: First step (e.g. BOS)'s score is always 0.
  1048. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1049. }
  1050. }
  1051. }
  1052. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1053. int topk(
  1054. ggml_tensor* lprobs, // (B, V)
  1055. std::int64_t k,
  1056. ggml_tensor* candidate_indices
  1057. ) {
  1058. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1059. // `beam_size` of these which don't predict EOS to continue with.
  1060. // (N, 2 x B)
  1061. // `vocab_size` - 1 to never select PAD.
  1062. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1063. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1064. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1065. };
  1066. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1067. auto cand = (std::int32_t*)candidate_indices->data;
  1068. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1069. return K;
  1070. }
  1071. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1072. std::size_t beam_size = job.opts.beam_size;
  1073. std::size_t eos_idx = job.eos_idx;
  1074. // Do not allow EOS before reaching the minimum sequence length.
  1075. if (step_nr < job.opts.min_seq_len) {
  1076. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1077. for (std::size_t i = 0; i < beam_size; ++i)
  1078. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1079. }
  1080. // If we have reached the maximum length, force the last step to be EOS.
  1081. if (step_nr == max_seq_len - 2) {
  1082. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1083. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1084. for (std::size_t b = 0; b < beam_size; ++b) {
  1085. std::size_t t = 0;
  1086. for (t = 0; t < eos_idx; ++t)
  1087. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1088. for (t = eos_idx + 1; t < vocab_size; ++t)
  1089. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1090. }
  1091. }
  1092. // Never allow PAD.
  1093. std::size_t pad_idx = job.pad_idx;
  1094. for (std::size_t i = 0; i < beam_size; ++i)
  1095. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1096. // Apply UNK penalty.
  1097. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1098. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1099. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1100. for (std::size_t i = 0; i < beam_size; ++i)
  1101. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1102. }
  1103. }
  1104. /// Copies the sequence and scores of a given candidate beam.
  1105. void _finalize_hypothesis(
  1106. const SequenceGeneratorJob& job,
  1107. ggml_context* ctx,
  1108. int step_nr,
  1109. std::int32_t beam,
  1110. std::int32_t token,
  1111. float eos_score,
  1112. ggml_tensor* seqs, // (beam_size, seq_len)
  1113. ggml_tensor* scores, // (beam_size, seq_len)
  1114. ggml_tensor* lid_scores,
  1115. Hypothesis* hypothesis
  1116. ) {
  1117. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1118. hypothesis->seq = seq;
  1119. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1120. hypothesis->step_scores = step_scores;
  1121. auto tok = (std::int32_t*)seq->data;
  1122. for (int i = 0; i < step_nr + 1; ++i) {
  1123. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1124. }
  1125. tok[step_nr + 1] = token;
  1126. // Convert from cumulative to per-step scores.
  1127. auto sc = (float*)step_scores->data;
  1128. float last_score = eos_score;
  1129. for (int i = step_nr; i >= 0; --i) {
  1130. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1131. sc[i + 1] = last_score - sc0;
  1132. last_score = sc0;
  1133. }
  1134. sc[0] = 0;
  1135. if (job.opts.normalize_scores)
  1136. // Skip first EOS since it is always 0 and skews normalization.
  1137. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1138. hypothesis->score = eos_score;
  1139. hypothesis->lid_scores = lid_scores;
  1140. }
  1141. // Uses ggml_context to store any object.
  1142. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1143. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1144. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1145. return ggml_init({
  1146. /*.mem_size =*/ static_cast<std::size_t>(buffer.capacity()),
  1147. /*.mem_buffer =*/ buffer.data(),
  1148. /*.no_alloc =*/ false,
  1149. });
  1150. }
  1151. ggml_allocr* new_arena_allocr(std::vector<uint8_t>& buffer) {
  1152. return ggml_allocr_new(buffer.data(), buffer.capacity(), 8);
  1153. }
  1154. /// Generates a translation for a single sequence
  1155. /// The results Hypothesis are written inside `result_ctx`.
  1156. extern "C" Hypothesis* generate_sequence(
  1157. fairseq2_model& model,
  1158. const SequenceGeneratorJob& job,
  1159. ggml_tensor* encoder_output,
  1160. ggml_tensor* encoder_padding_mask,
  1161. ggml_context* result_ctx,
  1162. int n_threads
  1163. ) {
  1164. // Pre allocate memory buffers.
  1165. // * step_ctx: contains metadata for the model graph, as well as some explicit
  1166. // buffers for the lprobs tweaking.
  1167. // * prev_step_ctx: is an additional buffer because we need some results from previous steps,
  1168. // to compute next step. Notably self attention kv cache.
  1169. // * search_ctx contains tensors that should live for the full search,
  1170. // like encoder kv cache.
  1171. // * step_alloc contains buffer for the forward pass of the model.
  1172. // Split mem_mb into the different context we need to use.
  1173. int mem_mb = job.opts.mem_mb;
  1174. std::vector<uint8_t> local_bufs[4] = {
  1175. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // step_ctx
  1176. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // prev_step_ctx
  1177. std::vector<uint8_t>(mem_mb * MB * 3 / 10), // search_ctx
  1178. std::vector<uint8_t>(mem_mb * MB * 1 / 10), // step_alloc
  1179. };
  1180. ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
  1181. std::vector<int> lang_ids;
  1182. if (job.prefix_seq->ne[0] > 1) {
  1183. for (const auto& kv : model.vocab.token_to_id) {
  1184. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  1185. lang_ids.push_back(kv.second);
  1186. }
  1187. }
  1188. std::sort(lang_ids.begin(), lang_ids.end());
  1189. }
  1190. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1191. std::size_t vocab_size = embed->ne[1];
  1192. std::size_t beam_size = job.opts.beam_size;
  1193. ggml_detach(encoder_output);
  1194. int source_seq_len = encoder_output->ne[1];
  1195. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1196. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1197. ggml_context* original_ctx = model.ctx;
  1198. fairseq2_kv_cache_alloc(model, search_ctx, beam_size, max_seq_len);
  1199. // (S_enc, M) -> (B, S_enc, M)
  1200. model.ctx = search_ctx;
  1201. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1202. // Allocate results in the context provided by the caller.
  1203. ggml_set_no_alloc(result_ctx, false);
  1204. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1205. Hypothesis* finished_searches = finished_searches_begin;
  1206. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1207. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1208. // Initialize buffers. (B, S)
  1209. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1210. ggml_set_i32(seqs, 0);
  1211. ggml_set_name(seqs, "seqs_0");
  1212. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1213. ggml_set_name(scores, "scores_0");
  1214. ggml_set_f32(scores, 0.0);
  1215. int prefix_seq_len = job.prefix_seq->ne[0];
  1216. int start_step = prefix_seq_len - 1;
  1217. ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step + 1) % 2]);
  1218. ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
  1219. GGML_ASSERT(step_ctx != search_ctx);
  1220. model.enc_kv_cache_ctx = search_ctx;
  1221. ggml_tensor* lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, 1); // Dummy initialization to get rid of warnings
  1222. if (lang_ids.size()) {
  1223. lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
  1224. }
  1225. // Multilingual models: Bootstrap LID scores
  1226. _bootstrap_seqs_and_scores(
  1227. model, job, seqs, scores, encoder_output, encoder_padding_mask, lid_scores, n_threads, lang_ids
  1228. );
  1229. // Holds the indices of beams (a beam can occur more than once) that we
  1230. // should continue with in the next step.
  1231. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1232. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1233. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1234. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1235. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1236. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1237. ((int32_t *)(candidate_indices->data))[i] = i;
  1238. printf_mem_usage(search_ctx, "search_ctx");
  1239. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1240. model.ctx = step_ctx;
  1241. ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
  1242. int p = 0;
  1243. if (step_nr == start_step) {
  1244. // Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
  1245. if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
  1246. float max_lprob = std::numeric_limits<float>::min();
  1247. for(std::size_t j = 0; j < lang_ids.size(); j++) {
  1248. auto val = ggml_get_f32_1d(lid_scores, j);
  1249. if (val > max_lprob) {
  1250. max_lprob = val;
  1251. p = lang_ids[j];
  1252. }
  1253. }
  1254. for (std::size_t k = 0; k < beam_size; k++) {
  1255. ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
  1256. }
  1257. }
  1258. }
  1259. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1260. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1261. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1262. model,
  1263. "text_decoder",
  1264. decoder_input,
  1265. nullptr, // We never generate PAD.
  1266. encoder_output,
  1267. encoder_padding_mask
  1268. ); // (B, 1, D)
  1269. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1270. // Force logits to be allocated in step_ctx, not in step_alloc.
  1271. ggml_set_no_alloc(step_ctx, false);
  1272. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1273. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1274. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1275. // TODO: use ggml properly compute the tweaks
  1276. struct ggml_cgraph * gf = ggml_new_graph(step_ctx);
  1277. ggml_build_forward_expand(gf, lprobs);
  1278. std::size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, gf);
  1279. GGML_UNUSED(fwd_mem);
  1280. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1281. ggml_detach(lprobs);
  1282. ggml_allocr_reset(step_alloc);
  1283. #if DEBUG_MEM_USAGE
  1284. printf("beam search step %d. Graph.n_nodes: %d.\n", step_nr, gf.n_nodes);
  1285. printf(" Fwd mem: %.1fMB, reserved %.1fMb\n", fwd_mem/(double)MB, local_bufs[3].capacity()/(double)MB);
  1286. std::fill(local_bufs[3].begin(), local_bufs[3].end(), 0xAA);
  1287. #endif
  1288. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1289. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1290. if (step_nr == start_step) {
  1291. // At the initial step, all hypotheses are equally likely, so we use
  1292. // only the first beam.
  1293. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1294. lprobs = ggml_cont(step_ctx, lprobs);
  1295. // The first step always indicates the beginning of the sequence and has no score.
  1296. if (step_nr > 0) {
  1297. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1298. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1299. }
  1300. } else {
  1301. // Make probabilities contain cumulative scores for each hypothesis.
  1302. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1303. }
  1304. ggml_build_forward_expand(gf, lprobs);
  1305. ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
  1306. // Determine (beam, token) candidates for the next step.
  1307. // (N, 2 x B)
  1308. std::int64_t K = topk(
  1309. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1310. );
  1311. std::size_t ongoing_beams = 0;
  1312. for (std::int32_t i = 0; i < K; ++i) {
  1313. int c = ggml_get_f32_1d(candidate_indices, i);
  1314. std::int32_t beam = c / vocab_size;
  1315. std::int32_t token = c % vocab_size;
  1316. float tok_score = ggml_get_f32_1d(lprobs, c);
  1317. // Detect beams that reached the minimum length and that end with an EOS.
  1318. bool eos = token == job.eos_idx;
  1319. eos &= tok_score != -INFINITY;
  1320. if (eos) {
  1321. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, lid_scores, finished_searches++);
  1322. if (finished_searches == finished_searches_end)
  1323. goto end_of_beam_search;
  1324. continue;
  1325. }
  1326. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1327. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1328. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1329. ongoing_beams += 1;
  1330. if (ongoing_beams >= beam_size) break;
  1331. }
  1332. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1333. // be selected more than once.
  1334. // (B, S), (B) -> (B, S)
  1335. // don't use allocr API, cause it might reuse a kv cache buffer several time.
  1336. ggml_set_no_alloc(step_ctx, false);
  1337. ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
  1338. ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
  1339. struct ggml_cgraph * gf_reorder = ggml_new_graph(step_ctx);
  1340. ggml_build_forward_expand(gf_reorder, new_seqs);
  1341. ggml_build_forward_expand(gf_reorder, new_scores);
  1342. reorder_kv_cache(model, step_ctx, gf_reorder, beam_indices);
  1343. ggml_graph_compute_with_ctx(step_ctx, gf_reorder, n_threads);
  1344. seqs = ggml_detach(new_seqs);
  1345. scores = ggml_detach(new_scores);
  1346. // seqs[:, step_nr + 1] = next_tokens
  1347. // scores[:, step_nr + 1] = next_scores
  1348. for (std::size_t i = 0; i < beam_size; ++i) {
  1349. ((std::int32_t*)seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1350. ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1351. }
  1352. printf_mem_usage(step_ctx, "step_ctx");
  1353. ggml_free(prev_step_ctx);
  1354. prev_step_ctx = step_ctx;
  1355. #if DEBUG_MEM_USAGE
  1356. std::fill(local_bufs[(step_nr + 1) % 2].begin(), local_bufs[(step_nr + 1) % 2].end(), 0xAA);
  1357. #endif
  1358. step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1359. }
  1360. end_of_beam_search:
  1361. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1362. std::sort(
  1363. finished_searches_begin,
  1364. finished_searches_end,
  1365. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1366. );
  1367. printf_mem_usage(search_ctx, "search_ctx");
  1368. fairseq2_kv_cache_reset(model);
  1369. model.ctx = original_ctx;
  1370. return finished_searches_begin;
  1371. }
  1372. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1373. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1374. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1375. ggml_set_i32_1d(result[0].seq, 0, 314);
  1376. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1377. ggml_set_i32_1d(result[1].seq, 0, 421);
  1378. return result;
  1379. }
  1380. // SPM tokenizer
  1381. // original implementation:
  1382. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1383. struct llm_symbol {
  1384. using index = int;
  1385. index prev;
  1386. index next;
  1387. const char * text;
  1388. std::size_t n;
  1389. llama_vocab::id id;
  1390. };
  1391. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1392. static std::size_t utf8_len(char src) {
  1393. const std::size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1394. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1395. return lookup[highbits];
  1396. }
  1397. struct llm_bigram_spm {
  1398. struct comparator {
  1399. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1400. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1401. }
  1402. };
  1403. using queue_storage = std::vector<llm_bigram_spm>;
  1404. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1405. llm_symbol::index left;
  1406. llm_symbol::index right;
  1407. float score;
  1408. std::size_t size;
  1409. llama_vocab::id id;
  1410. };
  1411. struct llm_tokenizer_spm {
  1412. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1413. void tokenize(const std::string& input_text, ggml_tensor* output) {
  1414. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1415. // split string into utf8 chars
  1416. int index = 0;
  1417. std::size_t offs = 0;
  1418. // This is kind of annoying, but needed because with SPM,
  1419. // characters following a space have a special meaning.
  1420. // And the algorithm rely on substrings to do the lookups.
  1421. std::string text = input_text;
  1422. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1423. if (need_extra_space) text = " " + text;
  1424. while (offs < text.size()) {
  1425. std::size_t len = utf8_len(text[offs]);
  1426. std::size_t n = std::min(len, text.size() - offs);
  1427. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1428. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1429. llm_symbol sym = {
  1430. /*prev*/ index - 1,
  1431. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1432. /*text*/ text.c_str() + offs,
  1433. /*n*/ n,
  1434. /*id*/ id
  1435. };
  1436. offs += n;
  1437. index++;
  1438. symbols.emplace_back(sym);
  1439. }
  1440. // seed the work queue with all possible 2-character tokens.
  1441. for (std::size_t i = 1; i < symbols.size(); ++i) {
  1442. try_add_bigram(i - 1, i);
  1443. }
  1444. // keep substituting the highest frequency pairs for as long as we can.
  1445. while (!work_queue.empty()) {
  1446. auto bigram = work_queue.top();
  1447. work_queue.pop();
  1448. auto & left_sym = symbols[bigram.left];
  1449. auto & right_sym = symbols[bigram.right];
  1450. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1451. // if one of the symbols already got merged, skip it.
  1452. if (
  1453. left_sym.n == 0
  1454. || right_sym.n == 0
  1455. || left_sym.n + right_sym.n != bigram.size
  1456. ) continue;
  1457. // merge the right sym into the left one
  1458. left_sym.n += right_sym.n;
  1459. left_sym.id = bigram.id;
  1460. right_sym.n = 0;
  1461. // remove the right sym from the chain
  1462. left_sym.next = right_sym.next;
  1463. if (right_sym.next >= 0) {
  1464. symbols[right_sym.next].prev = bigram.left;
  1465. }
  1466. // find more substitutions
  1467. try_add_bigram(left_sym.prev, bigram.left);
  1468. try_add_bigram(bigram.left, left_sym.next);
  1469. }
  1470. llama_vocab::id* out = (llama_vocab::id*)output->data;
  1471. int out_step = sizeof(llama_vocab::id) / output->nb[0];
  1472. int num_tokens = 0;
  1473. for (int i = 0; i > -1; i = symbols[i].next) {
  1474. llm_symbol& symbol = symbols[i];
  1475. *(out + num_tokens * out_step) = symbol.id;
  1476. num_tokens += 1;
  1477. }
  1478. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1479. num_tokens += 1;
  1480. output->ne[0] = num_tokens;
  1481. }
  1482. private:
  1483. void try_add_bigram(int left, int right) {
  1484. if (left == -1 || right == -1) {
  1485. return;
  1486. }
  1487. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1488. auto token = vocab.token_to_id.find(text);
  1489. if (token == vocab.token_to_id.end()) {
  1490. return;
  1491. }
  1492. llama_vocab::id id = token->second;
  1493. if (static_cast<std::size_t>(id) >= vocab.id_to_token.size()) {
  1494. return;
  1495. }
  1496. const auto& tok_data = vocab.id_to_token[id];
  1497. llm_bigram_spm bigram = {
  1498. /*left */ left,
  1499. /*right*/ right,
  1500. /*score*/ tok_data.score,
  1501. /*size */ text.size(),
  1502. /*id */ id
  1503. };
  1504. work_queue.push(bigram);
  1505. }
  1506. const llama_vocab& vocab;
  1507. std::vector<llm_symbol> symbols;
  1508. llm_bigram_spm::queue work_queue;
  1509. };
  1510. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
  1511. llm_tokenizer_spm spm = {model->vocab};
  1512. spm.tokenize(std::string(text), out);
  1513. }
  1514. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1515. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1516. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1517. int sent_len = tokens->ne[0];
  1518. std::size_t written = 0;
  1519. std::vector<llama_vocab::token_data> id_to_token = no_tgt_vocab ? model->vocab.id_to_token : model->tgt_vocab.id_to_token;
  1520. for (int i = 0; i < sent_len; ++i) {
  1521. int id = ggml_get_i32_1d(tokens, i);
  1522. // Don't print the EOS token but only if it appear at the end.
  1523. if (i == sent_len - 1 && eos_idx == id) break;
  1524. std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
  1525. // Skip the first space outputted.
  1526. auto begin = token.begin();
  1527. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1528. std::copy(begin, token.end(), out);
  1529. std::size_t n = token.end() - begin;
  1530. written += n;
  1531. out += n;
  1532. }
  1533. *out = '0';
  1534. return written;
  1535. }
  1536. // TODO: Unify with the above?
  1537. std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(
  1538. fairseq2_model* model,
  1539. ggml_tensor* tokens,
  1540. ggml_tensor* scores,
  1541. char* out) {
  1542. bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
  1543. int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
  1544. int sent_len = tokens->ne[0];
  1545. std::size_t written = 0;
  1546. std::vector<float> word_scores;
  1547. std::vector<float> subword_scores;
  1548. std::vector<std::string> result_text;
  1549. std::string curr_token = "";
  1550. for (int i = 0; i < sent_len; ++i) {
  1551. int id = ggml_get_i32_1d(tokens, i);
  1552. // Don't print the EOS token but only if it appear at the end.
  1553. if (i == sent_len - 1 && eos_idx == id) break;
  1554. std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
  1555. float score = ggml_get_f32_1d(scores, i+2); // 2 is prefix size
  1556. if(token[0] == ' ') {
  1557. // reset word score
  1558. if(subword_scores.size() > 0) {
  1559. float avg = std::accumulate(subword_scores.begin(), subword_scores.end(), 0.0f) / subword_scores.size();
  1560. word_scores.push_back(avg);
  1561. subword_scores.clear();
  1562. result_text.push_back(curr_token);
  1563. }
  1564. curr_token = token.substr(1);
  1565. } else {
  1566. curr_token += token;
  1567. }
  1568. subword_scores.push_back(score);
  1569. // Skip the first space outputted.
  1570. auto begin = token.begin();
  1571. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1572. std::copy(begin, token.end(), out);
  1573. std::size_t n = token.end() - begin;
  1574. written += n;
  1575. out += n;
  1576. }
  1577. if(subword_scores.size() > 0) {
  1578. word_scores.push_back(*std::min_element(subword_scores.begin(), subword_scores.end()));
  1579. subword_scores.clear();
  1580. result_text.push_back(curr_token);
  1581. }
  1582. *out = '0';
  1583. return std::make_pair(result_text, word_scores);
  1584. }