model_loader.h 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include "ggml/ggml.h"
  8. #include "ggml/ggml-alloc.h"
  9. #include "common.h"
  10. #include "common-ggml.h"
  11. #include "fairseq2.h"
  12. #include <iostream>
  13. #include <stdexcept>
  14. class model_loader {
  15. public:
  16. virtual ~model_loader() {};
  17. virtual void load_hparams(fairseq2_model& model, std::ifstream &fin) = 0;
  18. virtual std::size_t compute_context_size(void *raw_hparams) = 0;
  19. virtual void tensors_alloc(fairseq2_model& model) = 0;
  20. void load_model_weights(fairseq2_model &model, std::ifstream &fin);
  21. private:
  22. ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
  23. // TODO Move these two to helpers
  24. void load_tensor_value(std::ifstream &fin, ggml_tensor *tensor);
  25. std::string get_name(std::ifstream &fin);
  26. };
  27. std::ifstream open_ggml_file(const char* fname);
  28. template<typename T>
  29. void load_fairseq2_ggml_file(fairseq2_model& model, const char* fname) {
  30. T loader;
  31. auto fin = open_ggml_file(fname);
  32. loader.load_hparams(model, fin);
  33. std::size_t ctx_size = loader.compute_context_size(model.hparams);
  34. struct ggml_init_params params = {
  35. /*.mem_size =*/ ctx_size,
  36. /*.mem_buffer =*/ NULL,
  37. /*.no_alloc =*/ false,
  38. };
  39. model.ctx = ggml_init(params);
  40. // TODO: should we delay weights loading/allocating ?
  41. loader.tensors_alloc(model);
  42. loader.load_model_weights(model, fin);;
  43. }