model_loader.cpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #include "ggml/ggml.h"
  7. #include "ggml/ggml-alloc.h"
  8. #include "common.h"
  9. #include "common-ggml.h"
  10. #include <iostream>
  11. #include <stdexcept>
  12. #include "ggml/examples/unity/model_loader.h"
  13. template<typename T>
  14. void
  15. model_loader<T>::load_ggml_file(const std::string &fname, fairseq2_model<T> &model)
  16. {
  17. printf("%s: loading model from '%s'\n", __func__, fname.c_str());
  18. auto fin = std::ifstream(fname, std::ios::binary);
  19. if (!fin) {
  20. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  21. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  22. }
  23. if (!verify_magic(fin)) {
  24. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  25. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  26. }
  27. load_hparams(fin, model.hparams);
  28. init_model(model);
  29. load_model_weights(fin, model);
  30. };
  31. template<typename T>
  32. bool
  33. model_loader<T>::verify_magic(std::ifstream &fin)
  34. {
  35. uint32_t magic;
  36. fin.read((char *) &magic, sizeof(magic));
  37. return magic == GGML_FILE_MAGIC;
  38. };
  39. template<typename T>
  40. void
  41. model_loader<T>::init_model(fairseq2_model<T> &model)
  42. {
  43. struct ggml_init_params params = {
  44. /*.mem_size =*/ compute_context_size(model.hparams),
  45. /*.mem_buffer =*/ NULL,
  46. /*.no_alloc =*/ false,
  47. };
  48. model.ctx = ggml_init(params);
  49. if (!model.ctx)
  50. throw std::runtime_error("ggml_init() failed.");
  51. init_model_tensors(model);
  52. };
  53. template<typename T>
  54. void
  55. model_loader<T>::load_model_weights(std::ifstream &fin, fairseq2_model<T> &model)
  56. {
  57. size_t total_size = 0;
  58. while (!fin.eof()) {
  59. auto tensor = next_tensor(fin, model);
  60. load_tensor_value(fin, tensor);
  61. total_size += ggml_nbytes(tensor);
  62. }
  63. printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
  64. };
  65. template<typename T>
  66. ggml_tensor *
  67. model_loader<T>::next_tensor(std::ifstream &fin, fairseq2_model<T> &model)
  68. {
  69. auto name = get_name(fin);
  70. std::cout << "loading tensor: " << name << std::endl;
  71. if (model.tensors.find(name) == model.tensors.end()) {
  72. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
  73. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  74. }
  75. return model.tensors[name];
  76. };
  77. template<typename T>
  78. void
  79. model_loader<T>::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
  80. {
  81. int32_t n_dims;
  82. int32_t ttype;
  83. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  84. fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
  85. int32_t nelements = 1;
  86. int32_t ne[3] = {1, 1, 1};
  87. for (int i = 0; i < n_dims; ++i) {
  88. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  89. nelements *= ne[i];
  90. }
  91. if (ggml_nelements(tensor) != nelements) {
  92. std::cout << ggml_nelements(tensor) << std::endl;
  93. std::cout << nelements << std::endl;
  94. throw std::runtime_error("tensor has wrong size in model file.");
  95. }
  96. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
  97. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
  98. __func__, (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
  99. throw std::runtime_error("tensor has wrong shape in file."); // TODO Merge error message.
  100. }
  101. // for debugging
  102. if (0) {
  103. printf("%[%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  104. }
  105. const size_t bpe = ggml_type_size(ggml_type(ttype));
  106. if ((nelements * bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  107. fprintf(stderr, "%s: tensor has wrong size in model file: got %zu, expected %zu\n",
  108. __func__, ggml_nbytes(tensor), nelements * bpe);
  109. throw std::runtime_error("tensor has wrong size in file."); // TODO Merge error message.
  110. }
  111. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  112. };
  113. template<typename T>
  114. std::string
  115. model_loader<T>::get_name(std::ifstream& fin)
  116. {
  117. int32_t length;
  118. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  119. std::string name(length, 0);
  120. fin.read(&name[0], length);
  121. return name;
  122. };