model_loader.h 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include "ggml/ggml.h"
  8. #include "ggml/ggml-alloc.h"
  9. #include "common.h"
  10. #include "common-ggml.h"
  11. #include "fairseq2.h"
  12. #include <iostream>
  13. #include <stdexcept>
  14. class model_loader {
  15. public:
  16. virtual ~model_loader() {};
  17. virtual fairseq2_model& alloc_model(ggml_context* ctx) = 0;
  18. virtual void load_hparams(fairseq2_model& model, std::ifstream &fin) = 0;
  19. virtual void load_model_weights(fairseq2_model &model, std::ifstream &fin);
  20. virtual std::size_t
  21. compute_context_size(void *raw_hparams) = 0;
  22. virtual void
  23. init_model_tensors(fairseq2_model &model) = 0;
  24. private:
  25. ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
  26. // TODO Move these two to helpers
  27. void load_tensor_value(std::ifstream &fin, ggml_tensor *tensor);
  28. std::string get_name(std::ifstream &fin);
  29. };
  30. /// allocate the fairseq2 model and hyperparameters into the ggml context
  31. template<typename T>
  32. fairseq2_model& alloc_fairseq2_model(ggml_context* ctx) {
  33. auto hparams = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(T))->data;
  34. auto& model = (fairseq2_model&)ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(fairseq2_model))->data;
  35. model.ctx = ctx;
  36. model.hparams = hparams;
  37. return model;
  38. };
  39. std::ifstream open_ggml_file(const char* fname);
  40. template<typename T>
  41. fairseq2_model& load_fairseq2_ggml_file(ggml_context* ctx, const char* fname) {
  42. T loader;
  43. fairseq2_model& model = loader.alloc_model(ctx);
  44. auto fin = open_ggml_file(fname);
  45. loader.load_hparams(model, fin);
  46. loader.load_model_weights(model, fin);
  47. return model;
  48. }