unity_lib.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #include "unity_lib.h"
  2. #include <algorithm>
  3. #include <stdexcept>
  4. #include <numeric>
  5. struct ggml_cgraph * unity_text_encoder(
  6. fairseq2_model & model,
  7. struct ggml_tensor * text_input) {
  8. ggml_context* ctx0 = model.ctx;
  9. ggml_cgraph* gf = ggml_new_graph(ctx0);
  10. ggml_tensor* seqs = TransformerEmbeddingFrontend_forward(model, "text_encoder_frontend", text_input);
  11. ggml_tensor* encoder_output = StandardTransformerEncoder_forward(
  12. model,
  13. "text_encoder",
  14. seqs,
  15. nullptr // TODO: handle padding mask
  16. );
  17. encoder_output = ggml_dup(model.ctx, encoder_output);
  18. ggml_build_forward_expand(gf, encoder_output);
  19. return gf;
  20. }
  21. struct ggml_cgraph * unity_speech_encoder(
  22. fairseq2_model& model,
  23. struct ggml_tensor * speech_input) {
  24. ggml_context* ctx0 = model.ctx;
  25. ggml_cgraph* gf = ggml_new_graph(ctx0);
  26. ggml_tensor* seqs = StandardConformerEncoder_forward(model, "speech_encoder", speech_input, nullptr);
  27. seqs = ggml_dup(model.ctx, seqs);
  28. ggml_build_forward_expand(gf, seqs);
  29. return gf;
  30. }
  31. Hypothesis* unity_decode(
  32. fairseq2_model& model,
  33. const SequenceGeneratorOptions& opts,
  34. int tgt_lang_idx,
  35. ggml_tensor* encoder_output,
  36. int n_threads
  37. ) {
  38. SequenceGeneratorJob job = {
  39. opts,
  40. /*prefix_seq*/ nullptr,
  41. /*pad_idx*/model.vocab.token_to_id["<pad>"],
  42. /*unk_idx*/model.vocab.token_to_id["<unk>"],
  43. /*bos_idx*/model.vocab.token_to_id["<s>"],
  44. /*eos_idx*/model.vocab.token_to_id["</s>"],
  45. /*num_threads*/n_threads,
  46. };
  47. int prefix_seq_len = tgt_lang_idx ? 2 : 1;
  48. FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, prefix_seq_len));
  49. ((int *)prefix_seq->data)[0] = job.eos_idx;
  50. if (tgt_lang_idx != 0) { // multilingual case
  51. ((int *)prefix_seq->data)[1] = tgt_lang_idx;
  52. }
  53. job.prefix_seq = prefix_seq;
  54. return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
  55. }
  56. extern "C" fairseq2_model unity_init_model(const char* model_path) {
  57. fairseq2_model model;
  58. load_fairseq2_ggml_file(model, model_path);
  59. return model;
  60. }
  61. // struct as return - transcription, CE score, LID
  62. extern "C" Result unity_eval_speech(fairseq2_model& model, std::vector<float>& data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
  63. Result result;
  64. // The ctx_size_mb mostly depends of input length and model dim.
  65. int ctx_size_mb = opts.mem_mb;
  66. auto encoder_buf = std::vector<uint8_t>(8 * 1024 * 1024); // this is only for tensor metadata, it can be small
  67. auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
  68. ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
  69. int tgt_lang_idx;
  70. if (tgt_lang == "unk") {
  71. tgt_lang_idx = model.vocab.token_to_id["<unk>"];
  72. } else {
  73. auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__");
  74. if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
  75. std::cerr << "Unknown language " << tgt_lang << "\n";
  76. result.err = 1;
  77. return result;
  78. }
  79. tgt_lang_idx = tgt_lang_ptr->second;
  80. }
  81. // Reset the ggml_context
  82. model.ctx = ctx_from_buffer(encoder_buf);
  83. ggml_set_no_alloc(model.ctx, true);
  84. ggml_tensor* seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, data.size(), 1);
  85. seqs->data = data.data();
  86. // Audio encoder
  87. ggml_cgraph* gf = unity_speech_encoder(model, seqs);
  88. ggml_allocr_alloc_graph(fwd_alloc, gf);
  89. ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);
  90. // encoder_output is valid until we call `ggml_allocr_reset(fwd_alloc)`
  91. ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
  92. // Beam search decoding
  93. const Hypothesis* hypo = unity_decode(model, opts, tgt_lang_idx, encoder_output, n_threads);
  94. // Drop language and bos token.
  95. ggml_tensor* tokens = ggml_slice(model.ctx, hypo[0].seq, 0, 2, 0);
  96. // Collect result string
  97. char result_str[4096];
  98. std::pair<std::vector<std::string>, std::vector<float>> p = fairseq2_spm_detokenize(&model, tokens, hypo[0].step_scores, (char*)&result_str);
  99. std::vector<std::string> result_tokens = p.first;
  100. std::vector<float> word_scores = p.second;
  101. std::unordered_map<std::string, float> lid_scores;
  102. std::vector<int> lang_ids;
  103. for (const auto& kv : model.vocab.token_to_id) {
  104. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  105. lang_ids.push_back(kv.second);
  106. }
  107. }
  108. std::sort(lang_ids.begin(), lang_ids.end());
  109. for (size_t i = 0; i < lang_ids.size(); ++i) {
  110. lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i);
  111. }
  112. result.transcription = result_tokens;
  113. result.word_confidence_scores = word_scores;
  114. result.lid_scores = lid_scores;
  115. result.err = 0;
  116. ggml_free(model.ctx);
  117. ggml_allocr_reset(fwd_alloc);
  118. return result;
  119. }
  120. extern "C" Result unity_eval_text(fairseq2_model& model, const std::string& text, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
  121. Result result;
  122. // The ctx_size_mb mostly depends of input length and model dim.
  123. int ctx_size_mb = opts.mem_mb;
  124. auto encoder_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
  125. auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
  126. ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
  127. int tgt_lang_idx = 0;
  128. if (model.hparams["multilingual"] != 0) {
  129. auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__");
  130. if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
  131. std::cerr << "Unknown language " << tgt_lang << "\n";
  132. result.err = 1;
  133. return result;
  134. }
  135. tgt_lang_idx = tgt_lang_ptr->second;
  136. }
  137. // tokenize the input text
  138. model.ctx = ctx_from_buffer(encoder_buf);
  139. ggml_set_no_alloc(model.ctx, false);
  140. ggml_tensor* tokens_tensor = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 64);
  141. ggml_set_no_alloc(model.ctx, true);
  142. fairseq2_spm_tokenize(&model, text.c_str(), tokens_tensor);
  143. // Text encoder
  144. ggml_cgraph* gf = unity_text_encoder(model, tokens_tensor);
  145. ggml_allocr_alloc_graph(fwd_alloc, gf);
  146. ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);
  147. ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
  148. // Beam search decoding
  149. const Hypothesis* hypo = unity_decode(model, opts, tgt_lang_idx, encoder_output, n_threads);
  150. // Drop language and bos token for multilingual, or only bos token for the bilingual model
  151. int token_offset = (model.hparams["multilingual"] != 0) ? 2 : 1;
  152. ggml_tensor* tgt_tokens = ggml_slice(model.ctx, hypo[0].seq, 0, token_offset, 0);
  153. // Collect result string
  154. char result_str[4096];
  155. std::pair<std::vector<std::string>, std::vector<float>> p = fairseq2_spm_detokenize(&model, tgt_tokens, hypo[0].step_scores, (char*)&result_str);
  156. std::vector<std::string> result_tokens = p.first;
  157. std::vector<float> word_scores = p.second;
  158. std::unordered_map<std::string, float> lid_scores;
  159. if (model.hparams["multilingual"] != 0) {
  160. std::vector<int> lang_ids;
  161. for (const auto& kv : model.vocab.token_to_id) {
  162. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  163. lang_ids.push_back(kv.second);
  164. }
  165. }
  166. std::sort(lang_ids.begin(), lang_ids.end());
  167. for (size_t i = 0; i < lang_ids.size(); ++i) {
  168. lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i);
  169. }
  170. result.lid_scores = lid_scores;
  171. result.transcription = result_tokens;
  172. result.word_confidence_scores = word_scores;
  173. } else {
  174. // Store the concatenated text in transcription
  175. std::string concat_transcription = std::accumulate(std::next(result_tokens.begin()), result_tokens.end(), result_tokens[0],
  176. [](const std::string& a, const std::string& b) {
  177. return a + " " + b;
  178. }
  179. );
  180. float avg_score = (word_scores.size() > 0) ? std::accumulate(word_scores.begin(), word_scores.end(), 0.0 / word_scores.size()) : 0.0;
  181. result.transcription.push_back(concat_transcription);
  182. result.word_confidence_scores.push_back(avg_score);
  183. }
  184. result.err = 0;
  185. ggml_free(model.ctx);
  186. ggml_allocr_reset(fwd_alloc);
  187. return result;
  188. }