fairseq2.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #include <map>
  2. #include <string>
  3. #include "ggml.h"
  4. struct fairseq2_model {
  5. ggml_context* ctx;
  6. std::map<std::string, struct ggml_tensor *> tensors;
  7. void* hparams;
  8. };
  9. fairseq2_model fairseq2_model_alloc(ggml_context* ctx, void* hparams);
  10. struct Linear {
  11. struct ggml_tensor* weight; // out_dim * in_dim
  12. struct ggml_tensor* bias; // out_dim
  13. };
  14. std::size_t Linear_size(int32_t input_dim, int32_t output_dim);
  15. void Linear_init(Linear* self,fairseq2_model& model, const std::string &prefix, int input_dim, int output_dim, bool bias);
  16. // LayerNorm
  17. struct LayerNorm {
  18. struct ggml_tensor* weight; // model_dim
  19. struct ggml_tensor* bias; // model_dim
  20. };
  21. std::size_t LayerNorm_size(int32_t dim);
  22. void LayerNorm_init(LayerNorm* self, fairseq2_model& model, const std::string &prefix, int dim);
  23. struct MultiheadAttention {
  24. // num_key_value_heads: int
  25. struct Linear q_proj;
  26. struct Linear k_proj;
  27. struct Linear v_proj;
  28. // pos_encoder: Optional[PositionEncoder]
  29. struct ggml_tensor* bias_k;
  30. struct ggml_tensor* bias_v;
  31. // add_zero_attn: bool
  32. // head_scale_weight: Optional[Parameter]
  33. struct Linear output_proj;
  34. };
  35. struct StandardFeedForwardNetwork {
  36. struct Linear inner_proj; // ffn_inner_dim x model_dim
  37. // inner_activation -> Relu for unity
  38. // struct Dropout inner_dropout;
  39. struct LayerNorm inner_layer_norm; // ffn_inner_dim
  40. struct Linear output_proj; // model_dim x ffn_inner_dim
  41. };
  42. std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim);
  43. void StandardFeedForwardNetwork_init(
  44. StandardFeedForwardNetwork* self,
  45. fairseq2_model& model,
  46. const std::string &prefix,
  47. int model_dim,
  48. int inner_dim
  49. );
  50. ggml_tensor* StandardFeedForwardNetwork_forward(
  51. StandardFeedForwardNetwork* self,
  52. ggml_tensor* seqs
  53. );
  54. struct TransformerDecoderLayer {
  55. struct MultiheadAttention self_attn;
  56. struct LayerNorm self_attn_norm;
  57. // self_attn_dropout: Optional[Dropout]
  58. struct LayerNorm self_attn_layer_norm;
  59. struct MultiheadAttention encoder_decoder_attn;
  60. // encoder_decoder_dropout: Optional[Dropout]
  61. struct LayerNorm encoder_decoder_attn_layer_norm;
  62. struct StandardFeedForwardNetwork ffn;
  63. // ffn_dropout: Optional[Dropout]
  64. // residual_scale: Optional[Parameter]
  65. struct LayerNorm ffn_layer_norm;
  66. // norm_order: TransformerNormOrder
  67. };