model_loader.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include <string>
  2. #include "model_loader.h"
  3. #define DEBUG_MODEL_LOAD 0
  4. std::ifstream open_ggml_file(const char* fname) {
  5. printf("%s: loading model from '%s'\n", __func__, fname);
  6. auto fin = std::ifstream(std::string(fname), std::ios::binary);
  7. if (!fin) {
  8. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname);
  9. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  10. }
  11. std::uint32_t magic;
  12. fin.read((char*)&magic, 4);
  13. if (magic != GGML_FILE_MAGIC) {
  14. fprintf(stderr, "%s: invalid model file '%s' (bad header %d)\n", __func__, fname, magic);
  15. throw std::invalid_argument("failed to open file."); // TODO Merge error message.
  16. }
  17. return fin;
  18. }
  19. int
  20. model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
  21. {
  22. size_t total_size = 0;
  23. while (!fin.eof()) {
  24. std::string name = get_name(fin);
  25. if (name.length() == 0)
  26. break;
  27. auto tensor = load_tensor_value(fin, model.ctx);
  28. if (tensor == nullptr) {
  29. // Abort in case of error, the input stream is corrupted at this point.
  30. printf("Error while reading tensor %s\n", name.c_str() );
  31. return 1;
  32. }
  33. model.tensors[name] = tensor;
  34. if (DEBUG_MODEL_LOAD) {
  35. printf("%s [%5ld, %5ld], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), tensor->ne[0], tensor->ne[1], ggml_type_name(tensor->type), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
  36. }
  37. total_size += ggml_nbytes(tensor);
  38. }
  39. printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
  40. return 0;
  41. };
  42. ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx)
  43. {
  44. int32_t n_dims = 0;
  45. int32_t raw_type = 0;
  46. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  47. fin.read(reinterpret_cast<char *>(&raw_type), sizeof(raw_type));
  48. ggml_type type = ggml_type(raw_type);
  49. if (n_dims <= 0 || n_dims > GGML_MAX_DIMS || raw_type < 0 || raw_type > GGML_TYPE_COUNT) {
  50. return nullptr;
  51. }
  52. int64_t ne[4] = {1, 1, 1, 1};
  53. for (int i = 0; i < n_dims; ++i) {
  54. fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
  55. }
  56. ggml_tensor* tensor = ggml_new_tensor(ctx, type, n_dims, ne);
  57. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  58. return tensor;
  59. };
  60. std::string
  61. model_loader::get_name(std::ifstream& fin)
  62. {
  63. std::uint32_t length = 0;
  64. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  65. std::string name(length, 0);
  66. if (length == 0) {
  67. return name;
  68. };
  69. fin.read(&name[0], length);
  70. return name;
  71. };