unity.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. #include "ggml/ggml.h"
  2. #include "ggml/ggml-alloc.h"
  3. #include "common.h"
  4. #include "common-ggml.h"
  5. #include "math.h"
  6. #include <cassert>
  7. #include <cmath>
  8. #include <cstdio>
  9. #include <cstring>
  10. #include <fstream>
  11. #include <map>
  12. #include <string>
  13. #include <vector>
  14. #include <iostream>
  15. // default hparams
  16. struct unity_hparams {
  17. int32_t n_text_vocab = 256206;
  18. int32_t n_unit_vocab = 10084;
  19. int32_t n_audio_enc_dim = 1024;
  20. int32_t n_audio_enc_ffn_dim = 4096;
  21. int32_t n_audio_enc_feat_dim = 160;
  22. int32_t n_audio_enc_layer = 24;
  23. int32_t n_audio_enc_head = 16;
  24. int32_t ftype = 1;
  25. float eps = 1e-5f;
  26. };
  27. // layer def
  28. struct audio_enc_layer {
  29. struct ggml_tensor * self_attn_layer_norm_w;
  30. struct ggml_tensor * self_attn_layer_norm_b;
  31. struct ggml_tensor * self_attn_linear_k_w;
  32. struct ggml_tensor * self_attn_linear_k_b;
  33. struct ggml_tensor * self_attn_linear_q_w;
  34. struct ggml_tensor * self_attn_linear_q_b;
  35. struct ggml_tensor * self_attn_linear_v_w;
  36. struct ggml_tensor * self_attn_linear_v_b;
  37. struct ggml_tensor * self_attn_linear_out_w;
  38. struct ggml_tensor * self_attn_linear_out_b;
  39. struct ggml_tensor * self_attn_linear_pos_w;
  40. struct ggml_tensor * self_attn_pos_bias_u;
  41. struct ggml_tensor * self_attn_pos_bias_v;
  42. struct ggml_tensor * conv_layer_norm_w;
  43. struct ggml_tensor * conv_layer_norm_b;
  44. struct ggml_tensor * conv_pointwise_conv1_w;
  45. struct ggml_tensor * conv_depthwise_conv_w;
  46. struct ggml_tensor * conv_batch_norm_w;
  47. struct ggml_tensor * conv_batch_norm_b;
  48. struct ggml_tensor * conv_batch_norm_running_mean;
  49. struct ggml_tensor * conv_batch_norm_running_var;
  50. struct ggml_tensor * conv_batch_norm_num_batches_tracked;
  51. struct ggml_tensor * conv_pointwise_conv2_w;
  52. struct ggml_tensor * ffn1_layer_norm_w;
  53. struct ggml_tensor * ffn1_layer_norm_b;
  54. struct ggml_tensor * ffn1_w1;
  55. struct ggml_tensor * ffn1_b1;
  56. struct ggml_tensor * ffn1_w2;
  57. struct ggml_tensor * ffn1_b2;
  58. struct ggml_tensor * ffn2_layer_norm_w;
  59. struct ggml_tensor * ffn2_layer_norm_b;
  60. struct ggml_tensor * ffn2_w1;
  61. struct ggml_tensor * ffn2_b1;
  62. struct ggml_tensor * ffn2_w2;
  63. struct ggml_tensor * ffn2_b2;
  64. struct ggml_tensor * final_layer_norm_w;
  65. struct ggml_tensor * final_layer_norm_b;
  66. };
  67. // struct ggml_tensor * conv_ln;
  68. // struct ggml_tensor * conv_pool_1d;
  69. // model def
  70. struct unity_model {
  71. unity_hparams hparams;
  72. // audio encoder
  73. struct ggml_tensor * post_extract_proj_w;
  74. struct ggml_tensor * post_extract_proj_b;
  75. struct ggml_tensor * audio_enc_pos_conv_wg;
  76. struct ggml_tensor * audio_enc_pos_conv_wv;
  77. struct ggml_tensor * audio_enc_pos_conv_b;
  78. struct ggml_tensor * audio_enc_layer_norm_w;
  79. struct ggml_tensor * audio_enc_layer_norm_b;
  80. struct ggml_tensor * audio_enc_pos_enc_w;
  81. struct ggml_tensor * layer_norm_w;
  82. struct ggml_tensor * layer_norm_b;
  83. struct ggml_tensor * memory_k;
  84. struct ggml_tensor * memory_v;
  85. std::vector<audio_enc_layer> audio_enc_layers;
  86. // text encoder
  87. // std::vector<text_enc_layer> text_enc_layers;
  88. // adaptor
  89. // std::vector<adapter_layer> adapter_layers;
  90. // text decoder
  91. // std::vector<text_dec_layer> text_dec_layers;
  92. // unit decoder
  93. // std::vector<unit_dec_layer> unit_dec_layers;
  94. //
  95. struct ggml_context * ctx;
  96. std::map<std::string, struct ggml_tensor *> tensors;
  97. };
  98. // model load
  99. bool unity_model_load(const std::string & fname, unity_model & model, gpt_vocab & vocab) {
  100. printf("%s: loading model from '%s'\n", __func__, fname.c_str());
  101. auto fin = std::ifstream(fname, std::ios::binary);
  102. if (!fin) {
  103. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  104. return false;
  105. }
  106. // verify magic
  107. {
  108. uint32_t magic;
  109. fin.read((char *) &magic, sizeof(magic));
  110. if (magic != GGML_FILE_MAGIC) {
  111. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  112. return false;
  113. }
  114. }
  115. // load hparams
  116. {
  117. auto & hparams = model.hparams;
  118. fin.read((char *) &hparams.n_text_vocab, sizeof(hparams.n_text_vocab));
  119. fin.read((char *) &hparams.n_audio_enc_dim, sizeof(hparams.n_audio_enc_dim));
  120. fin.read((char *) &hparams.n_audio_enc_ffn_dim, sizeof(hparams.n_audio_enc_ffn_dim));
  121. fin.read((char *) &hparams.n_audio_enc_feat_dim, sizeof(hparams.n_audio_enc_feat_dim));
  122. fin.read((char *) &hparams.n_audio_enc_layer, sizeof(hparams.n_audio_enc_layer));
  123. fin.read((char *) &hparams.n_audio_enc_head, sizeof(hparams.n_audio_enc_head));
  124. fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
  125. const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
  126. printf("%s: n_text_vocab = %d\n", __func__, hparams.n_text_vocab);
  127. printf("%s: n_audio_enc_dim = %d\n", __func__, hparams.n_audio_enc_dim);
  128. printf("%s: n_audio_enc_ffn_dim = %d\n", __func__, hparams.n_audio_enc_ffn_dim);
  129. printf("%s: n_audio_enc_feat_dim = %d\n", __func__, hparams.n_audio_enc_feat_dim);
  130. printf("%s: n_audio_enc_layer = %d\n", __func__, hparams.n_audio_enc_layer);
  131. printf("%s: n_audio_enc_head = %d\n", __func__, hparams.n_audio_enc_head);
  132. printf("%s: ftype = %d\n", __func__, hparams.ftype);
  133. printf("%s: qntvr = %d\n", __func__, qntvr);
  134. hparams.ftype %= GGML_QNT_VERSION_FACTOR;
  135. }
  136. // for the big tensors, we have the option to store the data in 16-bit floats or quantized
  137. // in order to save memory and also to speed up the computation
  138. ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
  139. if (wtype == GGML_TYPE_COUNT) {
  140. fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
  141. __func__, fname.c_str(), model.hparams.ftype);
  142. return false;
  143. }
  144. auto & ctx = model.ctx;
  145. size_t ctx_size = 0;
  146. {
  147. const auto & hparams = model.hparams;
  148. const int n_audio_enc_dim = hparams.n_audio_enc_dim;
  149. const int n_audio_enc_ffn_dim = hparams.n_audio_enc_ffn_dim;
  150. const int n_audio_enc_layer = hparams.n_audio_enc_layer;
  151. const int n_ctx = 4096; // 20ms * 4096 = 80s
  152. // const int n_text_vocab = hparams.n_text_vocab;
  153. const int kernel_size = 31;
  154. ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm_w
  155. ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm_b
  156. ctx_size += n_audio_enc_layer*(5*n_audio_enc_dim*n_audio_enc_dim*ggml_type_sizef(wtype)); // self_attn_w
  157. ctx_size += n_audio_enc_layer*(4*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // self_attn_b
  158. ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm_w
  159. ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm_b
  160. ctx_size += n_audio_enc_layer*(n_audio_enc_dim*n_audio_enc_dim*2*ggml_type_sizef(wtype)); // conv_pointwise_conv1_w
  161. ctx_size += n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_batch_norm_w
  162. ctx_size += n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_batch_norm_b
  163. ctx_size += n_audio_enc_layer*(n_audio_enc_dim*n_audio_enc_dim*kernel_size*ggml_type_sizef(wtype)); // conv_depthwise_conv_w
  164. ctx_size += n_audio_enc_layer*(n_audio_enc_dim*n_audio_enc_dim*ggml_type_sizef(wtype)); // conv_pointwise_conv2_w
  165. ctx_size += 2 * n_audio_enc_layer * (n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_layer_norm_w
  166. ctx_size += 2 * n_audio_enc_layer * (n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_layer_norm_b
  167. ctx_size += 2 * n_audio_enc_layer * (2 * n_audio_enc_dim * n_audio_enc_ffn_dim * ggml_type_sizef(wtype)); // ffn{1,2}_w{1,2}
  168. ctx_size += 2 * n_audio_enc_layer * (2 * n_audio_enc_dim * ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_b{1,2}
  169. ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // final_layer_norm_w
  170. ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // final_layer_norm_b
  171. ctx_size += n_ctx*n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // memory_k
  172. ctx_size += n_ctx*n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // memory_v
  173. // Adaptor
  174. // ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // conv_ln
  175. // ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // conv_pool_1d
  176. // object overhead might differ depending on the structure and other miscellaneous factors
  177. ctx_size += (6 + 12*n_audio_enc_layer)*512; // updated object overhead
  178. printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
  179. printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
  180. }
  181. // create the ggml context
  182. {
  183. struct ggml_init_params params = {
  184. /*.mem_size =*/ ctx_size,
  185. /*.mem_buffer =*/ NULL,
  186. /*.no_alloc =*/ false,
  187. };
  188. model.ctx = ggml_init(params);
  189. if (!model.ctx) {
  190. fprintf(stderr, "%s: ggml_init() failed\n", __func__);
  191. return false;
  192. }
  193. }
  194. // prepare memory for the weights
  195. {
  196. const auto & hparams = model.hparams;
  197. const int n_audio_enc_dim = hparams.n_audio_enc_dim;
  198. const int n_audio_enc_ffn_dim = hparams.n_audio_enc_ffn_dim;
  199. const int n_audio_enc_feat_dim = hparams.n_audio_enc_feat_dim;
  200. const int n_audio_enc_layer = hparams.n_audio_enc_layer;
  201. const int n_audio_enc_head = hparams.n_audio_enc_head;
  202. const int n_ctx = 4096; // 20ms * 4096 = 80s
  203. const int pos_conv_kernel_size = 128;
  204. // const int n_text_vocab = hparams.n_text_vocab;
  205. model.audio_enc_layers.resize(n_audio_enc_layer);
  206. model.audio_enc_pos_enc_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_ctx * 2 - 1);
  207. model.tensors["model/enc/pos_enc/w"] = model.audio_enc_pos_enc_w;
  208. model.post_extract_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim, n_audio_enc_dim);
  209. model.post_extract_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  210. model.tensors["model/post_extract_proj/w"] = model.post_extract_proj_w;
  211. model.tensors["model/post_extract_proj/b"] = model.post_extract_proj_b;
  212. model.audio_enc_pos_conv_wg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, pos_conv_kernel_size);
  213. model.audio_enc_pos_conv_wv = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, pos_conv_kernel_size, 64, n_audio_enc_dim);
  214. model.audio_enc_pos_conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  215. model.tensors["model/enc/pos_conv/w_g"] = model.audio_enc_pos_conv_wg;
  216. model.tensors["model/enc/pos_conv/w_v"] = model.audio_enc_pos_conv_wv;
  217. model.tensors["model/enc/pos_conv/b"] = model.audio_enc_pos_conv_b;
  218. model.audio_enc_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  219. model.audio_enc_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  220. model.tensors["model/enc/layer_norm/w"] = model.audio_enc_layer_norm_w;
  221. model.tensors["model/enc/layer_norm/b"] = model.audio_enc_layer_norm_b;
  222. model.layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
  223. model.layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
  224. model.tensors["model/layer_norm/w"] = model.layer_norm_w;
  225. model.tensors["model/layer_norm/b"] = model.layer_norm_b;
  226. for (int i = 0; i < n_audio_enc_layer; ++i) {
  227. auto & layer = model.audio_enc_layers[i];
  228. layer.self_attn_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  229. layer.self_attn_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  230. layer.self_attn_linear_k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  231. layer.self_attn_linear_k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  232. layer.self_attn_linear_q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  233. layer.self_attn_linear_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  234. layer.self_attn_linear_v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  235. layer.self_attn_linear_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  236. layer.self_attn_linear_out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  237. layer.self_attn_linear_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  238. layer.self_attn_linear_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  239. layer.self_attn_pos_bias_u = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim / n_audio_enc_head, n_audio_enc_head);
  240. layer.self_attn_pos_bias_v = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim / n_audio_enc_head, n_audio_enc_head);
  241. layer.conv_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  242. layer.conv_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  243. layer.conv_pointwise_conv1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, 2*n_audio_enc_dim);
  244. layer.conv_depthwise_conv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 31, n_audio_enc_dim);
  245. layer.conv_batch_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  246. layer.conv_batch_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  247. layer.conv_batch_norm_running_mean = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  248. layer.conv_batch_norm_running_var = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  249. layer.conv_batch_norm_num_batches_tracked = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  250. layer.conv_pointwise_conv2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
  251. layer.ffn1_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  252. layer.ffn1_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  253. layer.ffn1_w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_ffn_dim);
  254. layer.ffn1_b1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim);
  255. layer.ffn1_w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim, n_audio_enc_dim);
  256. layer.ffn1_b2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  257. layer.ffn2_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  258. layer.ffn2_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  259. layer.ffn2_w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_ffn_dim);
  260. layer.ffn2_b1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim);
  261. layer.ffn2_w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim, n_audio_enc_dim);
  262. layer.ffn2_b2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  263. layer.final_layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  264. layer.final_layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
  265. // map by name
  266. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/w"] = layer.self_attn_layer_norm_w;
  267. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/b"] = layer.self_attn_layer_norm_b;
  268. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_k/w"] = layer.self_attn_linear_k_w;
  269. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_k/b"] = layer.self_attn_linear_k_b;
  270. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_q/w"] = layer.self_attn_linear_q_w;
  271. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_q/b"] = layer.self_attn_linear_q_b;
  272. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_v/w"] = layer.self_attn_linear_v_w;
  273. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_v/b"] = layer.self_attn_linear_v_b;
  274. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_out/w"] = layer.self_attn_linear_out_w;
  275. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_out/b"] = layer.self_attn_linear_out_b;
  276. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_pos/w"] = layer.self_attn_linear_pos_w;
  277. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_pos_bias/u"] = layer.self_attn_pos_bias_u;
  278. model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_pos_bias/v"] = layer.self_attn_pos_bias_v;
  279. model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/w"] = layer.conv_layer_norm_w;
  280. model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/b"] = layer.conv_layer_norm_b;
  281. model.tensors["model/enc/h" + std::to_string(i) + "/conv_pointwise_conv1/w"] = layer.conv_pointwise_conv1_w;
  282. model.tensors["model/enc/h" + std::to_string(i) + "/conv_depthwise_conv/w"] = layer.conv_depthwise_conv_w;
  283. model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/w"] = layer.conv_batch_norm_w;
  284. model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/b"] = layer.conv_batch_norm_b;
  285. model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/m"] = layer.conv_batch_norm_running_mean;
  286. model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/v"] = layer.conv_batch_norm_running_var;
  287. model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/n"] = layer.conv_batch_norm_num_batches_tracked;
  288. model.tensors["model/enc/h" + std::to_string(i) + "/conv_pointwise_conv2/w"] = layer.conv_pointwise_conv2_w;
  289. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/w"] = layer.ffn1_layer_norm_w;
  290. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/b"] = layer.ffn1_layer_norm_b;
  291. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_1/w"] = layer.ffn1_w1;
  292. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_1/b"] = layer.ffn1_b1;
  293. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_2/w"] = layer.ffn1_w2;
  294. model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_2/b"] = layer.ffn1_b2;
  295. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/w"] = layer.ffn2_layer_norm_w;
  296. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/b"] = layer.ffn2_layer_norm_b;
  297. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_1/w"] = layer.ffn2_w1;
  298. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_1/b"] = layer.ffn2_b1;
  299. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_2/w"] = layer.ffn2_w2;
  300. model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_2/b"] = layer.ffn2_b2;
  301. model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/w"] = layer.final_layer_norm_w;
  302. model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/b"] = layer.final_layer_norm_b;
  303. }
  304. }
  305. // load weights
  306. {
  307. size_t total_size = 0;
  308. while (true) {
  309. int32_t n_dims;
  310. int32_t length;
  311. int32_t ttype;
  312. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  313. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  314. fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
  315. if (fin.eof()) {
  316. break;
  317. }
  318. int32_t nelements = 1;
  319. int32_t ne[3] = { 1, 1, 1};
  320. for (int i = 0; i < n_dims; ++i) {
  321. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  322. nelements *= ne[i];
  323. }
  324. std::string name(length, 0);
  325. fin.read(&name[0], length);
  326. std::cout << "loading " << name << " " << n_dims << std::endl;
  327. if (model.tensors.find(name) == model.tensors.end()) {
  328. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
  329. return false;
  330. }
  331. auto tensor = model.tensors[name];
  332. if (ggml_nelements(tensor) != nelements) {
  333. std::cout << ggml_nelements(tensor) << std::endl;
  334. std::cout << nelements << std::endl;
  335. fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
  336. return false;
  337. }
  338. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  339. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  340. __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
  341. return false;
  342. }
  343. // for debugging
  344. if (0) {
  345. printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  346. }
  347. const size_t bpe = ggml_type_size(ggml_type(ttype));
  348. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  349. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  350. __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
  351. return false;
  352. }
  353. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  354. total_size += ggml_nbytes(tensor);
  355. }
  356. printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
  357. }
  358. fin.close();
  359. return true;
  360. }
  361. // build the computation graph
  362. struct ggml_cgraph * unity_graph(
  363. const unity_model & model,
  364. struct ggml_allocr * allocr) {
  365. const auto & hparams = model.hparams;
  366. const int n_audio_enc_dim = hparams.n_audio_enc_dim;
  367. const int n_audio_enc_ffn_dim = hparams.n_audio_enc_ffn_dim;
  368. const int n_audio_enc_feat_dim = hparams.n_audio_enc_feat_dim;
  369. const int n_audio_enc_layer = hparams.n_audio_enc_layer;
  370. const int n_audio_enc_head = hparams.n_audio_enc_head;
  371. const int n_ctx = 4096; // 20ms * 4096 = 80s
  372. const int pos_conv_kernel_size = 128;
  373. // const int n_text_vocab = hparams.n_text_vocab;
  374. const int kernel_size = 31;
  375. // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
  376. static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
  377. static std::vector<uint8_t> buf(buf_size);
  378. struct ggml_init_params params = {
  379. /*.mem_size =*/ buf_size,
  380. /*.mem_buffer =*/ buf.data(),
  381. /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
  382. };
  383. struct ggml_context * ctx0 = ggml_init(params);
  384. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  385. /// For dev, load an example input before conformer blocks
  386. auto file = std::ifstream("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.bin", std::ios::binary);
  387. if (!file) {
  388. std::cerr << "Failed to open binary file." << std::endl;
  389. }
  390. struct ggml_tensor * inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1024, 137);
  391. inpL->data = malloc(ggml_nbytes(inpL));
  392. file.read(reinterpret_cast<char *>(inpL->data), ggml_nbytes(inpL));
  393. struct ggml_tensor * ffn_scale = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, 1);
  394. ffn_scale->data = malloc(ggml_nbytes(ffn_scale));
  395. ggml_set_f32(ffn_scale, 0.5f);
  396. for (int il = 0; il < n_audio_enc_layer; ++il) {
  397. struct ggml_tensor * cur = inpL;
  398. struct ggml_tensor * residual = cur;
  399. const audio_enc_layer layer = model.audio_enc_layers[il];
  400. // FFN1: layernorm
  401. cur = ggml_norm(ctx0, cur, hparams.eps);
  402. cur = ggml_add(ctx0,
  403. ggml_mul(ctx0,
  404. ggml_repeat(ctx0, layer.ffn1_layer_norm_w, cur),
  405. cur),
  406. ggml_repeat(ctx0, layer.ffn1_layer_norm_b, cur));
  407. // FFN1: proj
  408. cur = ggml_mul_mat(ctx0, layer.ffn1_w1, cur);
  409. cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.ffn1_b1, cur), cur);
  410. cur = ggml_silu(ctx0, cur);
  411. cur = ggml_mul_mat(ctx0, layer.ffn1_w2, cur);
  412. cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.ffn1_b2, cur), cur);
  413. // FFN1: * 0.5
  414. cur = ggml_mul(ctx0, ggml_repeat(ctx0, ffn_scale, cur), cur);
  415. // FFN1: + residual
  416. cur = ggml_add(ctx0, cur, residual);
  417. // self_attn: layernorm
  418. cur = ggml_norm(ctx0, cur, hparams.eps);
  419. cur = ggml_add(ctx0,
  420. ggml_mul(ctx0,
  421. ggml_repeat(ctx0, layer.self_attn_layer_norm_w, cur),
  422. cur),
  423. ggml_repeat(ctx0, layer.self_attn_layer_norm_b, cur));
  424. // self_attn: qkv
  425. struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
  426. layer.self_attn_linear_q_w,
  427. cur);
  428. Qcur = ggml_add(ctx0,
  429. ggml_repeat(ctx0,
  430. layer.self_attn_linear_q_b,
  431. Qcur),
  432. Qcur);
  433. struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
  434. layer.self_attn_linear_k_w,
  435. cur);
  436. Kcur = ggml_add(ctx0,
  437. ggml_repeat(ctx0,
  438. layer.self_attn_linear_k_b,
  439. Kcur),
  440. Kcur);
  441. struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
  442. layer.self_attn_linear_v_w,
  443. cur);
  444. Vcur = ggml_add(ctx0,
  445. ggml_repeat(ctx0,
  446. layer.self_attn_linear_v_b,
  447. Vcur),
  448. Vcur);
  449. // self_attn: rel_pos SDPA
  450. int32_t S = cur->ne[1];
  451. int32_t H = n_audio_enc_head;
  452. int32_t K_h = n_audio_enc_dim / H;
  453. int32_t start_index = n_ctx - S;
  454. int32_t end_index = n_ctx + S - 1;
  455. int num_indices = end_index - start_index;
  456. struct ggml_tensor *rows = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_indices);
  457. rows->data = malloc(ggml_nbytes(rows));
  458. for (int i = 0; i < num_indices; i++) {
  459. ((int32_t *)rows->data)[i] = start_index + i;
  460. }
  461. // self_attn: load pos_enc weights & compute_r
  462. struct ggml_tensor * r = ggml_get_rows(ctx0, model.audio_enc_pos_enc_w, rows);
  463. r = ggml_mul_mat(ctx0, layer.self_attn_linear_pos_w, r); // TODO: reshape
  464. r = ggml_dup(ctx0, ggml_permute(ctx0,
  465. ggml_cpy(ctx0,
  466. r,
  467. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K_h, H, S*2-1)),
  468. 0, 2, 1, 3));
  469. struct ggml_tensor * u_bias = ggml_reshape_3d(ctx0, layer.self_attn_pos_bias_u, K_h, 1, H);
  470. struct ggml_tensor * v_bias = ggml_reshape_3d(ctx0, layer.self_attn_pos_bias_v, K_h, 1, H);
  471. // (H * K_h, S) -> (K_h, H, S) -> (K_h, S, H)
  472. struct ggml_tensor * Q =
  473. ggml_dup(ctx0, ggml_permute(ctx0,
  474. ggml_cpy(ctx0,
  475. Qcur,
  476. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K_h, H, S)),
  477. 0, 2, 1, 3));
  478. struct ggml_tensor * K =
  479. ggml_dup(ctx0, ggml_permute(ctx0,
  480. ggml_cpy(ctx0,
  481. Kcur,
  482. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K_h, H, S)),
  483. 0, 2, 1, 3));
  484. // struct ggml_tensor * V =
  485. // ggml_dup(ctx0, ggml_permute(ctx0,
  486. // ggml_cpy(ctx0,
  487. // Vcur,
  488. // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K_h, H, S)),
  489. // 1, 2, 0, 3));
  490. struct ggml_tensor * V =
  491. ggml_dup(ctx0, ggml_permute(ctx0,
  492. ggml_cpy(ctx0,
  493. Vcur,
  494. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K_h, H, S)),
  495. 1, 2, 0, 3));
  496. // (K_h, S, H)
  497. struct ggml_tensor * q_with_u_bias = ggml_add(ctx0, Q, u_bias);
  498. struct ggml_tensor * q_with_v_bias = ggml_add(ctx0, Q, v_bias);
  499. struct ggml_tensor * ac = ggml_mul_mat(ctx0, K, q_with_u_bias);
  500. struct ggml_tensor * bd = ggml_mul_mat(ctx0, r, q_with_v_bias);
  501. // self_attn: shift_bd
  502. bd = ggml_dup(ctx0, ggml_permute(ctx0, bd, 2, 1, 0, 3)); // H, S, 2S-1
  503. struct ggml_tensor * pad = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, H, S, 1);
  504. pad->data = malloc(ggml_nbytes(pad));
  505. pad = ggml_set_f32(pad, 0.0);
  506. bd = ggml_concat(ctx0, pad, bd); // bd[i][j][0] == 0
  507. bd = ggml_dup(ctx0, ggml_permute(ctx0, bd, 2, 1, 0, 3)); // ok -> (2S, S, H) = pytorch (H, S, 2S)
  508. bd = ggml_dup(ctx0, ggml_reshape_3d(ctx0, bd, S, 2*S, H)); // ok. (S, 2S, H)
  509. bd = ggml_remove_head_row(ctx0, bd);
  510. bd = ggml_reshape_3d(ctx0, bd, 2*S-1, S, H);
  511. bd = ggml_get_first_cols_by_rows(ctx0, bd);
  512. // self_attn: compute attn / weights
  513. struct ggml_tensor * attn_weights = ggml_add(ctx0, ac, bd);
  514. // inpL = ggml_sum(ctx0, attn_weights);
  515. struct ggml_tensor * attn_scale = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, 1);
  516. attn_scale->data = malloc(ggml_nbytes(attn_scale));
  517. ggml_set_f32(attn_scale, 1.0 / pow(K_h, 0.5));
  518. attn_weights = ggml_mul(ctx0, ggml_repeat(ctx0, attn_scale, attn_weights), attn_weights);
  519. attn_weights = ggml_soft_max(ctx0, attn_weights);
  520. struct ggml_tensor * attn = ggml_mul_mat(ctx0, V, attn_weights);
  521. inpL = attn;
  522. break;
  523. // conv
  524. // ffn2
  525. // norm
  526. }
  527. ggml_build_forward_expand(gf, inpL);
  528. ggml_free(ctx0);
  529. return gf;
  530. }
  531. bool unity_eval(
  532. const unity_model & model,
  533. struct ggml_allocr * allocr,
  534. const int n_threads) {
  535. const auto & hparams = model.hparams;
  536. // reset the allocator to free all the memory allocated during the previous inference
  537. ggml_allocr_reset(allocr);
  538. struct ggml_cgraph * gf = unity_graph(model, allocr);
  539. // allocate tensors
  540. ggml_allocr_alloc_graph(allocr, gf);
  541. // run the computation
  542. struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
  543. static std::vector<uint8_t> work_buffer;
  544. work_buffer.resize(plan.work_size);
  545. plan.work_data = work_buffer.data();
  546. ggml_graph_compute(gf, &plan);
  547. // in this case, the output tensor is the last one in the graph
  548. struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
  549. for (int i = 0; i < 1000; ++i) {
  550. printf("%8.4f ", ((float *)(inpL->data))[i]);
  551. }
  552. return true;
  553. }
  554. int main(int argc, char ** argv) {
  555. // ggml_time_init();
  556. // const int64_t t_main_start_us = ggml_time_us();
  557. gpt_params params;
  558. if (gpt_params_parse(argc, argv, params) == false) {
  559. return 1;
  560. }
  561. if (params.seed < 0) {
  562. params.seed = time(NULL);
  563. }
  564. printf("%s: seed = %d\n", __func__, params.seed);
  565. std::mt19937 rng(params.seed);
  566. if (params.prompt.empty()) {
  567. params.prompt = gpt_random_prompt(rng);
  568. }
  569. gpt_vocab vocab;
  570. unity_model model;
  571. // load the model
  572. {
  573. if (!unity_model_load(params.model, model, vocab)) {
  574. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  575. return 1;
  576. }
  577. }
  578. // keep this buffer alive while evaluating the model
  579. std::vector<uint8_t> compute_buffer;
  580. struct ggml_allocr * allocr = NULL;
  581. // allocate the compute buffer
  582. {
  583. allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
  584. struct ggml_cgraph * gf = unity_graph(model, allocr);
  585. // compute the required memory
  586. size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
  587. // recreate the allocator with the required memory
  588. ggml_allocr_free(allocr);
  589. compute_buffer.resize(mem_size);
  590. allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
  591. fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
  592. }
  593. if (!unity_eval(model, allocr, 1)) {
  594. printf("Failed to predict\n");
  595. return 1;
  596. }
  597. ggml_free(model.ctx);
  598. return 0;
  599. }