quantize.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #include "ggml/ggml.h"
  2. #include "common.h"
  3. #include "common-ggml.h"
  4. #include <cassert>
  5. #include <cmath>
  6. #include <cstdio>
  7. #include <cstring>
  8. #include <fstream>
  9. #include <map>
  10. #include <string>
  11. #include <vector>
  12. #include <regex>
  13. // default hparams (GPT-J 6B)
  14. struct gptj_hparams {
  15. int32_t n_vocab = 50400;
  16. int32_t n_ctx = 2048;
  17. int32_t n_embd = 4096;
  18. int32_t n_head = 16;
  19. int32_t n_layer = 28;
  20. int32_t n_rot = 64;
  21. int32_t ftype = 1;
  22. };
  23. // quantize a model
  24. bool gptj_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
  25. gpt_vocab vocab;
  26. printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
  27. auto finp = std::ifstream(fname_inp, std::ios::binary);
  28. if (!finp) {
  29. fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
  30. return false;
  31. }
  32. auto fout = std::ofstream(fname_out, std::ios::binary);
  33. if (!fout) {
  34. fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
  35. return false;
  36. }
  37. // verify magic
  38. {
  39. uint32_t magic;
  40. finp.read((char *) &magic, sizeof(magic));
  41. if (magic != GGML_FILE_MAGIC) {
  42. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
  43. return false;
  44. }
  45. fout.write((char *) &magic, sizeof(magic));
  46. }
  47. gptj_hparams hparams;
  48. // load hparams
  49. {
  50. finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
  51. finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
  52. finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
  53. finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
  54. finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
  55. finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
  56. finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
  57. const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
  58. const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
  59. printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
  60. printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
  61. printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
  62. printf("%s: n_head = %d\n", __func__, hparams.n_head);
  63. printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
  64. printf("%s: ftype (src) = %d\n", __func__, hparams.ftype);
  65. printf("%s: qntvr (src) = %d\n", __func__, qntvr_src);
  66. printf("%s: ftype (dst) = %d\n", __func__, ftype_dst);
  67. printf("%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
  68. fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
  69. fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
  70. fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd));
  71. fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
  72. fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
  73. fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
  74. fout.write((char *) &ftype_dst, sizeof(ftype_dst));
  75. }
  76. // load vocab
  77. {
  78. int32_t n_vocab = 0;
  79. finp.read ((char *) &n_vocab, sizeof(n_vocab));
  80. fout.write((char *) &n_vocab, sizeof(n_vocab));
  81. if (n_vocab != hparams.n_vocab) {
  82. fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
  83. __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
  84. return false;
  85. }
  86. std::string word;
  87. for (int i = 0; i < n_vocab; i++) {
  88. uint32_t len;
  89. finp.read ((char *) &len, sizeof(len));
  90. fout.write((char *) &len, sizeof(len));
  91. word.resize(len);
  92. finp.read ((char *) word.data(), len);
  93. fout.write((char *) word.data(), len);
  94. vocab.token_to_id[word] = i;
  95. vocab.id_to_token[i] = word;
  96. }
  97. }
  98. // regexes of tensor names to be quantized
  99. const std::vector<std::string> to_quant = {
  100. ".*weight",
  101. };
  102. if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
  103. fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
  104. return false;
  105. }
  106. finp.close();
  107. fout.close();
  108. return true;
  109. }
  110. // usage:
  111. // ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type
  112. //
  113. int main(int argc, char ** argv) {
  114. if (argc != 4) {
  115. fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
  116. ggml_print_ftypes(stderr);
  117. return 1;
  118. }
  119. // needed to initialize f16 tables
  120. {
  121. struct ggml_init_params params = { 0, NULL, false };
  122. struct ggml_context * ctx = ggml_init(params);
  123. ggml_free(ctx);
  124. }
  125. const std::string fname_inp = argv[1];
  126. const std::string fname_out = argv[2];
  127. const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
  128. const int64_t t_main_start_us = ggml_time_us();
  129. int64_t t_quantize_us = 0;
  130. // load the model
  131. {
  132. const int64_t t_start_us = ggml_time_us();
  133. if (!gptj_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
  134. fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
  135. return 1;
  136. }
  137. t_quantize_us = ggml_time_us() - t_start_us;
  138. }
  139. // report timing
  140. {
  141. const int64_t t_main_end_us = ggml_time_us();
  142. printf("\n");
  143. printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
  144. printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
  145. }
  146. return 0;
  147. }