fairseq2.cpp 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706
  1. #include <algorithm>
  2. #include <fnmatch.h>
  3. #include <iostream>
  4. #include <math.h>
  5. #include <queue>
  6. #include <unordered_map>
  7. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  8. #include "kaldi-native-fbank/csrc/feature-window.h"
  9. #include "fairseq2.h"
  10. #include "ggml.h"
  11. ggml_tensor* ggml_detach(ggml_tensor* a) {
  12. a->op = GGML_OP_NONE;
  13. std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
  14. return a;
  15. }
  16. #define DEBUG_MEM_USAGE 0
  17. void printf_mem_usage(ggml_context* ctx, std::string name) {
  18. #if DEBUG_MEM_USAGE
  19. double mb = 1024.0 * 1024.0;
  20. printf(
  21. "ctx %s: memory used = %8.2f MB, memory reserved = %8.2f Mb\n",
  22. name.c_str(),
  23. ggml_used_mem(ctx) / mb,
  24. ggml_get_mem_size(ctx) / mb
  25. );
  26. #endif
  27. }
  28. #define SWAP(x, y) \
  29. auto tmp_ ## x = x; x = y; y = tmp_ ## x;
  30. /// allocate the fairseq2 model and hyperparameters
  31. extern "C" fairseq2_model* fairseq2_model_alloc() {
  32. // pre-allocate some memory to write hyperparameters and tensors pointers
  33. auto* model = new fairseq2_model;
  34. model->tensors_ctx = nullptr;
  35. return model;
  36. }
  37. extern "C" void fairseq2_kv_cache_alloc(const fairseq2_model& model, int beam_size, int max_seq_len) {
  38. // Note: we only allocate the cache for the decoder attention.
  39. // For encoder attention since we compute it all at once,
  40. // the allocation is delayed to the first forward pass, to not over allocate.
  41. auto attn_glob = "text_decoder.*_attn.k_proj.weight";
  42. auto self_attn_glob = "text_decoder.*self_attn.k_proj.weight";
  43. ggml_tensor* self_attn_mask = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, max_seq_len, max_seq_len);
  44. self_attn_mask = ggml_diag_mask_inf_inplace(model.ctx, self_attn_mask, 0);
  45. ggml_format_name(self_attn_mask, "self_attn_mask[%d]", max_seq_len);
  46. for (auto named_tensor : model.tensors) {
  47. const std::string& name = named_tensor.first;
  48. if (::fnmatch(attn_glob, name.c_str(), 0) == FNM_NOMATCH)
  49. continue;
  50. // create a cache entry without the ".k_proj.weight" suffix
  51. const std::string& shortname = name.substr(0, name.size() - 14);
  52. KeyValueTensor& kv = model.kv_cache[shortname];
  53. kv.step_nr = 0;
  54. if (::fnmatch(self_attn_glob, name.c_str(), 0) == FNM_NOMATCH) {
  55. // enc_dec_attn
  56. // the tensors will be allocated during the first forward
  57. continue;
  58. }
  59. // self_attn
  60. ggml_tensor* k_proj = named_tensor.second;
  61. int model_dim = k_proj->ne[0];
  62. kv.full_k = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
  63. kv.full_v = ggml_new_tensor_3d(model.ctx, k_proj->type, model_dim, max_seq_len, beam_size);
  64. kv.self_attn_mask = self_attn_mask;
  65. ggml_format_name(kv.full_k, "%s.k_cache", shortname.c_str());
  66. ggml_format_name(kv.full_v, "%s.v_cache", shortname.c_str());
  67. }
  68. }
  69. extern "C" void fairseq2_kv_cache_reset(const fairseq2_model& model) {
  70. // TODO: use a dedicated allocator, so that kv_cache.clear actually frees the memory
  71. model.kv_cache.clear();
  72. }
  73. bool has_kv_cache(const fairseq2_model& model) {
  74. return model.kv_cache.size() > 0;
  75. }
  76. // copy k and v to kv cache
  77. // kv.full_k[step_nr] = k;
  78. // kv.full_v[step_nr] = v;
  79. void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
  80. KeyValueTensor& kv = model.kv_cache[prefix];
  81. GGML_ASSERT(kv.full_k != nullptr); // key not found !
  82. int step_nr = kv.step_nr;
  83. ggml_tensor* full_k = kv.full_k;
  84. ggml_tensor* full_v = kv.full_v;
  85. // (N, S_kv, K_proj)
  86. GGML_ASSERT((*k)->ne[1] == 1); // TODO I think we could handle adding a full prefix sequence
  87. ggml_tensor* updated_k = ggml_set_2d_inplace(model.ctx, full_k, *k, full_k->nb[2], full_k->nb[1] * step_nr);
  88. ggml_tensor* updated_v = ggml_set_2d_inplace(model.ctx, full_v, *v, full_v->nb[2], full_v->nb[1] * step_nr);
  89. *k = ggml_slice(model.ctx, updated_k, 1, 0, step_nr + 1);
  90. *v = ggml_slice(model.ctx, updated_v, 1, 0, step_nr + 1);
  91. ggml_format_name(*k, "%s (step=%d)", full_k->name, step_nr);
  92. ggml_format_name(*v, "%s (step=%d)", full_v->name, step_nr);
  93. // qk is (B * H, Sq, Sk) == (B*H, 1, Sk) in incremental mode
  94. // we return the Sq slice of the (Sq, Sk) attention mask
  95. *self_attn_mask = ggml_slice(
  96. model.ctx,
  97. ggml_slice(model.ctx, kv.self_attn_mask, 0, 0, step_nr + 1),
  98. 1,
  99. step_nr,
  100. step_nr + 1
  101. );
  102. kv.step_nr = step_nr + 1;
  103. }
  104. // variant of ggml_get_rows that allows for a with more than 2 dims.
  105. ggml_tensor* ggml_get_rows2(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  106. int flattened = 0;
  107. GGML_ASSERT(a->n_dims <= 3);
  108. if (a->n_dims == 3) {
  109. flattened = a->ne[0];
  110. a = ggml_flatten_1d(ctx, a, 0);
  111. }
  112. a = ggml_get_rows(ctx, a, b);
  113. if (flattened) {
  114. a = ggml_unflatten_1d(ctx, a, 0, flattened);
  115. }
  116. return a;
  117. }
  118. void _reorder_kv_cache(ggml_context* ctx, ggml_cgraph* gf, KeyValueTensor& kv, ggml_tensor* new_order) {
  119. if (kv.full_k != nullptr) {
  120. ggml_detach(kv.full_k);
  121. kv.full_k = ggml_get_rows2(ctx, kv.full_k, new_order);
  122. ggml_build_forward_expand(gf, kv.full_k);
  123. }
  124. if (kv.full_v != nullptr) {
  125. ggml_detach(kv.full_v);
  126. kv.full_v = ggml_get_rows2(ctx, kv.full_v, new_order);
  127. ggml_build_forward_expand(gf, kv.full_v);
  128. }
  129. }
  130. void reorder_kv_cache(const fairseq2_model& model, ggml_context* ctx, ggml_cgraph* gf, ggml_tensor* new_order) {
  131. for (auto& named_kv : model.kv_cache) {
  132. _reorder_kv_cache(ctx, gf, named_kv.second, new_order);
  133. }
  134. }
  135. inline double model_layer_config_d(const fairseq2_model& model, std::string name) {
  136. const std::int64_t* data = &model.layer_config.at(name);
  137. double val = *(const double*)data;
  138. return val;
  139. }
  140. extern "C" double fairseq2_model_layer_config_double(const fairseq2_model& model, const char* name) {
  141. return model_layer_config_d(model, std::string(name));
  142. }
  143. extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& model, const char* name) {
  144. return model.layer_config.at(std::string(name));
  145. }
  146. extern "C" void fairseq2_model_free(fairseq2_model* model) {
  147. if (model->tensors_ctx) ggml_free(model->tensors_ctx);
  148. delete model;
  149. }
  150. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
  151. model->ctx = ctx;
  152. }
  153. extern "C" std::string* std_string_alloc(char* c_str) {
  154. return new std::string(c_str);
  155. }
  156. extern "C" void std_string_free(std::string* str) {
  157. delete str;
  158. }
  159. bool has_layer(fairseq2_model& model, const std::string& name) {
  160. return model.tensors.find(name) != model.tensors.end();
  161. }
  162. ggml_tensor* mul_mat(ggml_context* ctx, ggml_tensor* a, ggml_tensor* b) {
  163. if (b->ne[1] == 1 && b->ne[2] > 1 && a->n_dims == 2) {
  164. // `b` has shape (B, 1, D).
  165. // if `a` is (D_out, D), then we do one matmul for the full batch.
  166. b = ggml_flatten_1d(ctx, b, 1);
  167. return ggml_unflatten_1d(ctx, ggml_mul_mat(ctx, a, b), 1, 1);
  168. }
  169. // there is also the k * q matmul -> (D, 1, B) * (D, 1, B) -> (1, 1, B)
  170. // not sure what's the best way to compute this with BLAS
  171. return ggml_mul_mat(ctx, a, b); // (d_out)
  172. }
  173. extern "C" ggml_tensor* Linear_forward(
  174. fairseq2_model& model,
  175. const std::string &prefix,
  176. ggml_tensor* input // (d_in)
  177. ) {
  178. // Note: for now we assumed un-batched input
  179. ggml_tensor* weight = model.tensors[prefix + ".weight"]; // (d_in, d_out)
  180. GGML_ASSERT(weight != nullptr);
  181. ggml_tensor* out = mul_mat(model.ctx, weight, input); // (d_out)
  182. ggml_tensor* bias = model.tensors[prefix + ".bias"]; // (d_out)
  183. if (bias == nullptr) return out;
  184. return ggml_add_inplace(model.ctx, out, bias);
  185. }
  186. extern "C" ggml_tensor* LayerNorm_forward(
  187. fairseq2_model& model,
  188. const std::string &prefix,
  189. ggml_tensor* input
  190. ) {
  191. ggml_tensor* weight = model.tensors[prefix + ".weight"];
  192. GGML_ASSERT(weight != nullptr);
  193. ggml_tensor* bias = model.tensors[prefix + ".bias"];
  194. GGML_ASSERT(bias != nullptr);
  195. auto ctx = model.ctx;
  196. double eps = model_layer_config_d(model, prefix + ".eps");
  197. input = ggml_norm(ctx, input, /*eps*/eps);
  198. return ggml_add_inplace(
  199. ctx,
  200. ggml_mul_inplace(ctx, ggml_repeat(ctx, weight, input), input),
  201. ggml_repeat(ctx, bias, input)
  202. );
  203. }
  204. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  205. fairseq2_model& model,
  206. const std::string& prefix,
  207. ggml_tensor* seqs
  208. ) {
  209. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  210. // inner_activation = ReLu // TODO: allow other activation
  211. seqs = ggml_relu_inplace(model.ctx, seqs);
  212. if (has_layer(model, prefix + ".inner_layer_norm")) {
  213. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  214. }
  215. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  216. return seqs;
  217. }
  218. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  219. fairseq2_model& model,
  220. const std::string& prefix,
  221. ggml_tensor* seqs
  222. ) {
  223. seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
  224. seqs = ggml_silu(model.ctx, seqs);
  225. if (has_layer(model, prefix + ".inner_layer_norm")) {
  226. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  227. }
  228. seqs = Linear_forward(model, prefix + ".output_proj", seqs);
  229. return seqs;
  230. }
  231. ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim) {
  232. int n_dims = x->n_dims;
  233. GGML_ASSERT(dim >= 0);
  234. GGML_ASSERT(dim < n_dims);
  235. GGML_ASSERT(ggml_is_contiguous(x));
  236. // Nothing to do
  237. if (dim == n_dims - 1) return x;
  238. if (n_dims == 2) {
  239. return ggml_reshape_1d(ctx, x, x->ne[0] * x->ne[1]);
  240. } else if (n_dims == 3) {
  241. if (dim == 0) {
  242. return ggml_reshape_2d(ctx, x, x->ne[0] * x->ne[1], x->ne[2]);
  243. } else { // dim == 1
  244. return ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2]);
  245. }
  246. } else { // n_dims == 4
  247. if (dim == 0) {
  248. return ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
  249. } else if (dim == 1) {
  250. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
  251. } else { // dim == 2
  252. return ggml_reshape_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2] * x->ne[3]);
  253. }
  254. }
  255. }
  256. ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el) {
  257. int n_dims = x->n_dims;
  258. GGML_ASSERT(dim >= 0);
  259. GGML_ASSERT(dim < n_dims);
  260. GGML_ASSERT(n_dims < 4);
  261. GGML_ASSERT(x->ne[dim] % num_el == 0);
  262. GGML_ASSERT(x->nb[dim + 1] == x->nb[dim] * x->ne[dim]); // `x` isn't contiguous along `dim`
  263. if (n_dims == 1) {
  264. return ggml_view_2d(ctx, x, num_el, x->ne[0] / num_el, x->nb[0] * num_el, 0);
  265. } else if (n_dims == 2) {
  266. if (dim == 0) {
  267. return ggml_view_3d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->nb[0] * num_el, x->nb[1], 0);
  268. } else { // dim == 1
  269. return ggml_view_3d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->nb[1], num_el * x->nb[1], 0);
  270. }
  271. } else { // (n_dims == 3)
  272. if (dim == 0) {
  273. return ggml_view_4d(ctx, x, num_el, x->ne[0] / num_el, x->ne[1], x->ne[2], x->nb[0] * num_el, x->nb[1], x->nb[2], 0);
  274. } else if (dim == 1) {
  275. return ggml_view_4d(ctx, x, x->ne[0], num_el, x->ne[1] / num_el, x->ne[2], x->nb[1], num_el * x->nb[1], x->nb[2], 0);
  276. } else { // dim == 2
  277. return ggml_view_4d(ctx, x, x->ne[0], x->ne[1], num_el, x->ne[2] / num_el, x->nb[1], x->nb[2], num_el * x->nb[2], 0);
  278. }
  279. }
  280. }
  281. ggml_tensor* _reshape_num_head(ggml_context* ctx, ggml_tensor* x, int head_dim) {
  282. // (B, S, dim) -> (B, S, H, H_dim)
  283. x = ggml_unflatten_1d(ctx, x, 0, head_dim);
  284. x = ggml_permute(ctx, x, 0, 2, 1, 3); // (B, H, S, H_dim)
  285. x = ggml_cont(ctx, x);
  286. x = ggml_flatten_1d(ctx, x, 2); // (B * H, S, H_dim)
  287. return x;
  288. }
  289. /// (B, Sk, dim) -> // (B?, H, H_dim, Sk)
  290. ggml_tensor* _reshape_num_head_values(ggml_context* ctx, ggml_tensor* v, int head_dim ) {
  291. // (B, Sk, dim) -> (B, Sk, H, H_dim)
  292. v = ggml_unflatten_1d(ctx, v, 0, head_dim);
  293. v = ggml_permute(ctx, v, 1, 2, 0, 3); // (B?, H, H_dim, Sk)
  294. v = ggml_cont(ctx, v);
  295. v = ggml_flatten_1d(ctx, v, 2); // (B * H, S, H_dim)
  296. return v;
  297. }
  298. // flash_attn doesn't work for cross attention because it assumes Q <= K
  299. // and it seems to yield slightly different scores than expected, and thus a different beam search
  300. # define UNITY_FLASH_ATTN 0
  301. extern "C" ggml_tensor* MultiheadAttention_forward(
  302. fairseq2_model& model,
  303. const std::string &prefix,
  304. ggml_tensor* queries, // (slen, d_in)
  305. ggml_tensor* keys, // (klen, d_in)
  306. ggml_tensor* values, // (klen, d_out)
  307. ggml_tensor* attn_mask // (klen, slen)
  308. ) {
  309. int model_dim = queries->ne[0];
  310. int num_heads = model.layer_config.at(prefix + ".num_heads");
  311. int head_dim = model_dim / num_heads;
  312. GGML_ASSERT(model_dim % num_heads == 0);
  313. ggml_context* ctx = model.ctx;
  314. ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries); // (B, S, H * H_dim)
  315. ggml_set_name(q, "q");
  316. q = _reshape_num_head(ctx, q, head_dim); // (B * H, S, H_dim)
  317. ggml_tensor *k, *v;
  318. if (!has_kv_cache(model)) {
  319. k = Linear_forward(model, prefix + ".k_proj", keys);
  320. ggml_set_name(k, "k");
  321. v = Linear_forward(model, prefix + ".v_proj", values);
  322. ggml_set_name(v, "v");
  323. } else {
  324. bool encoder_decoder_attn = keys == values && keys != queries;
  325. if (encoder_decoder_attn) {
  326. // The K and V tensors of an encoder-decoder attention (i.e. the
  327. // projected encoder outputs) remain static during evaluation.
  328. KeyValueTensor& kv_cache = model.kv_cache[prefix];
  329. if (kv_cache.step_nr == 0) {
  330. k = Linear_forward(model, prefix + ".k_proj", keys);
  331. v = Linear_forward(model, prefix + ".v_proj", values);
  332. // TODO: encoder_padding_mask
  333. // Note we are only storing a pointer to the buffer, not the full graph
  334. kv_cache.full_k = ggml_detach(ggml_dup_inplace(ctx, k));
  335. ggml_format_name(kv_cache.full_k, "%s.k_cache", prefix.c_str());
  336. kv_cache.full_v = ggml_detach(ggml_dup_inplace(ctx, v));
  337. ggml_format_name(kv_cache.full_v, "%s.v_cache", prefix.c_str());
  338. kv_cache.step_nr = keys->ne[1];
  339. } else {
  340. k = kv_cache.full_k;
  341. v = kv_cache.full_v;
  342. // This is a cache collision. TODO: fairseq2_kv_cache_reset
  343. GGML_ASSERT(keys->ne[1] == k->ne[1]);
  344. GGML_ASSERT(values->ne[1] == v->ne[1]);
  345. }
  346. } else { // self attention
  347. // (1, K) -> (N, 1, K_proj)
  348. k = Linear_forward(model, prefix + ".k_proj", keys);
  349. ggml_set_name(k, "k");
  350. // (1, V) -> (N, 1, V_proj)
  351. v = Linear_forward(model, prefix + ".v_proj", values);
  352. ggml_set_name(v, "v");
  353. append_to_prev_kv(model, prefix, &k, &v, &attn_mask);
  354. }
  355. }
  356. k = _reshape_num_head(ctx, k, head_dim); // (B * H, Sk, H_dim)
  357. v = _reshape_num_head_values(ctx, v, head_dim); // (B * H, H_dim, Sk)
  358. v = ggml_cont(ctx, v);
  359. #if UNITY_FLASH_ATTN
  360. // For flash_attn, we assume either no masks, or triangular masks.
  361. ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/attn_mask != nullptr); // (B * H, S, H_dim)
  362. ggml_set_name(attn, "attn");
  363. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  364. attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (B, S, H, H_dim)
  365. #else
  366. // (B * H, Sk, H_dim) x (B * H, S, H_dim) -> (B * H, S, Sk)
  367. ggml_tensor* qk = mul_mat(ctx, k, q);
  368. ggml_set_name(qk, "qk");
  369. ggml_tensor* qk_scale = ggml_new_tensor_1d(ctx, qk->type, 1);
  370. ggml_set_f32(qk_scale, 1.0f/sqrtf(float(head_dim)));
  371. qk = ggml_scale_inplace(ctx, qk, qk_scale);
  372. ggml_set_name(qk, "qk_scaled");
  373. // TODO: Should we replace this by ggml_diag_mask_inf ?
  374. if (attn_mask) qk = ggml_add_inplace(ctx, qk, attn_mask);
  375. // TODO: upgrade qk to float32 if needed
  376. ggml_tensor* attn_weights = ggml_soft_max(ctx, qk); // (B * H, S, Sk)
  377. ggml_set_name(attn_weights, "attn_weights");
  378. // (B * H, S, Sk) x (B * H, H_dim, Sk) -> (B * H, H_dim, S)
  379. ggml_tensor* attn = mul_mat(ctx, attn_weights, v);
  380. ggml_set_name(attn, "attn");
  381. attn = ggml_unflatten_1d(ctx, attn, 2, num_heads); // (B, H, H_dim, S)
  382. attn = ggml_permute(ctx, attn, 2, 0, 1, 3); // (B, S, H, H_dim)
  383. #endif // UNITY_FLASH_ATTN
  384. attn = ggml_cont(ctx, attn);
  385. attn = ggml_flatten_1d(ctx, attn, 0); // (B, S, H * H_dim)
  386. // out -> (B, S, d_out)
  387. ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
  388. ggml_set_name(out, "out");
  389. return out;
  390. }
  391. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  392. fairseq2_model& model,
  393. const std::string& prefix,
  394. ggml_tensor* seqs,
  395. ggml_tensor* padding_mask
  396. ) {
  397. ggml_context* ctx = model.ctx;
  398. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  399. // _forward_self_attn(seqs, padding_mask)
  400. auto residual = seqs;
  401. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  402. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  403. // TODO: add padding_mask to MultiheadAttention_forward
  404. GGML_ASSERT(padding_mask == nullptr);
  405. seqs = MultiheadAttention_forward(
  406. model,
  407. prefix + ".self_attn",
  408. seqs,
  409. seqs,
  410. seqs,
  411. /*attn_mask=*/nullptr
  412. );
  413. if (has_layer(model, prefix + ".self_attn_norm"))
  414. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  415. seqs = ggml_add_inplace(ctx, seqs, residual);
  416. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  417. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  418. // _forward_ffn(seqs)
  419. residual = seqs;
  420. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  421. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  422. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  423. // TODO: if self.residual_scale is not None:
  424. // residual = self.residual_scale * residual
  425. seqs = ggml_add_inplace(ctx, seqs, residual);
  426. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  427. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  428. return seqs;
  429. }
  430. extern "C" ggml_tensor* WaveformToFbank_forward(
  431. fairseq2_model& model,
  432. const std::string &prefix,
  433. ggml_tensor* waveform
  434. ) {
  435. // Hardcoding: num_bins 80, sample rate 16k, always standardize
  436. ggml_context* ctx = model.ctx;
  437. knf::MelBanksOptions mel_opts{};
  438. mel_opts.num_bins = 80;
  439. knf::FrameExtractionOptions frame_opts{};
  440. frame_opts.samp_freq = 16000;
  441. knf::FbankOptions opts{};
  442. opts.frame_opts = frame_opts;
  443. opts.mel_opts = mel_opts;
  444. std::vector<float_t> signal_frame{};
  445. std::int32_t num_frames = knf::NumFrames(/*num_samples=*/waveform->ne[0], frame_opts);
  446. ggml_tensor* output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 80, num_frames);
  447. knf::FbankComputer native_(opts);
  448. knf::FeatureWindowFunction window_fn_(native_.GetFrameOptions());
  449. for (std::int32_t frame_nr = 0; frame_nr < num_frames; ++frame_nr) {
  450. signal_frame.resize(0);
  451. // Extract the frame from the waveform tensor.
  452. knf::ExtractWindow(
  453. /*sample_offset=*/0,
  454. (float *)(waveform->data),
  455. waveform->ne[0],
  456. frame_nr,
  457. frame_opts,
  458. window_fn_,
  459. &signal_frame);
  460. native_.Compute(
  461. /*signal_raw_log_energy=*/0, /*vtln_warp=*/1.0, &signal_frame, ((float *)(output->data) + frame_nr * 80));
  462. }
  463. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  464. output = ggml_norm(ctx, output, 1e-5);
  465. output = ggml_dup(ctx, ggml_transpose(ctx, output));
  466. if (output->ne[1] % 2 == 1) {
  467. ggml_tensor* remove_last = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, output->ne[1]-1);
  468. for (int i = 0; i < output->ne[1]-1; ++i) {
  469. ((int32_t *) remove_last->data)[i] = i;
  470. }
  471. output = ggml_get_rows(ctx, output, remove_last);
  472. }
  473. output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
  474. return output;
  475. }
  476. // TODO: Check if it's possible to merge with standard MHA
  477. extern "C" ggml_tensor* RelativePositionMHA_forward(
  478. fairseq2_model& model,
  479. const std::string& prefix,
  480. ggml_tensor* seqs
  481. ) {
  482. ggml_context* ctx = model.ctx;
  483. ggml_tensor* residual = seqs;
  484. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  485. // self_attn: qkv
  486. ggml_tensor* Qcur = Linear_forward(model, prefix + ".q_proj", seqs);
  487. ggml_tensor* Kcur = Linear_forward(model, prefix + ".k_proj", seqs);
  488. ggml_tensor* Vcur = Linear_forward(model, prefix + ".v_proj", seqs);
  489. // self_attn: rel_pos SDPA
  490. int32_t S = seqs->ne[1];
  491. int32_t H = 16; // TODO: Make this configurable
  492. int32_t n_ctx = 4096;
  493. int32_t K_h = seqs->ne[0] / H;
  494. int32_t start_index = n_ctx - S;
  495. int32_t end_index = n_ctx + S - 1;
  496. int num_indices = end_index - start_index;
  497. ggml_tensor* rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
  498. for (int i = 0; i < num_indices; i++) {
  499. ((int32_t *)rows->data)[i] = start_index + i;
  500. }
  501. // self_attn: load pos_enc weights & compute_r
  502. // In fairseq2 pos_enc weights are calculated on the fly, since some more custom operators might be needed to enable this,
  503. // we store the results (fixed) in checkpoint as model.audio_enc_pos_enc_w and load directly.
  504. ggml_tensor* r = ggml_get_rows(ctx, model.tensors["speech_encoder.pos_enc"], rows);
  505. r = mul_mat(ctx, model.tensors[prefix + ".sdpa.r_proj.weight"], r);
  506. r = ggml_dup(ctx, ggml_permute(ctx,
  507. ggml_cpy(ctx,
  508. r,
  509. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S*2-1)),
  510. 0, 2, 1, 3));
  511. ggml_tensor* u_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.u_bias"], K_h, 1, H);
  512. ggml_tensor* v_bias = ggml_reshape_3d(ctx, model.tensors[prefix + ".sdpa.v_bias"], K_h, 1, H);
  513. // self_attn: Permute QKV
  514. ggml_tensor* Q = ggml_cont(ctx, ggml_permute(ctx,
  515. ggml_cpy(ctx,
  516. Qcur,
  517. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  518. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  519. ggml_tensor* K = ggml_cont(ctx, ggml_permute(ctx,
  520. ggml_cpy(ctx,
  521. Kcur,
  522. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  523. 0, 2, 1, 3)); // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  524. ggml_tensor* V = ggml_cont(ctx, ggml_permute(ctx,
  525. ggml_cpy(ctx,
  526. Vcur,
  527. ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K_h, H, S)),
  528. 1, 2, 0, 3)); // (H * K_h, S) -> (K_h, H, S) -> (H, S, K_h)
  529. ggml_tensor* q_with_u_bias = ggml_add_inplace(ctx, ggml_dup(ctx, Q), u_bias); // (K_h, S, H)
  530. ggml_tensor* q_with_v_bias = ggml_add_inplace(ctx, Q, v_bias); // (K_h, S, H)
  531. ggml_tensor* ac = mul_mat(ctx, K, q_with_u_bias);
  532. ggml_tensor* bd = mul_mat(ctx, r, q_with_v_bias);
  533. // self_attn: shift_bd. Logic follows https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L161
  534. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // H, S, 2S-1
  535. ggml_tensor* pad = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, S, 1);
  536. pad = ggml_set_f32(pad, 0.0);
  537. bd = ggml_concat(ctx, pad, bd); // bd[i][j][0] == 0, (H, S, 2S)
  538. bd = ggml_dup(ctx, ggml_permute(ctx, bd, 2, 1, 0, 3)); // (2S, S, H)
  539. bd = ggml_reshape_3d(ctx, bd, S, 2 * S, H); // (S, 2S, H)
  540. // discard the first set of positive positions
  541. bd = ggml_dup(ctx, ggml_slice(ctx, bd, 1, 1, 2 * S));
  542. // shifts each row by an extra step
  543. bd = ggml_reshape_3d(ctx, bd, 2 * S - 1, S, H);
  544. // Discard positions used for shift.
  545. bd = ggml_slice(ctx, bd, 0, 0, S);
  546. // self_attn: compute attn / weights
  547. ggml_tensor* attn_weights = ggml_add_inplace(ctx, ac, bd);
  548. ggml_tensor* attn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  549. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  550. attn_weights = ggml_mul_inplace(ctx, attn_weights, ggml_repeat(ctx, attn_scale, attn_weights));
  551. attn_weights = ggml_soft_max(ctx, attn_weights);
  552. ggml_tensor* attn = mul_mat(ctx, V, attn_weights); // K_h, S, H
  553. attn = ggml_dup(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3));
  554. ggml_tensor* attn_2d = ggml_reshape_2d(ctx, attn, K_h * H, S);
  555. ggml_tensor* attn_out = mul_mat(ctx, model.tensors[prefix + ".output_proj.weight"], attn_2d);
  556. attn_out = ggml_add_inplace(
  557. ctx,
  558. attn_out,
  559. ggml_repeat(ctx, model.tensors[prefix + ".output_proj.bias"], attn_out)
  560. );
  561. attn_out = ggml_add_inplace(ctx, attn_out, residual);
  562. return attn_out;
  563. }
  564. extern "C" ggml_tensor* ConvModule_forward(
  565. fairseq2_model& model,
  566. const std::string& prefix,
  567. ggml_tensor* seqs
  568. ) {
  569. ggml_context* ctx = model.ctx;
  570. ggml_tensor* residual = seqs;
  571. seqs = LayerNorm_forward(model, prefix + "_layer_norm", seqs);
  572. // conv: Use matmul for pointwise conv 1 - kernel_size=1, no padding case
  573. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv1.weight"], seqs);
  574. // conv: GLU
  575. seqs = ggml_glu(ctx, seqs);
  576. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  577. // S x C -> (S+K-1) x C -> K x S x C -> S x C
  578. seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, 15, 1);
  579. // conv: Custom implementation of batch norm
  580. seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
  581. // conv: SiLU actvation
  582. seqs = ggml_silu_inplace(ctx, seqs);
  583. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  584. // conv: Use matmul for pointwise conv 2 - kernel_size=1, no padding case
  585. seqs = mul_mat(ctx, model.tensors[prefix + ".pointwise_conv2.weight"], seqs);
  586. // conv: + residual
  587. seqs = ggml_add_inplace(ctx, seqs, residual);
  588. return seqs;
  589. }
  590. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  591. fairseq2_model& model,
  592. const std::string& prefix,
  593. ggml_tensor* seqs,
  594. ggml_tensor* padding_mask
  595. ) {
  596. ggml_context* ctx = model.ctx;
  597. ggml_tensor* ffn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  598. ggml_set_f32(ffn_scale, 0.5f);
  599. ggml_tensor* residual = seqs;
  600. seqs = LayerNorm_forward(model, prefix + ".ffn1_layer_norm", seqs);
  601. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn1", seqs);
  602. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  603. seqs = ggml_add_inplace(ctx, seqs, residual);
  604. seqs = RelativePositionMHA_forward(model, prefix + ".self_attn", seqs);
  605. seqs = ConvModule_forward(model, prefix + ".conv", seqs);
  606. residual = seqs;
  607. seqs = LayerNorm_forward(model, prefix + ".ffn2_layer_norm", seqs);
  608. seqs = SiluFeedForwardNetwork_forward(model, prefix + ".ffn2", seqs);
  609. seqs = ggml_mul_inplace(ctx, seqs, ggml_repeat(ctx, ffn_scale, seqs));
  610. seqs = ggml_add_inplace(ctx, seqs, residual);
  611. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  612. return seqs;
  613. }
  614. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  615. fairseq2_model& model,
  616. const std::string& prefix,
  617. ggml_tensor* seqs,
  618. ggml_tensor* padding_mask
  619. ) {
  620. ggml_context* ctx = model.ctx;
  621. seqs = WaveformToFbank_forward(model, prefix, seqs);
  622. seqs = LayerNorm_forward(model, prefix + "_frontend.post_extract_layer_norm", seqs);
  623. seqs = Linear_forward(model, prefix + "_frontend.model_dim_proj", seqs);
  624. int layer_idx = 0;
  625. std::string layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  626. while (has_layer(model, layer_name)) {
  627. seqs = StandardConformerEncoderLayer_forward(
  628. model, layer_name, seqs, padding_mask
  629. );
  630. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  631. layer_idx += 1;
  632. layer_name = prefix + ".inner.layers." + std::to_string(layer_idx);
  633. }
  634. seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
  635. ggml_tensor* residual = seqs;
  636. seqs = Linear_forward(model, prefix + ".proj1", seqs);
  637. seqs = ggml_relu_inplace(ctx, seqs);
  638. seqs = Linear_forward(model, prefix + ".proj2", seqs);
  639. ggml_tensor* ffn_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, 1);
  640. ggml_set_f32(ffn_scale, 0.5f);
  641. seqs = ggml_mul(ctx, ggml_repeat(ctx, ffn_scale, seqs), seqs);
  642. seqs = ggml_add_inplace(ctx, seqs, residual);
  643. layer_idx = 0;
  644. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  645. while (has_layer(model, layer_name)) {
  646. seqs = StandardConformerEncoderAdaptorLayer_forward(
  647. model, layer_name, seqs, padding_mask
  648. );
  649. ggml_set_name(seqs, ("x_ada_" + std::to_string(layer_idx)).c_str());
  650. layer_idx += 1;
  651. layer_name = prefix + ".adaptor_layers." + std::to_string(layer_idx);
  652. }
  653. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  654. return seqs;
  655. }
  656. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  657. fairseq2_model& model,
  658. const std::string& prefix,
  659. ggml_tensor* seqs,
  660. ggml_tensor* padding_mask
  661. ) {
  662. ggml_context* ctx = model.ctx;
  663. ggml_tensor* residual = seqs;
  664. residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
  665. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  666. residual = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1);
  667. residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
  668. residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
  669. residual = ggml_glu(ctx, residual);
  670. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  671. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  672. seqs = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1);
  673. seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
  674. seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
  675. seqs = ggml_glu(ctx, seqs);
  676. seqs = MultiheadAttention_forward(
  677. model,
  678. prefix + ".self_attn",
  679. seqs,
  680. seqs,
  681. seqs,
  682. /*attention masks=*/nullptr
  683. );
  684. seqs = ggml_add_inplace(ctx, seqs, residual);
  685. residual = seqs;
  686. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  687. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  688. seqs = ggml_add_inplace(ctx, seqs, residual);
  689. return seqs;
  690. }
  691. /// ggml_slice(X, -1, start, end) is equivalent to X[start:end]
  692. /// ggml_slice(X, 0, start, end) is equivalent to X[..., start:end]
  693. ggml_tensor* ggml_slice(
  694. struct ggml_context * ctx,
  695. struct ggml_tensor * a,
  696. int axis,
  697. int64_t start,
  698. int64_t end
  699. ) {
  700. int64_t ne[4];
  701. std::copy(a->ne, a->ne + 4, ne);
  702. if (axis < 0) axis = a->n_dims + axis;
  703. if (start < 0) start = ne[axis] + start;
  704. if (end <= 0) end = ne[axis] + end;
  705. GGML_ASSERT(0 <= start);
  706. GGML_ASSERT(start < end);
  707. GGML_ASSERT(end <= ne[axis]);
  708. ne[axis] = end - start;
  709. size_t offset = a->nb[axis] * start;
  710. size_t* nb = a->nb;
  711. ggml_tensor* result = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], ne[3], nb[1], nb[2], nb[3], offset);
  712. ggml_format_name(result, "%s [(%d)%ld:%ld]", a->name, axis, start, end);
  713. result->n_dims = a->n_dims;
  714. return result;
  715. }
  716. ggml_tensor* ggml_select(
  717. struct ggml_context * ctx,
  718. struct ggml_tensor * a,
  719. int axis,
  720. int64_t index
  721. ) {
  722. int64_t ne[GGML_MAX_DIMS];
  723. std::copy(a->ne, a->ne + GGML_MAX_DIMS, ne);
  724. if (axis < 0) axis = a->n_dims + axis;
  725. if (index < 0) index = ne[axis] + index;
  726. GGML_ASSERT(0 <= index);
  727. GGML_ASSERT(index < ne[axis]);
  728. std::copy(a->ne + axis + 1, a->ne + GGML_MAX_DIMS, ne + axis);
  729. size_t offset = a->nb[axis] * index;
  730. size_t* nb = a->nb;
  731. GGML_ASSERT(GGML_MAX_DIMS == 4);
  732. ggml_tensor* result = ggml_view_3d(ctx, a, ne[0], ne[1], ne[2], nb[1], nb[2], offset);
  733. ggml_format_name(result, "%s [(%d)%ld]", a->name, axis, index);
  734. result->n_dims = a->n_dims - 1;
  735. return result;
  736. }
  737. // Inplace computation of PositionalEmbedding
  738. extern "C" ggml_tensor* PositionalEmbedding_forward(
  739. fairseq2_model& model,
  740. const std::string& prefix,
  741. ggml_tensor* embeds
  742. ) {
  743. // This only work with the simple pos encoders
  744. int seq_len = embeds->ne[1];
  745. ggml_tensor* full_pos_embeds = model.tensors[prefix];
  746. int start_step = 0;
  747. if (has_kv_cache(model)) {
  748. start_step = model.kv_cache[prefix].step_nr++;
  749. }
  750. ggml_tensor* pos_embeds = ggml_slice(model.ctx, full_pos_embeds, /*axis*/1, start_step, seq_len + start_step);
  751. return ggml_add(model.ctx, embeds, pos_embeds);
  752. }
  753. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  754. fairseq2_model& model,
  755. const std::string& prefix,
  756. ggml_tensor* seqs
  757. ) {
  758. GGML_ASSERT(seqs->n_dims < GGML_MAX_DIMS);
  759. ggml_context* ctx = model.ctx;
  760. ggml_tensor* embed_weights = model.tensors[prefix + ".embed.weight"];
  761. GGML_ASSERT(embed_weights != nullptr);
  762. ggml_tensor* embeds;
  763. if (seqs->n_dims == 1) {
  764. embeds = ggml_get_rows(ctx, embed_weights, seqs);
  765. } else {
  766. // ggml_get_rows isn't very flexible, we have to handle the reshape ourselves.
  767. ggml_tensor* flat_seqs = seqs;
  768. if (!ggml_is_contiguous(seqs)) {
  769. flat_seqs->type = GGML_TYPE_F32;
  770. flat_seqs = ggml_cont(ctx, flat_seqs);
  771. }
  772. flat_seqs = ggml_reshape_1d(ctx, flat_seqs, ggml_nelements(seqs));
  773. flat_seqs->type = GGML_TYPE_I32;
  774. embeds = ggml_get_rows(ctx, embed_weights, flat_seqs);
  775. embeds = ggml_reshape_4d(ctx, embeds, embed_weights->ne[0], seqs->ne[0], seqs->ne[1], seqs->ne[2]);
  776. embeds->n_dims = seqs->n_dims + 1;
  777. }
  778. // padding mask ?
  779. // padding_mask = to_padding_mask(embeds, seq_lens)
  780. if (has_layer(model, prefix + ".pos_encoder")) {
  781. embeds = PositionalEmbedding_forward(model, prefix + ".pos_encoder", embeds);
  782. }
  783. if (has_layer(model, prefix + ".layer_norm")) {
  784. embeds = LayerNorm_forward(model, prefix + ".layer_norm", embeds);
  785. }
  786. return embeds;
  787. }
  788. extern "C" ggml_tensor* StandardTransformerEncoder_forward(
  789. fairseq2_model& model,
  790. const std::string& prefix,
  791. ggml_tensor* seqs,
  792. ggml_tensor* padding_mask
  793. ) {
  794. int layer_idx = 0;
  795. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  796. while (has_layer(model, layer_name)) {
  797. seqs = StandardTransformerEncoderLayer_forward(
  798. model, layer_name, seqs, padding_mask
  799. );
  800. ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
  801. layer_idx += 1;
  802. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  803. }
  804. if (has_layer(model, prefix + ".layer_norm"))
  805. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  806. return seqs;
  807. }
  808. extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
  809. fairseq2_model& model,
  810. const std::string& prefix,
  811. ggml_tensor* seqs,
  812. ggml_tensor* self_attn_mask,
  813. ggml_tensor* encoder_output,
  814. ggml_tensor* encoder_padding_mask
  815. ) {
  816. ggml_context* ctx = model.ctx;
  817. auto norm_order = model.layer_config.at(prefix + ".norm_order");
  818. // _forward_self_attn(seqs, padding_mask)
  819. auto residual = seqs;
  820. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  821. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  822. seqs = MultiheadAttention_forward(
  823. model,
  824. prefix + ".self_attn",
  825. seqs,
  826. seqs,
  827. seqs,
  828. /*attn_mask=*/self_attn_mask
  829. );
  830. if (has_layer(model, prefix + ".self_attn_norm"))
  831. seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
  832. seqs = ggml_add_inplace(ctx, seqs, residual);
  833. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  834. seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
  835. // _forward_encoder_decoder_attn
  836. if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
  837. // `encoder_output` must be `None` for decoder-only attention.
  838. GGML_ASSERT(encoder_output == nullptr);
  839. return seqs;
  840. }
  841. // `encoder_output` must not be `None` for encoder-decoder attention.
  842. GGML_ASSERT(encoder_output != nullptr);
  843. residual = seqs;
  844. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  845. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  846. seqs = MultiheadAttention_forward(
  847. model,
  848. prefix + ".encoder_decoder_attn",
  849. seqs,
  850. encoder_output,
  851. encoder_output,
  852. /*attention masks=*/encoder_padding_mask
  853. );
  854. seqs = ggml_add_inplace(ctx, seqs, residual);
  855. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  856. seqs = LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
  857. // _forward_ffn(seqs)
  858. residual = seqs;
  859. if (norm_order != TRANSFORMER_NORM_ORDER_POST)
  860. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  861. seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
  862. // TODO:
  863. // if self.residual_scale is not None:
  864. // residual = self.residual_scale * residual
  865. seqs = ggml_add_inplace(ctx, seqs, residual);
  866. if (norm_order == TRANSFORMER_NORM_ORDER_POST)
  867. seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
  868. return seqs;
  869. }
  870. extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
  871. auto seq_len = seqs->ne[1];
  872. // TODO: allow other ggml_type
  873. ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
  874. return ggml_diag_mask_inf(ctx, mask, 0);
  875. }
  876. extern "C" ggml_tensor* StandardTransformerDecoder_forward(
  877. fairseq2_model& model,
  878. const std::string& prefix,
  879. ggml_tensor* seqs,
  880. ggml_tensor* padding_mask,
  881. ggml_tensor* encoder_output,
  882. ggml_tensor* encoder_padding_mask
  883. ) {
  884. int layer_idx = 0;
  885. std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
  886. ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
  887. while (has_layer(model, layer_name)) {
  888. seqs = StandardTransformerDecoderLayer_forward(
  889. model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
  890. );
  891. ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
  892. layer_idx += 1;
  893. layer_name = prefix + ".layers." + std::to_string(layer_idx);
  894. }
  895. if (has_layer(model, prefix + ".layer_norm"))
  896. seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
  897. return seqs;
  898. }
  899. int _determine_max_seq_len(const SequenceGeneratorJob& job, int source_seq_len) {
  900. auto opts = job.opts;
  901. int max_seq_len = -1;
  902. if (source_seq_len <= 0 || opts.soft_max_seq_len_a <= 0) {
  903. max_seq_len = opts.hard_max_seq_len;
  904. } else {
  905. max_seq_len = std::min(opts.hard_max_seq_len, int(opts.soft_max_seq_len_a * source_seq_len) + opts.soft_max_seq_len_b);
  906. }
  907. if (opts.min_seq_len > max_seq_len) {
  908. printf(
  909. "The effective maximum sequence length must be greater than or equal to `min_seq_len` (%d), but is %d instead. Adjust your soft and hard maximum sequence length limits.\n",
  910. opts.min_seq_len,
  911. max_seq_len
  912. );
  913. GGML_ASSERT(opts.min_seq_len <= max_seq_len);
  914. }
  915. int prefix_seq_len = job.prefix_seq->ne[0];
  916. if (prefix_seq_len >= max_seq_len) {
  917. printf(
  918. "The effective maximum sequence length must be greater than `prefix_seq_len` (%d), but is %d instead.\n",
  919. prefix_seq_len,
  920. max_seq_len
  921. );
  922. GGML_ASSERT(prefix_seq_len < max_seq_len);
  923. }
  924. return max_seq_len;
  925. }
  926. void _fan_out_encoder_output(
  927. ggml_context* ctx,
  928. ggml_tensor** encoder_output_out,
  929. ggml_tensor** encoder_padding_mask_out,
  930. int beam_size
  931. ) {
  932. // (S_enc, M)
  933. ggml_tensor* encoder_output = *encoder_output_out;
  934. ggml_tensor* encoder_padding_mask = *encoder_padding_mask_out;
  935. // (B, S_enc, M)
  936. ggml_tensor* shape = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_output->ne[0], encoder_output->ne[1], beam_size);
  937. // (S_enc, M) -> (B, S_enc, M)
  938. *encoder_output_out = ggml_repeat(ctx, encoder_output, shape);
  939. // (S_enc) -> (B, S_enc)
  940. if (encoder_padding_mask != nullptr) {
  941. ggml_tensor* shape_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_I8, encoder_padding_mask->ne[0], 1, beam_size);
  942. *encoder_padding_mask_out = ggml_repeat(ctx, encoder_padding_mask, shape_mask);
  943. }
  944. }
  945. ggml_tensor* ggml_log_softmax(ggml_context* ctx, ggml_tensor* logits) {
  946. // TODO: this isn't the most precise way of doing this
  947. return ggml_log_inplace(ctx, ggml_soft_max_inplace(ctx, logits));
  948. }
  949. ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int64_t ne1) {
  950. ggml_tensor* shape = ggml_new_tensor_2d(ctx, GGML_TYPE_I8, ne0, ne1);
  951. ggml_type true_type = x->type;
  952. x->type = GGML_TYPE_F32;
  953. ggml_tensor* y = ggml_repeat(ctx, x, shape);
  954. y->type = true_type;
  955. return y;
  956. }
  957. extern "C" void _bootstrap_seqs_and_scores(
  958. fairseq2_model& model,
  959. const SequenceGeneratorJob& job,
  960. ggml_tensor* full_seqs,
  961. ggml_tensor* scores,
  962. ggml_tensor* encoder_output,
  963. ggml_tensor* encoder_padding_mask,
  964. int n_threads
  965. ) {
  966. int prefix_seq_len = job.prefix_seq->ne[0];
  967. int max_seq_len = scores->ne[0];
  968. int beam_size = scores->ne[1];
  969. GGML_ASSERT(prefix_seq_len > 0);
  970. if (prefix_seq_len == 1)
  971. return;
  972. ggml_context* ctx = model.ctx;
  973. // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
  974. full_seqs->type = GGML_TYPE_F32;
  975. job.prefix_seq->type = GGML_TYPE_F32;
  976. ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
  977. seqs = ggml_cpy(ctx, ggml_repeat(ctx, job.prefix_seq, seqs), seqs);
  978. // We have to bootstrap the model with the already fanned-out encoder
  979. // output to correctly initialize its incremental state.
  980. // Note: we don't start decoding the last prefix token just yet.
  981. seqs = ggml_slice(ctx, seqs, 0, 0, prefix_seq_len - 1);
  982. seqs->type = GGML_TYPE_I32;
  983. // Bootstrap the model state with prefix sequence.
  984. seqs = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", seqs);
  985. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  986. model,
  987. "text_decoder",
  988. seqs,
  989. /*padding_mask*/ nullptr,
  990. encoder_output,
  991. encoder_padding_mask
  992. );
  993. // TODO state_bag.increment_step(prefix_seq_len - 1)
  994. // logits, lprobs: (N, S_pfx - 1, V)
  995. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
  996. int vocab_size = logits->ne[0];
  997. ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
  998. ggml_cgraph gf = ggml_build_forward(lprobs);
  999. ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
  1000. ggml_free(ctx);
  1001. full_seqs->type = GGML_TYPE_I32;
  1002. job.prefix_seq->type = GGML_TYPE_I32;
  1003. // Fetch scores of next steps from "lprobs"
  1004. float p_score = 0;
  1005. for (int i = 1; i < prefix_seq_len; ++i) {
  1006. int p = ggml_get_i32_1d(job.prefix_seq, i);
  1007. p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
  1008. for (int b = 0; b < beam_size; ++b) {
  1009. // scores: (N, S)
  1010. // Note: First step (e.g. BOS)'s score is always 0.
  1011. ggml_set_f32_1d(scores, b * max_seq_len + i, p_score);
  1012. }
  1013. }
  1014. }
  1015. /// Finds the topk indices, and write the winning indices in "candidate_indices" array.
  1016. int topk(
  1017. ggml_tensor* lprobs, // (B, V)
  1018. std::int64_t k,
  1019. ggml_tensor* candidate_indices
  1020. ) {
  1021. // Take the best 2 x `beam_size` predictions. We'll choose the first
  1022. // `beam_size` of these which don't predict EOS to continue with.
  1023. // (N, 2 x B)
  1024. // `vocab_size` - 1 to never select PAD.
  1025. std::int64_t K = std::min(k, ggml_nelements(lprobs));
  1026. auto comp = [lprobs](std::int32_t a, std::int32_t b) {
  1027. return ggml_get_f32_1d(lprobs, a) > ggml_get_f32_1d(lprobs, b);
  1028. };
  1029. GGML_ASSERT(ggml_nelements(candidate_indices) >= k);
  1030. auto cand = (std::int32_t*)candidate_indices->data;
  1031. std::partial_sort(cand, cand + K, cand + ggml_nelements(lprobs), comp);
  1032. return K;
  1033. }
  1034. void _tweak_lprobs(const SequenceGeneratorJob& job, ggml_tensor* lprobs, int step_nr, int max_seq_len, std::size_t vocab_size) {
  1035. std::size_t beam_size = job.opts.beam_size;
  1036. std::size_t eos_idx = job.eos_idx;
  1037. // Do not allow EOS before reaching the minimum sequence length.
  1038. if (step_nr < job.opts.min_seq_len) {
  1039. // lprobs[:, :, self.eos_idx] = -INFINITY;
  1040. for (size_t i = 0; i < beam_size; ++i)
  1041. ggml_set_f32_1d(lprobs, vocab_size * i + eos_idx, -INFINITY);
  1042. }
  1043. // If we have reached the maximum length, force the last step to be EOS.
  1044. if (step_nr == max_seq_len - 2) {
  1045. // lprobs[:, :, : self.eos_idx] = -torch.inf
  1046. // lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
  1047. for (size_t b = 0; b < beam_size; ++b) {
  1048. size_t t = 0;
  1049. for (t = 0; t < eos_idx; ++t)
  1050. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1051. for (t = eos_idx + 1; t < vocab_size; ++t)
  1052. ggml_set_f32_1d(lprobs, vocab_size * b + t, -INFINITY);
  1053. }
  1054. }
  1055. // Never allow PAD.
  1056. std::size_t pad_idx = job.pad_idx;
  1057. for (size_t i = 0; i < beam_size; ++i)
  1058. ggml_set_f32_1d(lprobs, vocab_size * i + pad_idx, -INFINITY);
  1059. // Apply UNK penalty.
  1060. if (job.unk_idx >= 0 && job.opts.unk_penalty != 0) {
  1061. // lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
  1062. auto lprobs_raw = ggml_get_data_f32(lprobs);
  1063. for (size_t i = 0; i < beam_size; ++i)
  1064. lprobs_raw[vocab_size * i + job.unk_idx] -= job.opts.unk_penalty;
  1065. }
  1066. }
  1067. /// Copies the sequence and scores of a given candidate beam.
  1068. void _finalize_hypothesis(
  1069. const SequenceGeneratorJob& job,
  1070. ggml_context* ctx,
  1071. int step_nr,
  1072. std::int32_t beam,
  1073. std::int32_t token,
  1074. float eos_score,
  1075. ggml_tensor* seqs, // (beam_size, seq_len)
  1076. ggml_tensor* scores, // (beam_size, seq_len)
  1077. Hypothesis* hypothesis
  1078. ) {
  1079. ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
  1080. hypothesis->seq = seq;
  1081. ggml_tensor* step_scores = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, step_nr + 2);
  1082. hypothesis->step_scores = step_scores;
  1083. auto tok = (std::int32_t*)seq->data;
  1084. for (int i = 0; i < step_nr + 1; ++i) {
  1085. tok[i] = ggml_get_i32_1d(seqs, seqs->ne[0] * beam + i);
  1086. }
  1087. tok[step_nr + 1] = token;
  1088. // Convert from cumulative to per-step scores.
  1089. auto sc = (float*)step_scores->data;
  1090. float last_score = eos_score;
  1091. for (int i = step_nr; i >= 0; --i) {
  1092. float sc0 = ggml_get_f32_1d(scores, scores->ne[0] * beam + i);
  1093. sc[i + 1] = last_score - sc0;
  1094. last_score = sc0;
  1095. }
  1096. sc[0] = 0;
  1097. if (job.opts.normalize_scores)
  1098. // Skip first EOS since it is always 0 and skews normalization.
  1099. eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
  1100. hypothesis->score = eos_score;
  1101. }
  1102. // Uses ggml_context to store any object.
  1103. #define GGML_CTX_ALLOC(ctx, Type, n) \
  1104. (Type*)(ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(Type) * n)->data);
  1105. ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
  1106. return ggml_init({
  1107. /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
  1108. /*.mem_buffer =*/ buffer.data(),
  1109. /*.no_alloc =*/ false,
  1110. });
  1111. }
  1112. /// Generates a translation for a single sequence
  1113. // TODO: clean ups
  1114. // * replace manual tensor tweaking with ggml_set_*d (a ggml_set_slice could be useful)
  1115. extern "C" Hypothesis* generate_sequence(
  1116. fairseq2_model& model,
  1117. const SequenceGeneratorJob& job,
  1118. ggml_tensor* encoder_output,
  1119. ggml_tensor* encoder_padding_mask,
  1120. ggml_context* result_ctx,
  1121. int n_threads
  1122. ) {
  1123. std::vector<uint8_t> local_bufs[3] = {
  1124. std::vector<uint8_t>(1024 * 1024 * 1024), // step_ctx
  1125. std::vector<uint8_t>(1024 * 1024 * 1024), // next_step_ctx
  1126. std::vector<uint8_t>(1024 * 1024 * 1024) // search_ctx
  1127. };
  1128. ggml_context* search_ctx = ctx_from_buffer(local_bufs[2]);
  1129. ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
  1130. size_t vocab_size = embed->ne[1];
  1131. std::size_t beam_size = job.opts.beam_size;
  1132. int source_seq_len = encoder_output->ne[1];
  1133. int max_seq_len = _determine_max_seq_len(job, source_seq_len);
  1134. ggml_context* original_ctx = model.ctx;
  1135. model.ctx = search_ctx;
  1136. fairseq2_kv_cache_alloc(model, beam_size, max_seq_len);
  1137. // (S_enc, M) -> (B, S_enc, M)
  1138. _fan_out_encoder_output(search_ctx, &encoder_output, &encoder_padding_mask, beam_size);
  1139. // Allocate results in the context provided by the caller.
  1140. Hypothesis* finished_searches_begin = GGML_CTX_ALLOC(result_ctx, Hypothesis, beam_size);
  1141. Hypothesis* finished_searches = finished_searches_begin;
  1142. for (std::size_t i = 0; i < beam_size; ++i) finished_searches[i] = {nullptr, -INFINITY, nullptr};
  1143. Hypothesis* finished_searches_end = finished_searches + beam_size;
  1144. // Initialize buffers. (B, S)
  1145. ggml_tensor* seqs = ggml_new_tensor_2d(search_ctx, GGML_TYPE_I32, max_seq_len, beam_size);
  1146. ggml_set_i32(seqs, 0);
  1147. ggml_set_name(seqs, "seqs_0");
  1148. ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
  1149. ggml_set_name(scores, "scores_0");
  1150. ggml_set_f32(scores, 0.0);
  1151. _bootstrap_seqs_and_scores(
  1152. model, job, seqs, scores, encoder_output, encoder_padding_mask, n_threads
  1153. );
  1154. int prefix_seq_len = job.prefix_seq->ne[0];
  1155. int start_step = prefix_seq_len - 1;
  1156. // Holds the indices of beams (a beam can occur more than once) that we
  1157. // should continue with in the next step.
  1158. ggml_tensor* beam_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1159. ggml_tensor* next_tokens = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, beam_size);
  1160. ggml_tensor* next_scores = ggml_new_tensor_1d(search_ctx, GGML_TYPE_F32, beam_size);
  1161. // Array with integers up to 'vocab_size * beam_size' to represent next beams to explore
  1162. ggml_tensor* candidate_indices = ggml_new_tensor_1d(search_ctx, GGML_TYPE_I32, vocab_size * beam_size);
  1163. for (std::size_t i = 0; i < vocab_size * beam_size; ++i)
  1164. ((int32_t *)(candidate_indices->data))[i] = i;
  1165. printf_mem_usage(search_ctx, "search_ctx");
  1166. ggml_context* step_ctx = ctx_from_buffer(local_bufs[0]);
  1167. ggml_context* next_step_ctx = nullptr;
  1168. for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
  1169. model.ctx = step_ctx;
  1170. ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
  1171. ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
  1172. ggml_tensor* decoder_output = StandardTransformerDecoder_forward(
  1173. model,
  1174. "text_decoder",
  1175. decoder_input,
  1176. nullptr, // We never generate PAD.
  1177. encoder_output,
  1178. encoder_padding_mask
  1179. ); // (B, 1, D)
  1180. // Just look at the last token.
  1181. decoder_output = ggml_flatten_1d(step_ctx, decoder_output, 0); // (B, model_dim)
  1182. ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output); // (B, vocab_size)
  1183. ggml_tensor* lprobs = ggml_log_softmax(step_ctx, logits);
  1184. // Compute lprobs here so we can modify it in place in the lprob tweaking phase
  1185. // TODO: use ggml properly compute the tweaks
  1186. ggml_cgraph gf = ggml_build_forward(lprobs);
  1187. // printf("beam search step %d. Graph.n_nodes: %d\n", step_nr, gf.n_nodes);
  1188. ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
  1189. ggml_detach(lprobs);
  1190. _tweak_lprobs(job, lprobs, step_nr, max_seq_len, vocab_size);
  1191. ggml_tensor* last_scores = ggml_slice(step_ctx, scores, 0, step_nr, step_nr+1);
  1192. if (step_nr == start_step) {
  1193. // At the initial step, all hypotheses are equally likely, so we use
  1194. // only the first beam.
  1195. lprobs = ggml_slice(step_ctx, lprobs, 1, 0, 1);
  1196. lprobs = ggml_cont(step_ctx, lprobs);
  1197. // The first step always indicates the beginning of the sequence and has no score.
  1198. if (step_nr > 0) {
  1199. last_scores = ggml_slice(step_ctx, last_scores, 1, 0, 1);
  1200. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1201. }
  1202. } else {
  1203. // Make probabilities contain cumulative scores for each hypothesis.
  1204. lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
  1205. }
  1206. gf = ggml_build_forward(lprobs);
  1207. ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
  1208. // Determine (beam, token) candidates for the next step.
  1209. // (N, 2 x B)
  1210. std::int64_t K = topk(
  1211. lprobs, std::min(2 * beam_size, vocab_size - 1), candidate_indices
  1212. );
  1213. std::size_t ongoing_beams = 0;
  1214. for (std::int32_t i = 0; i < K; ++i) {
  1215. int c = ggml_get_f32_1d(candidate_indices, i);
  1216. std::int32_t beam = c / vocab_size;
  1217. std::int32_t token = c % vocab_size;
  1218. float tok_score = ggml_get_f32_1d(lprobs, c);
  1219. // Detect beams that reached the minimum length and that end with an EOS.
  1220. bool eos = token == job.eos_idx;
  1221. eos &= tok_score != -INFINITY;
  1222. if (eos) {
  1223. _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches++);
  1224. if (finished_searches == finished_searches_end)
  1225. goto end_of_beam_search;
  1226. continue;
  1227. }
  1228. ggml_set_f32_1d(beam_indices, ongoing_beams, beam);
  1229. ggml_set_f32_1d(next_tokens, ongoing_beams, token);
  1230. ggml_set_f32_1d(next_scores, ongoing_beams, tok_score);
  1231. ongoing_beams += 1;
  1232. if (ongoing_beams >= beam_size) break;
  1233. }
  1234. // Reorder beams in the `seq` and `score` buffers. The same beam can
  1235. // be selected more than once.
  1236. ggml_tensor* new_seqs = seqs;
  1237. ggml_tensor* new_scores = scores;
  1238. if (step_nr > start_step) {
  1239. // (B, S), (B) -> (B, S)
  1240. // ggml_get_rows and ggml_set only work with floats ...
  1241. new_seqs->type = GGML_TYPE_F32;
  1242. new_seqs = ggml_get_rows(search_ctx, seqs, beam_indices);
  1243. new_scores = ggml_get_rows(search_ctx, scores, beam_indices);
  1244. ggml_cgraph gf_reorder = ggml_build_forward(new_seqs);
  1245. ggml_build_forward_expand(&gf_reorder, new_scores);
  1246. reorder_kv_cache(model, step_ctx, &gf_reorder, beam_indices);
  1247. ggml_graph_compute_with_ctx(step_ctx, &gf_reorder, n_threads);
  1248. ggml_detach(new_seqs);
  1249. ggml_detach(new_scores);
  1250. new_seqs->type = GGML_TYPE_I32;
  1251. printf_mem_usage(search_ctx, "search_ctx");
  1252. next_step_ctx = ctx_from_buffer(local_bufs[(step_nr + 1) % 2]);
  1253. SWAP(step_ctx, next_step_ctx);
  1254. ggml_free(next_step_ctx);
  1255. }
  1256. // new_seqs[:, step_nr + 1] = next_tokens
  1257. // new_scores[:, step_nr + 1] = next_scores
  1258. for (std::size_t i = 0; i < beam_size; ++i) {
  1259. ((std::int32_t*)new_seqs->data)[step_nr + 1 + i * max_seq_len] = ggml_get_i32_1d(next_tokens, i);
  1260. ((float*)new_scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
  1261. }
  1262. // TODO the old seqs and score buffers could be reused for next step
  1263. seqs = new_seqs;
  1264. scores = new_scores;
  1265. printf_mem_usage(step_ctx, "step_ctx");
  1266. }
  1267. end_of_beam_search:
  1268. // Ensure that hypotheses are sorted by decreasing scores before returning.
  1269. std::sort(
  1270. finished_searches_begin,
  1271. finished_searches_end,
  1272. [](Hypothesis a, Hypothesis b) { return a.score > b.score; }
  1273. );
  1274. fairseq2_kv_cache_reset(model);
  1275. model.ctx = original_ctx;
  1276. return finished_searches_begin;
  1277. }
  1278. extern "C" Hypothesis* _testing_return_hypothesis_ptr(ggml_context* ctx) {
  1279. Hypothesis* result = GGML_CTX_ALLOC(ctx, struct Hypothesis, 2);
  1280. result[0] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 3.14f, (ggml_tensor*)result};
  1281. ggml_set_i32_1d(result[0].seq, 0, 314);
  1282. result[1] = {ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), 4.21f, nullptr};
  1283. ggml_set_i32_1d(result[1].seq, 0, 421);
  1284. return result;
  1285. }
  1286. // SPM tokenizer
  1287. // original implementation:
  1288. // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
  1289. struct llm_symbol {
  1290. using index = int;
  1291. index prev;
  1292. index next;
  1293. const char * text;
  1294. size_t n;
  1295. llama_vocab::id id;
  1296. };
  1297. static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
  1298. static size_t utf8_len(char src) {
  1299. const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
  1300. uint8_t highbits = static_cast<uint8_t>(src) >> 4;
  1301. return lookup[highbits];
  1302. }
  1303. struct llm_bigram_spm {
  1304. struct comparator {
  1305. bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
  1306. return (l.score < r.score) || (l.score == r.score && l.left > r.left);
  1307. }
  1308. };
  1309. using queue_storage = std::vector<llm_bigram_spm>;
  1310. using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
  1311. llm_symbol::index left;
  1312. llm_symbol::index right;
  1313. float score;
  1314. size_t size;
  1315. llama_vocab::id id;
  1316. };
  1317. struct llm_tokenizer_spm {
  1318. llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
  1319. void tokenize(const std::string& input_text, ggml_tensor& output) {
  1320. llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
  1321. // split string into utf8 chars
  1322. int index = 0;
  1323. size_t offs = 0;
  1324. // This is kind of annoying, but needed because with SPM,
  1325. // characters following a space have a special meaning.
  1326. // And the algorithm rely on substrings to do the lookups.
  1327. std::string text = input_text;
  1328. bool need_extra_space = text.size() > 0 && text[0] != ' ';
  1329. if (need_extra_space) text = " " + text;
  1330. while (offs < text.size()) {
  1331. size_t len = utf8_len(text[offs]);
  1332. size_t n = std::min(len, text.size() - offs);
  1333. auto token = vocab.token_to_id.find(std::string(text, offs, n));
  1334. llama_vocab::id id = token == vocab.token_to_id.end() ? unk_idx : token->second;
  1335. llm_symbol sym = {
  1336. /*prev*/ index - 1,
  1337. /*next*/ offs + n == text.size() ? -1 : index + 1,
  1338. /*text*/ text.c_str() + offs,
  1339. /*n*/ n,
  1340. /*id*/ id
  1341. };
  1342. offs += n;
  1343. index++;
  1344. symbols.emplace_back(sym);
  1345. }
  1346. // seed the work queue with all possible 2-character tokens.
  1347. for (size_t i = 1; i < symbols.size(); ++i) {
  1348. try_add_bigram(i - 1, i);
  1349. }
  1350. // keep substituting the highest frequency pairs for as long as we can.
  1351. while (!work_queue.empty()) {
  1352. auto bigram = work_queue.top();
  1353. work_queue.pop();
  1354. auto & left_sym = symbols[bigram.left];
  1355. auto & right_sym = symbols[bigram.right];
  1356. const std::string text = std::string(left_sym.text, left_sym.n + right_sym.n);
  1357. // if one of the symbols already got merged, skip it.
  1358. if (
  1359. left_sym.n == 0
  1360. || right_sym.n == 0
  1361. || left_sym.n + right_sym.n != bigram.size
  1362. ) continue;
  1363. // merge the right sym into the left one
  1364. left_sym.n += right_sym.n;
  1365. left_sym.id = bigram.id;
  1366. right_sym.n = 0;
  1367. // remove the right sym from the chain
  1368. left_sym.next = right_sym.next;
  1369. if (right_sym.next >= 0) {
  1370. symbols[right_sym.next].prev = bigram.left;
  1371. }
  1372. // find more substitutions
  1373. try_add_bigram(left_sym.prev, bigram.left);
  1374. try_add_bigram(bigram.left, left_sym.next);
  1375. }
  1376. llama_vocab::id* out = (llama_vocab::id*)output.data;
  1377. int out_step = sizeof(llama_vocab::id) / output.nb[0];
  1378. int num_tokens = 0;
  1379. for (int i = 0; i > -1; i = symbols[i].next) {
  1380. llm_symbol& symbol = symbols[i];
  1381. *(out + num_tokens * out_step) = symbol.id;
  1382. num_tokens += 1;
  1383. }
  1384. *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
  1385. num_tokens += 1;
  1386. output.ne[0] = num_tokens;
  1387. }
  1388. private:
  1389. void try_add_bigram(int left, int right) {
  1390. if (left == -1 || right == -1) {
  1391. return;
  1392. }
  1393. const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
  1394. auto token = vocab.token_to_id.find(text);
  1395. if (token == vocab.token_to_id.end()) {
  1396. return;
  1397. }
  1398. llama_vocab::id id = token->second;
  1399. if (static_cast<size_t>(id) >= vocab.id_to_token.size()) {
  1400. return;
  1401. }
  1402. const auto& tok_data = vocab.id_to_token[id];
  1403. llm_bigram_spm bigram = {
  1404. /*left */ left,
  1405. /*right*/ right,
  1406. /*score*/ tok_data.score,
  1407. /*size */ text.size(),
  1408. /*id */ id
  1409. };
  1410. work_queue.push(bigram);
  1411. }
  1412. const llama_vocab& vocab;
  1413. std::vector<llm_symbol> symbols;
  1414. llm_bigram_spm::queue work_queue;
  1415. };
  1416. extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
  1417. llm_tokenizer_spm spm = {model->vocab};
  1418. spm.tokenize(std::string(text), out);
  1419. }
  1420. extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
  1421. int eos_idx = model->vocab.token_to_id["</s>"];
  1422. int sent_len = tokens->ne[0];
  1423. std::size_t written = 0;
  1424. for (int i = 0; i < sent_len; ++i) {
  1425. int id = ggml_get_i32_1d(tokens, i);
  1426. // Don't print the EOS token but only if it appear at the end.
  1427. if (i == sent_len - 1 && eos_idx == id) break;
  1428. std::string token = model->vocab.id_to_token.at(id).text;
  1429. // Skip the first space outputted.
  1430. auto begin = token.begin();
  1431. if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
  1432. std::copy(begin, token.end(), out);
  1433. std::size_t n = token.end() - begin;
  1434. written += n;
  1435. out += n;
  1436. }
  1437. *out = '0';
  1438. return written;
  1439. }