Browse Source

Interactive-only unity.cpp with S2TT+T2TT

cndn 1 năm trước cách đây
mục cha
commit
326f489194
1 tập tin đã thay đổi với 67 bổ sung81 xóa
  1. 67 81
      ggml/examples/unity/unity.cpp

+ 67 - 81
ggml/examples/unity/unity.cpp

@@ -14,9 +14,6 @@
 struct unity_params {
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     std::string model = "seamlessM4T_medium.ggml"; // model path
-    std::string input_text = "";
-    std::string tgt_lang = "eng";
-    std::vector<std::string> files = {};
     bool text = false;
     SequenceGeneratorOptions opts = {
         /*beam_size*/ 5,
@@ -39,14 +36,11 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
-    fprintf(stderr, "  -i, --input           Input text for the text-2-text translation\n");
-    fprintf(stderr, "  -l, --tgt-lang        Target translation lang (default: %s\n", params.tgt_lang);
-
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -v, --verbose         Print out word level confidence score and LID score (default: off)");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  --text                text output\n");
+    fprintf(stderr, "  --text                text-to-text translation (default is speech-to-text without this option on)\n");
     fprintf(stderr, "  --beam-size           beam size (default: %d)\n", params.opts.beam_size);
     fprintf(stderr, "  -M, --mem             memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
     fprintf(stderr, " --max-audio max duration of audio in seconds (default: %d)\n", params.max_audio_s);
@@ -73,10 +67,6 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
             params.model = get_next_arg(i, argc, argv, arg, params);
-        } else if (arg == "-i" || arg == "--input") {
-            params.input_text = get_next_arg(i, argc, argv, arg, params);
-        } else if (arg == "-l" || arg == "--tgt-lang") {
-            params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "--text") {
             params.text = true;
         } else if (arg == "-b" || arg == "--beam-size") {
@@ -87,9 +77,7 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--max-audio") {
             params.max_audio_s = std::stoi(get_next_arg(i, argc, argv, arg, params));
-        } else {
-            params.files.push_back(std::string(arg));
-        }
+        } 
     }
     return true;
 }
@@ -117,84 +105,82 @@ int main(int argc, char ** argv) {
     ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
     char result_str[4096];
 
-    std::string input;
-    bool interactive = (params.files.size() == 0 && params.input_text.length() == 0);
-    auto next_file = params.files.begin();
-
-    // Flag for the input case: true --> s2st, false --> t2tt
-    bool s2st_or_t2tt = true;
-
-    // S2ST
     while (true) {
-        if (interactive) {
+        // S2ST
+        if (!params.text) {
+            std::string input;
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
             std::getline(std::cin, input);
             if (input == "exit") {
                 break;
             }
-        } else {
-            if (params.input_text.length() > 0) {
-                break;
-            }
-            if (next_file == params.files.end() && s2st_or_t2tt) break;
-            input = *(next_file++);
-        }
-        std::istringstream iss(input);
-        std::string audio_path;
-        std::string tgt_lang = params.tgt_lang;
-        iss >> audio_path >> tgt_lang;
-        if (audio_path == "-") {
-            audio_path = "/proc/self/fd/0";
-        }
-        std::cerr << "Translating (Transcribing) " << audio_path << " to " << tgt_lang << "\n";
-        SF_INFO info;
-        SNDFILE* sndfile = sf_open(audio_path.c_str(), SFM_READ, &info);
-        if (!sndfile) {
-            std::cerr << "Could not open file\n";
-            if (interactive) continue;
-            else return 1;
-        }
-        // Load audio input
-        GGML_ASSERT(info.samplerate == 16000);
-        GGML_ASSERT(info.channels == 1);
-        // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
-        int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
-        std::vector<float> data(n_frames * info.channels);
-        sf_readf_float(sndfile, data.data(), n_frames);
-
-        Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
-        std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
-            [](const std::string& a, const std::string& b) {
-                return a + " " + b;
+            std::istringstream iss(input);
+            std::string audio_path;
+            std::string tgt_lang;
+            iss >> audio_path >> tgt_lang;
+            if (audio_path == "-") {
+                audio_path = "/proc/self/fd/0";
             }
-        );
-        if (params.verbose) {
-            std::cout << "Final transcription: " << concat_transcription << std::endl;
-            std::cout << std::endl;
-            std::cout << "Word level confidence score:" << std::endl;
-            for (size_t i = 0; i < result.transcription.size(); ++i) {
-                std::cout << "Word: " << result.transcription[i] << " | Score: " << result.word_confidence_scores[i] << std::endl;
+            std::cerr << "Translating (Transcribing) " << audio_path << " to " << tgt_lang << "\n";
+            SF_INFO info;
+            SNDFILE* sndfile = sf_open(audio_path.c_str(), SFM_READ, &info);
+            if (!sndfile) {
+                std::cerr << "Could not open file\n";
+                continue;
             }
-            std::cout << std::endl;
-            std::cout << "LID scores: " << std::endl;
-            for (const auto& kv : result.lid_scores) {
-                std::cout << "Language: " << kv.first << "| Score: " << kv.second << std::endl;
+            // Load audio input
+            GGML_ASSERT(info.samplerate == 16000);
+            GGML_ASSERT(info.channels == 1);
+            // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
+            int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
+            std::vector<float> data(n_frames * info.channels);
+            sf_readf_float(sndfile, data.data(), n_frames);
+            Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
+            std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+                [](const std::string& a, const std::string& b) {
+                    return a + " " + b;
+                }
+            );
+            if (params.verbose) {
+                std::cout << "Final transcription: " << concat_transcription << std::endl;
+                std::cout << std::endl;
+                std::cout << "Word level confidence score:" << std::endl;
+                for (size_t i = 0; i < result.transcription.size(); ++i) {
+                    std::cout << "Word: " << result.transcription[i] << " | Score: " << result.word_confidence_scores[i] << std::endl;
+                }
+                std::cout << std::endl;
+                std::cout << "LID scores: " << std::endl;
+                for (const auto& kv : result.lid_scores) {
+                    std::cout << "Language: " << kv.first << "| Score: " << kv.second << std::endl;
+                }
+            } else {
+                std::cout << concat_transcription << std::endl;
             }
+        // T2TT
         } else {
-            std::cout << concat_transcription << std::endl;
-        }
-    }
-
-    // T2TT
-    if (params.input_text.length() > 0) {
-        // tokenize the input text
-        Result result = unity_eval_text(model, params.input_text, params.opts, params.tgt_lang, params.n_threads);
-        std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
-            [](const std::string& a, const std::string& b) {
-                return a + " " + b;
+            std::string line;
+            std::string input_text;
+            std::string tgt_lang;
+            std::cout << "\nEnter input_text and tgt_lang, separated by space (or 'exit' to quit):\n";
+            if (std::getline(std::cin, line)) {
+                std::size_t last_space = line.find_last_of(' ');
+                if (last_space != std::string::npos) {
+                    input_text = line.substr(0, last_space);
+                    tgt_lang = line.substr(last_space + 1);
+                    std::cerr << "Translating \"" << input_text << "\" to " << tgt_lang << "\n";
+                } else {
+                    std::cout << "No spaces found in the input. \n";
+                }
             }
-        );
-        std::cout << "Translation: " << concat_translation << std::endl;
+            // tokenize the input text
+            Result result = unity_eval_text(model, input_text, params.opts, tgt_lang, params.n_threads);
+            std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+                [](const std::string& a, const std::string& b) {
+                    return a + " " + b;
+                }
+            );
+            std::cout << "Translation: " << concat_translation << std::endl;
+        }
     }
 
     return 0;