unity.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #include "ggml/ggml.h"
  2. #include "ggml/ggml-alloc.h"
  3. #include "math.h"
  4. #include "model_loader.h"
  5. #include "fairseq2.h"
  6. #include "lib/unity_lib.h"
  7. #include <sndfile.h>
  8. #include <cstdlib>
  9. #include "ggml-alloc.h"
  10. #include <numeric>
  11. #include <algorithm>
  12. struct unity_params {
  13. int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
  14. std::string model = "seamlessM4T_medium.ggml"; // model path
  15. bool text = false;
  16. SequenceGeneratorOptions opts = {
  17. /*beam_size*/ 5,
  18. /*min_seq_len*/ 1,
  19. /*soft_max_seq_len_a*/ 1,
  20. /*soft_max_seq_len_b*/ 200,
  21. /*hard_max_seq_len*/ 1000,
  22. /*len_penalty*/ 1.0,
  23. /*unk_penalty*/ 0.0,
  24. /*normalize_scores*/ true,
  25. /*mem_mb*/ 512
  26. };
  27. int32_t max_audio_s = 30;
  28. bool verbose = false;
  29. };
  30. void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params) {
  31. fprintf(stderr, "usage: %s [options] file1 file2 ...\n", argv[0]);
  32. fprintf(stderr, "\n");
  33. fprintf(stderr, "options:\n");
  34. fprintf(stderr, " -h, --help show this help message and exit\n");
  35. fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
  36. fprintf(stderr, " -v, --verbose Print out word level confidence score and LID score (default: off)");
  37. fprintf(stderr, " -m FNAME, --model FNAME\n");
  38. fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
  39. fprintf(stderr, " --text text-to-text translation (default is speech-to-text without this option on)\n");
  40. fprintf(stderr, " --beam-size beam size (default: %d)\n", params.opts.beam_size);
  41. fprintf(stderr, " -M, --mem memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
  42. fprintf(stderr, " --max-audio max duration of audio in seconds (default: %d)\n", params.max_audio_s);
  43. fprintf(stderr, "\n");
  44. }
  45. std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, unity_params& params) {
  46. if (i + 1 < argc && argv[i + 1][0] != '-') {
  47. return argv[++i];
  48. } else {
  49. fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
  50. unity_print_usage(argc, argv, params);
  51. exit(0);
  52. }
  53. }
  54. bool unity_params_parse(int argc, char ** argv, unity_params & params) {
  55. for (int i = 1; i < argc; i++) {
  56. std::string arg = argv[i];
  57. if (arg == "-h" || arg == "--help") {
  58. unity_print_usage(argc, argv, params);
  59. } else if (arg == "-t" || arg == "--threads") {
  60. params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
  61. } else if (arg == "-m" || arg == "--model") {
  62. params.model = get_next_arg(i, argc, argv, arg, params);
  63. } else if (arg == "--text") {
  64. params.text = true;
  65. } else if (arg == "-b" || arg == "--beam-size") {
  66. params.opts.beam_size = std::stoi(get_next_arg(i, argc, argv, arg, params));
  67. } else if (arg == "-v" || arg == "--verbose") {
  68. params.verbose = true;
  69. } else if (arg == "-M" || arg == "--mem") {
  70. params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
  71. } else if (arg == "--max-audio") {
  72. params.max_audio_s = std::stoi(get_next_arg(i, argc, argv, arg, params));
  73. }
  74. }
  75. return true;
  76. }
  77. int main(int argc, char ** argv) {
  78. unity_params params;
  79. if (unity_params_parse(argc, argv, params) == false) {
  80. return 1;
  81. }
  82. fairseq2_model model;
  83. // load the model
  84. if (load_fairseq2_ggml_file(model, params.model.c_str())) {
  85. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  86. return 1;
  87. }
  88. // The ctx_size_mb mostly depends of input length and model dim.
  89. int ctx_size_mb = params.opts.mem_mb;
  90. auto encoder_buf = std::vector<uint8_t>(8 * 1024 * 1024); // Only tensor metadata goes in there
  91. auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024 / 2);
  92. while (true) {
  93. // S2ST
  94. if (!params.text) {
  95. std::string input;
  96. std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
  97. std::getline(std::cin, input);
  98. if (input == "exit") {
  99. break;
  100. }
  101. std::istringstream iss(input);
  102. std::string audio_path;
  103. std::string tgt_lang;
  104. iss >> audio_path >> tgt_lang;
  105. if (audio_path == "-") {
  106. audio_path = "/proc/self/fd/0";
  107. }
  108. std::cerr << "Translating (Transcribing) " << audio_path << " to " << tgt_lang << "\n";
  109. SF_INFO info;
  110. SNDFILE* sndfile = sf_open(audio_path.c_str(), SFM_READ, &info);
  111. if (!sndfile) {
  112. std::cerr << "Could not open file\n";
  113. continue;
  114. }
  115. // Load audio input
  116. GGML_ASSERT(info.samplerate == 16000);
  117. GGML_ASSERT(info.channels == 1);
  118. // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
  119. int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
  120. std::vector<float> data(n_frames * info.channels);
  121. sf_readf_float(sndfile, data.data(), n_frames);
  122. Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
  123. std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
  124. [](const std::string& a, const std::string& b) {
  125. return a + " " + b;
  126. }
  127. );
  128. if (params.verbose) {
  129. std::cout << "Final transcription: " << concat_transcription << std::endl;
  130. std::cout << std::endl;
  131. std::cout << "Word level confidence score:" << std::endl;
  132. for (size_t i = 0; i < result.transcription.size(); ++i) {
  133. std::cout << "Word: " << result.transcription[i] << " | Score: " << result.word_confidence_scores[i] << std::endl;
  134. }
  135. std::cout << std::endl;
  136. std::cout << "LID scores: " << std::endl;
  137. for (const auto& kv : result.lid_scores) {
  138. std::cout << "Language: " << kv.first << "| Score: " << kv.second << std::endl;
  139. }
  140. } else {
  141. std::cout << concat_transcription << std::endl;
  142. }
  143. // T2TT
  144. } else {
  145. std::string line;
  146. std::string input_text;
  147. std::string tgt_lang;
  148. std::cout << "\nEnter input_text and tgt_lang, separated by space (or 'exit' to quit):\n";
  149. if (std::getline(std::cin, line)) {
  150. std::size_t last_space = line.find_last_of(' ');
  151. if (last_space != std::string::npos) {
  152. input_text = line.substr(0, last_space);
  153. tgt_lang = line.substr(last_space + 1);
  154. std::cerr << "Translating \"" << input_text << "\" to " << tgt_lang << "\n";
  155. } else {
  156. std::cout << "No spaces found in the input. \n";
  157. }
  158. }
  159. // tokenize the input text
  160. Result result = unity_eval_text(model, input_text, params.opts, tgt_lang, params.n_threads);
  161. std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
  162. [](const std::string& a, const std::string& b) {
  163. return a + " " + b;
  164. }
  165. );
  166. std::cout << "Translation: " << concat_translation << std::endl;
  167. }
  168. }
  169. return 0;
  170. }