|
@@ -1223,11 +1223,11 @@ void _bootstrap_seqs_and_scores(
|
|
// Fetch scores of next steps from "lprobs"
|
|
// Fetch scores of next steps from "lprobs"
|
|
float p_score = 0;
|
|
float p_score = 0;
|
|
for (int i = 1; i < prefix_seq_len; ++i) {
|
|
for (int i = 1; i < prefix_seq_len; ++i) {
|
|
- int p;
|
|
|
|
|
|
+ int p = 0;
|
|
if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
|
|
if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
|
|
// If tgt_lang is unk, use the most probable lang tag predicted by model
|
|
// If tgt_lang is unk, use the most probable lang tag predicted by model
|
|
int max_value = std::numeric_limits<float>::min();
|
|
int max_value = std::numeric_limits<float>::min();
|
|
- for (int j = 0; j < lang_ids.size(); j++) {
|
|
|
|
|
|
+ for (size_t j = 0; j < lang_ids.size(); j++) {
|
|
if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
|
|
if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
|
|
max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
|
|
max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
|
|
p = lang_ids[j];
|
|
p = lang_ids[j];
|
|
@@ -1354,7 +1354,7 @@ void _finalize_hypothesis(
|
|
|
|
|
|
ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
|
|
ggml_context* ctx_from_buffer(std::vector<uint8_t>& buffer) {
|
|
return ggml_init({
|
|
return ggml_init({
|
|
- /*.mem_size =*/ static_cast<int64_t>(buffer.capacity()),
|
|
|
|
|
|
+ /*.mem_size =*/ static_cast<size_t>(buffer.capacity()),
|
|
/*.mem_buffer =*/ buffer.data(),
|
|
/*.mem_buffer =*/ buffer.data(),
|
|
/*.no_alloc =*/ false,
|
|
/*.no_alloc =*/ false,
|
|
});
|
|
});
|
|
@@ -1438,7 +1438,7 @@ extern "C" Hypothesis* generate_sequence(
|
|
ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
|
|
ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
|
|
GGML_ASSERT(step_ctx != search_ctx);
|
|
GGML_ASSERT(step_ctx != search_ctx);
|
|
model.enc_kv_cache_ctx = search_ctx;
|
|
model.enc_kv_cache_ctx = search_ctx;
|
|
- ggml_tensor* lid_scores;
|
|
|
|
|
|
+ ggml_tensor* lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, 1); // Dummy initialization to get rid of warnings
|
|
if (lang_ids.size()) {
|
|
if (lang_ids.size()) {
|
|
lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
|
|
lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
|
|
}
|
|
}
|
|
@@ -1463,20 +1463,19 @@ extern "C" Hypothesis* generate_sequence(
|
|
for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
|
|
for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
|
|
model.ctx = step_ctx;
|
|
model.ctx = step_ctx;
|
|
ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
|
|
ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
|
|
- float max_lprob;
|
|
|
|
- int p;
|
|
|
|
|
|
+ int p = 0;
|
|
if (step_nr == start_step) {
|
|
if (step_nr == start_step) {
|
|
// Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
|
|
// Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
|
|
if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
|
|
if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
|
|
float max_lprob = std::numeric_limits<float>::min();
|
|
float max_lprob = std::numeric_limits<float>::min();
|
|
- for(int j = 0; j < lang_ids.size(); j++) {
|
|
|
|
|
|
+ for(size_t j = 0; j < lang_ids.size(); j++) {
|
|
auto val = ggml_get_f32_1d(lid_scores, j);
|
|
auto val = ggml_get_f32_1d(lid_scores, j);
|
|
if (val > max_lprob) {
|
|
if (val > max_lprob) {
|
|
max_lprob = val;
|
|
max_lprob = val;
|
|
p = lang_ids[j];
|
|
p = lang_ids[j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- for (int k = 0; k < beam_size; k++) {
|
|
|
|
|
|
+ for (std::size_t k = 0; k < beam_size; k++) {
|
|
ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
|
|
ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
|
|
}
|
|
}
|
|
}
|
|
}
|