Parcourir la source

make T2TT in blingual model auto-concatenated

Tuan Tran il y a 1 an
Parent
commit
e27d56ebbc
1 fichiers modifiés avec 14 ajouts et 2 suppressions
  1. 14 2
      ggml/examples/unity/lib/unity_lib.cpp

+ 14 - 2
ggml/examples/unity/lib/unity_lib.cpp

@@ -1,6 +1,7 @@
 #include "unity_lib.h"
 #include <algorithm>
 #include <stdexcept>
+#include <numeric>
 
 
 struct ggml_cgraph * unity_text_encoder(
@@ -189,9 +190,20 @@ extern "C" Result unity_eval_text(fairseq2_model& model, const std::string& text
             lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
         }
         result.lid_scores = lid_scores;
+
+        result.transcription = result_tokens;
+        result.word_confidence_scores = word_scores;
+    } else {
+        // Store the concatenated text in transcription
+        std::string concat_transcription = std::accumulate(std::next(result_tokens.begin()), result_tokens.end(), result_tokens[0],
+            [](const std::string& a, const std::string& b) {
+                return a + " " + b;
+            }
+        );
+        float avg_score = (word_scores.size() > 0) ? std::accumulate(word_scores.begin(), word_scores.end(), 0.0 / word_scores.size()) : 0.0;
+        result.transcription.push_back(concat_transcription);
+        result.word_confidence_scores.push_back(avg_score);
     }
-    result.transcription = result_tokens;
-    result.word_confidence_scores = word_scores;
     result.err = 0;
     ggml_free(model.ctx);
     ggml_allocr_reset(fwd_alloc);