model_loader.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include "ggml/ggml.h"
  8. #include "ggml/ggml-alloc.h"
  9. #include "common.h"
  10. #include "common-ggml.h"
  11. #include <iostream>
  12. #include <stdexcept>
  13. template <typename T>
  14. struct fairseq2_model {
  15. struct ggml_context *ctx;
  16. std::map<std::string, struct ggml_tensor *> tensors;
  17. T hparams;
  18. };
  19. template <typename T>
  20. class model_loader {
  21. public:
  22. void
  23. load_ggml_file(const std::string &fname, fairseq2_model<T> &model);
  24. protected:
  25. virtual void
  26. load_hparams(std::ifstream &fin, T &hparams) = 0;
  27. virtual std::size_t
  28. compute_context_size(T &hparams) = 0;
  29. virtual void
  30. init_model_tensors(fairseq2_model<T> &model);
  31. private:
  32. bool verify_magic(std::ifstream &fin);
  33. void
  34. init_model(fairseq2_model<T> &model);
  35. void load_model_weights(std::ifstream &fin, fairseq2_model<T> &model);
  36. ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model<T> &model);
  37. // TODO Move these two to helpers
  38. void load_tensor_value(std::ifstream &fin, ggml_tensor *tensor);
  39. std::string get_name(std::ifstream &fin);
  40. };