unity_lib.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #include "unity_lib.h"
  2. #include <algorithm>
  3. struct ggml_cgraph * unity_speech_encoder(
  4. fairseq2_model& model,
  5. struct ggml_tensor * speech_input) {
  6. ggml_context* ctx0 = model.ctx;
  7. ggml_cgraph* gf = ggml_new_graph(ctx0);
  8. ggml_tensor* seqs = StandardConformerEncoder_forward(model, "speech_encoder", speech_input, nullptr);
  9. seqs = ggml_dup(model.ctx, seqs);
  10. ggml_build_forward_expand(gf, seqs);
  11. return gf;
  12. }
  13. Hypothesis* unity_decode(
  14. fairseq2_model& model,
  15. const SequenceGeneratorOptions& opts,
  16. int tgt_lang_idx,
  17. ggml_tensor* encoder_output,
  18. int n_threads
  19. ) {
  20. SequenceGeneratorJob job = {
  21. opts,
  22. /*prefix_seq*/ nullptr,
  23. /*pad_idx*/model.vocab.token_to_id["<pad>"],
  24. /*unk_idx*/model.vocab.token_to_id["<unk>"],
  25. /*bos_idx*/model.vocab.token_to_id["<s>"],
  26. /*eos_idx*/model.vocab.token_to_id["</s>"],
  27. /*num_threads*/n_threads,
  28. };
  29. FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2));
  30. ((int *)prefix_seq->data)[0] = job.eos_idx;
  31. ((int *)prefix_seq->data)[1] = tgt_lang_idx;
  32. job.prefix_seq = prefix_seq;
  33. return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
  34. }
  35. extern "C" fairseq2_model unity_init_model(const char* model_path) {
  36. fairseq2_model model;
  37. load_fairseq2_ggml_file(model, model_path);
  38. return model;
  39. }
  40. // struct as return - transcription, CE score, LID
  41. extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads, int memory_mb) {
  42. Result result;
  43. // The ctx_size_mb mostly depends of input length and model dim.
  44. int ctx_size_mb = 128;
  45. auto encoder_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
  46. auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
  47. ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
  48. char result_str[4096];
  49. int tgt_lang_idx;
  50. if (tgt_lang == "unk") {
  51. tgt_lang_idx = model.vocab.token_to_id["<unk>"];
  52. } else {
  53. auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__");
  54. if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
  55. std::cerr << "Unknown language " << tgt_lang << "\n";
  56. result.err = 1;
  57. return result;
  58. }
  59. tgt_lang_idx = tgt_lang_ptr->second;
  60. }
  61. // Reset the ggml_context
  62. model.ctx = ctx_from_buffer(encoder_buf);
  63. ggml_set_no_alloc(model.ctx, false);
  64. struct ggml_tensor * seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, data.size(), 1);
  65. memcpy(seqs->data, data.data(), data.size() * sizeof(float));
  66. ggml_set_no_alloc(model.ctx, true);
  67. // Audio encoder
  68. ggml_cgraph* gf = unity_speech_encoder(model, seqs);
  69. ggml_allocr_alloc_graph(fwd_alloc, gf);
  70. ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);
  71. // encoder_output is valid until we call `ggml_allocr_reset(fwd_alloc)`
  72. ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
  73. // Beam search decoding
  74. const Hypothesis* hypo = unity_decode(model, opts, tgt_lang_idx, encoder_output, n_threads);
  75. // Drop language and bos token.
  76. ggml_tensor* tokens = ggml_slice(model.ctx, hypo[0].seq, 0, 2, 0);
  77. // Collect result string
  78. std::pair<std::vector<std::string>, std::vector<float>> p = fairseq2_spm_detokenize(&model, tokens, hypo[0].step_scores, (char*)&result_str);
  79. std::vector<std::string> result_tokens = p.first;
  80. std::vector<float> word_scores = p.second;
  81. std::unordered_map<std::string, float> lid_scores;
  82. std::vector<int> lang_ids;
  83. for (const auto& kv : model.vocab.token_to_id) {
  84. if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
  85. lang_ids.push_back(kv.second);
  86. }
  87. }
  88. std::sort(lang_ids.begin(), lang_ids.end());
  89. for (size_t i = 0; i < lang_ids.size(); ++i) {
  90. lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i);
  91. }
  92. result.transcription = result_tokens;
  93. result.word_confidence_scores = word_scores;
  94. result.lid_scores = lid_scores;
  95. result.err = 0;
  96. ggml_free(model.ctx);
  97. ggml_allocr_reset(fwd_alloc);
  98. return result;
  99. }