model_loader.h 987 B

12345678910111213141516171819202122232425262728293031323334353637
  1. // Copyright (c) Meta Platforms, Inc. and affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the license found in the
  5. // MIT_LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <fstream>
  8. #include <iostream>
  9. #include <stdexcept>
  10. #include "ggml/ggml.h"
  11. #include "ggml/ggml-alloc.h"
  12. #include "fairseq2.h"
  13. class model_loader {
  14. public:
  15. std::int64_t load_model_weights(fairseq2_model &model, std::ifstream &fin);
  16. void load_hparams(std::unordered_map<std::string, std::int64_t>& hparams, std::ifstream &fin);
  17. void load_vocab(llama_vocab& vocab, std::ifstream &fin);
  18. private:
  19. ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
  20. std::string get_name(std::ifstream &fin);
  21. };
  22. ggml_tensor* load_tensor_value(std::ifstream &fin, ggml_context* ctx, bool as_float32);
  23. std::ifstream open_ggml_file(const char* fname);
  24. extern "C" int load_fairseq2_ggml_file(fairseq2_model& model, const char* fname);