model_loader.cpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include <string>
  2. #include "model_loader.h"
  3. std::ifstream open_ggml_file(const char* fname) {
  4. printf("%s: loading model from '%s'\n", __func__, fname);
  5. auto fin = std::ifstream(std::string(fname), std::ios::binary);
  6. if (!fin) {
  7. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname);
  8. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  9. }
  10. std::uint32_t magic;
  11. fin.read((char*)&magic, 4);
  12. if (magic != GGML_FILE_MAGIC) {
  13. fprintf(stderr, "%s: invalid model file '%s' (bad header %d)\n", __func__, fname, magic);
  14. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  15. }
  16. return fin;
  17. }
  18. void
  19. model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
  20. {
  21. size_t total_size = 0;
  22. while (!fin.eof()) {
  23. auto tensor = next_tensor(fin, model);
  24. load_tensor_value(fin, tensor);
  25. total_size += ggml_nbytes(tensor);
  26. }
  27. printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
  28. };
  29. ggml_tensor *
  30. model_loader::next_tensor(std::ifstream &fin, fairseq2_model &model)
  31. {
  32. auto name = get_name(fin);
  33. std::cout << "loading tensor: " << name << std::endl;
  34. if (model.tensors.find(name) == model.tensors.end()) {
  35. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
  36. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  37. }
  38. return model.tensors[name];
  39. };
  40. void
  41. model_loader::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
  42. {
  43. int32_t n_dims;
  44. int32_t ttype;
  45. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  46. fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
  47. int32_t nelements = 1;
  48. int32_t ne[3] = {1, 1, 1};
  49. for (int i = 0; i < n_dims; ++i) {
  50. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  51. nelements *= ne[i];
  52. }
  53. if (ggml_nelements(tensor) != nelements) {
  54. std::cout << ggml_nelements(tensor) << std::endl;
  55. std::cout << nelements << std::endl;
  56. throw std::runtime_error("tensor has wrong size in model file.");
  57. }
  58. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  59. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  60. __func__, (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
  61. throw std::runtime_error("tensor has wrong shape in file."); // TODO Merge error message.
  62. }
  63. // for debugging
  64. if (0) {
  65. printf("%[%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  66. }
  67. const size_t bpe = ggml_type_size(ggml_type(ttype));
  68. if ((nelements * bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  69. fprintf(stderr, "%s: tensor has wrong size in model file: got %zu, expected %zu\n",
  70. __func__, ggml_nbytes(tensor), nelements * bpe);
  71. throw std::runtime_error("tensor has wrong size in file."); // TODO Merge error message.
  72. }
  73. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  74. };
  75. std::string
  76. model_loader::get_name(std::ifstream& fin)
  77. {
  78. int32_t length;
  79. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  80. std::string name(length, 0);
  81. fin.read(&name[0], length);
  82. return name;
  83. };