fairseq2.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #pragma once
  2. #include <map>
  3. #include <string>
  4. #include <vector>
  5. #include "ggml.h"
  6. #include "kaldi-native-fbank/csrc/feature-fbank.h"
  7. struct fairseq2_model {
  8. // Context containing all tensors memory
  9. ggml_context* tensors_ctx;
  10. // Named tensors, all tensors should belong to tensors_ctx
  11. std::map<std::string, struct ggml_tensor *> tensors;
  12. void* arch;
  13. void* hparams;
  14. // an inference context, not managed by this object
  15. // TODO: is this the best place to store this or should we also pass this to all forward methods ?
  16. ggml_context* ctx;
  17. };
  18. /// allocate the fairseq2 model and hyperparameters
  19. extern "C" fairseq2_model* fairseq2_model_alloc();
  20. // free the models and all its owned tensors
  21. extern "C" void fairseq2_model_free(fairseq2_model* model);
  22. extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx);
  23. extern "C" std::string* std_string_alloc(char* c_str);
  24. extern "C" void std_string_free(std::string* str);
  25. extern "C" ggml_tensor* WaveformToFbank_forward(
  26. fairseq2_model& model,
  27. const std::string &prefix,
  28. ggml_tensor* waveform
  29. );
  30. extern "C" ggml_tensor* ggml_slice(
  31. struct ggml_context* ctx,
  32. struct ggml_tensor* a,
  33. int axis,
  34. int64_t start,
  35. int64_t end
  36. );
  37. /// Merge the given dimension and the previous one in the tensor.
  38. /// (..., num_heads, N, ...) -> (..., num_heads * N, ...)
  39. /// dim is the position of the resulting merged dimension
  40. /// ggml_flatten_1d(x, d) <==> torch.flatten(x, -1-d-1, -1-d0
  41. extern "C" ggml_tensor* ggml_flatten_1d(ggml_context* ctx, ggml_tensor* x, int dim);
  42. /// Split the given dimension.
  43. /// (..., K * N, ...) -> (..., K, N, ...)
  44. /// dim is the position of the output dimension with the given number of element (N).
  45. extern "C" ggml_tensor* ggml_unflatten_1d(ggml_context* ctx, ggml_tensor* x, int dim, int num_el);
  46. extern "C" ggml_tensor* Linear_forward(
  47. fairseq2_model& model,
  48. const std::string &prefix,
  49. ggml_tensor* input
  50. );
  51. extern "C" ggml_tensor* LayerNorm_forward(
  52. fairseq2_model& model,
  53. const std::string &prefix,
  54. ggml_tensor* input
  55. );
  56. extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
  57. fairseq2_model& model,
  58. const std::string& prefix,
  59. ggml_tensor* seqs
  60. );
  61. extern "C" ggml_tensor* SiluFeedForwardNetwork_forward(
  62. fairseq2_model& model,
  63. const std::string& prefix,
  64. ggml_tensor* seqs
  65. );
  66. extern "C" ggml_tensor* MultiheadAttention_forward(
  67. fairseq2_model& model,
  68. const std::string &prefix,
  69. ggml_tensor* queries, // (slen, d_in)
  70. ggml_tensor* keys, // (klen, d_in)
  71. ggml_tensor* values, // (klen, d_out)
  72. ggml_tensor* _ // (klen, slen) TODO: do we need to pass mask here ?
  73. );
  74. extern "C" ggml_tensor* PositionalEmbedding_forward(
  75. fairseq2_model& model,
  76. const std::string& prefix,
  77. ggml_tensor* embeds
  78. );
  79. extern "C" ggml_tensor* TransformerEmbeddingFrontend_forward(
  80. fairseq2_model& model,
  81. const std::string& prefix,
  82. ggml_tensor* seqs
  83. );
  84. extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
  85. fairseq2_model& model,
  86. const std::string& prefix,
  87. ggml_tensor* seqs,
  88. ggml_tensor* padding_mask
  89. );
  90. extern "C" ggml_tensor* RelativePositionMHA_forward(
  91. fairseq2_model& model,
  92. const std::string& prefix,
  93. ggml_tensor* seqs
  94. );
  95. extern "C" ggml_tensor* ConvModule_forward(
  96. fairseq2_model& model,
  97. const std::string& prefix,
  98. ggml_tensor* seqs
  99. );
  100. extern "C" ggml_tensor* StandardConformerEncoderLayer_forward(
  101. fairseq2_model& model,
  102. const std::string& prefix,
  103. ggml_tensor* seqs,
  104. ggml_tensor* padding_mask
  105. );
  106. extern "C" ggml_tensor* StandardConformerEncoder_forward(
  107. fairseq2_model& model,
  108. const std::string& prefix,
  109. ggml_tensor* seqs,
  110. ggml_tensor* padding_mask
  111. );
  112. extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
  113. fairseq2_model& model,
  114. const std::string& prefix,
  115. ggml_tensor* seqs,
  116. ggml_tensor* padding_mask
  117. );
  118. extern "C" ggml_tensor* StandardConformerEncoderAdaptor_forward(
  119. fairseq2_model& model,
  120. const std::string& prefix,
  121. ggml_tensor* seqs,
  122. ggml_tensor* padding_mask
  123. );
  124. // Specifies the Layer Normalization order.
  125. enum TransformerNormOrder {
  126. TRANSFORMER_NORM_ORDER_POST = 0,
  127. TRANSFORMER_NORM_ORDER_PRE = 1,
  128. TRANSFORMER_NORM_ORDER_PRE_WITH_NORMFORMER = 2
  129. };
  130. /// Holds the options to pass to a sequence generator.
  131. struct SequenceGeneratorOptions {
  132. /// The beam size.
  133. int beam_size = 5;
  134. /// The minimum length of generated sequences (including prefix sequence).
  135. int min_seq_len = 1;
  136. /// The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
  137. /// sequence length. The generated sequences (including prefix sequence) will
  138. /// have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
  139. /// ``hard_max_seq_len``.
  140. float soft_max_seq_len_a = 1;
  141. int soft_max_seq_len_b = 200;
  142. /// The hard limit on maximum length of generated sequences.
  143. int hard_max_seq_len = 1024;
  144. /// The length penalty, where values less than 1.0 favor shorter, values
  145. /// greater than 1.0 favor longer sequences.
  146. float len_penalty = 1.0;
  147. /// The unknown symbol penalty, where values less than 0 produce more UNKs,
  148. /// values greater than 0 produce fewer UNKs.
  149. float unk_penalty = 0.0;
  150. /// If ``True``, normalizes scores by the length of generated sequences.
  151. bool normalize_scores = true;
  152. };
  153. struct SequenceGeneratorJob {
  154. SequenceGeneratorOptions opts;
  155. ggml_tensor* prefix_seq;
  156. std::int32_t pad_idx;
  157. std::int32_t unk_idx;
  158. std::int32_t bos_idx;
  159. std::int32_t eos_idx;
  160. };
  161. /// Represents a hypothesis produced by a sequence generator.
  162. struct Hypothesis {
  163. /// The generated sequence.
  164. ggml_tensor* seq;
  165. /// The score of the hypothesis.
  166. float score;
  167. /// The score of each individual sequence step.
  168. ggml_tensor* step_scores;
  169. };
  170. extern "C" Hypothesis* generate_sequence(
  171. fairseq2_model& model,
  172. const SequenceGeneratorJob& opts,
  173. ggml_tensor* encoder_output,
  174. ggml_tensor* encoder_padding_mask,
  175. ggml_context* result_ctx
  176. );