fairseq2.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #pragma once
  2. #include <map>
  3. #include <string>
  4. #include <vector>
  5. #include "ggml.h"
  6. struct fairseq2_model {
  7. // Context containing all tensors memory
  8. ggml_context* tensors_ctx;
  9. // Named tensors, all tensors should belong to tensors_ctx
  10. std::map<std::string, struct ggml_tensor *> tensors;
  11. void* arch;
  12. void* hparams;
  13. // an inference context, not managed by this object
  14. // TODO: is this the best place to store this or should we also pass this to all forward methods ?
  15. ggml_context* ctx;
  16. };
  17. /// allocate the fairseq2 model and hyperparameters
  18. extern "C" fairseq2_model* fairseq2_model_alloc();
  19. // free the models and all its owned tensors
  20. extern "C" void fairseq2_model_free(fairseq2_model* model);
  21. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx);
  22. extern "C" std::string* std_string_alloc(char* c_str);
  23. extern "C" void std_string_free(std::string* str);
  24. extern "C" ggml_tensor* Linear_forward(
  25. fairseq2_model& model,
  26. const std::string &prefix,
  27. ggml_tensor* input
  28. );
  29. extern "C" ggml_tensor* LayerNorm_forward(
  30. fairseq2_model& model,
  31. const std::string &prefix,
  32. ggml_tensor* input
  33. );
  34. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  35. fairseq2_model& model,
  36. const std::string& prefix,
  37. ggml_tensor* seqs
  38. );
  39. extern "C" ggml_tensor* MultiheadAttention_forward(
  40. fairseq2_model& model,
  41. const std::string &prefix,
  42. ggml_tensor* queries, // (slen, d_in)
  43. ggml_tensor* keys, // (klen, d_in)
  44. ggml_tensor* values, // (klen, d_out)
  45. ggml_tensor* _ // (klen, slen) TODO: do we need to pass mask here ?
  46. );