main.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. #include "ggml/ggml.h"
  2. #include "common.h"
  3. #include "common-ggml.h"
  4. #include <cassert>
  5. #include <cmath>
  6. #include <cstdio>
  7. #include <cstring>
  8. #include <fstream>
  9. #include <map>
  10. #include <string>
  11. #include <vector>
  12. #if defined(_MSC_VER)
  13. #pragma warning(disable: 4244 4267) // possible loss of data
  14. #endif
  15. // default hparams (GPT-2 117M)
  16. // https://huggingface.co/bigcode/gpt_bigcode-santacoder/blob/main/config.json
  17. struct starcoder_hparams {
  18. int32_t n_vocab = 49280;
  19. int32_t n_ctx = 2048;
  20. int32_t n_embd = 2048;
  21. int32_t n_head = 16;
  22. int32_t n_layer = 24;
  23. int32_t ftype = 1;
  24. float eps = 1e-5f;
  25. };
  26. struct starcoder_layer {
  27. // normalization
  28. struct ggml_tensor * ln_1_g;
  29. struct ggml_tensor * ln_1_b;
  30. struct ggml_tensor * ln_2_g;
  31. struct ggml_tensor * ln_2_b;
  32. // attention
  33. struct ggml_tensor * c_attn_attn_w;
  34. struct ggml_tensor * c_attn_attn_b;
  35. struct ggml_tensor * c_attn_proj_w;
  36. struct ggml_tensor * c_attn_proj_b;
  37. // mlp
  38. struct ggml_tensor * c_mlp_fc_w;
  39. struct ggml_tensor * c_mlp_fc_b;
  40. struct ggml_tensor * c_mlp_proj_w;
  41. struct ggml_tensor * c_mlp_proj_b;
  42. };
  43. struct starcoder_model {
  44. starcoder_hparams hparams;
  45. // normalization
  46. struct ggml_tensor * ln_f_g;
  47. struct ggml_tensor * ln_f_b;
  48. struct ggml_tensor * wte; // position embedding
  49. struct ggml_tensor * wpe; // token embedding
  50. struct ggml_tensor * lm_head; // language model head
  51. std::vector<starcoder_layer> layers;
  52. // key + value memory
  53. struct ggml_tensor * memory_k;
  54. struct ggml_tensor * memory_v;
  55. //
  56. struct ggml_context * ctx;
  57. std::map<std::string, struct ggml_tensor *> tensors;
  58. };
  59. // load the model's weights from a file
  60. bool starcoder_model_load(const std::string & fname, starcoder_model & model, gpt_vocab & vocab) {
  61. printf("%s: loading model from '%s'\n", __func__, fname.c_str());
  62. auto fin = std::ifstream(fname, std::ios::binary);
  63. if (!fin) {
  64. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  65. return false;
  66. }
  67. // verify magic
  68. {
  69. uint32_t magic;
  70. fin.read((char *) &magic, sizeof(magic));
  71. if (magic != GGML_FILE_MAGIC) {
  72. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  73. return false;
  74. }
  75. }
  76. // load hparams
  77. {
  78. auto & hparams = model.hparams;
  79. fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
  80. fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
  81. fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
  82. fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
  83. fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
  84. fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
  85. const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
  86. printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
  87. printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
  88. printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
  89. printf("%s: n_head = %d\n", __func__, hparams.n_head);
  90. printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
  91. printf("%s: ftype = %d\n", __func__, hparams.ftype);
  92. printf("%s: qntvr = %d\n", __func__, qntvr);
  93. hparams.ftype %= GGML_QNT_VERSION_FACTOR;
  94. }
  95. // load vocab
  96. {
  97. int32_t n_vocab = 0;
  98. fin.read((char *) &n_vocab, sizeof(n_vocab));
  99. if (n_vocab != model.hparams.n_vocab) {
  100. fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
  101. __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
  102. return false;
  103. }
  104. std::string word;
  105. std::vector<char> buf(128);
  106. for (int i = 0; i < n_vocab; i++) {
  107. uint32_t len;
  108. fin.read((char *) &len, sizeof(len));
  109. buf.resize(len);
  110. fin.read((char *) buf.data(), len);
  111. word.assign(buf.data(), len);
  112. vocab.token_to_id[word] = i;
  113. vocab.id_to_token[i] = word;
  114. // if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
  115. }
  116. // Add StarChat special tokens.
  117. for (std::string token : {
  118. "<|system|>",
  119. "<|user|>",
  120. "<|assistant|>",
  121. "<|end|>",
  122. "<fim-prefix>",
  123. "<fim-middle>",
  124. "<fim-suffix>",
  125. "<fim-pad>",
  126. "<|end_of_turn|>"
  127. }) {
  128. if (vocab.token_to_id.find(token) != vocab.token_to_id.end()) {
  129. vocab.add_special_token(token);
  130. }
  131. }
  132. }
  133. // for the big tensors, we have the option to store the data in 16-bit floats or quantized
  134. // in order to save memory and also to speed up the computation
  135. ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
  136. if (wtype == GGML_TYPE_COUNT) {
  137. fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
  138. __func__, fname.c_str(), model.hparams.ftype);
  139. return false;
  140. }
  141. auto & ctx = model.ctx;
  142. size_t ctx_size = 0;
  143. {
  144. const auto & hparams = model.hparams;
  145. const int n_embd = hparams.n_embd;
  146. const int n_layer = hparams.n_layer;
  147. const int n_ctx = hparams.n_ctx;
  148. const int n_vocab = hparams.n_vocab;
  149. const int head_dim = n_embd / hparams.n_head;
  150. const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
  151. const int kv_dim = kv_heads * head_dim;
  152. ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
  153. ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
  154. ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
  155. ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
  156. ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
  157. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
  158. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
  159. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
  160. ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
  161. ctx_size += n_layer*((n_embd + 2*kv_dim)*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w // TODO:
  162. ctx_size += n_layer*( (n_embd + 2*kv_dim)*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
  163. ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
  164. ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
  165. ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
  166. ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
  167. ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
  168. ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
  169. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
  170. ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
  171. ctx_size += (6 + 12*n_layer)*512; // object overhead
  172. printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
  173. }
  174. // create the ggml context
  175. {
  176. struct ggml_init_params params = {
  177. /*.mem_size =*/ ctx_size,
  178. /*.mem_buffer =*/ NULL,
  179. /*.no_alloc =*/ false,
  180. };
  181. model.ctx = ggml_init(params);
  182. if (!model.ctx) {
  183. fprintf(stderr, "%s: ggml_init() failed\n", __func__);
  184. return false;
  185. }
  186. }
  187. // prepare memory for the weights
  188. {
  189. const auto & hparams = model.hparams;
  190. const int n_embd = hparams.n_embd;
  191. const int n_layer = hparams.n_layer;
  192. const int n_ctx = hparams.n_ctx;
  193. const int n_vocab = hparams.n_vocab;
  194. const int head_dim = n_embd / hparams.n_head;
  195. const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
  196. const int kv_dim = kv_heads * head_dim;
  197. model.layers.resize(n_layer);
  198. model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  199. model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  200. model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  201. model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
  202. model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
  203. // map by name
  204. model.tensors["model/ln_f/g"] = model.ln_f_g;
  205. model.tensors["model/ln_f/b"] = model.ln_f_b;
  206. model.tensors["model/wte"] = model.wte;
  207. model.tensors["model/wpe"] = model.wpe;
  208. model.tensors["model/lm_head"] = model.lm_head;
  209. for (int i = 0; i < n_layer; ++i) {
  210. auto & layer = model.layers[i];
  211. layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  212. layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  213. layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  214. layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  215. layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd + 2*kv_dim);
  216. layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd + 2*kv_dim);
  217. layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
  218. layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  219. layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); //TODO: 4*n_embd = config.n_inner
  220. layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
  221. layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
  222. layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
  223. // map by name
  224. model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
  225. model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
  226. model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
  227. model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
  228. model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
  229. model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
  230. model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
  231. model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
  232. model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
  233. model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
  234. model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
  235. model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
  236. }
  237. }
  238. // key + value memory
  239. {
  240. const auto & hparams = model.hparams;
  241. const int n_embd = hparams.n_embd;
  242. const int n_layer = hparams.n_layer;
  243. const int n_ctx = hparams.n_ctx;
  244. const int n_mem = n_layer*n_ctx;
  245. const int n_elements = n_embd*n_mem;
  246. model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  247. model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
  248. const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
  249. printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
  250. }
  251. // load weights
  252. {
  253. size_t total_size = 0;
  254. bool has_lm_head = false;
  255. while (true) {
  256. int32_t n_dims;
  257. int32_t length;
  258. int32_t ttype;
  259. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  260. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  261. fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
  262. if (fin.eof()) {
  263. break;
  264. }
  265. int32_t nelements = 1;
  266. int32_t ne[2] = { 1, 1 };
  267. for (int i = 0; i < n_dims; ++i) {
  268. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  269. nelements *= ne[i];
  270. }
  271. std::string name(length, 0);
  272. fin.read(&name[0], length);
  273. if (model.tensors.find(name) == model.tensors.end()) {
  274. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
  275. return false;
  276. }
  277. auto tensor = model.tensors[name];
  278. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  279. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  280. __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
  281. return false;
  282. }
  283. if (ggml_nelements(tensor) != nelements) {
  284. fprintf(stderr, "%s: tensor '%s' has wrong size in model file. got %d, expected %d\n",
  285. __func__, name.c_str(), (int) ggml_nelements(tensor), nelements);
  286. return false;
  287. }
  288. // for debugging
  289. if (0) {
  290. printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  291. }
  292. const size_t bpe = ggml_type_size(ggml_type(ttype));
  293. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  294. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  295. __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
  296. return false;
  297. }
  298. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  299. // GPT-2 models share the WTE tensor as the LM head
  300. if (name == "model/wte" && has_lm_head == false) {
  301. memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
  302. }
  303. if (name == "model/lm_head") {
  304. has_lm_head = true;
  305. }
  306. total_size += ggml_nbytes(tensor);
  307. }
  308. printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
  309. }
  310. fin.close();
  311. return true;
  312. }
  313. // evaluate the transformer
  314. //
  315. // - model: the model
  316. // - n_threads: number of threads to use
  317. // - n_past: the context size so far
  318. // - embd_inp: the embeddings of the tokens in the context
  319. // - embd_w: the predicted logits for the next token
  320. //
  321. bool starcoder_eval(
  322. const starcoder_model & model,
  323. const int n_threads,
  324. const int n_past,
  325. const std::vector<gpt_vocab::id> & embd_inp,
  326. std::vector<float> & embd_w,
  327. size_t & mem_per_token) {
  328. const int N = embd_inp.size();
  329. const auto & hparams = model.hparams;
  330. const int n_embd = hparams.n_embd;
  331. const int n_layer = hparams.n_layer;
  332. const int n_ctx = hparams.n_ctx;
  333. const int n_head = hparams.n_head;
  334. const int n_vocab = hparams.n_vocab;
  335. static size_t buf_size = 256u*1024*1024;
  336. static void * buf = malloc(buf_size);
  337. // use 2 scratch buffers
  338. // TODO: very hacky solution - reimplement in a more elegant way
  339. static size_t scr0_size = 256u*1024*1024;
  340. static void * scr0 = malloc(scr0_size);
  341. static size_t scr1_size = 256u*1024*1024;
  342. static void * scr1 = malloc(scr1_size);
  343. if (mem_per_token > 0 && mem_per_token*N > buf_size) {
  344. const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
  345. //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
  346. // reallocate
  347. buf_size = buf_size_new;
  348. buf = realloc(buf, buf_size);
  349. if (buf == nullptr) {
  350. fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
  351. return false;
  352. }
  353. }
  354. struct ggml_init_params params = {
  355. /*.mem_size =*/ buf_size,
  356. /*.mem_buffer =*/ buf,
  357. /*.no_alloc =*/ false,
  358. };
  359. struct ggml_context * ctx0 = ggml_init(params);
  360. struct ggml_cgraph gf = {};
  361. struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
  362. memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
  363. struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
  364. for (int i = 0; i < N; ++i) {
  365. ((int32_t *) position->data)[i] = n_past + i;
  366. }
  367. // wte + wpe
  368. struct ggml_tensor * inpL =
  369. ggml_add(ctx0,
  370. ggml_get_rows(ctx0, model.wte, embd),
  371. ggml_get_rows(ctx0, model.wpe, position));
  372. for (int il = 0; il < n_layer; ++il) {
  373. struct ggml_tensor * cur;
  374. ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
  375. // norm
  376. {
  377. // [ 768, N]
  378. cur = ggml_norm(ctx0, inpL, hparams.eps);
  379. // cur = ln_1_g*cur + ln_1_b
  380. // [ 768, N]
  381. cur = ggml_add(ctx0,
  382. ggml_mul(ctx0,
  383. ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
  384. cur),
  385. ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
  386. }
  387. // attn
  388. // [2304, 768] - model.layers[il].c_attn_attn_w
  389. // [2304, 1] - model.layers[il].c_attn_attn_b
  390. // [ 768, N] - cur (in)
  391. // [2304, N] - cur (out)
  392. //
  393. // cur = attn_w*cur + attn_b
  394. // [2304, N]
  395. {
  396. cur = ggml_mul_mat(ctx0,
  397. model.layers[il].c_attn_attn_w,
  398. cur);
  399. cur = ggml_add(ctx0,
  400. ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
  401. cur);
  402. }
  403. // self-attention
  404. {
  405. struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
  406. struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
  407. struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
  408. // store key and value to memory
  409. if (N >= 1) {
  410. struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
  411. struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
  412. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
  413. ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
  414. }
  415. // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
  416. // [64, N, 12]
  417. struct ggml_tensor * Q =
  418. ggml_permute(ctx0,
  419. ggml_cpy(ctx0,
  420. Qcur,
  421. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
  422. 0, 2, 1, 3);
  423. // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
  424. // [64, n_past + N, 12]
  425. struct ggml_tensor * K =
  426. ggml_permute(ctx0,
  427. ggml_reshape_3d(ctx0,
  428. ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
  429. n_embd/n_head, n_head, n_past + N),
  430. 0, 2, 1, 3); //TODO: need to be tiled
  431. // GG: flash attention
  432. //struct ggml_tensor * V =
  433. // ggml_cpy(ctx0,
  434. // ggml_permute(ctx0,
  435. // ggml_reshape_3d(ctx0,
  436. // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
  437. // n_embd/n_head, n_head, n_past + N),
  438. // 1, 2, 0, 3),
  439. // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
  440. //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
  441. // K * Q
  442. // [n_past + N, N, 12]
  443. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts
  444. // KQ_scaled = KQ / sqrt(n_embd/n_head)
  445. // [n_past + N, N, 12]
  446. struct ggml_tensor * KQ_scaled =
  447. ggml_scale_inplace(ctx0,
  448. KQ,
  449. ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
  450. );
  451. // KQ_masked = mask_past(KQ_scaled)
  452. // [n_past + N, N, 12]
  453. struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
  454. // KQ = soft_max(KQ_masked)
  455. // [n_past + N, N, 12]
  456. struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
  457. // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
  458. // [n_past + N, 64, 12]
  459. struct ggml_tensor * V_trans =
  460. ggml_cpy(ctx0,
  461. ggml_permute(ctx0,
  462. ggml_reshape_3d(ctx0,
  463. ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
  464. n_embd/n_head, n_head, n_past + N),
  465. 1, 2, 0, 3),
  466. ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
  467. // KQV = transpose(V) * KQ_soft_max
  468. // [64, N, 12]
  469. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
  470. // KQV_merged = KQV.permute(0, 2, 1, 3)
  471. // [64, 12, N]
  472. struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
  473. // cur = KQV_merged.contiguous().view(n_embd, N)
  474. // [768, N]
  475. cur = ggml_cpy(ctx0,
  476. KQV_merged,
  477. ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
  478. }
  479. // projection
  480. // [ 768, 768] - model.layers[il].c_attn_proj_w
  481. // [ 768, 1] - model.layers[il].c_attn_proj_b
  482. // [ 768, N] - cur (in)
  483. // [ 768, N] - cur (out)
  484. //
  485. // cur = proj_w*cur + proj_b
  486. // [768, N]
  487. {
  488. cur = ggml_mul_mat(ctx0,
  489. model.layers[il].c_attn_proj_w,
  490. cur);
  491. cur = ggml_add(ctx0,
  492. ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
  493. cur);
  494. }
  495. // add the input
  496. cur = ggml_add(ctx0, cur, inpL);
  497. struct ggml_tensor * inpFF = cur;
  498. ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
  499. // feed-forward network
  500. {
  501. // norm
  502. {
  503. cur = ggml_norm(ctx0, inpFF, hparams.eps);
  504. // cur = ln_2_g*cur + ln_2_b
  505. // [ 768, N]
  506. cur = ggml_add(ctx0,
  507. ggml_mul(ctx0,
  508. ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
  509. cur),
  510. ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
  511. }
  512. // fully connected
  513. // [3072, 768] - model.layers[il].c_mlp_fc_w
  514. // [3072, 1] - model.layers[il].c_mlp_fc_b
  515. // [ 768, N] - cur (in)
  516. // [3072, N] - cur (out)
  517. //
  518. // cur = fc_w*cur + fc_b
  519. // [3072, N]
  520. cur = ggml_mul_mat(ctx0,
  521. model.layers[il].c_mlp_fc_w,
  522. cur);
  523. cur = ggml_add(ctx0,
  524. ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
  525. cur);
  526. // GELU activation
  527. // [3072, N]
  528. cur = ggml_gelu(ctx0, cur);
  529. // projection
  530. // [ 768, 3072] - model.layers[il].c_mlp_proj_w
  531. // [ 768, 1] - model.layers[il].c_mlp_proj_b
  532. // [3072, N] - cur (in)
  533. // [ 768, N] - cur (out)
  534. //
  535. // cur = proj_w*cur + proj_b
  536. // [768, N]
  537. cur = ggml_mul_mat(ctx0,
  538. model.layers[il].c_mlp_proj_w,
  539. cur);
  540. cur = ggml_add(ctx0,
  541. ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
  542. cur);
  543. }
  544. // input for next layer
  545. inpL = ggml_add(ctx0, cur, inpFF);
  546. }
  547. ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
  548. // norm
  549. {
  550. // [ 768, N]
  551. inpL = ggml_norm(ctx0, inpL, hparams.eps);
  552. // inpL = ln_f_g*inpL + ln_f_b
  553. // [ 768, N]
  554. inpL = ggml_add(ctx0,
  555. ggml_mul(ctx0,
  556. ggml_repeat(ctx0, model.ln_f_g, inpL),
  557. inpL),
  558. ggml_repeat(ctx0, model.ln_f_b, inpL));
  559. }
  560. ggml_set_scratch(ctx0, { 0, 0, nullptr, });
  561. // inpL = WTE * inpL
  562. // [ 768, 50257] - model.lm_head
  563. // [ 768, N] - inpL
  564. inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
  565. // logits -> probs
  566. //inpL = ggml_soft_max_inplace(ctx0, inpL);
  567. // run the computation
  568. ggml_build_forward_expand(&gf, inpL);
  569. ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
  570. //if (n_past%100 == 0) {
  571. // ggml_graph_print (&gf);
  572. // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
  573. //}
  574. //embd_w.resize(n_vocab*N);
  575. //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
  576. // return result just for the last token
  577. embd_w.resize(n_vocab);
  578. memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
  579. if (mem_per_token == 0) {
  580. mem_per_token = ggml_used_mem(ctx0)/N;
  581. }
  582. //printf("used_mem = %zu MB\n", ggml_used_mem(ctx0)/(1024*1024));
  583. ggml_free(ctx0);
  584. return true;
  585. }
  586. int main(int argc, char ** argv) {
  587. ggml_time_init();
  588. const int64_t t_main_start_us = ggml_time_us();
  589. gpt_params params;
  590. if (gpt_params_parse(argc, argv, params) == false) {
  591. return 1;
  592. }
  593. if (params.seed < 0) {
  594. params.seed = time(NULL);
  595. }
  596. printf("%s: seed = %d\n", __func__, params.seed);
  597. std::mt19937 rng(params.seed);
  598. if (params.prompt.empty()) {
  599. params.prompt = gpt_random_prompt(rng);
  600. }
  601. int64_t t_load_us = 0;
  602. gpt_vocab vocab;
  603. starcoder_model model;
  604. // load the model
  605. {
  606. const int64_t t_start_us = ggml_time_us();
  607. if (!starcoder_model_load(params.model, model, vocab)) {
  608. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  609. return 1;
  610. }
  611. t_load_us = ggml_time_us() - t_start_us;
  612. test_gpt_tokenizer(vocab, params.token_test);
  613. }
  614. if (params.repeat_last_n == -1) {
  615. params.repeat_last_n = model.hparams.n_ctx;
  616. }
  617. printf("\n");
  618. printf("%s: temp = %.3f\n", __func__, params.temp);
  619. printf("%s: top_k = %d\n", __func__, params.top_k);
  620. printf("%s: top_p = %.3f\n", __func__, params.top_p);
  621. printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n);
  622. printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
  623. int n_past = 0;
  624. int64_t t_sample_us = 0;
  625. int64_t t_predict_us = 0;
  626. std::vector<float> logits;
  627. std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
  628. std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
  629. // tokenize the prompt
  630. std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
  631. params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
  632. printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
  633. printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
  634. for (size_t i = 0; i < embd_inp.size(); i++) {
  635. printf("%s: token[%zu] = %6d, %s\n", __func__, i, embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
  636. }
  637. printf("\n\n");
  638. // Handle StarChat "<|end|>" and OpenCoder "<|end_of_turn>" tokens.
  639. gpt_vocab::id starchat_end_token = -1;
  640. {
  641. const auto it = vocab.token_to_id.find("<|end|>");
  642. if (it != vocab.token_to_id.end()) {
  643. starchat_end_token = it->second;
  644. } else {
  645. const auto eot_token_id = vocab.token_to_id.find("<|end_of_turn|>");
  646. if (eot_token_id != vocab.token_to_id.end()) {
  647. starchat_end_token = eot_token_id->second;
  648. }
  649. }
  650. }
  651. // submit the input prompt token-by-token
  652. // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
  653. std::vector<gpt_vocab::id> embd;
  654. // determine the required inference memory per token:
  655. size_t mem_per_token = 0;
  656. starcoder_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
  657. for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
  658. // predict
  659. if (embd.size() > 0) {
  660. const int64_t t_start_us = ggml_time_us();
  661. if (!starcoder_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
  662. printf("Failed to predict\n");
  663. return 1;
  664. }
  665. t_predict_us += ggml_time_us() - t_start_us;
  666. }
  667. n_past += embd.size();
  668. embd.clear();
  669. if (i >= embd_inp.size()) {
  670. // sample next token
  671. const int top_k = params.top_k;
  672. const float top_p = params.top_p;
  673. const float temp = params.temp;
  674. const int n_vocab = model.hparams.n_vocab;
  675. gpt_vocab::id id = 0;
  676. {
  677. const int64_t t_start_sample_us = ggml_time_us();
  678. id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, params.repeat_last_n, params.repeat_penalty, rng);
  679. t_sample_us += ggml_time_us() - t_start_sample_us;
  680. }
  681. // add it to the context
  682. embd.push_back(id);
  683. last_n_tokens.erase(last_n_tokens.begin());
  684. last_n_tokens.push_back(id);
  685. } else {
  686. // if here, it means we are still processing the input prompt
  687. for (size_t k = i; k < embd_inp.size(); k++) {
  688. embd.push_back(embd_inp[k]);
  689. last_n_tokens.erase(last_n_tokens.begin());
  690. last_n_tokens.push_back(embd_inp[k]);
  691. if (int32_t(embd.size()) >= params.n_batch) {
  692. break;
  693. }
  694. }
  695. i += embd.size() - 1;
  696. }
  697. // display text
  698. for (auto id : embd) {
  699. printf("%s", vocab.id_to_token[id].c_str());
  700. }
  701. fflush(stdout);
  702. // check if model is santacoder
  703. if (model.hparams.n_layer <= 30 && embd.back() == 49152) {
  704. break;
  705. }
  706. // check if model is starcoder
  707. else if (embd.back() == 0) { //TODO: this is only for starcoder
  708. break;
  709. }
  710. // Handle StarChat "<|end|>" token.
  711. else if (embd.back() == starchat_end_token && i >= embd_inp.size()) {
  712. break;
  713. }
  714. }
  715. // report timing
  716. {
  717. const int64_t t_main_end_us = ggml_time_us();
  718. printf("\n\n");
  719. printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
  720. printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
  721. printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
  722. printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
  723. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
  724. }
  725. ggml_free(model.ctx);
  726. return 0;
  727. }