unity.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #include "ggml/ggml.h"
  2. #include "ggml/ggml-alloc.h"
  3. #include "math.h"
  4. #include "model_loader.h"
  5. #include "fairseq2.h"
  6. #include <thread>
  7. #include <cassert>
  8. #include <cmath>
  9. #include <cstdio>
  10. #include <cstring>
  11. #include <fstream>
  12. #include <map>
  13. #include <string>
  14. #include <vector>
  15. #include <iostream>
  16. #include <sndfile.h>
  17. #include <cstdlib>
  18. #include "ggml-alloc.h"
  19. struct unity_params {
  20. int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
  21. std::string model = "seamlessM4T_medium.ggml"; // model path
  22. std::string tgt_lang = "eng";
  23. std::vector<std::string> files = {};
  24. bool text = false;
  25. SequenceGeneratorOptions opts = {
  26. /*beam_size*/ 5,
  27. /*min_seq_len*/ 1,
  28. /*soft_max_seq_len_a*/ 1,
  29. /*soft_max_seq_len_b*/ 200,
  30. /*hard_max_seq_len*/ 1000,
  31. /*len_penalty*/ 1.0,
  32. /*unk_penalty*/ 0.0,
  33. /*normalize_scores*/ true,
  34. /*mem_mb*/ 512,
  35. };
  36. int32_t max_audio_s = 30;
  37. };
  38. void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params) {
  39. fprintf(stderr, "usage: %s [options] file1 file2 ...\n", argv[0]);
  40. fprintf(stderr, "\n");
  41. fprintf(stderr, "options:\n");
  42. fprintf(stderr, " -h, --help show this help message and exit\n");
  43. fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
  44. fprintf(stderr, " -m FNAME, --model FNAME\n");
  45. fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
  46. fprintf(stderr, " --text text output\n");
  47. fprintf(stderr, " --beam-size beam size (default: %d)\n", params.opts.beam_size);
  48. fprintf(stderr, " -M, --mem memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
  49. fprintf(stderr, " --max-audio max duration of audio in seconds (default: %d)\n", params.max_audio_s);
  50. fprintf(stderr, "\n");
  51. }
  52. std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, unity_params& params) {
  53. if (i + 1 < argc && argv[i + 1][0] != '-') {
  54. return argv[++i];
  55. } else {
  56. fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
  57. unity_print_usage(argc, argv, params);
  58. exit(0);
  59. }
  60. }
  61. bool unity_params_parse(int argc, char ** argv, unity_params & params) {
  62. for (int i = 1; i < argc; i++) {
  63. std::string arg = argv[i];
  64. if (arg == "-h" || arg == "--help") {
  65. unity_print_usage(argc, argv, params);
  66. } else if (arg == "-t" || arg == "--threads") {
  67. params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
  68. } else if (arg == "-m" || arg == "--model") {
  69. params.model = get_next_arg(i, argc, argv, arg, params);
  70. } else if (arg == "-l" || arg == "--tgt-lang") {
  71. params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
  72. } else if (arg == "--text") {
  73. params.text = true;
  74. } else if (arg == "-b" || arg == "--beam-size") {
  75. params.opts.beam_size = std::stoi(get_next_arg(i, argc, argv, arg, params));
  76. } else if (arg == "-M" || arg == "--mem") {
  77. params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
  78. } else if (arg == "--max-audio") {
  79. params.max_audio_s = std::stoi(get_next_arg(i, argc, argv, arg, params));
  80. } else {
  81. params.files.push_back(std::string(arg));
  82. }
  83. }
  84. return true;
  85. }
  86. struct ggml_cgraph * unity_speech_encoder(
  87. fairseq2_model& model,
  88. struct ggml_tensor * speech_input) {
  89. ggml_context* ctx0 = model.ctx;
  90. ggml_cgraph* gf = ggml_new_graph(ctx0);
  91. ggml_tensor* seqs = StandardConformerEncoder_forward(model, "speech_encoder", speech_input, nullptr);
  92. seqs = ggml_dup(model.ctx, seqs);
  93. ggml_build_forward_expand(gf, seqs);
  94. return gf;
  95. }
  96. Hypothesis* unity_decode(
  97. fairseq2_model& model,
  98. const SequenceGeneratorOptions& opts,
  99. int tgt_lang_idx,
  100. ggml_tensor* encoder_output,
  101. int n_threads
  102. ) {
  103. SequenceGeneratorJob job = {
  104. opts,
  105. /*prefix_seq*/ nullptr,
  106. /*pad_idx*/model.vocab.token_to_id["<pad>"],
  107. /*unk_idx*/model.vocab.token_to_id["<unk>"],
  108. /*bos_idx*/model.vocab.token_to_id["<s>"],
  109. /*eos_idx*/model.vocab.token_to_id["</s>"],
  110. /*num_threads*/n_threads,
  111. };
  112. FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2));
  113. ((int *)prefix_seq->data)[0] = job.eos_idx;
  114. ((int *)prefix_seq->data)[1] = tgt_lang_idx;
  115. job.prefix_seq = prefix_seq;
  116. return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
  117. }
  118. int main(int argc, char ** argv) {
  119. unity_params params;
  120. if (unity_params_parse(argc, argv, params) == false) {
  121. return 1;
  122. }
  123. fairseq2_model model;
  124. // load the model
  125. if (load_fairseq2_ggml_file(model, params.model.c_str())) {
  126. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  127. return 1;
  128. }
  129. // The ctx_size_mb mostly depends of input length and model dim.
  130. int ctx_size_mb = params.opts.mem_mb;
  131. auto encoder_buf = std::vector<uint8_t>(8 * 1024 * 1024); // Only tensor metadata goes in there
  132. auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024 / 2);
  133. ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
  134. char result_str[4096];
  135. std::string input;
  136. bool interactive = params.files.size() == 0;
  137. auto next_file = params.files.begin();
  138. while (true) {
  139. if (interactive) {
  140. std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
  141. std::getline(std::cin, input);
  142. if (input == "exit") {
  143. break;
  144. }
  145. } else {
  146. if (next_file == params.files.end()) break;
  147. input = *(next_file++);
  148. }
  149. std::istringstream iss(input);
  150. std::string audio_path;
  151. std::string tgt_lang = params.tgt_lang;
  152. iss >> audio_path >> tgt_lang;
  153. if (audio_path == "-") {
  154. audio_path = "/proc/self/fd/0";
  155. }
  156. std::cerr << "Translating (Transcribing) " << audio_path << " to " << tgt_lang << "\n";
  157. SF_INFO info;
  158. SNDFILE* sndfile = sf_open(audio_path.c_str(), SFM_READ, &info);
  159. if (!sndfile) {
  160. std::cerr << "Could not open file\n";
  161. if (interactive) continue;
  162. else return 1;
  163. }
  164. auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__");
  165. if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
  166. std::cerr << "Unknown language " << tgt_lang << "\n";
  167. if (interactive) continue;
  168. else return 2;
  169. }
  170. int tgt_lang_idx = tgt_lang_ptr->second;
  171. // Reset the ggml_context
  172. model.ctx = ctx_from_buffer(encoder_buf);
  173. ggml_set_no_alloc(model.ctx, true);
  174. GGML_ASSERT(info.samplerate == 16000);
  175. GGML_ASSERT(info.channels == 1);
  176. // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
  177. int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
  178. ggml_tensor* seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_frames, info.channels);
  179. ggml_allocr_alloc(fwd_alloc, seqs);
  180. // Load audio input
  181. sf_readf_float(sndfile, (float*)seqs->data, n_frames);
  182. // Audio encoder
  183. ggml_cgraph* gf = unity_speech_encoder(model, seqs);
  184. size_t enc_mem_used = ggml_allocr_alloc_graph(fwd_alloc, gf);
  185. ggml_graph_compute_with_ctx(model.ctx, gf, params.n_threads);
  186. // encoder_output is valid until we call `ggml_allocr_reset(fwd_alloc)`
  187. ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
  188. // Beam search decoding
  189. const Hypothesis* result = unity_decode(model, params.opts, tgt_lang_idx, encoder_output, params.n_threads);
  190. // Drop language and bos token.
  191. ggml_tensor* tokens = ggml_slice(model.ctx, result[0].seq, 0, 2, 0);
  192. // Collect result string
  193. int n = fairseq2_spm_detokenize(&model, tokens, (char*)&result_str);
  194. std::cout << std::string((char*)&result_str, n) << std::endl;
  195. ggml_free(model.ctx);
  196. ggml_allocr_reset(fwd_alloc);
  197. }
  198. return 0;
  199. }