瀏覽代碼

unity.cpp sync (#329)

* Sync unity.cpp

Co-authored-by: Tuan Tran <1254753+antoine-tran@users.noreply.github.com>
Co-authored-by: Guillaume Wenzek <5920036+gwenzek@users.noreply.github.com>
Co-authored-by: Ning <7022920+cndn@users.noreply.github.com>

* address comments

---------

Co-authored-by: Tuan Tran <1254753+antoine-tran@users.noreply.github.com>
Co-authored-by: Guillaume Wenzek <5920036+gwenzek@users.noreply.github.com>
Co-authored-by: Ning <7022920+cndn@users.noreply.github.com>
Ning 1 年之前
父節點
當前提交
b0415debec

+ 2 - 5
ggml/README.md

@@ -11,8 +11,7 @@ To build the interactive console for S2TT & ASR,
 
 
 cd seamless_communication/ggml
 cd seamless_communication/ggml
 mkdir build; cd build
 mkdir build; cd build
-cmake \
-    -DGGML_OPENBLAS=ON \
+cmake -DGGML_OPENBLAS=ON \
     -DBUILD_SHARED_LIBS=On \
     -DBUILD_SHARED_LIBS=On \
 	  -DCMAKE_BUILD_TYPE=Release \
 	  -DCMAKE_BUILD_TYPE=Release \
 	  -DCMAKE_CXX_FLAGS="-g2 -fno-omit-frame-pointer" \
 	  -DCMAKE_CXX_FLAGS="-g2 -fno-omit-frame-pointer" \
@@ -20,8 +19,6 @@ cmake \
 make -j4 unity # Interactive Console
 make -j4 unity # Interactive Console
 
 
 ```
 ```
-Note that `-DGGML_OPENBLAS=ON` is not necessary on macOS.
-
 For more build commands see [Makefile](Makefile). 
 For more build commands see [Makefile](Makefile). 
 
 
 ## CLI usage
 ## CLI usage
@@ -34,7 +31,7 @@ In the console, enter the path of local waveform file and target language, separ
 Converted ggml models could be downloaded from 
 Converted ggml models could be downloaded from 
 |SeamlessM4T_large | SeamlessM4T_medium | 
 |SeamlessM4T_large | SeamlessM4T_medium | 
 |-------- | -------- | 
 |-------- | -------- | 
-| [model](https://dl.fbaipublicfiles.com/seamless/models/seamlessM4T_large.ggml) | [model](https://dl.fbaipublicfiles.com/seamless/models/seamlessM4T_medium.ggml) |  
+| [model](dl.fbaipublicfiles.com/seamless/models/seamlessM4T_large.ggml) | [model](dl.fbaipublicfiles.com/seamless/models/seamlessM4T_medium.ggml) |  
 
 
 ## Fairseq2 model conversion 
 ## Fairseq2 model conversion 
 Models from fairseq2 checkpoints could be converted to ggml automatically with [ggml_convert.py](ggml_convert.py). 
 Models from fairseq2 checkpoints could be converted to ggml automatically with [ggml_convert.py](ggml_convert.py). 

+ 5 - 1
ggml/examples/common.h

@@ -37,8 +37,12 @@ struct gpt_params {
     int32_t n_gpu_layers     = 0;
     int32_t n_gpu_layers     = 0;
 };
 };
 
 
+bool unity_params_parse(int argc, char ** argv, unity_params & params);
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 
 
+void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params);
+
 void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
 void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
 
 
 
 
@@ -175,4 +179,4 @@ struct sam_params {
 
 
 bool sam_params_parse(int argc, char ** argv, sam_params & params);
 bool sam_params_parse(int argc, char ** argv, sam_params & params);
 
 
-void sam_print_usage(int argc, char ** argv, const sam_params & params);
+void sam_print_usage(int argc, char ** argv, const sam_params & params);

+ 14 - 2
ggml/examples/unity/CMakeLists.txt

@@ -7,12 +7,24 @@ target_sources(fairseq2_cpp
         fairseq2.cpp
         fairseq2.cpp
         model_loader.cpp
         model_loader.cpp
 )
 )
+add_library(unity_lib)
+target_include_directories(unity_lib PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
+target_link_libraries(unity_lib PRIVATE ggml kaldi-native-fbank fairseq2_cpp)
+target_sources(unity_lib
+    PRIVATE
+        lib/unity_lib.h
+        lib/unity_lib.cpp
+)
+
 add_executable(unity unity.cpp)
 add_executable(unity unity.cpp)
 find_package(PkgConfig REQUIRED)
 find_package(PkgConfig REQUIRED)
-pkg_check_modules(SNDFILE REQUIRED IMPORTED_TARGET sndfile)
-target_link_libraries(unity PRIVATE ggml PkgConfig::SNDFILE)
+pkg_check_modules(SNDFILE REQUIRED sndfile)
+target_include_directories(unity PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${SNDFILE_INCLUDE_DIRS})
+target_link_libraries(unity PRIVATE ggml unity_lib ${SNDFILE_LIBRARIES})
 target_sources(unity
 target_sources(unity
     PRIVATE
     PRIVATE
         fairseq2.cpp
         fairseq2.cpp
         model_loader.cpp
         model_loader.cpp
+        lib/unity_lib.h
+        lib/unity_lib.cpp
 )
 )

+ 160 - 51
ggml/examples/unity/fairseq2.cpp

@@ -11,6 +11,8 @@
 #include "ggml.h"
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-alloc.h"
 
 
+#include <numeric>
+
 ggml_tensor* ggml_detach(ggml_tensor* a) {
 ggml_tensor* ggml_detach(ggml_tensor* a) {
     a->op = GGML_OP_NONE;
     a->op = GGML_OP_NONE;
     std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
     std::fill(a->src, a->src + GGML_MAX_SRC, nullptr);
@@ -56,7 +58,6 @@ extern "C" void fairseq2_kv_cache_alloc(fairseq2_model& model, ggml_context* kv_
     // Note: we only allocate the masks, proper kv cache allocation is delayed.
     // Note: we only allocate the masks, proper kv cache allocation is delayed.
     GGML_ASSERT(kv_cache_ctx);
     GGML_ASSERT(kv_cache_ctx);
     GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx));  // We need to be able to alloc the kv_cache buffers
     GGML_ASSERT(!ggml_get_no_alloc(kv_cache_ctx));  // We need to be able to alloc the kv_cache buffers
-    model.kv_cache_ctx = kv_cache_ctx;
     auto attn_glob = "text_decoder.*_attn.k_proj.weight";
     auto attn_glob = "text_decoder.*_attn.k_proj.weight";
     FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
     FORCE_ALLOC(self_attn_mask, kv_cache_ctx, ggml_new_tensor_2d(kv_cache_ctx, GGML_TYPE_F32, max_seq_len, max_seq_len));
     self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
     self_attn_mask = ggml_diag_mask_inf_inplace(kv_cache_ctx, self_attn_mask, 0);
@@ -107,7 +108,7 @@ inline ggml_tensor* ggml_unsqueeze(ggml_context* ctx, ggml_tensor* x, int dim) {
 void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
 void append_to_prev_kv(const fairseq2_model& model, const std::string& prefix, ggml_tensor** k, ggml_tensor** v, ggml_tensor** self_attn_mask) {
     KeyValueTensor& kv = model.kv_cache[prefix];
     KeyValueTensor& kv = model.kv_cache[prefix];
     int step_nr = kv.step_nr;
     int step_nr = kv.step_nr;
-    ggml_context* ctx = model.kv_cache_ctx ? model.kv_cache_ctx : model.ctx;
+    ggml_context* ctx = model.ctx;
     // We need to force allocation here, otherwise the kv_cache buffers can be reused
     // We need to force allocation here, otherwise the kv_cache buffers can be reused
     bool no_alloc_save = ggml_get_no_alloc(ctx);
     bool no_alloc_save = ggml_get_no_alloc(ctx);
     ggml_set_no_alloc(ctx, false);
     ggml_set_no_alloc(ctx, false);
@@ -214,7 +215,7 @@ extern "C" std::int64_t fairseq2_model_layer_config_int(const fairseq2_model& mo
 
 
 extern "C" void fairseq2_model_free(fairseq2_model* model) {
 extern "C" void fairseq2_model_free(fairseq2_model* model) {
     if (model->tensors_ctx) ggml_free(model->tensors_ctx);
     if (model->tensors_ctx) ggml_free(model->tensors_ctx);
-    delete model;
+    // delete model;
 }
 }
 
 
 extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
 extern "C" void fairseq2_model_set_inference_ctx(fairseq2_model* model, ggml_context* ctx) {
@@ -429,7 +430,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
             if (kv_cache.step_nr == 0) {
             if (kv_cache.step_nr == 0) {
                 // If possible we use the ctx dedicated to kv_cache here,
                 // If possible we use the ctx dedicated to kv_cache here,
                 // because the enc dec attention is typically long lived.
                 // because the enc dec attention is typically long lived.
-                if (model.kv_cache_ctx) model.ctx = model.kv_cache_ctx;
+                if (model.enc_kv_cache_ctx) model.ctx = model.enc_kv_cache_ctx;
                 k = Linear_forward(model, prefix + ".k_proj", keys);
                 k = Linear_forward(model, prefix + ".k_proj", keys);
                 ggml_set_name(k, "k");
                 ggml_set_name(k, "k");
                 v = Linear_forward(model, prefix + ".v_proj", values);
                 v = Linear_forward(model, prefix + ".v_proj", values);
@@ -594,11 +595,7 @@ extern "C" ggml_tensor* WaveformToFbank_forward(
     output = ggml_norm(ctx, output, 1e-5);
     output = ggml_norm(ctx, output, 1e-5);
     output = ggml_dup(ctx, ggml_transpose(ctx, output));
     output = ggml_dup(ctx, ggml_transpose(ctx, output));
     if (output->ne[1] % 2 == 1) {
     if (output->ne[1] % 2 == 1) {
-        ggml_tensor* remove_last = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, output->ne[1]-1);
-        for (int i = 0; i < output->ne[1]-1; ++i) {
-            ((int32_t *) remove_last->data)[i] = i;
-        }
-        output = ggml_get_rows(ctx, output, remove_last);
+        output = ggml_dup(ctx, ggml_slice(ctx, output, 1, 0, output->ne[1]-1));
     }
     }
     output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
     output = ggml_reshape_2d(ctx, output, output->ne[0] * 2, output->ne[1] / 2);
     return output;
     return output;
@@ -714,7 +711,9 @@ extern "C" ggml_tensor* ConvModule_forward(
         seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
         seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
 
 
         // S x C -> (S+K-1) x C -> K x S x C -> S x C
         // S x C -> (S+K-1) x C -> K x S x C -> S x C
-        seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, 15, 1);
+        int K = model.tensors[prefix + ".depthwise_conv.weight"]->ne[0];
+
+        seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".depthwise_conv.weight"], seqs, 1, K / 2, 1, seqs->ne[1]);
 
 
         // conv: Custom implementation of batch norm
         // conv: Custom implementation of batch norm
         seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
         seqs = ggml_batch_norm(ctx, seqs, model.tensors[prefix + ".batch_norm.weight"], model.tensors[prefix + ".batch_norm.bias"], model.tensors[prefix + ".batch_norm.running_mean"], model.tensors[prefix + ".batch_norm.running_var"], 1e-5);
@@ -813,14 +812,14 @@ extern "C" ggml_tensor* StandardConformerEncoderAdaptorLayer_forward(
     ggml_tensor* residual = seqs;
     ggml_tensor* residual = seqs;
     residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
     residual = LayerNorm_forward(model, prefix + ".residual_layer_norm", residual);
     residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
     residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
-    residual = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1);
+    residual = ggml_conv_1d(ctx, model.tensors[prefix + ".residual_conv.weight"], residual, 8, 4, 1, 1);
     residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
     residual = ggml_dup(ctx, ggml_permute(ctx, residual, 1, 0, 2, 3));
     residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
     residual = ggml_add_inplace(ctx, ggml_repeat(ctx, model.tensors[prefix + ".residual_conv.bias"], residual), residual);
     residual = ggml_glu(ctx, residual);
     residual = ggml_glu(ctx, residual);
 
 
     seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
     seqs = LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
     seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
     seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
-    seqs = ggml_conv_1d_generic(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1);
+    seqs = ggml_conv_1d(ctx, model.tensors[prefix + ".self_attn_conv.weight"], seqs, 8, 4, 1, 1);
     seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
     seqs = ggml_dup(ctx, ggml_permute(ctx, seqs, 1, 0, 2, 3));
     seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
     seqs = ggml_add_inplace(ctx, seqs, ggml_repeat(ctx, model.tensors[prefix + ".self_attn_conv.bias"], seqs));
     seqs = ggml_glu(ctx, seqs);
     seqs = ggml_glu(ctx, seqs);
@@ -1160,23 +1159,31 @@ ggml_tensor* ggml_expand_2d(ggml_context* ctx, ggml_tensor* x, int64_t ne0, int6
     return y;
     return y;
 }
 }
 
 
-extern "C" void _bootstrap_seqs_and_scores(
+void _bootstrap_seqs_and_scores(
     fairseq2_model& model,
     fairseq2_model& model,
     const SequenceGeneratorJob& job,
     const SequenceGeneratorJob& job,
     ggml_tensor* full_seqs,
     ggml_tensor* full_seqs,
     ggml_tensor* scores,
     ggml_tensor* scores,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
     ggml_tensor* encoder_padding_mask,
-    int n_threads
+    ggml_tensor* lid_scores,
+    int n_threads,
+    const std::vector<int>& lang_ids
 ) {
 ) {
+    // Returns LID score map
     int prefix_seq_len = job.prefix_seq->ne[0];
     int prefix_seq_len = job.prefix_seq->ne[0];
     int max_seq_len = scores->ne[0];
     int max_seq_len = scores->ne[0];
     int beam_size = scores->ne[1];
     int beam_size = scores->ne[1];
     GGML_ASSERT(prefix_seq_len > 0);
     GGML_ASSERT(prefix_seq_len > 0);
-    if (prefix_seq_len == 1)
-        return;
-
     ggml_context* ctx = model.ctx;
     ggml_context* ctx = model.ctx;
+    if (prefix_seq_len == 1) {
+        // We only have one token in prefix, we won't compute decoding scores,
+        // we just need to copy the token to seqs.
+        // Note: it also means the enc_kv_cache will be populated later.
+        ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
+        ggml_set_i32(seqs, ggml_get_i32_1d(job.prefix_seq, 0));
+        return;
+    }
 
 
     // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
     // full_seqs[:, : prefix_seq_len] = job.prefix_seq;
     ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
     ggml_tensor* seqs = ggml_slice(ctx, full_seqs, 0, 0, prefix_seq_len);
@@ -1202,14 +1209,33 @@ extern "C" void _bootstrap_seqs_and_scores(
     ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
     ggml_tensor* logits = Linear_forward(model, "final_proj", decoder_output);
     int vocab_size = logits->ne[0];
     int vocab_size = logits->ne[0];
     ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
     ggml_tensor* lprobs = ggml_log_softmax(ctx, ggml_slice(ctx, logits, 1, 0, 1));
-
-    ggml_cgraph gf = ggml_build_forward(lprobs);
-    ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+    ggml_build_forward_expand(gf, lprobs);
+    ggml_graph_compute_with_ctx(ctx, gf, n_threads);
+
+    full_seqs->type = GGML_TYPE_I32;
+    job.prefix_seq->type = GGML_TYPE_I32;
+    // For LID
+    for (size_t i = 0; i < lang_ids.size(); ++i) {
+        ggml_set_f32_1d(lid_scores, i, std::exp(ggml_get_f32_1d(lprobs, lang_ids[i])));
+    }
 
 
     // Fetch scores of next steps from "lprobs"
     // Fetch scores of next steps from "lprobs"
     float p_score = 0;
     float p_score = 0;
     for (int i = 1; i < prefix_seq_len; ++i) {
     for (int i = 1; i < prefix_seq_len; ++i) {
-        int p = ggml_get_i32_1d(job.prefix_seq, i);
+        int p;
+        if (ggml_get_i32_1d(job.prefix_seq, i) == model.vocab.token_to_id["<unk>"]) {
+            // If tgt_lang is unk, use the most probable lang tag predicted by model
+            int max_value = std::numeric_limits<float>::min();
+            for (int j = 0; j < lang_ids.size(); j++) {
+                if(ggml_get_f32_1d(lprobs, lang_ids[j]) > max_value) {
+                    max_value = ggml_get_f32_1d(lprobs, lang_ids[j]);
+                    p = lang_ids[j];
+                }
+            }
+        } else {
+            p = ggml_get_i32_1d(job.prefix_seq, i);
+        }
         p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
         p_score += ggml_get_f32_1d(lprobs, i * vocab_size + p);
         for (int b = 0; b < beam_size; ++b) {
         for (int b = 0; b < beam_size; ++b) {
             // scores: (N, S)
             // scores: (N, S)
@@ -1290,6 +1316,7 @@ void _finalize_hypothesis(
     float eos_score,
     float eos_score,
     ggml_tensor* seqs, // (beam_size, seq_len)
     ggml_tensor* seqs, // (beam_size, seq_len)
     ggml_tensor* scores, // (beam_size, seq_len)
     ggml_tensor* scores, // (beam_size, seq_len)
+    ggml_tensor* lid_scores,
     Hypothesis* hypothesis
     Hypothesis* hypothesis
 ) {
 ) {
     ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
     ggml_tensor* seq = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, step_nr + 2);
@@ -1317,6 +1344,7 @@ void _finalize_hypothesis(
         // Skip first EOS since it is always 0 and skews normalization.
         // Skip first EOS since it is always 0 and skews normalization.
         eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
         eos_score /= (float)std::pow((step_nr + 1), job.opts.len_penalty);
     hypothesis->score = eos_score;
     hypothesis->score = eos_score;
+    hypothesis->lid_scores = lid_scores;
 }
 }
 
 
 // Uses ggml_context to store any object.
 // Uses ggml_context to store any object.
@@ -1366,6 +1394,15 @@ extern "C" Hypothesis* generate_sequence(
     };
     };
     ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
     ggml_allocr* step_alloc = new_arena_allocr(local_bufs[3]);
 
 
+    std::vector<int> lang_ids;
+    if (job.prefix_seq->ne[0] > 1) {
+        for (const auto& kv : model.vocab.token_to_id) {
+            if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
+                lang_ids.push_back(kv.second);
+            }
+        }
+        std::sort(lang_ids.begin(), lang_ids.end());
+    }
     ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
     ggml_tensor* embed = model.tensors["text_decoder_frontend.embed.weight"];
     size_t vocab_size = embed->ne[1];
     size_t vocab_size = embed->ne[1];
     std::size_t beam_size = job.opts.beam_size;
     std::size_t beam_size = job.opts.beam_size;
@@ -1395,22 +1432,20 @@ extern "C" Hypothesis* generate_sequence(
     ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
     ggml_tensor* scores = ggml_new_tensor_2d(search_ctx, GGML_TYPE_F32, max_seq_len, beam_size);
     ggml_set_name(scores, "scores_0");
     ggml_set_name(scores, "scores_0");
     ggml_set_f32(scores, 0.0);
     ggml_set_f32(scores, 0.0);
-
     int prefix_seq_len = job.prefix_seq->ne[0];
     int prefix_seq_len = job.prefix_seq->ne[0];
     int start_step = prefix_seq_len - 1;
     int start_step = prefix_seq_len - 1;
-
-    ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step - 1) % 2]);
+    ggml_context* prev_step_ctx = ctx_from_buffer(local_bufs[(start_step + 1) % 2]);
     ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
     ggml_context* step_ctx = ctx_from_buffer(local_bufs[start_step % 2]);
     GGML_ASSERT(step_ctx != search_ctx);
     GGML_ASSERT(step_ctx != search_ctx);
-    GGML_ASSERT(prev_step_ctx != step_ctx);
-    model.ctx = prev_step_ctx;
-    // search_ctx because we need encoder_decoder_attn.k_cache to survive for the full search
-    model.kv_cache_ctx = search_ctx;
+    model.enc_kv_cache_ctx = search_ctx;
+    ggml_tensor* lid_scores;
+    if (lang_ids.size()) {
+        lid_scores = ggml_new_tensor_1d(result_ctx, GGML_TYPE_F32, lang_ids.size());
+    } 
+    // Multilingual models: Bootstrap LID scores
     _bootstrap_seqs_and_scores(
     _bootstrap_seqs_and_scores(
-        model, job, seqs, scores, encoder_output, encoder_padding_mask, n_threads
+        model, job, seqs, scores, encoder_output, encoder_padding_mask, lid_scores, n_threads, lang_ids
     );
     );
-    // Now we will only add self_attn.k_cache and those need to be resorted and copied at every step.
-    model.kv_cache_ctx = nullptr;
 
 
     // Holds the indices of beams (a beam can occur more than once) that we
     // Holds the indices of beams (a beam can occur more than once) that we
     // should continue with in the next step.
     // should continue with in the next step.
@@ -1428,6 +1463,24 @@ extern "C" Hypothesis* generate_sequence(
     for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
     for (int step_nr = start_step; step_nr < max_seq_len - 1; ++step_nr) {
         model.ctx = step_ctx;
         model.ctx = step_ctx;
         ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
         ggml_set_no_alloc(step_ctx, true); // Use allocr for the model forward pass
+        float max_lprob;
+        int p;
+        if (step_nr == start_step) {
+            // Find the most probable lang_tok and assign it to all beams, when prefix_seq[1] is <unk>
+            if (lang_ids.size() && ggml_get_i32_1d(job.prefix_seq, 1) == model.vocab.token_to_id["<unk>"]) {
+                float max_lprob = std::numeric_limits<float>::min();
+                for(int j = 0; j < lang_ids.size(); j++) {
+                    auto val = ggml_get_f32_1d(lid_scores, j);
+                    if (val > max_lprob) {
+                        max_lprob = val;
+                        p = lang_ids[j];
+                    }
+                }
+                for (int k = 0; k < beam_size; k++) {
+                    ggml_set_i32_1d(seqs, k * vocab_size + step_nr, p);
+                }
+            }
+        }
         ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
         ggml_tensor* prev_token = ggml_slice(step_ctx, seqs, 0, step_nr, step_nr + 1);
 
 
         ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
         ggml_tensor* decoder_input = TransformerEmbeddingFrontend_forward(model, "text_decoder_frontend", prev_token);
@@ -1448,10 +1501,11 @@ extern "C" Hypothesis* generate_sequence(
 
 
         // Compute lprobs here so we can modify it in place in the lprob tweaking phase
         // Compute lprobs here so we can modify it in place in the lprob tweaking phase
         // TODO: use ggml properly compute the tweaks
         // TODO: use ggml properly compute the tweaks
-        ggml_cgraph gf = ggml_build_forward(lprobs);
-        size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, &gf);
+        struct ggml_cgraph * gf = ggml_new_graph(step_ctx);
+        ggml_build_forward_expand(gf, lprobs);
+        size_t fwd_mem = ggml_allocr_alloc_graph(step_alloc, gf);
         GGML_UNUSED(fwd_mem);
         GGML_UNUSED(fwd_mem);
-        ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
+        ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
         ggml_detach(lprobs);
         ggml_detach(lprobs);
         ggml_allocr_reset(step_alloc);
         ggml_allocr_reset(step_alloc);
 #if DEBUG_MEM_USAGE
 #if DEBUG_MEM_USAGE
@@ -1476,9 +1530,8 @@ extern "C" Hypothesis* generate_sequence(
             // Make probabilities contain cumulative scores for each hypothesis.
             // Make probabilities contain cumulative scores for each hypothesis.
             lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
             lprobs = ggml_add_inplace(step_ctx, lprobs, ggml_repeat(step_ctx, last_scores, lprobs));
         }
         }
-
-        gf = ggml_build_forward(lprobs);
-        ggml_graph_compute_with_ctx(step_ctx, &gf, n_threads);
+        ggml_build_forward_expand(gf, lprobs);
+        ggml_graph_compute_with_ctx(step_ctx, gf, n_threads);
 
 
         // Determine (beam, token) candidates for the next step.
         // Determine (beam, token) candidates for the next step.
         // (N, 2 x B)
         // (N, 2 x B)
@@ -1497,7 +1550,7 @@ extern "C" Hypothesis* generate_sequence(
             bool eos = token == job.eos_idx;
             bool eos = token == job.eos_idx;
             eos &= tok_score != -INFINITY;
             eos &= tok_score != -INFINITY;
             if (eos) {
             if (eos) {
-                _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, finished_searches++);
+                _finalize_hypothesis(job, result_ctx, step_nr, beam, token, tok_score, seqs, scores, lid_scores, finished_searches++);
                 if (finished_searches == finished_searches_end)
                 if (finished_searches == finished_searches_end)
                     goto end_of_beam_search;
                     goto end_of_beam_search;
                 continue;
                 continue;
@@ -1517,10 +1570,11 @@ extern "C" Hypothesis* generate_sequence(
         ggml_set_no_alloc(step_ctx, false);
         ggml_set_no_alloc(step_ctx, false);
         ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
         ggml_tensor* new_seqs = ggml_get_rows(step_ctx, seqs, beam_indices);
         ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
         ggml_tensor* new_scores = ggml_get_rows(step_ctx, scores, beam_indices);
-        ggml_cgraph gf_reorder = ggml_build_forward(new_seqs);
-        ggml_build_forward_expand(&gf_reorder, new_scores);
-        reorder_kv_cache(model, step_ctx, &gf_reorder, beam_indices);
-        ggml_graph_compute_with_ctx(step_ctx, &gf_reorder, n_threads);
+        struct ggml_cgraph * gf_reorder = ggml_new_graph(step_ctx);
+        ggml_build_forward_expand(gf_reorder, new_seqs);
+        ggml_build_forward_expand(gf_reorder, new_scores);
+        reorder_kv_cache(model, step_ctx, gf_reorder, beam_indices);
+        ggml_graph_compute_with_ctx(step_ctx, gf_reorder, n_threads);
         seqs = ggml_detach(new_seqs);
         seqs = ggml_detach(new_seqs);
         scores = ggml_detach(new_scores);
         scores = ggml_detach(new_scores);
 
 
@@ -1531,7 +1585,7 @@ extern "C" Hypothesis* generate_sequence(
             ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
             ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
         }
         }
 
 
-        printf_mem_usage(step_ctx, "  step_ctx");
+        printf_mem_usage(step_ctx, "step_ctx");
         ggml_free(prev_step_ctx);
         ggml_free(prev_step_ctx);
         prev_step_ctx = step_ctx;
         prev_step_ctx = step_ctx;
 #if DEBUG_MEM_USAGE
 #if DEBUG_MEM_USAGE
@@ -1607,7 +1661,7 @@ struct llm_bigram_spm {
 struct llm_tokenizer_spm {
 struct llm_tokenizer_spm {
     llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
     llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
 
 
-    void tokenize(const std::string& input_text, ggml_tensor& output) {
+    void tokenize(const std::string& input_text, ggml_tensor* output) {
         llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
         llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
 
 
         // split string into utf8 chars
         // split string into utf8 chars
@@ -1675,8 +1729,8 @@ struct llm_tokenizer_spm {
             try_add_bigram(bigram.left, left_sym.next);
             try_add_bigram(bigram.left, left_sym.next);
         }
         }
 
 
-        llama_vocab::id* out = (llama_vocab::id*)output.data;
-        int out_step = sizeof(llama_vocab::id) / output.nb[0];
+        llama_vocab::id* out = (llama_vocab::id*)output->data;
+        int out_step = sizeof(llama_vocab::id) / output->nb[0];
         int num_tokens = 0;
         int num_tokens = 0;
         for (int i = 0; i > -1; i = symbols[i].next) {
         for (int i = 0; i > -1; i = symbols[i].next) {
             llm_symbol& symbol = symbols[i];
             llm_symbol& symbol = symbols[i];
@@ -1685,7 +1739,7 @@ struct llm_tokenizer_spm {
         }
         }
         *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
         *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
         num_tokens += 1;
         num_tokens += 1;
-        output.ne[0] = num_tokens;
+        output->ne[0] = num_tokens;
     }
     }
 
 
 private:
 private:
@@ -1724,21 +1778,23 @@ private:
 };
 };
 
 
 
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
     llm_tokenizer_spm spm = {model->vocab};
     llm_tokenizer_spm spm = {model->vocab};
     spm.tokenize(std::string(text), out);
     spm.tokenize(std::string(text), out);
 }
 }
 
 
+
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
-    int eos_idx = model->vocab.token_to_id["</s>"];
+    bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
+    int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
     int sent_len = tokens->ne[0];
     int sent_len = tokens->ne[0];
     std::size_t written = 0;
     std::size_t written = 0;
+    std::vector<llama_vocab::token_data> id_to_token = no_tgt_vocab ? model->vocab.id_to_token : model->tgt_vocab.id_to_token;
     for (int i = 0; i < sent_len; ++i) {
     for (int i = 0; i < sent_len; ++i) {
         int id = ggml_get_i32_1d(tokens, i);
         int id = ggml_get_i32_1d(tokens, i);
         // Don't print the EOS token but only if it appear at the end.
         // Don't print the EOS token but only if it appear at the end.
         if (i == sent_len - 1 && eos_idx == id) break;
         if (i == sent_len - 1 && eos_idx == id) break;
-
-        std::string token = model->vocab.id_to_token.at(id).text;
+        std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
         // Skip the first space outputted.
         // Skip the first space outputted.
         auto begin = token.begin();
         auto begin = token.begin();
         if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
         if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
@@ -1750,3 +1806,56 @@ extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tenso
     *out = '0';
     *out = '0';
     return written;
     return written;
 }
 }
+
+
+// TODO: Unify with the above?
+std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(
+        fairseq2_model* model,
+        ggml_tensor* tokens,
+        ggml_tensor* scores,
+        char* out) {
+    bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
+    int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
+    int sent_len = tokens->ne[0];
+    std::size_t written = 0;
+    std::vector<float> word_scores;
+    std::vector<float> subword_scores;
+    std::vector<std::string> result_text;
+    std::string curr_token = "";
+    for (int i = 0; i < sent_len; ++i) {
+        int id = ggml_get_i32_1d(tokens, i);
+        // Don't print the EOS token but only if it appear at the end.
+        if (i == sent_len - 1 && eos_idx == id) break;
+
+        std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
+        float score = ggml_get_f32_1d(scores, i+2); // 2 is prefix size
+        if(token[0] == ' ') {
+            // reset word score
+            if(subword_scores.size() > 0) {
+                float avg = std::accumulate(subword_scores.begin(), subword_scores.end(), 0.0f) / subword_scores.size();
+                word_scores.push_back(avg);
+                subword_scores.clear();
+                result_text.push_back(curr_token);
+            }
+            curr_token = token.substr(1);
+        } else {
+            curr_token += token;
+        }
+        subword_scores.push_back(score);
+        // Skip the first space outputted.
+        auto begin = token.begin();
+        if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
+        std::copy(begin, token.end(), out);
+        std::size_t n = token.end() - begin;
+        written += n;
+        out += n;
+
+    }
+    if(subword_scores.size() > 0) {
+        word_scores.push_back(*std::min_element(subword_scores.begin(), subword_scores.end()));
+        subword_scores.clear();
+        result_text.push_back(curr_token);
+    }
+    *out = '0';
+    return std::make_pair(result_text, word_scores);
+}

+ 19 - 3
ggml/examples/unity/fairseq2.h

@@ -97,8 +97,12 @@ struct fairseq2_model {
     // Normally those can be inferred from hparams, but it avoids doing this logic in GGML
     // Normally those can be inferred from hparams, but it avoids doing this logic in GGML
     std::unordered_map<std::string, std::int64_t> layer_config = {};
     std::unordered_map<std::string, std::int64_t> layer_config = {};
 
 
+    // Vocabulary for text transcription and translation APIs
     llama_vocab vocab;
     llama_vocab vocab;
 
 
+    // Optional target vocabulary for bilingual models
+    llama_vocab tgt_vocab;
+
     // KV cache for attention layers
     // KV cache for attention layers
     mutable std::unordered_map<std::string, KeyValueTensor> kv_cache = {};
     mutable std::unordered_map<std::string, KeyValueTensor> kv_cache = {};
 
 
@@ -106,7 +110,7 @@ struct fairseq2_model {
     // TODO: is this the best place to store this or should we also pass this to all forward methods ?
     // TODO: is this the best place to store this or should we also pass this to all forward methods ?
     ggml_context* ctx = nullptr;
     ggml_context* ctx = nullptr;
 
 
-    ggml_context* kv_cache_ctx = nullptr;
+    ggml_context* enc_kv_cache_ctx = nullptr;
 };
 };
 
 
 double fairseq2_model_layer_config_double(const fairseq2_model& model, std::string name);
 double fairseq2_model_layer_config_double(const fairseq2_model& model, std::string name);
@@ -199,6 +203,13 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
     ggml_tensor* padding_mask
     ggml_tensor* padding_mask
 );
 );
 
 
+extern "C" ggml_tensor* StandardTransformerEncoder_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask
+);
+
 extern "C" ggml_tensor* RelativePositionMHA_forward(
 extern "C" ggml_tensor* RelativePositionMHA_forward(
     fairseq2_model& model,
     fairseq2_model& model,
     const std::string& prefix,
     const std::string& prefix,
@@ -302,6 +313,9 @@ struct Hypothesis {
 
 
     /// The score of each individual sequence step.
     /// The score of each individual sequence step.
     ggml_tensor* step_scores;
     ggml_tensor* step_scores;
+
+    /// The score of each lang tok at first decoding step, serving as LID 
+    ggml_tensor* lid_scores;
 };
 };
 
 
 
 
@@ -311,8 +325,10 @@ extern "C" Hypothesis* generate_sequence(
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_output,
     ggml_tensor* encoder_padding_mask,
     ggml_tensor* encoder_padding_mask,
     ggml_context* result_ctx,
     ggml_context* result_ctx,
-    int n_threads
+    int threads
 );
 );
 
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out);
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out);
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out);
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out);
+
+std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out);

+ 8 - 2
ggml/examples/unity/model_loader.cpp

@@ -1,5 +1,5 @@
-#include <string>
 #include "model_loader.h"
 #include "model_loader.h"
+#include <string>
 
 
 #define DEBUG_MODEL_LOAD 0
 #define DEBUG_MODEL_LOAD 0
 
 
@@ -133,7 +133,10 @@ void model_loader::load_vocab(llama_vocab& vocab, std::ifstream &fin)
 
 
     std::int64_t vocab_size = 0;
     std::int64_t vocab_size = 0;
     fin.read(reinterpret_cast<char*>(&vocab_size), sizeof(vocab_size));
     fin.read(reinterpret_cast<char*>(&vocab_size), sizeof(vocab_size));
-    GGML_ASSERT(fin.gcount() == 8);
+    // GGML_ASSERT(fin.gcount() == 8);
+    if (vocab_size == 0) {
+        return;
+    }
 
 
     vocab.token_to_id.reserve(vocab_size);
     vocab.token_to_id.reserve(vocab_size);
     vocab.id_to_token.reserve(vocab_size);
     vocab.id_to_token.reserve(vocab_size);
@@ -219,5 +222,8 @@ extern "C" int load_fairseq2_ggml_file(fairseq2_model& model, const char* fname)
     loader.load_hparams(model.layer_config, fin);
     loader.load_hparams(model.layer_config, fin);
     loader.load_vocab(model.vocab, fin);
     loader.load_vocab(model.vocab, fin);
     loader.load_model_weights(model, fin);
     loader.load_model_weights(model, fin);
+    
+    // load optional target vocabulary in cases of bilingual models
+    loader.load_vocab(model.tgt_vocab, fin);
     return 0;
     return 0;
 }
 }

+ 0 - 2
ggml/examples/unity/model_loader.h

@@ -25,8 +25,6 @@ public:
     void load_vocab(llama_vocab& vocab, std::ifstream &fin);
     void load_vocab(llama_vocab& vocab, std::ifstream &fin);
 
 
 private:
 private:
-    ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
-
     std::string get_name(std::ifstream &fin);
     std::string get_name(std::ifstream &fin);
 };
 };
 
 

+ 62 - 86
ggml/examples/unity/unity.cpp

@@ -4,24 +4,17 @@
 #include "math.h"
 #include "math.h"
 #include "model_loader.h"
 #include "model_loader.h"
 #include "fairseq2.h"
 #include "fairseq2.h"
-
-#include <thread>
-#include <cassert>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <fstream>
-#include <map>
-#include <string>
-#include <vector>
-#include <iostream>
+#include "lib/unity_lib.h"
 #include <sndfile.h>
 #include <sndfile.h>
 #include <cstdlib>
 #include <cstdlib>
 #include "ggml-alloc.h"
 #include "ggml-alloc.h"
+#include <numeric>
+#include <algorithm>
 
 
 struct unity_params {
 struct unity_params {
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    std::string model      = "seamlessM4T_medium.ggml"; // model path
+    std::string model = "seamlessM4T_medium.ggml"; // model path
+    std::string input_text = "";
     std::string tgt_lang = "eng";
     std::string tgt_lang = "eng";
     std::vector<std::string> files = {};
     std::vector<std::string> files = {};
     bool text = false;
     bool text = false;
@@ -34,9 +27,10 @@ struct unity_params {
         /*len_penalty*/ 1.0,
         /*len_penalty*/ 1.0,
         /*unk_penalty*/ 0.0,
         /*unk_penalty*/ 0.0,
         /*normalize_scores*/ true,
         /*normalize_scores*/ true,
-        /*mem_mb*/ 512,
+        /*mem_mb*/ 512
     };
     };
     int32_t max_audio_s = 30;
     int32_t max_audio_s = 30;
+    bool verbose = false;
 };
 };
 
 
 
 
@@ -45,13 +39,17 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -i, --input           Input text for the text-2-text translation\n");
+    fprintf(stderr, "  -l, --tgt-lang        Target translation lang (default: %s\n", params.tgt_lang);
+
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -v, --verbose         Print out word level confidence score and LID score (default: off)");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "  --text                text output\n");
     fprintf(stderr, "  --text                text output\n");
     fprintf(stderr, "  --beam-size           beam size (default: %d)\n", params.opts.beam_size);
     fprintf(stderr, "  --beam-size           beam size (default: %d)\n", params.opts.beam_size);
     fprintf(stderr, "  -M, --mem             memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
     fprintf(stderr, "  -M, --mem             memory buffer, increase for long inputs (default: %d)\n", params.opts.mem_mb);
-    fprintf(stderr, "  --max-audio           max duration of audio in seconds (default: %d)\n", params.max_audio_s);
+    fprintf(stderr, " --max-audio max duration of audio in seconds (default: %d)\n", params.max_audio_s);
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
 }
 }
 
 
@@ -75,12 +73,16 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
         } else if (arg == "-m" || arg == "--model") {
             params.model = get_next_arg(i, argc, argv, arg, params);
             params.model = get_next_arg(i, argc, argv, arg, params);
+        } else if (arg == "-i" || arg == "--input") {
+            params.input_text = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-l" || arg == "--tgt-lang") {
         } else if (arg == "-l" || arg == "--tgt-lang") {
             params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
             params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "--text") {
         } else if (arg == "--text") {
             params.text = true;
             params.text = true;
         } else if (arg == "-b" || arg == "--beam-size") {
         } else if (arg == "-b" || arg == "--beam-size") {
             params.opts.beam_size = std::stoi(get_next_arg(i, argc, argv, arg, params));
             params.opts.beam_size = std::stoi(get_next_arg(i, argc, argv, arg, params));
+        } else if (arg == "-v" || arg == "--verbose") {
+            params.verbose = true;
         } else if (arg == "-M" || arg == "--mem") {
         } else if (arg == "-M" || arg == "--mem") {
             params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
             params.opts.mem_mb = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "--max-audio") {
         } else if (arg == "--max-audio") {
@@ -92,41 +94,6 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
     return true;
     return true;
 }
 }
 
 
-struct ggml_cgraph * unity_speech_encoder(
-        fairseq2_model& model,
-        struct ggml_tensor * speech_input) {
-    ggml_context* ctx0 = model.ctx;
-    ggml_cgraph* gf = ggml_new_graph(ctx0);
-    ggml_tensor* seqs = StandardConformerEncoder_forward(model, "speech_encoder", speech_input, nullptr);
-    seqs = ggml_dup(model.ctx, seqs);
-    ggml_build_forward_expand(gf, seqs);
-    return gf;
-}
-
-
-Hypothesis* unity_decode(
-        fairseq2_model& model,
-        const SequenceGeneratorOptions& opts,
-        int tgt_lang_idx,
-        ggml_tensor* encoder_output,
-        int n_threads
-) {
-    SequenceGeneratorJob job = {
-        opts,
-        /*prefix_seq*/ nullptr,
-        /*pad_idx*/model.vocab.token_to_id["<pad>"],
-        /*unk_idx*/model.vocab.token_to_id["<unk>"],
-        /*bos_idx*/model.vocab.token_to_id["<s>"],
-        /*eos_idx*/model.vocab.token_to_id["</s>"],
-        /*num_threads*/n_threads,
-    };
-    FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2));
-    ((int *)prefix_seq->data)[0]  = job.eos_idx;
-    ((int *)prefix_seq->data)[1]  = tgt_lang_idx;
-    job.prefix_seq = prefix_seq;
-    return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
-}
-
 int main(int argc, char ** argv) {
 int main(int argc, char ** argv) {
 
 
     unity_params params;
     unity_params params;
@@ -151,8 +118,13 @@ int main(int argc, char ** argv) {
     char result_str[4096];
     char result_str[4096];
 
 
     std::string input;
     std::string input;
-    bool interactive = params.files.size() == 0;
+    bool interactive = (params.files.size() == 0 && params.input_text.length() == 0);
     auto next_file = params.files.begin();
     auto next_file = params.files.begin();
+
+    // Flag for the input case: true --> s2st, false --> t2tt
+    bool s2st_or_t2tt = true;
+
+    // S2ST
     while (true) {
     while (true) {
         if (interactive) {
         if (interactive) {
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
@@ -161,7 +133,10 @@ int main(int argc, char ** argv) {
                 break;
                 break;
             }
             }
         } else {
         } else {
-            if (next_file == params.files.end()) break;
+            if (params.input_text.length() > 0) {
+                break;
+            }
+            if (next_file == params.files.end() && s2st_or_t2tt) break;
             input = *(next_file++);
             input = *(next_file++);
         }
         }
         std::istringstream iss(input);
         std::istringstream iss(input);
@@ -179,46 +154,47 @@ int main(int argc, char ** argv) {
             if (interactive) continue;
             if (interactive) continue;
             else return 1;
             else return 1;
         }
         }
-        auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__");
-        if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
-            std::cerr << "Unknown language " << tgt_lang << "\n";
-            if (interactive) continue;
-            else return 2;
-        }
-        int tgt_lang_idx = tgt_lang_ptr->second;
-
-
-        // Reset the ggml_context
-        model.ctx = ctx_from_buffer(encoder_buf);
-        ggml_set_no_alloc(model.ctx, true);
+        // Load audio input
         GGML_ASSERT(info.samplerate == 16000);
         GGML_ASSERT(info.samplerate == 16000);
         GGML_ASSERT(info.channels == 1);
         GGML_ASSERT(info.channels == 1);
         // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
         // Truncate audio input. Ideally we should chunk it, but this will prevent most obvious OOM.
         int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
         int n_frames = std::min(info.samplerate * params.max_audio_s, (int)info.frames);
-        ggml_tensor* seqs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_frames, info.channels);
-        ggml_allocr_alloc(fwd_alloc, seqs);
+        std::vector<float> data(n_frames * info.channels);
+        sf_readf_float(sndfile, data.data(), n_frames);
 
 
-        // Load audio input
-        sf_readf_float(sndfile, (float*)seqs->data, n_frames);
-
-        // Audio encoder
-        ggml_cgraph* gf = unity_speech_encoder(model, seqs);
-        size_t enc_mem_used = ggml_allocr_alloc_graph(fwd_alloc, gf);
-        ggml_graph_compute_with_ctx(model.ctx, gf, params.n_threads);
-        // encoder_output is valid until we call `ggml_allocr_reset(fwd_alloc)`
-        ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
-
-        // Beam search decoding
-        const Hypothesis* result = unity_decode(model, params.opts, tgt_lang_idx, encoder_output, params.n_threads);
-    
-        // Drop language and bos token.
-        ggml_tensor* tokens = ggml_slice(model.ctx, result[0].seq, 0, 2, 0);
-
-        // Collect result string
-        int n = fairseq2_spm_detokenize(&model, tokens, (char*)&result_str);
-        std::cout << std::string((char*)&result_str, n) << std::endl;
-        ggml_free(model.ctx);
-        ggml_allocr_reset(fwd_alloc);
+        Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
+        std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+            [](const std::string& a, const std::string& b) {
+                return a + " " + b;
+            }
+        );
+        if (params.verbose) {
+            std::cout << "Final transcription: " << concat_transcription << std::endl;
+            std::cout << std::endl;
+            std::cout << "Word level confidence score:" << std::endl;
+            for (size_t i = 0; i < result.transcription.size(); ++i) {
+                std::cout << "Word: " << result.transcription[i] << " | Score: " << result.word_confidence_scores[i] << std::endl;
+            }
+            std::cout << std::endl;
+            std::cout << "LID scores: " << std::endl;
+            for (const auto& kv : result.lid_scores) {
+                std::cout << "Language: " << kv.first << "| Score: " << kv.second << std::endl;
+            }
+        } else {
+            std::cout << concat_transcription << std::endl;
+        }
+    }
+
+    // T2TT
+    if (params.input_text.length() > 0) {
+        // tokenize the input text
+        Result result = unity_eval_text(model, params.input_text, params.opts, params.tgt_lang, params.n_threads);
+        std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+            [](const std::string& a, const std::string& b) {
+                return a + " " + b;
+            }
+        );
+        std::cout << "Translation: " << concat_translation << std::endl;
     }
     }
 
 
     return 0;
     return 0;

+ 2 - 2
ggml/ggml.py

@@ -282,7 +282,7 @@ class NativeObj:
         cls._cache[kind] = (alloc_fn, free_fn)
         cls._cache[kind] = (alloc_fn, free_fn)
         return (alloc_fn, free_fn)
         return (alloc_fn, free_fn)
 
 
-    def __init__(self, kind: str, ptr: ctypes.c_void_p = NULL):
+    def __init__(self, kind: str, ptr: ctypes.c_void_p = NULLPTR):
         self.kind = kind
         self.kind = kind
         alloc_fn, self._free_fn = self._init_c_func(kind)
         alloc_fn, self._free_fn = self._init_c_func(kind)
         self.ptr = alloc_fn() if ptr is None else ptr
         self.ptr = alloc_fn() if ptr is None else ptr
@@ -292,7 +292,7 @@ class NativeObj:
         if self.ptr is not None:
         if self.ptr is not None:
             self._free_fn(self.ptr)
             self._free_fn(self.ptr)
             # print(f"freeing {self}")
             # print(f"freeing {self}")
-            self.ptr = NULL
+            self.ptr = NULLPTR
 
 
     def __enter__(self) -> ctypes.c_void_p:
     def __enter__(self) -> ctypes.c_void_p:
         return self.ptr
         return self.ptr

+ 329 - 50
ggml/ggml_convert.py

@@ -6,54 +6,330 @@
 
 
 import dataclasses
 import dataclasses
 import logging
 import logging
-import math
 import struct
 import struct
 from enum import Enum
 from enum import Enum
 from io import BufferedWriter
 from io import BufferedWriter
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Mapping, Tuple, Union, Sequence, Set, final
+import re
 
 
 import torch
 import torch
 from fairseq2.assets import AssetCard
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn.transformer import RelativePositionalEncoding
 from fairseq2.nn.transformer import RelativePositionalEncoding
-from seamless_communication.models import unity
+from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
+from fairseq2.data.typing import PathLike
+from fairseq2.typing import Device, finaloverride
+from fairseq2.models.utils import TokenizerLoaderBase, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
+from fairseq2.assets import asset_store, download_manager
 
 
 import ggml
 import ggml
-import re
 
 
 Preprocessor = Callable[[Any], Any]
 Preprocessor = Callable[[Any], Any]
 log = logging.getLogger("ggml_convert")
 log = logging.getLogger("ggml_convert")
 
 
 
 
+class ModelType(str, Enum):
+    AUTO = "auto"  # inferred from the model name
+    UNITY = "unity"
+    NLLB = "nllb"
+    MT = "bitext"
+    MTS = "bitext_scripted"
+
+
+UNITY_SMALLER_MODELS = [
+    "unity_nano",
+    "unity_micro",
+]  # Trained with fairseq2, with custom dict (not original NLLB ones)
+
+
+NLLB_2_UNITY_KEYMAP = {
+    r"^encoder_frontend\.": r"text_encoder_frontend.",
+    r"^encoder\."         : r"text_encoder.",
+    r"^decoder\."         : r"text_decoder.",
+    r"^decoder_frontend\.": r"text_decoder_frontend.",
+}
+
+
+@final
+class NllbLikeTokenizer(SentencePieceTokenizerBase):
+    """The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
+    Since NllbTokenizer is defined as final, we couldn't inherit from it directly. So copying ~everything"""
+
+    langs: Set[str]
+    default_lang: str
+
+    def __init__(
+        self, pathname: PathLike, langs: Sequence[str], default_lang: str
+    ) -> None:
+        """
+        :param pathname:
+            The pathname of the SentencePiece model file.
+        :param langs:
+            The list of supported languages.
+        :param default_lang:
+            The fall-back language if no language is specified.
+        """
+        # Each language is represented by a `__lang__` control symbol.
+        control_symbols = [f"__{lang}__" for lang in langs]
+
+        # Internal control symbols that are not relevant for eval use.
+        control_symbols.extend(["<MINED_DATA>", "<MMT_BT_DATA>", "<SMT_BT_DATA>"])
+        super().__init__(pathname, control_symbols)
+
+        self.langs = set(langs)
+
+        self.default_lang = default_lang
+
+    @finaloverride
+    def create_encoder(
+        self,
+        *,
+        task: Optional[str] = None,
+        lang: Optional[str] = None,
+        mode: Optional[str] = None,
+        device: Optional[Device] = None,
+        pin_memory: bool = False,
+    ) -> SentencePieceEncoder:
+        """Create a token encoder.
+
+        :param task:
+            Must be 'translation'. If ``None``, defaults to 'translation'.
+        :param lang:
+            A language from :attr:`langs`. If ``None``, defaults to
+            :attr:`default_lang`.
+        :param mode:
+            Must be 'source' or 'target'. Set to 'source' if ``lang`` is the
+            source language; set to 'target' if ``lang`` is the target language.
+            If ``None``, defaults to 'source'.
+        :param device:
+            The device on which to construct tensors.
+        :param pin_memory:
+            If ``True``, uses pinned memory while constructing tensors.
+        """
+        if task is not None and task != "translation":
+            raise ValueError(f"`task` must be 'translation', but is '{task}' instead.")
+
+        if lang is None:
+            lang = self.default_lang
+
+        if lang not in self.langs:
+            raise ValueError(
+                f"`lang` must be a supported language, but is '{lang}' instead."
+            )
+
+        if mode is None or mode == "source":
+            # NLLB models expect a language token in place of BOS in source
+            # sequences.
+            prefix_tokens = [f"__{lang}__"]
+            suffix_tokens = ["</s>"]
+        elif mode == "source_mining":
+            prefix_tokens = [f"__{lang}__", "<MINED_DATA>"]
+            suffix_tokens = ["</s>"]
+        elif mode == "source_mmt_bt":
+            prefix_tokens = [f"__{lang}__", "<MMT_BT_DATA>"]
+            suffix_tokens = ["</s>"]
+        elif mode == "source_smt_bt":
+            prefix_tokens = [f"__{lang}__", "<SMT_BT_DATA>"]
+            suffix_tokens = ["</s>"]
+        elif mode == "target":
+            # Target sequences are expected to start with an EOS, followed by
+            # the language token.
+            prefix_tokens = ["</s>", f"__{lang}__"]
+            suffix_tokens = []
+        else:
+            raise ValueError(
+                f"`mode` must be 'source' or 'target', but is '{mode}' instead."
+            )
+
+        return SentencePieceEncoder(
+            self.model,
+            prefix_tokens=prefix_tokens,
+            suffix_tokens=suffix_tokens,
+            device=device,
+            pin_memory=pin_memory,
+        )
+
+
+@final
+class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
+    """Loads tokenizers used by NLLB models."""
+
+    @finaloverride
+    def _load(self, pathname: Path, card: AssetCard) -> NllbLikeTokenizer:
+        langs = card.field("langs").as_list(str)
+
+        default_lang = card.field("default_lang").as_(str)
+
+        return NllbLikeTokenizer(pathname, langs, default_lang)
+
+
+def convert_state_dict(
+    state_dict: Dict[str, Any], key_map: Optional[Mapping[str, str]] = None
+) -> Dict[str, Any]:
+
+    if key_map is None:
+        return state_dict
+    
+    state_dict = convert_model_state_dict(state_dict, key_map=key_map)
+
+    # We use the built-in version attribute of `torch.nn.Module`.
+    try:
+        del state_dict["encoder.version"]
+    except KeyError:
+        pass
+    try:
+        del state_dict["decoder.version"]
+    except KeyError:
+        pass
+
+    try:
+        del state_dict["encoder.embed_positions._float_tensor"]
+    except KeyError:
+        pass
+    try:
+        del state_dict["decoder.embed_positions._float_tensor"]
+    except KeyError:
+        pass
+
+    return state_dict
+
+
+def convert_unity_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from seamless_communication.models import unity
+    from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
+    from seamless_communication.models.unity.model import UnitYModel
+
+    load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
+        asset_store,
+        download_manager,
+        unity.load_unity_config,
+        create_unity_model,
+        None,
+        restrict_checkpoints=False,
+    )
+
+    model_config = unity.load_unity_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams
+    )
+    hparams["multilingual"] = True
+    log.info(hparams)
+    # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
+    if model_name in UNITY_SMALLER_MODELS:
+        model = load_unity_model_without_conversion(model_name)
+        tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(model_name)
+    else:
+        model = unity.load_unity_model(model_name)
+        tokenizer = unity.load_unity_text_tokenizer(model_name)
+
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
+def convert_nllb_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from fairseq2.models.nllb.loader import load_nllb_tokenizer, load_nllb_model, load_nllb_config
+
+    model_config = load_nllb_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams,
+    )
+    hparams["multilingual"] = True
+
+    model = load_nllb_model(model_name)
+    tokenizer = load_nllb_tokenizer(model_name)
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
+def convert_bitext_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from mt import load_mt_model, load_vocab  #, test_mt
+
+    hparams = hparams or {}
+    hparams["multilingual"] = False
+    model = load_mt_model(model_name)
+    src_vocab, src_spm = load_vocab(model_name, "src")
+    tgt_vocab, tgt_spm = load_vocab(model_name, "tgt")
+
+    # test_mt(model, src_spm, tgt_spm)
+
+    return model, hparams, src_vocab, tgt_vocab
+
+
 def convert_model(
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
     out: Optional[Path] = None,
+    model_type: ModelType = ModelType.AUTO,
     layers: str = "",
     layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     hparams: Optional[Dict[str, Any]] = None,
-    vocab: Optional[List[Tuple[str, float]]] = None,
     fp16: bool = False,
     fp16: bool = False,
 ) -> None:
 ) -> None:
+    """
+    Entry function for converting different kinds of model into GGML file. Supported model checkpoints:
+        - unity models
+        - nllb models
+        - Bilingual encoder-decoder model (Pytorch) with separate vocabulary for src and tgt languages
+        - Bilingual encoder-decoder model (torchscript)
+    Args:
+        model_name: name of a registered model (discoverable in a fairseq2 asset), path to a checkpoint,\
+            or the model object passed directly
+        out: path to store the converted .ggml model. If None, the ggml model is stored in the same place\
+            as input model
+        model_type: type of the model (or inferred from the name, only applied to nllb, unity and seamless)
+        layers: wildcard patterns to filter the layers from the model. Does not applied to scripted models
+        hparams: override the hparams in the model with the user-defined values
+        vocab: Path to  vocabulary files (in case not bundled with the model checkpoint)
+        extra_vocab: Path to additional vocabulary files (used in bilingual models with explicit tgt languages)
+        fp16: Save to .GGML float16 tensors instead of float32
+    """
+
+    key_map: Optional[Dict[str, str]] = None
+    tgt_vocab: Optional[List[Tuple[str, float]]] = None
     if isinstance(model_name, str):
     if isinstance(model_name, str):
         # Load the corresponding fairseq2 model
         # Load the corresponding fairseq2 model
         if out is None:
         if out is None:
             out = Path(model_name).with_suffix(".ggml")
             out = Path(model_name).with_suffix(".ggml")
 
 
-        # The type of model depends on the name
-        if "unity" in model_name or "seamlessM4T" in model_name:
-            if hparams is None:
-                model_config = unity.load_unity_config(model_name)
-                hparams = flatten_config(
-                    dataclasses.asdict(model_config), separator="__"
-                )
-                log.info(hparams)
-            model = unity.load_unity_model(model_name)
-            if vocab is None:
-                tokenizer = unity.load_unity_text_tokenizer(model_name)
-                vocab = read_vocab(tokenizer)
-        else:
-            raise ValueError(f"Unsupported model type: {model_name}")
+        # Reason the model architecture from the model name or user input
+        try:
+            if model_type == ModelType.AUTO:
+                if "unity" in model_name or "seamlessM4T" in model_name:
+                    model_type = ModelType.UNITY
+                elif "nllb" in model_name:
+                    model_type = ModelType.NLLB
+
+            assert (
+                model_type != ModelType.AUTO
+            ), "Cannot infer model type from the `model_name`. Please specify `model_type`"
+
+            if model_type == ModelType.UNITY:
+                model, hparams, vocab = convert_unity_model(model_name, hparams=hparams)
+            elif model_type == ModelType.NLLB:
+                model, hparams, vocab = convert_nllb_model(model_name, hparams=hparams)
+                key_map = NLLB_2_UNITY_KEYMAP
+            elif model_type == ModelType.MTS:
+                # TODO: implement the EdgeML model conversion here
+                raise NotImplementedError("Scripted model conversion not implemented yet")
+            
+            # Bilingual non-scripted model
+            else:
+                model, hparams, vocab, tgt_vocab = convert_bitext_model(model_name, hparams=hparams)
+                key_map = NLLB_2_UNITY_KEYMAP
+        except Exception as exc:
+            raise ValueError(f"Error in loading model: {model_name}") from exc
     else:
     else:
         # Use the model passed explicitly
         # Use the model passed explicitly
         assert (
         assert (
@@ -66,19 +342,12 @@ def convert_model(
     if layers:
     if layers:
         state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
         state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
     fixup_model(model, state_dict, layer_filter=layers)
     fixup_model(model, state_dict, layer_filter=layers)
-    layer_config = read_layer_config(model, layer_filter=layers)
-    vocab = vocab or []
-    write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
-
+    state_dict = convert_state_dict(state_dict, key_map=key_map)
+    layer_config = read_layer_config(model, layer_filter=layers, key_map=key_map)
 
 
-def _nested_getattr(model: Any, name: str) -> Any:
-    parts = name.split(".")
-    node = model
-    for part in parts:
-        node = getattr(node, part)
-        if node is None:
-            return None
-    return node
+    vocab = vocab or []
+    tgt_vocab = tgt_vocab or []
+    write_ggml_file(out, hparams, layer_config, state_dict=state_dict, vocab=vocab, tgt_vocab=tgt_vocab, fp16=fp16)
 
 
 
 
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
@@ -133,15 +402,6 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], lay
         state_dict["speech_encoder.pos_enc"] = rel_pos_enc.freqs
         state_dict["speech_encoder.pos_enc"] = rel_pos_enc.freqs
 
 
 
 
-def convert_to_fp16(state_dict: Dict[str, torch.Tensor]) -> None:
-    for k in state_dict:
-        v = state_dict[k]
-        if v.dtype != torch.float32:
-            # ignore int tensors
-            continue
-        state_dict[k] = v.to(torch.float16)
-
-
 def read_vocab(tokenizer: Any) -> List[Tuple[str, float]]:
 def read_vocab(tokenizer: Any) -> List[Tuple[str, float]]:
     vocab_info = tokenizer.vocab_info
     vocab_info = tokenizer.vocab_info
     vocab = [
     vocab = [
@@ -155,9 +415,10 @@ def write_ggml_file(
     out: Path,
     out: Path,
     hparams: Dict[str, Any],
     hparams: Dict[str, Any],
     layer_config: Dict[str, Any],
     layer_config: Dict[str, Any],
-    vocab: List[Tuple[str, float]],
     state_dict: Dict[str, torch.Tensor],
     state_dict: Dict[str, torch.Tensor],
-    fp16: bool,
+    vocab: List[Tuple[str, float]],
+    tgt_vocab: Optional[List[Tuple[str, float]]] = None,  # tgt_vocab for bilingual models
+    fp16: bool = False,
 ) -> None:
 ) -> None:
     with out.open("wb") as o:
     with out.open("wb") as o:
         write_ggml_header(o)
         write_ggml_header(o)
@@ -165,6 +426,7 @@ def write_ggml_file(
         write_hparams(o, layer_config)
         write_hparams(o, layer_config)
         write_vocab(o, vocab)
         write_vocab(o, vocab)
         write_state_dict(o, state_dict, fp16)
         write_state_dict(o, state_dict, fp16)
+        write_vocab(o, tgt_vocab)
 
 
 
 
 def write_ggml_header(out: BufferedWriter) -> None:
 def write_ggml_header(out: BufferedWriter) -> None:
@@ -200,6 +462,9 @@ def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
 def write_vocab(out: BufferedWriter, vocab: List[Tuple[str, float]]) -> None:
 def write_vocab(out: BufferedWriter, vocab: List[Tuple[str, float]]) -> None:
     out.write(struct.pack("<q", len(vocab)))
     out.write(struct.pack("<q", len(vocab)))
 
 
+    if len(vocab) == 0:
+        return
+
     # Write all words concatenated in a buffer
     # Write all words concatenated in a buffer
     words = [bytes(w, "utf8") for w, score in vocab]
     words = [bytes(w, "utf8") for w, score in vocab]
     packed_words = b"\0".join(words)
     packed_words = b"\0".join(words)
@@ -246,10 +511,12 @@ def write_state_dict(
         # Compressed size
         # Compressed size
         compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
         compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
         log.warning(
         log.warning(
-            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb compressed to {compressed_byte_size / GB:.3f}"
+            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb"
+            f". Compressed to {compressed_byte_size / GB:.3f}Gb"
         )
         )
 
 
     for key, value in state_dict.items():
     for key, value in state_dict.items():
+        # Rename the layers to make it look like "unity-arch"
         write_string(out, key)
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
             # GGML broadcasting isn't as strong as numpy
             # GGML broadcasting isn't as strong as numpy
@@ -324,7 +591,7 @@ def torch_to_ggml_type(dtype: torch.dtype) -> int:
 def flatten_config(
 def flatten_config(
     config: Dict[str, Any],
     config: Dict[str, Any],
     separator: str,
     separator: str,
-    config_preprocessor: Optional[Preprocessor] = None,
+    overrides: Optional[Dict[str, Any]] = None,
 ) -> Dict[str, Any]:
 ) -> Dict[str, Any]:
     """Flatten nested dictionnary
     """Flatten nested dictionnary
 
 
@@ -339,9 +606,6 @@ def flatten_config(
         flat dictionnary
         flat dictionnary
     """
     """
 
 
-    if config_preprocessor is None:
-        config_preprocessor = lambda x: x
-
     def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
     def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
         result = {}
         result = {}
         for key in config:
         for key in config:
@@ -350,16 +614,22 @@ def flatten_config(
                 nested_result = __flatten(config[key], f"{new_key}{separator}")
                 nested_result = __flatten(config[key], f"{new_key}{separator}")
                 result.update(nested_result)
                 result.update(nested_result)
             else:
             else:
-                new_config = config_preprocessor(config[key])
+                new_config = config[key]
                 if new_config is not None:
                 if new_config is not None:
                     result[new_key] = config[key]
                     result[new_key] = config[key]
 
 
         return result
         return result
 
 
-    return __flatten(config)
+    res_config = __flatten(config)
+    if overrides:
+        return {**res_config, **overrides}
+    else:
+        return res_config
 
 
 
 
-def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
+def read_layer_config(
+    model: torch.nn.Module, layer_filter: str, key_map: Optional[Dict[str, str]] = None
+) -> Dict[str, Any]:
     layer_config = {}
     layer_config = {}
 
 
     def _append_node_config(node: Any, prefix: str) -> None:
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -384,6 +654,15 @@ def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, An
     _append_node_config(model, "")
     _append_node_config(model, "")
     for name, node in find_children(model, torch.nn.Module, layer_filter):
     for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
         _append_node_config(node, name + ".")
+
+    key_map = key_map or {}
+    keys_to_replace = []
+    for k, v in layer_config.items():
+        for old_pattern, replacement in key_map.items():
+            if (new_key := re.sub(old_pattern, replacement, k)) != k:
+                keys_to_replace.append((k, new_key))
+    for old_key, new_key in keys_to_replace:
+        layer_config[new_key] = layer_config.pop(old_key)
     return layer_config
     return layer_config
 
 
 
 

+ 75 - 9
ggml/include/ggml/ggml-alloc.h

@@ -6,21 +6,87 @@
 extern "C" {
 extern "C" {
 #endif
 #endif
 
 
+struct ggml_backend;
+struct ggml_backend_buffer;
+struct ggml_backend_buffer_type;
 
 
-GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
-GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
+//
+// Legacy API
+//
+
+typedef struct ggml_allocr * ggml_allocr_t;
+
+// initialize allocator for use with CPU backend only
+GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
+
+// initialize allocator for use with ggml-backend
+GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
+
+GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
 
 
 // tell the allocator to parse nodes following the order described in the list
 // tell the allocator to parse nodes following the order described in the list
 // you should call this if your graph are optimized to execute out-of-order
 // you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
+GGML_API void   ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
+
+GGML_API void   ggml_allocr_free       (ggml_allocr_t alloc);
+GGML_API bool   ggml_allocr_is_measure (ggml_allocr_t alloc);
+GGML_API void   ggml_allocr_reset      (ggml_allocr_t alloc);
+GGML_API void   ggml_allocr_alloc      (ggml_allocr_t alloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_allocr_max_size   (ggml_allocr_t alloc);
+
+GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
+
+//
+// ggml-backend v2 API
+//
+
+// Separate tensor and graph allocator objects
+// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
+// The original API is kept as a wrapper around the new API
+
+// Tensor allocator
+typedef struct ggml_tallocr * ggml_tallocr_t;
 
 
-GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);
-GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
-GGML_API void   ggml_allocr_reset(struct ggml_allocr * alloc);
-GGML_API void   ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
+GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
 
 
+GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
+
+GGML_API void   ggml_tallocr_free       (ggml_tallocr_t talloc);
+GGML_API bool   ggml_tallocr_is_measure (ggml_tallocr_t talloc);
+GGML_API void   ggml_tallocr_reset      (ggml_tallocr_t talloc);
+GGML_API void   ggml_tallocr_alloc      (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_tallocr_max_size   (ggml_tallocr_t talloc);
+
+
+// Graph allocator
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(void);
+GGML_API void   ggml_gallocr_free(ggml_gallocr_t galloc);
+
+GGML_API void   ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
+GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
+
+// Allocate tensors from the allocators given by the hash table
+GGML_API void   ggml_gallocr_alloc_graph_n(
+                    ggml_gallocr_t galloc,
+                    struct ggml_cgraph * graph,
+                    struct ggml_hash_set hash_set,
+                    ggml_tallocr_t * hash_node_talloc);
+
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
 
 
 #ifdef  __cplusplus
 #ifdef  __cplusplus
 }
 }
-#endif
+#endif

+ 181 - 0
ggml/include/ggml/ggml-backend.h

@@ -0,0 +1,181 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
+    //
+    // Backend buffer
+    //
+
+    // buffer type
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
+
+    // buffer
+    GGML_API void   ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API void * ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API size_t ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API void   ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
+
+    //
+    // Backend
+    //
+
+
+    GGML_API const char * ggml_backend_name(ggml_backend_t backend);
+    GGML_API void         ggml_backend_free(ggml_backend_t backend);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
+
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
+
+    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+    GGML_API void ggml_backend_graph_plan_free   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API void ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API bool ggml_backend_supports_op       (ggml_backend_t backend, const struct ggml_tensor * op);
+
+    // tensor copy between different backends
+    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
+
+    //
+    // CPU backend
+    //
+
+    GGML_API ggml_backend_t ggml_backend_cpu_init(void);
+
+    GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
+    GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
+
+    // Create a backend buffer from an existing pointer
+    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
+
+    //
+    // Backend scheduler
+    //
+
+    // The backend scheduler allows for multiple backends to be used together
+    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+    // The backends are selected based on:
+    // - the backend that supports the operation
+    // - the location of the pre-allocated tensors (e.g. the weights)
+    /*
+      Example usage:
+
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
+        // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
+
+        // initialize buffers from a measure graph
+        measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
+
+        // in build_graph:
+        build_graph(...) {
+            // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
+            alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
+            ggml_allocr_alloc(alloc_cpu, tensor);
+
+            // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
+            struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+            ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
+        }
+
+        // allocate backend buffers from measure graph
+        ggml_backend_sched_init_measure(sched, measure_graph);
+
+        // the scheduler is now ready to compute graphs
+
+        // compute
+        graph = build_graph(sched);
+        ggml_backend_sched_graph_compute(sched, graph);
+    */
+
+    struct ggml_backend_sched;
+    typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+    // Initialize a backend scheduler
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
+
+    GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+    // Initialize backend buffers from a measure graph
+    GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+    GGML_API ggml_tallocr_t        ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
+
+    GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+
+    // Allocate a graph on the backend scheduler
+    GGML_API void ggml_backend_sched_graph_compute(
+            ggml_backend_sched_t sched,
+            struct ggml_cgraph * graph);
+
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+
+
+#ifdef  __cplusplus
+}
+#endif

+ 308 - 101
ggml/include/ggml/ggml.h

@@ -58,7 +58,8 @@
 //   {
 //   {
 //       ...
 //       ...
 //
 //
-//       struct ggml_cgraph gf = ggml_build_forward(f);
+//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
+//       ggml_build_forward_expand(gf, f);
 //
 //
 //       // set the input variable and parameter values
 //       // set the input variable and parameter values
 //       ggml_set_f32(x, 2.0f);
 //       ggml_set_f32(x, 2.0f);
@@ -213,15 +214,14 @@
 #define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 
-#define GGML_MAX_DIMS          4
-#define GGML_MAX_NODES         4096
-#define GGML_MAX_PARAMS        256
-#define GGML_MAX_CONTEXTS      64
-#define GGML_MAX_SRC           6
-#define GGML_MAX_NAME          64
-#define GGML_MAX_OP_PARAMS     32
-#define GGML_DEFAULT_N_THREADS 4
-
+#define GGML_MAX_DIMS           4
+#define GGML_MAX_PARAMS         4096
+#define GGML_MAX_CONTEXTS       64
+#define GGML_MAX_SRC            10
+#define GGML_MAX_NAME           64
+#define GGML_MAX_OP_PARAMS      64
+#define GGML_DEFAULT_N_THREADS  4
+#define GGML_DEFAULT_GRAPH_SIZE 4096
 #if UINTPTR_MAX == 0xFFFFFFFF
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
     #define GGML_MEM_ALIGN 4
 #else
 #else
@@ -231,8 +231,9 @@
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 #define GGML_EXIT_ABORTED 1
 
 
-#define GGUF_MAGIC   0x46554747 // "GGUF"
-#define GGUF_VERSION 2
+#define GGUF_MAGIC "GGUF"
+
+#define GGUF_VERSION 3
 
 
 #define GGUF_DEFAULT_ALIGNMENT 32
 #define GGUF_DEFAULT_ALIGNMENT 32
 
 
@@ -243,11 +244,21 @@
 #define GGML_ASSERT(x) \
 #define GGML_ASSERT(x) \
     do { \
     do { \
         if (!(x)) { \
         if (!(x)) { \
+            fflush(stdout); \
             fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            ggml_print_backtrace(); \
             abort(); \
             abort(); \
         } \
         } \
     } while (0)
     } while (0)
 
 
+#ifndef NDEBUG
+#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
+#elif defined(__GNUC__)
+#define GGML_UNREACHABLE() __builtin_unreachable()
+#else
+#define GGML_UNREACHABLE() ((void) 0)
+#endif
+
 // used to copy the number of elements and stride in bytes of tensors into local variables.
 // used to copy the number of elements and stride in bytes of tensors into local variables.
 // main purpose is to reduce code duplication and improve readability.
 // main purpose is to reduce code duplication and improve readability.
 //
 //
@@ -272,6 +283,20 @@
     const type prefix##3 = (pointer)->array[3]; \
     const type prefix##3 = (pointer)->array[3]; \
     GGML_UNUSED(prefix##3);
     GGML_UNUSED(prefix##3);
 
 
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #ifdef  __cplusplus
 #ifdef  __cplusplus
 extern "C" {
 extern "C" {
 #endif
 #endif
@@ -318,7 +343,7 @@ extern "C" {
         GGML_TYPE_COUNT,
         GGML_TYPE_COUNT,
     };
     };
 
 
-    enum ggml_backend {
+    enum ggml_backend_type {
         GGML_BACKEND_CPU = 0,
         GGML_BACKEND_CPU = 0,
         GGML_BACKEND_GPU = 10,
         GGML_BACKEND_GPU = 10,
         GGML_BACKEND_GPU_SPLIT = 20,
         GGML_BACKEND_GPU_SPLIT = 20,
@@ -371,6 +396,7 @@ extern "C" {
         GGML_OP_GROUP_NORM,
         GGML_OP_GROUP_NORM,
 
 
         GGML_OP_MUL_MAT,
         GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
         GGML_OP_OUT_PROD,
         GGML_OP_OUT_PROD,
 
 
         GGML_OP_SCALE,
         GGML_OP_SCALE,
@@ -392,21 +418,23 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
         GGML_OP_CLAMP,
+        GGML_OP_CONV_TRANSPOSE_1D,
+        GGML_OP_IM2COL,
         GGML_OP_CONV_1D,
         GGML_OP_CONV_1D,
-        GGML_OP_CONV_1D_GENERIC,
         GGML_OP_CONV_2D,
         GGML_OP_CONV_2D,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_POOL_2D,
+        GGML_OP_DEPTHWISE_CONV_STAGE_0,  // internal
+        GGML_OP_DEPTHWISE_CONV_STAGE_1,  // internal
+        GGML_OP_DEPTHWISE_CONV_STAGE_2,  // internal
 
 
-        GGML_OP_CONV_1D_STAGE_0,  // internal
-        GGML_OP_CONV_1D_STAGE_1,  // internal
-        GGML_OP_CONV_1D_STAGE_2,  // internal
-
-        GGML_OP_CONV_1D_GENERIC_STAGE_0,
-        GGML_OP_CONV_1D_GENERIC_STAGE_1,  
-
+        GGML_OP_CONV_1D_STAGE_0,
+        GGML_OP_CONV_1D_STAGE_1,  
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
+        GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
 
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_FF,
@@ -446,6 +474,8 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
         GGML_UNARY_OP_SILU,
+
+        GGML_UNARY_OP_COUNT,
         GGML_UNARY_OP_GLU,
         GGML_UNARY_OP_GLU,
     };
     };
 
 
@@ -455,6 +485,12 @@ extern "C" {
         GGML_OBJECT_WORK_BUFFER
         GGML_OBJECT_WORK_BUFFER
     };
     };
 
 
+    enum ggml_log_level {
+        GGML_LOG_LEVEL_ERROR = 2,
+        GGML_LOG_LEVEL_WARN = 3,
+        GGML_LOG_LEVEL_INFO = 4
+    };
+
     // ggml object
     // ggml object
     struct ggml_object {
     struct ggml_object {
         size_t offs;
         size_t offs;
@@ -471,14 +507,16 @@ extern "C" {
 
 
     // n-dimensional tensor
     // n-dimensional tensor
     struct ggml_tensor {
     struct ggml_tensor {
-        enum ggml_type    type;
-        enum ggml_backend backend;
+        enum ggml_type         type;
+        enum ggml_backend_type backend;
+
+        struct ggml_backend_buffer * buffer;
 
 
         int     n_dims;
         int     n_dims;
         int64_t ne[GGML_MAX_DIMS]; // number of elements
         int64_t ne[GGML_MAX_DIMS]; // number of elements
         size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
         size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
-                                   // nb[0] = sizeof(type)
-                                   // nb[1] = nb[0]   * ne[0] + padding
+                                   // nb[0] = ggml_type_size(type)
+                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
                                    // nb[i] = nb[i-1] * ne[i-1]
                                    // nb[i] = nb[i-1] * ne[i-1]
 
 
         // compute data
         // compute data
@@ -506,7 +544,7 @@ extern "C" {
 
 
         void * extra; // extra things e.g. for ggml-cuda.cu
         void * extra; // extra things e.g. for ggml-cuda.cu
 
 
-        char padding[4];
+        char padding[12];
     };
     };
 
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -519,29 +557,35 @@ extern "C" {
 
 
         int n_threads;
         int n_threads;
 
 
-        // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
-        int n_tasks[GGML_MAX_NODES];
-
         // abort ggml_graph_compute when true
         // abort ggml_graph_compute when true
         bool (*abort_callback)(void * data);
         bool (*abort_callback)(void * data);
         void * abort_callback_data;
         void * abort_callback_data;
     };
     };
 
 
-    // next prime after GGML_MAX_NODES
-    // #define GGML_GRAPH_HASHTABLE_SIZE 4099
-    // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
-    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+    enum ggml_cgraph_eval_order {
+        GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+        GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+        GGML_CGRAPH_EVAL_ORDER_COUNT
+    };
+
+    struct ggml_hash_set {
+        size_t size;
+        struct ggml_tensor ** keys;
+    };
 
 
     // computation graph
     // computation graph
     struct ggml_cgraph {
     struct ggml_cgraph {
+        int size;
         int n_nodes;
         int n_nodes;
         int n_leafs;
         int n_leafs;
 
 
-        struct ggml_tensor * nodes[GGML_MAX_NODES];
-        struct ggml_tensor * grads[GGML_MAX_NODES];
-        struct ggml_tensor * leafs[GGML_MAX_NODES];
+        struct ggml_tensor ** nodes;
+        struct ggml_tensor ** grads;
+        struct ggml_tensor ** leafs;
+
+        struct ggml_hash_set visited_hash_table;
 
 
-        void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+        enum ggml_cgraph_eval_order order;
 
 
         // performance
         // performance
         int     perf_runs;
         int     perf_runs;
@@ -549,8 +593,6 @@ extern "C" {
         int64_t perf_time_us;
         int64_t perf_time_us;
     };
     };
 
 
-    static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
-
     // scratch buffer
     // scratch buffer
     struct ggml_scratch {
     struct ggml_scratch {
         size_t offs;
         size_t offs;
@@ -560,7 +602,7 @@ extern "C" {
 
 
     struct ggml_init_params {
     struct ggml_init_params {
         // memory pool
         // memory pool
-        int64_t mem_size;   // bytes
+        size_t mem_size;   // bytes
         void * mem_buffer; // if NULL, memory will be allocated internally
         void * mem_buffer; // if NULL, memory will be allocated internally
         bool   no_alloc;   // don't allocate memory for the tensor data
         bool   no_alloc;   // don't allocate memory for the tensor data
     };
     };
@@ -595,6 +637,8 @@ extern "C" {
     GGML_API int64_t ggml_cycles(void);
     GGML_API int64_t ggml_cycles(void);
     GGML_API int64_t ggml_cycles_per_ms(void);
     GGML_API int64_t ggml_cycles_per_ms(void);
 
 
+    GGML_API void    ggml_print_backtrace(void);
+
     GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems
     GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems
     GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
     GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
 
 
@@ -615,6 +659,9 @@ extern "C" {
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
 
+    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
 
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
@@ -643,7 +690,7 @@ extern "C" {
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
 
 
     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
-    GGML_API int64_t  ggml_get_mem_size       (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
     GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
     GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
 
 
     GGML_API struct ggml_tensor * ggml_new_tensor(
     GGML_API struct ggml_tensor * ggml_new_tensor(
@@ -684,18 +731,30 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
     GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
     GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
     GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
 
 
+    // Context tensor enumeration and lookup
+    GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
+    GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
     GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
 
 
     GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
     GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
     GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
     GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
 
 
+    // Converts a flat index into coordinates
+    GGML_API void    ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
     GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
     GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
 
 
+    GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
     GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
     GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
 
 
+    GGML_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
 
@@ -729,6 +788,12 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
+    GGML_API struct ggml_tensor * ggml_add_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            enum   ggml_type      type);
+
     GGML_API struct ggml_tensor * ggml_add1(
     GGML_API struct ggml_tensor * ggml_add1(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -739,6 +804,9 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
     GGML_API struct ggml_tensor * ggml_acc(
     GGML_API struct ggml_tensor * ggml_acc(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -838,6 +906,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
+    // sums repetitions in a into shape of b
     GGML_API struct ggml_tensor * ggml_repeat_back(
     GGML_API struct ggml_tensor * ggml_repeat_back(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -902,11 +971,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
+
     GGML_API struct ggml_tensor * ggml_relu_inplace(
     GGML_API struct ggml_tensor * ggml_relu_inplace(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
-    // TODO: double-check this computation is correct
     GGML_API struct ggml_tensor * ggml_gelu(
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
@@ -952,8 +1024,8 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             float                 eps);
             float                 eps);
-
-    GGML_API struct ggml_tensor * ggml_batch_norm(
+    
+     GGML_API struct ggml_tensor * ggml_batch_norm(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * gamma,
             struct ggml_tensor  * gamma,
@@ -993,14 +1065,24 @@ extern "C" {
             struct ggml_tensor  * b,
             struct ggml_tensor  * b,
             float                 eps);
             float                 eps);
 
 
-    // A: n columns, m rows
-    // B: n columns, p rows  (i.e. we transpose it internally)
-    // result is m columns, p rows
+    // A: k columns, n rows => [ne03, ne02, n, k]
+    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
+    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
     GGML_API struct ggml_tensor * ggml_mul_mat(
     GGML_API struct ggml_tensor * ggml_mul_mat(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
+    // indirect matrix multiplication
+    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * const as[],
+            int                   n_as,
+            struct ggml_tensor  * ids,
+            int                   id,
+            struct ggml_tensor  * b);
+
     // A: m columns, n rows,
     // A: m columns, n rows,
     // B: p columns, n rows,
     // B: p columns, n rows,
     // result is m columns, p rows
     // result is m columns, p rows
@@ -1072,7 +1154,6 @@ extern "C" {
             size_t                nb1,
             size_t                nb1,
             size_t                offset);
             size_t                offset);
 
 
-
     // a -> b, return view(b)
     // a -> b, return view(b)
     GGML_API struct ggml_tensor * ggml_cpy(
     GGML_API struct ggml_tensor * ggml_cpy(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
@@ -1095,6 +1176,33 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    // make contiguous, with new shape
+    GGML_API struct ggml_tensor * ggml_cont_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_cont_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    GGML_API struct ggml_tensor * ggml_cont_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_cont_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
     // return view(a), b specifies the new shape
     // return view(a), b specifies the new shape
     // TODO: when we start computing gradient, make a copy instead of view
     // TODO: when we start computing gradient, make a copy instead of view
     GGML_API struct ggml_tensor * ggml_reshape(
     GGML_API struct ggml_tensor * ggml_reshape(
@@ -1182,6 +1290,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    // supports 3D: a->ne[2] == b->ne[1]
     GGML_API struct ggml_tensor * ggml_get_rows(
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -1230,6 +1339,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
             struct ggml_tensor  * a);
 
 
+    // fused soft_max(a*scale + mask)
+    // mask is optional
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
     GGML_API struct ggml_tensor * ggml_soft_max_back(
     GGML_API struct ggml_tensor * ggml_soft_max_back(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -1242,14 +1359,15 @@ extern "C" {
             struct ggml_tensor  * b);
             struct ggml_tensor  * b);
 
 
     // rotary position embedding
     // rotary position embedding
-    // if mode & 1 == 1, skip n_past elements
+    // if mode & 1 == 1, skip n_past elements (DEPRECATED)
     // if mode & 2 == 1, GPT-NeoX style
     // if mode & 2 == 1, GPT-NeoX style
     // if mode & 4 == 1, ChatGLM style
     // if mode & 4 == 1, ChatGLM style
-    // TODO: avoid creating a new tensor every time
+    //
+    // b is an int32 vector with size a->ne[2], it contains the positions
     GGML_API struct ggml_tensor * ggml_rope(
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx);
             int                   n_ctx);
@@ -1258,7 +1376,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_inplace(
     GGML_API struct ggml_tensor * ggml_rope_inplace(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx);
             int                   n_ctx);
@@ -1267,29 +1385,43 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_custom(
     GGML_API struct ggml_tensor * ggml_rope_custom(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
 
 
     // in-place, returns view(a)
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // compute correction dims for YaRN RoPE scaling
+    void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
 
     // xPos RoPE, in-place, returns view(a)
     // xPos RoPE, in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             float                 base,
             float                 base,
             bool                  down);
             bool                  down);
@@ -1299,18 +1431,23 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_back(
     GGML_API struct ggml_tensor * ggml_rope_back(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_base,
             float                 freq_scale,
             float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow,
             float                 xpos_base,
             float                 xpos_base,
             bool                  xpos_down);
             bool                  xpos_down);
 
 
     // alibi position embedding
     // alibi position embedding
     // in-place, returns view(a)
     // in-place, returns view(a)
-    struct ggml_tensor * ggml_alibi(
+    GGML_API struct ggml_tensor * ggml_alibi(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_past,
@@ -1319,27 +1456,33 @@ extern "C" {
 
 
     // clamp
     // clamp
     // in-place, returns view(a)
     // in-place, returns view(a)
-    struct ggml_tensor * ggml_clamp(
+    GGML_API struct ggml_tensor * ggml_clamp(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             float                 min,
             float                 min,
             float                 max);
             float                 max);
 
 
-    GGML_API struct ggml_tensor * ggml_conv_1d(
+    GGML_API struct ggml_tensor * ggml_im2col(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             struct ggml_tensor  * b,
-            int                   s0,  // stride
-            int                   p0,  // padding
-            int                   d0); // dilation
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D);
 
 
-    GGML_API struct ggml_tensor * ggml_conv_1d_generic(
+    GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             struct ggml_tensor  * b,
             int                   s0,  // stride
             int                   s0,  // stride
             int                   p0,  // padding
             int                   p0,  // padding
-            int                   d0); // dilation
+            int                   d0,  // dilation
+            int                   groups // Number of blocked connections from input channels to output channels. Now supports 1 and model_dim (depthwise convolution)
+            ); 
 
 
     // conv_1d with padding = half
     // conv_1d with padding = half
     // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
     // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
@@ -1350,6 +1493,14 @@ extern "C" {
             int                   s,
             int                   s,
             int                   d);
             int                   d);
 
 
+    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s0,
+            int                   p0,
+            int                   d0);
+
     GGML_API struct ggml_tensor * ggml_conv_2d(
     GGML_API struct ggml_tensor * ggml_conv_2d(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -1408,6 +1559,8 @@ extern "C" {
             int                   s0, // stride
             int                   s0, // stride
             int                   p0); // padding
             int                   p0); // padding
 
 
+    // the result will have 2*p0 padding for the first dimension
+    // and 2*p1 padding for the second dimension
     GGML_API struct ggml_tensor * ggml_pool_2d(
     GGML_API struct ggml_tensor * ggml_pool_2d(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -1416,8 +1569,8 @@ extern "C" {
             int                   k1,
             int                   k1,
             int                   s0,
             int                   s0,
             int                   s1,
             int                   s1,
-            int                   p0,
-            int                   p1);
+            float                 p0,
+            float                 p1);
 
 
     // nearest interpolate
     // nearest interpolate
     // used in stable-diffusion
     // used in stable-diffusion
@@ -1426,6 +1579,32 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
             int                   scale_factor);
             int                   scale_factor);
 
 
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ASC,
+        GGML_SORT_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
             struct ggml_tensor  * q,
@@ -1487,7 +1666,6 @@ extern "C" {
             int                   kh);
             int                   kh);
 
 
     // used in sam
     // used in sam
-
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * a,
@@ -1658,19 +1836,22 @@ extern "C" {
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
     GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
 
 
-    GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
-    GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
-
     // graph allocation in a context
     // graph allocation in a context
-    GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);
-    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+    GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
+    GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
+    GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
+
     GGML_API size_t ggml_graph_overhead(void);
     GGML_API size_t ggml_graph_overhead(void);
+    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
 
 
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
-    GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
+    GGML_API int               ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
 
 
     // same as ggml_graph_compute() but the work data is allocated as a part of the context
     // same as ggml_graph_compute() but the work data is allocated as a part of the context
     // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
     // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
@@ -1678,8 +1859,8 @@ extern "C" {
 
 
     GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
     GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
 
 
-    GGML_API void               ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
-    GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+    GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+    GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
 
 
     // print info and performance information for the graph
     // print info and performance information for the graph
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
@@ -1687,6 +1868,16 @@ extern "C" {
     // dump the graph into a file using the dot format
     // dump the graph into a file using the dot format
     GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
     GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
 
+    // build gradient checkpointing backward graph gb for gf using provided checkpoints
+    // gb_tmp will contain original backward graph with rewritten backward process nodes,
+    // but without the second forward pass nodes.
+    GGML_API void ggml_build_backward_gradient_checkpointing(
+            struct ggml_context   * ctx,
+            struct ggml_cgraph    * gf,
+            struct ggml_cgraph    * gb,
+            struct ggml_cgraph    * gb_tmp,
+            struct ggml_tensor  * * checkpoints,
+            int                     n_checkpoints);
     //
     //
     // optimization
     // optimization
     //
     //
@@ -1713,6 +1904,7 @@ extern "C" {
         GGML_OPT_NO_CONTEXT,
         GGML_OPT_NO_CONTEXT,
         GGML_OPT_INVALID_WOLFE,
         GGML_OPT_INVALID_WOLFE,
         GGML_OPT_FAIL,
         GGML_OPT_FAIL,
+        GGML_OPT_CANCEL,
 
 
         GGML_LINESEARCH_FAIL = -128,
         GGML_LINESEARCH_FAIL = -128,
         GGML_LINESEARCH_MINIMUM_STEP,
         GGML_LINESEARCH_MINIMUM_STEP,
@@ -1721,7 +1913,8 @@ extern "C" {
         GGML_LINESEARCH_INVALID_PARAMETERS,
         GGML_LINESEARCH_INVALID_PARAMETERS,
     };
     };
 
 
-    typedef void (*ggml_opt_callback)(void * data, float * sched);
+    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
+    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
 
 
     // optimization parameters
     // optimization parameters
     //
     //
@@ -1730,6 +1923,8 @@ extern "C" {
     struct ggml_opt_params {
     struct ggml_opt_params {
         enum ggml_opt_type type;
         enum ggml_opt_type type;
 
 
+        size_t graph_size;
+
         int n_threads;
         int n_threads;
 
 
         // delta-based convergence test
         // delta-based convergence test
@@ -1752,6 +1947,8 @@ extern "C" {
         bool print_forward_graph;
         bool print_forward_graph;
         bool print_backward_graph;
         bool print_backward_graph;
 
 
+        int n_gradient_accumulation;
+
         // ADAM parameters
         // ADAM parameters
         struct {
         struct {
             int n_iter;
             int n_iter;
@@ -1797,6 +1994,7 @@ extern "C" {
         float loss_after;
         float loss_after;
 
 
         struct {
         struct {
+            struct ggml_tensor * g;  // current gradient
             struct ggml_tensor * m;  // first moment
             struct ggml_tensor * m;  // first moment
             struct ggml_tensor * v;  // second moment
             struct ggml_tensor * v;  // second moment
             struct ggml_tensor * pf; // past function values
             struct ggml_tensor * pf; // past function values
@@ -1860,12 +2058,19 @@ extern "C" {
     // quantization
     // quantization
     //
     //
 
 
+    // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
 
 
+    GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
+
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
 
     //
     //
@@ -1913,26 +2118,27 @@ extern "C" {
 
 
     GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
     GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
     GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
     GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
-
-    // results are undefined if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int i);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int i);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int i);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int i);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int i);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int i);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int i);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int i);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int i);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int i);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int i);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int i);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
+    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
     GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
 
 
     GGML_API int    gguf_get_n_tensors    (const struct gguf_context * ctx);
     GGML_API int    gguf_get_n_tensors    (const struct gguf_context * ctx);
@@ -2001,6 +2207,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
     GGML_API int ggml_cpu_has_f16c       (void);
     GGML_API int ggml_cpu_has_fp16_va    (void);
     GGML_API int ggml_cpu_has_fp16_va    (void);
     GGML_API int ggml_cpu_has_wasm_simd  (void);
     GGML_API int ggml_cpu_has_wasm_simd  (void);
@@ -2038,8 +2245,8 @@ extern "C" {
         enum ggml_type    vec_dot_type;
         enum ggml_type    vec_dot_type;
     } ggml_type_traits_t;
     } ggml_type_traits_t;
 
 
-    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+    GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
 
 
 #ifdef  __cplusplus
 #ifdef  __cplusplus
 }
 }
-#endif
+#endif

+ 182 - 0
ggml/mt.py

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+#
+# This script contains the builder and loader for the MT models. It has some
+# overlaps with fairseq2.models.nllb, except for a few subtle changes
+# in the tokenizer, patches of layers, etc.
+
+from pathlib import Path
+from typing import Any, Mapping, Optional, Literal
+import torch
+from torch.nn.parameter import Parameter
+
+from fairseq2.assets import InProcAssetMetadataProvider, asset_store, download_manager
+from fairseq2.generation.beam_search import BeamSearchSeq2SeqGenerator
+from fairseq2.nn.embedding import StandardEmbedding
+from fairseq2.models.nllb.builder import NllbBuilder, NllbConfig
+from fairseq2.models.nllb.loader import load_nllb_config
+from fairseq2.nn.projection import TiedProjection
+from fairseq2.models.transformer.model import TransformerModel
+from fairseq2.models.utils import ModelLoader
+from fairseq2.typing import Device, DataType
+from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint
+
+import sentencepiece as spm
+
+
+class MTBuilder(NllbBuilder):
+    def build_embedding(self) -> StandardEmbedding:
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=self.config.vocab_info.pad_idx,
+            init_fn=lambda x: x,
+            device=self.device,
+            dtype=self.dtype,
+        ).requires_grad_(False)
+
+    def build_model(self) -> TransformerModel:
+        """Build a model."""
+        encoder_embed = self.build_embedding()
+        decoder_embed = self.build_embedding()
+
+        encoder_frontend = self.build_frontend(encoder_embed)
+        decoder_frontend = self.build_frontend(decoder_embed)
+
+        encoder = self.build_encoder()
+        decoder = self.build_decoder()
+
+        # Unlike NLLB, in MT we de-couple
+        new_weight = Parameter(torch.zeros_like(
+            encoder_embed.weight, requires_grad=False)
+        )
+        final_proj = TiedProjection(new_weight, bias=None)
+
+        return TransformerModel(
+            encoder_frontend,
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            self.config.vocab_info,
+        )
+
+
+def create_mt_model(
+    config: NllbConfig,
+    *,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> TransformerModel:
+    return MTBuilder(config, device=device, dtype=dtype).build_model()
+
+
+def convert_mt_checkpoint(
+    ckpt: Mapping[str, Any], config: NllbConfig,
+) -> Mapping[str, Any]:
+    global_key_map = {
+        # fmt: off
+        r"^encoder\.embed_tokens\.":                              r"encoder_frontend.embed.",
+        r"^decoder\.embed_tokens\.":                              r"decoder_frontend.embed.",
+        r"^encoder\.embed_positions.weights":                     r"encoder_frontend.pos_encoder.freqs",
+        r"^decoder\.embed_positions.weights":                     r"decoder_frontend.pos_encoder.freqs",
+        r"^encoder\.layernorm_embedding\.":                       r"encoder_frontend.layer_norm.",
+        r"^decoder\.layernorm_embedding\.":                       r"decoder_frontend.layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"decoder.layers.\1.encoder_decoder_attn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"decoder.layers.\1.encoder_decoder_attn.",
+        r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+        r"^encoder\.layers\.([0-9]+)\.fc1\.":                     r"encoder.layers.\1.ffn.inner_proj.",
+        r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"decoder.layers.\1.ffn.inner_proj.",
+        r"^encoder\.layers\.([0-9]+)\.fc2\.":                     r"encoder.layers.\1.ffn.output_proj.",
+        r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"decoder.layers.\1.ffn.output_proj.",
+        r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"encoder.layers.\1.ffn_layer_norm.",
+        r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"decoder.layers.\1.ffn_layer_norm.",
+        r"^decoder\.output_projection\.":                         r"final_proj.",
+        # fmt: on
+    }
+    return convert_fairseq_checkpoint(ckpt, global_key_map)
+
+
+def load_vocab(model_dir: str, mode: Literal["src", "tgt"]):
+    vocab_file = f"{model_dir}/{mode}.spm"
+    spmp = spm.SentencePieceProcessor(vocab_file)
+
+    return [
+        (spmp.id_to_piece(id).replace("▁", " "), spmp.get_score(id))
+        for id in range(spmp.get_piece_size())
+    ], spmp
+
+
+def load_mt_model(model_dir: str):
+    """
+    Load MT model and the vocabulary processors (spm) for source and target languages
+    Args:
+        model_dir: Directory of the model. It must contain files averaged_checkpoint.pt, src.spm and tgt.spm
+    """
+
+    # Create a fairseq2 model card on the fly. This must ensure that we do not have any other fairseq2
+    # environment resolvers and always return
+    model_dir = Path(model_dir)
+    model_card_info = [
+        {
+            "name": "mt_model",
+            "model_type": "nllb",  # Re-use the same encoder-decoder arch of NLLB
+            "model_arch": "dense_600m",  # Dummy value to pass fairseq2 asset's valdilation logic
+            "checkpoint": "file://" + str(model_dir / "averaged_checkpoint.pt"),
+            "model_config": {
+                "model_dim": 512,
+                "num_encoder_layers": 4,
+                "num_decoder_layers": 2,
+                "ffn_inner_dim": 2048,
+                "vocab_info": {
+                    "size": 10000,
+                    "unk_idx": 3,
+                    "bos_idx": 0,
+                    "eos_idx": 2,
+                    "pad_idx": 1,
+                }
+            }
+        }
+    ]
+    asset_store.metadata_providers.append(
+        InProcAssetMetadataProvider(model_card_info)
+    )
+    mt_card = asset_store.retrieve_card("mt_model")
+
+    return ModelLoader[TransformerModel, NllbConfig](
+        asset_store,
+        download_manager,
+        load_nllb_config,
+        create_mt_model,
+        convert_mt_checkpoint,
+        restrict_checkpoints=False,
+    )(mt_card)
+
+
+def test_mt(
+    model: TransformerModel,
+    src_spm: spm.SentencePieceProcessor,
+    tgt_spm: spm.SentencePieceProcessor,
+):
+    from fairseq2.nn.padding import pad_seqs
+
+    # Tokens of "This is an example"
+    src_tokens = torch.LongTensor([688, 153, 62, 4581, 2])
+    src_seqs, src_padding_mask = pad_seqs(src_tokens, src_spm.pad_id())
+
+    # Force the developer begins with the EOS <s> token
+    prompt_tokens = torch.LongTensor([[2]])
+    generator = BeamSearchSeq2SeqGenerator(model)
+    output = generator(src_seqs, src_padding_mask, prompt_tokens, None)
+
+    print(output.hypotheses[0][0].seq)
+    tgt_tokens = output.hypotheses[0][0].seq.tolist()
+    out_text = tgt_spm.decode(tgt_tokens)
+
+    # assert out_text == "Este es un ejemplo"
+    print(out_text)

+ 3 - 0
ggml/requirements.txt

@@ -4,3 +4,6 @@ sentencepiece==0.1.98
 torch==2.0.1
 torch==2.0.1
 torchaudio==2.0.2
 torchaudio==2.0.2
 torchvision==0.15.2
 torchvision==0.15.2
+transformers==4.29.2
+fairseq2==0.2.1
+func_argparse

+ 94 - 12
ggml/src/CMakeLists.txt

@@ -26,6 +26,15 @@ if (NOT UNAME_M)
 endif()
 endif()
 #message(STATUS "UNAME_S: ${UNAME_S}  UNAME_P: ${UNAME_P}  UNAME_M: ${UNAME_M}")
 #message(STATUS "UNAME_S: ${UNAME_S}  UNAME_P: ${UNAME_P}  UNAME_M: ${UNAME_M}")
 
 
+# this version of Apple ld64 is buggy
+execute_process(
+    COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
+    ERROR_VARIABLE output
+)
+if (output MATCHES "dyld-1015\.7")
+    add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
+endif()
+
 # Mac OS + Arm can report x86_64
 # Mac OS + Arm can report x86_64
 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
 if (UNAME_S MATCHES "Darwin")
 if (UNAME_S MATCHES "Darwin")
@@ -162,7 +171,7 @@ if (GGML_OPENBLAS)
 
 
         set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${OPENBLAS_LIB})
         set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${OPENBLAS_LIB})
         set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${OPENBLAS_INC})
         set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${OPENBLAS_INC})
-	set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
     else()
     else()
         message(WARNING "OpenBLAS not found")
         message(WARNING "OpenBLAS not found")
     endif()
     endif()
@@ -177,12 +186,12 @@ if (GGML_CLBLAST)
         )
         )
 	find_path(CLBLAST_INC NAMES clblast.h PATHS ${CLBLAST_INCLUDE_SEARCH_PATHS})
 	find_path(CLBLAST_INC NAMES clblast.h PATHS ${CLBLAST_INCLUDE_SEARCH_PATHS})
 	find_library(CLBLAST_LIB NAMES clblast)
 	find_library(CLBLAST_LIB NAMES clblast)
-	if (CLBLAST_LIB AND CLBLAST_INC)
+	find_library(OPENCL_LIB NAMES OpenCL)
+	if (CLBLAST_LIB AND OPENCL_LIB AND CLBLAST_INC)
 		message(STATUS "clBLAST found")
 		message(STATUS "clBLAST found")
 
 
-
 		set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${CLBLAST_INC})
 		set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${CLBLAST_INC})
-		set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${CLBLAST_LIB})
+		set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${CLBLAST_LIB}  ${OPENCL_LIB})
 		set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CLBLAST)
 		set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CLBLAST)
 
 
 		set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
 		set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
@@ -204,7 +213,17 @@ if (GGML_CUBLAS)
 
 
         set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
         set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
 
 
-        add_compile_definitions(GGML_USE_CUBLAS)
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
+
+        if (GGML_CUDA_FORCE_DMMV)
+            add_compile_definitions(GGML_CUDA_FORCE_DMMV)
+        endif()
+        if (GGML_CUDA_FORCE_MMQ)
+            add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+        endif()
+
+        # required for dynamic parallelism
+        # set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
 
 
         if (GGML_STATIC)
         if (GGML_STATIC)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
@@ -212,11 +231,59 @@ if (GGML_CUBLAS)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
         endif()
         endif()
 
 
+        if (CMAKE_BUILD_TYPE MATCHES Debug)
+            set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo")
+        endif()
     else()
     else()
         message(WARNING "cuBLAS not found")
         message(WARNING "cuBLAS not found")
     endif()
     endif()
 endif()
 endif()
 
 
+if (GGML_HIPBLAS)
+    list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
+
+    if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
+    endif()
+    if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+        message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+    endif()
+
+    find_package(hip)
+    find_package(hipblas)
+    find_package(rocblas)
+
+    if (${hipblas_FOUND} AND ${hip_FOUND})
+        message(STATUS "HIP and hipBLAS found")
+
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
+
+        add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
+        if (BUILD_SHARED_LIBS)
+            set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
+        endif()
+        if (GGML_CUDA_FORCE_DMMV)
+            target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
+        endif()
+        if (GGML_CUDA_FORCE_MMQ)
+            target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ)
+        endif()
+        target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
+        target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
+        target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
+        set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
+        target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
+        target_include_directories(ggml-rocm PRIVATE . ../include ../include/ggml)
+
+        if (GGML_STATIC)
+            message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+        endif()
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ggml-rocm)
+    else()
+        message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
+    endif()
+endif()
+
 if (GGML_METAL)
 if (GGML_METAL)
     find_library(FOUNDATION_LIBRARY         Foundation              REQUIRED)
     find_library(FOUNDATION_LIBRARY         Foundation              REQUIRED)
     find_library(METAL_FRAMEWORK            Metal                   REQUIRED)
     find_library(METAL_FRAMEWORK            Metal                   REQUIRED)
@@ -225,8 +292,9 @@ if (GGML_METAL)
 
 
     set(GGML_METAL_SOURCES ggml-metal.m ggml-metal.h)
     set(GGML_METAL_SOURCES ggml-metal.m ggml-metal.h)
 
 
-    add_compile_definitions(GGML_USE_METAL)
-    add_compile_definitions(GGML_METAL_NDEBUG)
+    set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_METAL)
+
+    #add_compile_definitions(GGML_METAL_NDEBUG)
 
 
     # get full path to the file
     # get full path to the file
     #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
     #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
@@ -249,8 +317,13 @@ endif()
 add_library(${TARGET}
 add_library(${TARGET}
     ggml.c
     ggml.c
     ggml-alloc.c
     ggml-alloc.c
+    ggml-backend.c
+    ggml-quants.c
+    ggml-impl.h
+    ggml-backend-impl.h
     ../include/ggml/ggml.h
     ../include/ggml/ggml.h
     ../include/ggml/ggml-alloc.h
     ../include/ggml/ggml-alloc.h
+    ../include/ggml/ggml-backend.h
     ${GGML_CUDA_SOURCES}
     ${GGML_CUDA_SOURCES}
     ${GGML_OPENCL_SOURCES}
     ${GGML_OPENCL_SOURCES}
     ${GGML_METAL_SOURCES}
     ${GGML_METAL_SOURCES}
@@ -301,8 +374,16 @@ if (MINGW)
 endif()
 endif()
 
 
 if (GGML_CUDA_SOURCES)
 if (GGML_CUDA_SOURCES)
-    message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
-    set_property(TARGET ggml  PROPERTY CUDA_ARCHITECTURES "52;61")
+    message(STATUS "GGML CUDA sources found")
+    if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+        # Only configure gmml CUDA architectures is not globally set
+        if (NOT DEFINED GGML_CUDA_ARCHITECTURES)
+            # Not overriden by user, so set defaults
+            set(GGML_CUDA_ARCHITECTURES 52 61 70)
+        endif()
+        message(STATUS "GGML Configuring CUDA architectures ${GGML_CUDA_ARCHITECTURES}")
+        set_property(TARGET ggml  PROPERTY CUDA_ARCHITECTURES ${GGML_CUDA_ARCHITECTURES})
+    endif()
     set_property(TARGET ggml  PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
     set_property(TARGET ggml  PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
     if (NOT MSVC)
     if (NOT MSVC)
         target_link_libraries(ggml PUBLIC stdc++)
         target_link_libraries(ggml PUBLIC stdc++)
@@ -311,12 +392,13 @@ endif()
 
 
 set (GGML_PUBLIC_HEADERS
 set (GGML_PUBLIC_HEADERS
      ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml.h
      ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml.h
-     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-alloc.h)
+     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-alloc.h
+     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-backend.h)
+
 set_target_properties(${TARGET} PROPERTIES
 set_target_properties(${TARGET} PROPERTIES
                       PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
                       PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
 
 
 install(TARGETS ${TARGET}
 install(TARGETS ${TARGET}
     LIBRARY DESTINATION lib
     LIBRARY DESTINATION lib
-    ARCHIVE DESTINATION lib/static
     PUBLIC_HEADER DESTINATION include/ggml
     PUBLIC_HEADER DESTINATION include/ggml
-    )
+    )

+ 453 - 284
ggml/src/ggml-alloc.c

@@ -1,69 +1,21 @@
 #include "ggml-alloc.h"
 #include "ggml-alloc.h"
+#include "ggml-backend-impl.h"
 #include "ggml.h"
 #include "ggml.h"
+#include "ggml-impl.h"
 #include <assert.h>
 #include <assert.h>
+#include <limits.h>
 #include <stdarg.h>
 #include <stdarg.h>
 #include <stdio.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <stdlib.h>
 #include <string.h>
 #include <string.h>
 
 
-#ifdef __has_include
-    #if __has_include(<unistd.h>)
-        #include <unistd.h>
-        #if defined(_POSIX_MAPPED_FILES)
-            #include <sys/types.h>
-            #include <sys/mman.h>
-        #endif
-    #endif
-#endif
-
-#if defined(_WIN32)
-    #define WIN32_LEAN_AND_MEAN
-    #ifndef NOMINMAX
-        #define NOMINMAX
-    #endif
-    #include <windows.h>
-    #include <memoryapi.h>
-#endif
-
-
-#define UNUSED(x) (void)(x)
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+#define MAX_FREE_BLOCKS 256
 
 
 //#define GGML_ALLOCATOR_DEBUG
 //#define GGML_ALLOCATOR_DEBUG
 
 
-//#define AT_PRINTF printf
-#define AT_PRINTF(...) ((void)0)
-
-struct hash_node {
-    struct ggml_tensor * t;
-    int n_children;
-    int n_views;
-};
-
-static size_t hash(void * p) {
-    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
-    size_t h = hash(t);
-
-    // linear probing
-    size_t i = h;
-    while (hash_table[i].t != NULL) {
-        if (hash_table[i].t == t) {
-            return &hash_table[i];
-        }
-        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
-        if (i == h) {
-            // hash table is full
-            GGML_ASSERT(false);
-        }
-    }
-
-    hash_table[i].t = t;
-    return &hash_table[i];
-}
+//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+#define AT_PRINTF(...)
 
 
 // TODO: GGML_PAD ?
 // TODO: GGML_PAD ?
 static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
 static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
@@ -77,19 +29,18 @@ struct free_block {
     size_t size;
     size_t size;
 };
 };
 
 
-#define MAX_FREE_BLOCKS 128
-
-struct ggml_allocr {
-    void * data;
-    size_t size;
+struct ggml_tallocr {
+    struct ggml_backend_buffer * buffer;
+    bool buffer_owned;
+    void * base;
     size_t alignment;
     size_t alignment;
+
     int n_free_blocks;
     int n_free_blocks;
     struct free_block free_blocks[MAX_FREE_BLOCKS];
     struct free_block free_blocks[MAX_FREE_BLOCKS];
-    struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
     size_t max_size;
     size_t max_size;
+
     bool measure;
     bool measure;
-    int parse_seq[GGML_MAX_CONCUR];
-    int parse_seq_len;
 
 
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
     struct ggml_tensor * allocated_tensors[1024];
     struct ggml_tensor * allocated_tensors[1024];
@@ -97,7 +48,7 @@ struct ggml_allocr {
 };
 };
 
 
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == NULL) {
         if (alloc->allocated_tensors[i] == NULL) {
             alloc->allocated_tensors[i] = tensor;
             alloc->allocated_tensors[i] = tensor;
@@ -106,7 +57,7 @@ static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor
     }
     }
     GGML_ASSERT(!"out of allocated_tensors");
     GGML_ASSERT(!"out of allocated_tensors");
 }
 }
-static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == tensor ||
         if (alloc->allocated_tensors[i] == tensor ||
             (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
             (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
@@ -119,24 +70,20 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
 }
 }
 #endif
 #endif
 
 
-static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
-    return ggml_nbytes(tensor);
-
-    UNUSED(alloc);
+// check if a tensor is allocated by this buffer
+static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
+    return tensor->buffer == alloc->buffer;
 }
 }
 
 
-// check if a tensor is allocated by this buffer
-static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
-    void * ptr = tensor->data;
-    return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
+static bool ggml_is_view(struct ggml_tensor * t) {
+    return t->view_src != NULL;
 }
 }
 
 
-void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
-#ifdef GGML_ALLOCATOR_DEBUG
+void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
     GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
     GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
     GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
     GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
-#endif
-    size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
+
+    size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
     size = aligned_offset(NULL, size, alloc->alignment);
     size = aligned_offset(NULL, size, alloc->alignment);
 
 
     AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
     AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -183,10 +130,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
     }
 
 
     tensor->data = addr;
     tensor->data = addr;
+    tensor->buffer = alloc->buffer;
+    if (!alloc->measure) {
+        ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
+    }
 
 
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
     add_allocated_tensor(alloc, tensor);
-    size_t cur_max = (char*)addr - (char*)alloc->data + size;
+    size_t cur_max = (char*)addr - (char*)alloc->base + size;
     if (cur_max > alloc->max_size) {
     if (cur_max > alloc->max_size) {
         printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
         for (int i = 0; i < 1024; i++) {
@@ -198,23 +149,24 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
     }
 #endif
 #endif
 
 
-    alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
+    alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
 }
 }
 
 
 // this is a very naive implementation, but for our case the number of free blocks should be very small
 // this is a very naive implementation, but for our case the number of free blocks should be very small
-static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
-    void * ptr = tensor->data;
-
-    if (ggml_allocr_is_own(alloc, tensor) == false) {
+static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
+    if (ggml_tallocr_is_own(alloc, tensor) == false) {
         // the tensor was not allocated in this buffer
         // the tensor was not allocated in this buffer
         // this can happen because the graph allocator will try to free weights and other tensors from different buffers
         // this can happen because the graph allocator will try to free weights and other tensors from different buffers
         // the easiest way to deal with this is just to ignore it
         // the easiest way to deal with this is just to ignore it
+        // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
         return;
         return;
     }
     }
 
 
-    size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
+    void * ptr = tensor->data;
+
+    size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
     size = aligned_offset(NULL, size, alloc->alignment);
     size = aligned_offset(NULL, size, alloc->alignment);
-    AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
+    AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
 
 
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
     remove_allocated_tensor(alloc, tensor);
@@ -268,139 +220,179 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
     alloc->n_free_blocks++;
     alloc->n_free_blocks++;
 }
 }
 
 
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
-    for (int i = 0; i < n; i++) {
-        alloc->parse_seq[i] = list[i];
+void ggml_tallocr_reset(ggml_tallocr_t alloc) {
+    alloc->n_free_blocks = 1;
+    size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
+    alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
+
+    if (alloc->measure) {
+        alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+    } else {
+        alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
     }
     }
-    alloc->parse_seq_len = n;
 }
 }
 
 
-void ggml_allocr_reset(struct ggml_allocr * alloc) {
-    alloc->n_free_blocks = 1;
-    size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
-    alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
-    alloc->free_blocks[0].size = alloc->size - align_offset;
-}
+ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
+    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
 
 
-struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
-    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
+    ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
 
-    *alloc = (struct ggml_allocr){
-        /*.data          = */ data,
-        /*.size          = */ size,
+    *alloc = (struct ggml_tallocr) {
+        /*.buffer        = */ buffer,
+        /*.buffer_owned  = */ true,
+        /*.base          = */ ggml_backend_buffer_get_base(buffer),
         /*.alignment     = */ alignment,
         /*.alignment     = */ alignment,
         /*.n_free_blocks = */ 0,
         /*.n_free_blocks = */ 0,
         /*.free_blocks   = */ {{0}},
         /*.free_blocks   = */ {{0}},
-        /*.hash_table    = */ {{0}},
         /*.max_size      = */ 0,
         /*.max_size      = */ 0,
         /*.measure       = */ false,
         /*.measure       = */ false,
-        /*.parse_seq     = */ {0},
-        /*.parse_seq_len = */ 0,
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
         /*.allocated_tensors = */ {0},
 #endif
 #endif
     };
     };
 
 
-    ggml_allocr_reset(alloc);
+    ggml_tallocr_reset(alloc);
 
 
     return alloc;
     return alloc;
 }
 }
 
 
-// OS specific functions to allocate and free uncommitted virtual memory
-static void * alloc_vmem(size_t size) {
-#if defined(_WIN32)
-    return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
-#elif defined(_POSIX_MAPPED_FILES)
-    void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
-    if (ptr == MAP_FAILED) {
-        return NULL;
-    }
-    return ptr;
-#else
-    // use a fixed address for other platforms
-    uintptr_t base_addr = (uintptr_t)-size - 0x100;
-    return (void *)base_addr;
-#endif
-}
+ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
+    ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
+    alloc->measure = true;
 
 
-static void free_vmem(void * base_addr, size_t size) {
-#if defined(_WIN32)
-    VirtualFree(base_addr, 0, MEM_RELEASE);
-    UNUSED(size);
-#elif defined(_POSIX_MAPPED_FILES)
-    munmap(base_addr, size);
-#else
-    // nothing to do
-    UNUSED(base_addr);
-    UNUSED(size);
-#endif
+    return alloc;
 }
 }
 
 
-// allocate uncommitted virtual memory to measure the size of the graph
-static void alloc_measure_vmem(void ** base_addr, size_t * size) {
-    // 1TB for 64-bit, 1GB for 32-bit
-    *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40;
-    do {
-        *base_addr = alloc_vmem(*size);
-        if (*base_addr != NULL) {
-            AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
-            return;
-        }
-        // try again with half the size
-        *size /= 2;
-    } while (*size > 0);
+ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
+    // create a backend buffer to get the correct tensor allocation sizes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
 
 
-    GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
+    // TODO: move alloc initialization to a common ggml_tallocr_new_impl function
+    ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+    alloc->buffer_owned = true;
+    alloc->measure = true;
+    ggml_tallocr_reset(alloc);
+    return alloc;
 }
 }
 
 
-static void free_measure_vmem(void * base_addr, size_t size) {
-    free_vmem(base_addr, size);
+ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
+    ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+    alloc->buffer_owned = true;
+    return alloc;
 }
 }
 
 
-struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
-    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
+ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+    ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
 
-    void * base_addr;
-    size_t size;
-
-    alloc_measure_vmem(&base_addr, &size);
-
-    *alloc = (struct ggml_allocr){
-        /*.data          = */ base_addr,
-        /*.size          = */ size,
-        /*.alignment     = */ alignment,
+    *alloc = (struct ggml_tallocr) {
+        /*.buffer        = */ buffer,
+        /*.buffer_owned  = */ false,
+        /*.base          = */ ggml_backend_buffer_get_base(buffer),
+        /*.alignment     = */ ggml_backend_buffer_get_alignment(buffer),
         /*.n_free_blocks = */ 0,
         /*.n_free_blocks = */ 0,
         /*.free_blocks   = */ {{0}},
         /*.free_blocks   = */ {{0}},
-        /*.hash_table    = */ {{0}},
         /*.max_size      = */ 0,
         /*.max_size      = */ 0,
-        /*.measure       = */ true,
-        /*.parse_seq     = */ {0},
-        /*.parse_seq_len = */ 0,
+        /*.measure       = */ false,
 #ifdef GGML_ALLOCATOR_DEBUG
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
         /*.allocated_tensors = */ {0},
 #endif
 #endif
     };
     };
 
 
-    ggml_allocr_reset(alloc);
+    ggml_tallocr_reset(alloc);
 
 
     return alloc;
     return alloc;
 }
 }
 
 
-void ggml_allocr_free(struct ggml_allocr * alloc) {
-    if (alloc->measure) {
-        free_measure_vmem(alloc->data, alloc->size);
+struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
+    return alloc->buffer;
+}
+
+void ggml_tallocr_free(ggml_tallocr_t alloc) {
+    if (alloc == NULL) {
+        return;
+    }
+
+    if (alloc->buffer_owned) {
+        ggml_backend_buffer_free(alloc->buffer);
     }
     }
     free(alloc);
     free(alloc);
 }
 }
 
 
-bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
+bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
     return alloc->measure;
     return alloc->measure;
 }
 }
 
 
-//////////// compute graph allocator
+size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
+    return alloc->max_size;
+}
 
 
-static bool ggml_is_view(struct ggml_tensor * t) {
-    return t->view_src != NULL;
+// graph allocator
+
+struct hash_node {
+    int n_children;
+    int n_views;
+};
+
+struct ggml_gallocr {
+    ggml_tallocr_t talloc;
+    struct ggml_hash_set hash_set;
+    struct hash_node * hash_values;
+    size_t hash_values_size;
+    ggml_tallocr_t * hash_allocs;
+    int * parse_seq;
+    int parse_seq_len;
+};
+
+ggml_gallocr_t ggml_gallocr_new(void) {
+    ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
+
+    *galloc = (struct ggml_gallocr) {
+        /*.talloc           = */ NULL,
+        /*.hash_set         = */ {0},
+        /*.hash_values      = */ NULL,
+        /*.hash_values_size = */ 0,
+        /*.hash_allocs      = */ NULL,
+        /*.parse_seq        = */ NULL,
+        /*.parse_seq_len    = */ 0,
+    };
+
+    return galloc;
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+    if (galloc == NULL) {
+        return;
+    }
+
+    if (galloc->hash_set.keys != NULL) {
+        free(galloc->hash_set.keys);
+    }
+    if (galloc->hash_values != NULL) {
+        free(galloc->hash_values);
+    }
+    if (galloc->hash_allocs != NULL) {
+        free(galloc->hash_allocs);
+    }
+    if (galloc->parse_seq != NULL) {
+        free(galloc->parse_seq);
+    }
+    free(galloc);
+}
+
+void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
+    free(galloc->parse_seq);
+    galloc->parse_seq = malloc(sizeof(int) * n);
+
+    for (int i = 0; i < n; i++) {
+        galloc->parse_seq[i] = list[i];
+    }
+    galloc->parse_seq_len = n;
+}
+
+static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
+    return &galloc->hash_values[i];
 }
 }
 
 
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
@@ -435,7 +427,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_ROPE:
         case GGML_OP_ROPE:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX:
-        case GGML_OP_CONT:
             return true;
             return true;
 
 
         default:
         default:
@@ -443,12 +434,39 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
     }
     }
 }
 }
 
 
-static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
-    struct hash_node * ht = alloc->hash_table;
+static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    if (galloc->talloc != NULL) {
+        return galloc->talloc;
+    }
+
+    return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
+}
+
+static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, view);
+
+    GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
+    if (update_backend) {
+        view->backend = view->view_src->backend;
+    }
+    view->buffer  = view->view_src->buffer;
+    view->data    = (char *)view->view_src->data + view->view_offs;
+
+    // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
+    // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
+    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
+
+    if (!alloc->measure) {
+        ggml_backend_buffer_init_tensor(alloc->buffer, view);
+    }
+}
+
+static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, node);
+
     if (node->data == NULL) {
     if (node->data == NULL) {
         if (ggml_is_view(node)) {
         if (ggml_is_view(node)) {
-            assert(node->view_src->data != NULL);
-            node->data = (char *)node->view_src->data + node->view_offs;
+            init_view(galloc, node, true);
         } else {
         } else {
             // see if we can reuse a parent's buffer (inplace)
             // see if we can reuse a parent's buffer (inplace)
             if (ggml_op_can_inplace(node->op)) {
             if (ggml_op_can_inplace(node->op)) {
@@ -459,16 +477,16 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                     }
                     }
 
 
                     // if the node's data is external, then we cannot re-use it
                     // if the node's data is external, then we cannot re-use it
-                    if (ggml_allocr_is_own(alloc, parent) == false) {
+                    if (ggml_tallocr_is_own(alloc, parent) == false) {
                         AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
                         AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
                         continue;
                         continue;
                     }
                     }
 
 
-                    struct hash_node * p_hn = hash_get(ht, parent);
+                    struct hash_node * p_hn = hash_get(galloc, parent);
                     if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
                     if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
                         if (ggml_is_view(parent)) {
                         if (ggml_is_view(parent)) {
                             struct ggml_tensor * view_src = parent->view_src;
                             struct ggml_tensor * view_src = parent->view_src;
-                            struct hash_node * view_src_hn = hash_get(ht, view_src);
+                            struct hash_node * view_src_hn = hash_get(galloc, view_src);
                             if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
                             if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
                                 // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
                                 // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
                                 // the parent's data that it will need later (same layout requirement). the problem is that then
                                 // the parent's data that it will need later (same layout requirement). the problem is that then
@@ -476,158 +494,309 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                                 // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
                                 // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
                                 // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
                                 // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
                                 AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
                                 AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
-                                node->data = parent->data;
+                                node->view_src = view_src;
+                                view_src_hn->n_views += 1;
+                                init_view(galloc, node, false);
                                 return;
                                 return;
                             }
                             }
-                        }
-                        else {
+                        } else {
                             AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
                             AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
-                            node->data = parent->data;
+                            node->view_src = parent;
+                            p_hn->n_views += 1;
+                            init_view(galloc, node, false);
                             return;
                             return;
                         }
                         }
                     }
                     }
                 }
                 }
             }
             }
-            ggml_allocr_alloc(alloc, node);
+            ggml_tallocr_alloc(alloc, node);
         }
         }
     }
     }
 }
 }
 
 
-static size_t ggml_allocr_alloc_graph_tensors_n(
-    struct ggml_allocr * alloc,
-    struct ggml_cgraph ** graphs, int n_graphs,
-    struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
+static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    ggml_tallocr_t alloc = node_tallocr(galloc, node);
 
 
-    // reset hash table
-    struct hash_node * ht = alloc->hash_table;
-    memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
+    ggml_tallocr_free_tensor(alloc, node);
+}
+
+static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
+    const int * parse_seq     = galloc->parse_seq;
+    int         parse_seq_len = galloc->parse_seq_len;
 
 
     // count number of children and views
     // count number of children and views
-    for (int g = 0; g < n_graphs; g++) {
-        struct ggml_cgraph * gf = graphs[g];
-        for (int i = 0; i < gf->n_nodes; i++) {
-            struct ggml_tensor * node = gf->nodes[i];
+    for (int i = 0; i < gf->n_nodes; i++) {
+        struct ggml_tensor * node = gf->nodes[i];
 
 
-            if (ggml_is_view(node)) {
-                struct ggml_tensor * view_src = node->view_src;
-                hash_get(ht, view_src)->n_views += 1;
+        if (ggml_is_view(node)) {
+            struct ggml_tensor * view_src = node->view_src;
+            hash_get(galloc, view_src)->n_views += 1;
+            if (node->buffer == NULL && node->data != NULL) {
+                // view of a pre-allocated tensor, didn't call init_view() yet
+                init_view(galloc, node, true);
             }
             }
+        }
 
 
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                break;
+            }
+            hash_get(galloc, parent)->n_children += 1;
+            if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
+                init_view(galloc, parent, true);
+            }
+        }
+   }
+
+    // allocate tensors
+    // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
+    int last_barrier_pos = 0;
+    int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
+
+    for (int ind = 0; ind < n_nodes; ind++) {
+        // allocate a node if there is no parse_seq or this is not a barrier
+        if (parse_seq_len == 0 || parse_seq[ind] != -1) {
+            int i = parse_seq_len ? parse_seq[ind] : ind;
+            struct ggml_tensor * node = gf->nodes[i];
+
+            // allocate parents (leafs)
             for (int j = 0; j < GGML_MAX_SRC; j++) {
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * parent = node->src[j];
                 struct ggml_tensor * parent = node->src[j];
                 if (parent == NULL) {
                 if (parent == NULL) {
                     break;
                     break;
                 }
                 }
-                hash_get(ht, parent)->n_children += 1;
+                allocate_node(galloc, parent);
             }
             }
-        }
-    }
 
 
-    // allocate tensors
-    for (int g = 0; g < n_graphs; g++) {
-        struct ggml_cgraph * gf = graphs[g];
-        AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
-        // graph inputs are allocated first to ensure that they are not overwritten by each other
-        if (inputs != NULL && inputs[g] != NULL) {
-            for (int i = 0; inputs[g][i] != NULL; i++) {
-                struct ggml_tensor * input = inputs[g][i];
-                AT_PRINTF("input: %s\n", input->name);
-                allocate_node(alloc, input);
+            // allocate node
+            allocate_node(galloc, node);
+
+            AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
+                }
+                AT_PRINTF("%s", parent->name);
+                if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                    AT_PRINTF(", ");
+                }
             }
             }
+            AT_PRINTF("\n");
         }
         }
-        // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
-        int last_barrier_pos = 0;
-        int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
 
 
-        for (int ind = 0; ind < n_nodes; ind++) {
-            // allocate a node if there is no parse_seq or this is not a barrier
-            if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
-                int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
-                struct ggml_tensor * node = gf->nodes[i];
+        // update parents
+        // update immediately if there is no parse_seq
+        // update only at barriers if there is parse_seq
+        if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
+            int update_start = parse_seq_len ? last_barrier_pos : ind;
+            int update_end   = parse_seq_len ? ind              : ind + 1;
+            for (int i = update_start; i < update_end; i++) {
+                int node_i = parse_seq_len ? parse_seq[i] : i;
+                struct ggml_tensor * node = gf->nodes[node_i];
 
 
-                // allocate parents (leafs)
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
                     struct ggml_tensor * parent = node->src[j];
                     struct ggml_tensor * parent = node->src[j];
                     if (parent == NULL) {
                     if (parent == NULL) {
                         break;
                         break;
                     }
                     }
-                    allocate_node(alloc, parent);
-                }
+                    struct hash_node * p_hn = hash_get(galloc, parent);
+                    p_hn->n_children -= 1;
 
 
-                // allocate node
-                allocate_node(alloc, node);
+                    //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
 
 
-                AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
-                for (int j = 0; j < GGML_MAX_SRC; j++) {
-                    struct ggml_tensor * parent = node->src[j];
-                    if (parent == NULL) {
-                        break;
-                    }
-                    AT_PRINTF("%s", parent->name);
-                    if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
-                        AT_PRINTF(", ");
-                    }
-                }
-                AT_PRINTF("\n");
-            }
-
-            // update parents
-            // update immediately if there is no parse_seq
-            // update only at barriers if there is parse_seq
-            if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
-                int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
-                int update_end   = alloc->parse_seq_len ? ind              : ind + 1;
-                for (int i = update_start; i < update_end; i++) {
-                    int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
-                    struct ggml_tensor * node = gf->nodes[node_i];
-
-                    for (int j = 0; j < GGML_MAX_SRC; j++) {
-                        struct ggml_tensor * parent = node->src[j];
-                        if (parent == NULL) {
-                            break;
-                        }
-                        struct hash_node * p_hn = hash_get(ht, parent);
-                        p_hn->n_children -= 1;
-
-                        //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
-
-                        if (p_hn->n_children == 0 && p_hn->n_views == 0) {
-                            if (ggml_is_view(parent)) {
-                                struct ggml_tensor * view_src = parent->view_src;
-                                struct hash_node * view_src_hn = hash_get(ht, view_src);
-                                view_src_hn->n_views -= 1;
-                                AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
-                                if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
-                                    ggml_allocr_free_tensor(alloc, view_src);
-                                }
-                            }
-                            else {
-                                if (parent->data != node->data) {
-                                    ggml_allocr_free_tensor(alloc, parent);
-                                }
+                    if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                        if (ggml_is_view(parent)) {
+                            struct ggml_tensor * view_src = parent->view_src;
+                            struct hash_node * view_src_hn = hash_get(galloc, view_src);
+                            view_src_hn->n_views -= 1;
+                            AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+                            if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
+                                free_node(galloc, view_src);
                             }
                             }
                         }
                         }
+                        else {
+                            free_node(galloc, parent);
+                        }
                     }
                     }
                 }
                 }
-                AT_PRINTF("\n");
-                if (alloc->parse_seq_len) {
-                    last_barrier_pos = ind + 1;
-                }
             }
             }
-        }
-        // free graph outputs here that wouldn't be freed otherwise because they have no children
-        if (outputs != NULL && outputs[g] != NULL) {
-            for (int i = 0; outputs[g][i] != NULL; i++) {
-                struct ggml_tensor * output = outputs[g][i];
-                AT_PRINTF("output: %s\n", output->name);
-                ggml_allocr_free_tensor(alloc, output);
+            AT_PRINTF("\n");
+            if (parse_seq_len) {
+                last_barrier_pos = ind + 1;
             }
             }
         }
         }
     }
     }
+}
 
 
-    return alloc->max_size;
+size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
+    size_t hash_size = graph->visited_hash_table.size;
+
+    // check if the hash table is initialized and large enough
+    if (galloc->hash_set.size < hash_size) {
+        if (galloc->hash_set.keys != NULL) {
+            free(galloc->hash_set.keys);
+        }
+        if (galloc->hash_values != NULL) {
+            free(galloc->hash_values);
+        }
+        galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
+        galloc->hash_set.size = hash_size;
+        galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
+    }
+
+    // reset hash table
+    memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
+    memset(galloc->hash_values,   0, sizeof(struct hash_node) * hash_size);
+
+    galloc->talloc = talloc;
+    ggml_tallocr_alloc_graph_impl(galloc, graph);
+    galloc->talloc = NULL;
+
+    size_t max_size = ggml_tallocr_max_size(talloc);
+
+    return max_size;
 }
 }
 
 
-size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
-    return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
+void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
+    const size_t hash_size = hash_set.size;
+
+    GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
+
+    galloc->talloc = NULL;
+
+    // alloc hash_values if needed
+    if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
+        free(galloc->hash_values);
+        galloc->hash_values      = malloc(sizeof(struct hash_node) * hash_size);
+        galloc->hash_values_size = hash_size;
+    }
+
+    // free hash_set.keys if needed
+    if (galloc->hash_set.keys != NULL) {
+        free(galloc->hash_set.keys);
+    }
+    galloc->hash_set = hash_set;
+
+    // reset hash values
+    memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
+
+    galloc->hash_allocs = hash_node_talloc;
+
+    ggml_tallocr_alloc_graph_impl(galloc, graph);
+
+    // remove unowned resources
+    galloc->hash_set.keys = NULL;
+    galloc->hash_allocs = NULL;
+}
+
+// legacy API wrapper
+
+struct ggml_allocr {
+    ggml_tallocr_t talloc;
+    ggml_gallocr_t galloc;
+};
+
+static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
+    ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
+    *alloc = (struct ggml_allocr) {
+        /*.talloc = */ talloc,
+        /*.galloc = */ ggml_gallocr_new(),
+    };
+    return alloc;
+}
+
+ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
+    return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
+}
+
+ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
+}
+
+ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
+    return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
+}
+
+struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
+    return ggml_tallocr_get_buffer(alloc->talloc);
+}
+
+void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
+    ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
+}
+
+void ggml_allocr_free(ggml_allocr_t alloc) {
+    ggml_gallocr_free(alloc->galloc);
+    ggml_tallocr_free(alloc->talloc);
+    free(alloc);
+}
+
+bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
+    return ggml_tallocr_is_measure(alloc->talloc);
+}
+
+void ggml_allocr_reset(ggml_allocr_t alloc) {
+    ggml_tallocr_reset(alloc->talloc);
+}
+
+void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
+    ggml_tallocr_alloc(alloc->talloc, tensor);
+}
+
+size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
+    return ggml_tallocr_max_size(alloc->talloc);
+}
+
+size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
+    return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
+}
+
+// utils
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+
+    size_t nbytes = 0;
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL && t->view_src == NULL) {
+            nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+    }
+
+    if (nbytes == 0) {
+        fprintf(stderr, "%s: no tensors to allocate\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
+    ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
+
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(tallocr, t);
+            } else {
+                ggml_backend_view_init(buffer, t);
+            }
+        }
+    }
+
+    ggml_tallocr_free(tallocr);
+
+    return buffer;
 }
 }
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}

+ 112 - 0
ggml/src/ggml-backend-impl.h

@@ -0,0 +1,112 @@
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    //
+    // Backend buffer
+    //
+
+    // buffer type
+    typedef void * ggml_backend_buffer_type_context_t;
+
+    struct ggml_backend_buffer_type_i {
+        ggml_backend_buffer_t (*alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        size_t                (*get_alignment)   (ggml_backend_buffer_type_t buft); // tensor alignment
+        size_t                (*get_alloc_size)  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
+        bool                  (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_buffer_type_context_t context;
+    };
+
+    // buffer
+    typedef void * ggml_backend_buffer_context_t;
+
+    struct ggml_backend_buffer_i {
+        void     (*free_buffer)(ggml_backend_buffer_t buffer);
+        //void     (*reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        void *   (*get_base)   (ggml_backend_buffer_t buffer);
+        void     (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void     (*set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void     (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
+        void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to)  (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
+    };
+
+    struct ggml_backend_buffer {
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
+        ggml_backend_buffer_context_t context;
+        size_t size;
+    };
+
+    ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t      buft,
+            struct ggml_backend_buffer_i           iface,
+                   ggml_backend_buffer_context_t   context,
+                   size_t                          size);
+
+
+    //
+    // Backend
+    //
+
+    typedef void * ggml_backend_context_t;
+
+    struct ggml_backend_i {
+        const char * (*get_name)(ggml_backend_t backend);
+
+        void (*free)(ggml_backend_t backend);
+
+        // buffer allocation
+        ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
+
+        // (optional) asynchroneous tensor data access
+        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+        // (optional) asynchroneous tensor copy
+        void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to_async)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        void (*synchronize)     (ggml_backend_t backend);
+
+        // compute graph with a plan
+        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        void                      (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+        // compute graph without a plan
+        void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+        // check if the backend supports an operation
+        bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+    };
+
+    struct ggml_backend {
+        struct ggml_backend_i iface;
+
+        ggml_backend_context_t context;
+    };
+
+
+    //
+    // Backend registry
+    //
+
+    typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
+
+    void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
+#ifdef  __cplusplus
+}
+#endif

+ 1357 - 0
ggml/src/ggml-backend.c

@@ -0,0 +1,1357 @@
+#include "ggml-backend-impl.h"
+#include "ggml-alloc.h"
+#include "ggml-impl.h"
+
+#include <assert.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+
+// backend buffer type
+
+ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        return buft->iface.get_alloc_size(buft, tensor);
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return buft->iface.supports_backend(buft, backend);
+}
+
+// backend buffer
+
+ggml_backend_buffer_t ggml_backend_buffer_init(
+               ggml_backend_buffer_type_t      buft,
+        struct ggml_backend_buffer_i           iface,
+               ggml_backend_buffer_context_t   context,
+               size_t                          size) {
+    ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
+
+    GGML_ASSERT(iface.get_base != NULL);
+
+    (*buffer) = (struct ggml_backend_buffer) {
+        /* .interface = */ iface,
+        /* .buft      = */ buft,
+        /* .context   = */ context,
+        /* .size      = */ size,
+    };
+
+    return buffer;
+}
+
+void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return;
+    }
+
+    if (buffer->iface.free_buffer != NULL) {
+        buffer->iface.free_buffer(buffer);
+    }
+    free(buffer);
+}
+
+size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+    return buffer->size;
+}
+
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+    void * base = buffer->iface.get_base(buffer);
+
+    GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+    return base;
+}
+
+void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // init_tensor is optional
+    if (buffer->iface.init_tensor) {
+        buffer->iface.init_tensor(buffer, tensor);
+    }
+}
+
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
+}
+
+ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
+}
+
+// backend
+
+const char * ggml_backend_name(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return "NULL";
+    }
+    return backend->iface.get_name(backend);
+}
+
+void ggml_backend_free(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return;
+    }
+
+    backend->iface.free(backend);
+}
+
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return backend->iface.get_default_buffer_type(backend);
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
+}
+
+size_t ggml_backend_get_alignment(ggml_backend_t backend) {
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
+}
+
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+}
+
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+}
+
+void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
+}
+
+void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
+}
+
+void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
+    backend->iface.synchronize(backend);
+}
+
+ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    return backend->iface.graph_plan_create(backend, cgraph);
+}
+
+void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    backend->iface.graph_plan_free(backend, plan);
+}
+
+void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    backend->iface.graph_plan_compute(backend, plan);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
+}
+
+void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    backend->iface.graph_compute(backend, cgraph);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
+}
+
+bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    return backend->iface.supports_op(backend, op);
+}
+
+// backend copy
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
+    //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
+    //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
+    GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+    // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
+
+    if (src == dst) {
+        return;
+    }
+
+    // TODO: allow backends to support copy to/from same backend
+
+    if (dst->buffer->iface.cpy_tensor_from != NULL) {
+        dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
+    } else if (src->buffer->iface.cpy_tensor_to != NULL) {
+        src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
+    } else {
+        // shouldn't be hit when copying from/to CPU
+        #ifndef NDEBUG
+        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
+                        "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
+        #endif
+        size_t nbytes = ggml_nbytes(src);
+        void * data = malloc(nbytes);
+        ggml_backend_tensor_get(src, data, 0, nbytes);
+        ggml_backend_tensor_set(dst, data, 0, nbytes);
+        free(data);
+    }
+}
+
+// backend registry
+
+#define GGML_MAX_BACKENDS_REG 16
+
+struct ggml_backend_reg {
+    char name[128];
+    ggml_backend_init_fn init_fn;
+    ggml_backend_buffer_type_t default_buffer_type;
+    void * user_data;
+};
+
+static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
+static size_t ggml_backend_registry_count = 0;
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
+
+static void ggml_backend_registry_init(void) {
+    static bool initialized = false;
+
+    if (initialized) {
+        return;
+    }
+
+    initialized = true;
+
+    ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
+
+    // add forward decls here to avoid including the backend headers
+#ifdef GGML_USE_CUBLAS
+    extern void ggml_backend_cuda_reg_devices(void);
+    ggml_backend_cuda_reg_devices();
+#endif
+
+#ifdef GGML_USE_METAL
+    extern ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
+    extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+    ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
+#endif
+}
+
+void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+    GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
+
+    int id = ggml_backend_registry_count;
+
+    ggml_backend_registry[id] = (struct ggml_backend_reg) {
+        /* .name                = */ {0},
+        /* .fn                  = */ init_fn,
+        /* .default_buffer_type = */ default_buffer_type,
+        /* .user_data           = */ user_data,
+    };
+
+    snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+#ifndef NDEBUG
+    fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+#endif
+
+    ggml_backend_registry_count++;
+}
+
+size_t ggml_backend_reg_get_count(void) {
+    ggml_backend_registry_init();
+
+    return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+    ggml_backend_registry_init();
+
+    for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+        // TODO: case insensitive in a portable way
+        if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+            return i;
+        }
+    }
+    return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+    ggml_backend_registry_init();
+
+    const char * params = strchr(backend_str, ':');
+    char backend_name[128];
+    if (params == NULL) {
+        strcpy(backend_name, backend_str);
+        params = "";
+    } else {
+        strncpy(backend_name, backend_str, params - backend_str);
+        backend_name[params - backend_str] = '\0';
+        params++;
+    }
+
+    size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+    if (backend_i == SIZE_MAX) {
+        fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+        return NULL;
+    }
+
+    return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
+}
+
+// backend CPU
+
+static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return (void *)buffer->context;
+}
+
+static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    free(buffer->context);
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
+}
+
+static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
+};
+
+// for buffers from ptr, free is not called
+static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
+};
+
+static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
+
+static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
+    void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
+
+    GGML_ASSERT(data != NULL && "failed to allocate buffer");
+
+    return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
+}
+
+static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cpu(backend);
+
+    GGML_UNUSED(buft);
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_buffer_type_cpu;
+}
+
+struct ggml_backend_cpu_context {
+    int n_threads;
+    void * work_data;
+    size_t work_size;
+};
+
+static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+    return "CPU";
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    free(cpu_ctx->work_data);
+    free(cpu_ctx);
+    free(backend);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(backend);
+}
+
+struct ggml_backend_plan_cpu {
+    struct ggml_cplan cplan;
+    struct ggml_cgraph cgraph;
+};
+
+static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
+
+    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
+    cpu_plan->cgraph = *cgraph;
+
+    if (cpu_plan->cplan.work_size > 0) {
+        cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
+    }
+
+    return cpu_plan;
+}
+
+static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    free(cpu_plan->cplan.work_data);
+    free(cpu_plan);
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
+
+    if (cpu_ctx->work_size < cplan.work_size) {
+        // TODO: may be faster to free and use malloc to avoid the copy
+        cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
+        cpu_ctx->work_size = cplan.work_size;
+    }
+
+    cplan.work_data = cpu_ctx->work_data;
+
+    ggml_graph_compute(cgraph, &cplan);
+}
+
+static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    return true;
+
+    GGML_UNUSED(backend);
+    GGML_UNUSED(op);
+}
+
+static struct ggml_backend_i cpu_backend_i = {
+    /* .get_name                = */ ggml_backend_cpu_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .supports_op             = */ ggml_backend_cpu_supports_op,
+};
+
+ggml_backend_t ggml_backend_cpu_init(void) {
+    struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
+
+    ctx->n_threads = GGML_DEFAULT_N_THREADS;
+    ctx->work_data = NULL;
+    ctx->work_size = 0;
+
+    ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
+
+    *cpu_backend = (struct ggml_backend) {
+        /* .interface = */ cpu_backend_i,
+        /* .context   = */ ctx
+    };
+    return cpu_backend;
+}
+
+bool ggml_backend_is_cpu(ggml_backend_t backend) {
+    return backend->iface.get_name == ggml_backend_cpu_name;
+}
+
+void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->n_threads = n_threads;
+}
+
+ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
+}
+
+
+// scheduler
+
+#define GGML_MAX_BACKENDS 4
+#define GGML_MAX_SPLITS 256
+#define GGML_MAX_SPLIT_INPUTS 16
+
+struct ggml_backend_sched_split {
+    ggml_tallocr_t tallocr;
+    int i_start;
+    int i_end;
+    struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
+    int n_inputs;
+    struct ggml_cgraph graph;
+};
+
+struct ggml_backend_sched {
+    int n_backends;
+    ggml_backend_t backends[GGML_MAX_BACKENDS];
+    ggml_tallocr_t  tallocs[GGML_MAX_BACKENDS];
+
+    ggml_gallocr_t galloc;
+
+    struct ggml_hash_set    hash_set;
+    ggml_tallocr_t *        node_talloc;                     // [hash_set.size]
+    struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
+
+    struct ggml_cgraph * graph;
+    struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
+    int n_splits;
+
+    struct ggml_context * ctx;
+
+    // align context_buffer to GGML_MEM_ALIGN
+    #ifdef _MSC_VER
+    __declspec(align(GGML_MEM_ALIGN))
+    #else
+    __attribute__((aligned(GGML_MEM_ALIGN)))
+    #endif
+    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+};
+
+#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
+#define node_allocr(node) sched->node_talloc[hash_id(node)]
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// returns the priority of the backend, lower is better
+static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->backends[i] == backend) {
+            return i;
+        }
+    }
+    return INT_MAX;
+}
+
+static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return i;
+        }
+    }
+    return INT_MAX;
+}
+
+static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
+            return sched->backends[i];
+        }
+    }
+    GGML_ASSERT(false && "tensor buffer type not supported by any backend");
+}
+
+static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    if (allocr == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return sched->backends[i];
+        }
+    }
+    GGML_UNREACHABLE();
+}
+
+#if 0
+static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
+// returns the backend that should be used for the node based on the current locations
+static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+    // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
+    // ie. kv cache updates
+    // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
+    // dst
+    ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
+    if (cur_backend != NULL) {
+        SET_CAUSE(node, "1.dst");
+        return cur_backend;
+    }
+
+    // view_src
+    if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
+        SET_CAUSE(node, "1.vsrc");
+        return get_buffer_backend(sched, node->view_src->buffer);
+    }
+
+    // src
+    int cur_prio = INT_MAX;
+    size_t cur_size = 0;
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        const struct ggml_tensor * src = node->src[i];
+        if (src == NULL) {
+            break;
+        }
+        ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
+        if (src_backend != NULL) {
+            int src_prio = sched_backend_prio(sched, src_backend);
+            size_t src_size = ggml_nbytes(src);
+            if (src_prio < cur_prio && src_size >= cur_size) {
+                cur_prio = src_prio;
+                cur_size = src_size;
+                cur_backend = src_backend;
+                SET_CAUSE(node, "1.src%d", i);
+            }
+        }
+    }
+    return cur_backend;
+}
+
+static char * fmt_size(size_t size) {
+    static char buffer[128];
+    if (size >= 1024*1024) {
+        sprintf(buffer, "%zuM", size/1024/1024);
+    } else {
+        sprintf(buffer, "%zuK", size/1024);
+    }
+    return buffer;
+}
+
+static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+            ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
+            for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+            }
+            fprintf(stderr, "\n");
+            cur_split++;
+        }
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
+            fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
+            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
+                fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
+        }
+        fprintf(stderr, "\n");
+    }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+    struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        dup->nb[i] = tensor->nb[i];
+    }
+    return dup;
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+// TODO: merge passes
+static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    // reset state
+    size_t hash_size = sched->hash_set.size;
+    memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
+    memset(sched->node_talloc,   0, sizeof(sched->node_talloc[0])   * hash_size);
+    memset(sched->node_copies,   0, sizeof(sched->node_copies[0])   * hash_size);
+    sched->n_splits = 0;
+
+    struct ggml_init_params params = {
+        /* .mem_size =   */ sizeof(sched->context_buffer),
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
+    };
+
+    if (sched->ctx != NULL) {
+        ggml_free(sched->ctx);
+    }
+
+    sched->ctx = ggml_init(params);
+
+    // pass 1: assign backends to ops with allocated inputs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        if (node_allocr(leaf) != NULL) {
+            // do not overwrite user assignments
+            continue;
+        }
+        ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
+        if (leaf_backend == NULL && leaf->view_src != NULL) {
+            leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
+        }
+        if (leaf_backend != NULL) {
+            node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
+        }
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (node_allocr(node) != NULL) {
+            // do not overwrite user assignments
+            continue;
+        }
+        ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
+        if (node_backend != NULL) {
+            node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
+        }
+    }
+    //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 2: assign backends to ops from current assignments
+    // TODO:
+    //  - reuse sched_backend_from_cur
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        if (node_allocr == NULL) {
+            int    cur_prio = INT_MAX;
+            size_t cur_size = 0;
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    break;
+                }
+                ggml_tallocr_t src_allocr = node_allocr(src);
+                if (src_allocr != NULL) {
+                    int    src_prio = sched_allocr_prio(sched, src_allocr);
+                    size_t src_size = ggml_nbytes(src);
+                    if (src_prio < cur_prio && src_size >= cur_size) {
+                        cur_prio = src_prio;
+                        cur_size = src_size;
+                        node_allocr = src_allocr;
+                        SET_CAUSE(node, "2.src%d", j);
+                    }
+                }
+            }
+            if (node_allocr != NULL) {
+                node_allocr(node) = node_allocr;
+            }
+        }
+    }
+    //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 3: assign backends to remaining src from dst (should only be leafs)
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr == NULL) {
+                node_allocr(src) = node_allocr;
+            }
+        }
+    }
+    //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+    // pass 4: split graph, find tensors that need to be copied
+    // TODO:
+    //  - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
+    // find first backend
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (node->view_src == NULL) {
+            sched->splits[0].tallocr = node_allocr(node);
+            break;
+        }
+    }
+    sched->splits[0].i_start = 0;
+    sched->splits[0].n_inputs = 0;
+    memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
+    ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
+    size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+
+        ggml_tallocr_t node_allocr = node_allocr(node);
+
+        if (node_allocr != cur_allocr) {
+            sched->splits[cur_split].i_end = i;
+            cur_split++;
+            GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
+            sched->splits[cur_split].tallocr = node_allocr;
+            sched->splits[cur_split].i_start = i;
+            sched->splits[cur_split].n_inputs = 0;
+            memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
+            cur_allocr = node_allocr;
+            cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+        }
+
+        // find inputs that are not on the same backend
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr != node_allocr) {
+                int n_inputs = sched->splits[cur_split].n_inputs++;
+                GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+                sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
+
+                // create copies
+                size_t id = hash_id(src);
+                if (sched->node_copies[id][cur_backend_id] == NULL) {
+                    struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                    sched->node_copies[id][cur_backend_id] = tensor_copy;
+                    node_allocr(tensor_copy) = cur_allocr;
+                    ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
+                    ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
+                }
+                node->src[j] = sched->node_copies[id][cur_backend_id];
+            }
+        }
+    }
+    sched->splits[cur_split].i_end = graph->n_nodes;
+    sched->n_splits = cur_split + 1;
+
+    //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
+
+#if 1
+    // sanity check: all sources should have the same backend as the node
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        ggml_tallocr_t node_allocr = node_allocr(node);
+        if (node_allocr == NULL) {
+            fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                break;
+            }
+            ggml_tallocr_t src_allocr = node_allocr(src);
+            if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
+                fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
+                    node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
+                    j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
+            }
+        }
+    }
+#endif
+
+    // create copies of the graph for each split
+    // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &sched->splits[i];
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
+
+        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+        for (int j = 0; j < split->n_inputs; j++) {
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
+            input_cpy->src[0] = input;
+            graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+        }
+
+        for (int j = split->i_start; j < split->i_end; j++) {
+            graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+        }
+    }
+    sched->graph = graph_copy;
+}
+
+static void sched_alloc_splits(ggml_backend_sched_t sched) {
+    ggml_gallocr_alloc_graph_n(
+        sched->galloc,
+        sched->graph,
+        sched->hash_set,
+        sched->node_talloc);
+}
+
+static void sched_compute_splits(ggml_backend_sched_t sched) {
+    uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
+    uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
+
+    struct ggml_backend_sched_split * splits = sched->splits;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &splits[i];
+        ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
+        int split_backend_id = sched_backend_prio(sched, split_backend);
+
+        // copy the input tensors to the split backend
+        uint64_t copy_start_us = ggml_time_us();
+        for (int j = 0; j < split->n_inputs; j++) {
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
+            if (input->buffer == NULL) {
+                if (input->view_src == NULL) {
+                    fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
+                    exit(1);
+                }
+                // FIXME: may need to use the sched buffer instead
+                ggml_backend_view_init(input->view_src->buffer, input);
+            }
+            if (input_cpy->buffer == NULL) {
+                fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
+                exit(1);
+            }
+            //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
+            //GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+            ggml_backend_tensor_copy(input, input_cpy);
+        }
+        // ggml_backend_synchronize(split_backend);
+        int64_t copy_end_us = ggml_time_us();
+        copy_us[split_backend_id] += copy_end_us - copy_start_us;
+
+#if 0
+        char split_filename[GGML_MAX_NAME];
+        snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
+        ggml_graph_dump_dot(split->graph, NULL, split_filename);
+#endif
+
+        uint64_t compute_start_us = ggml_time_us();
+        ggml_backend_graph_compute(split_backend, &split->graph);
+        // ggml_backend_synchronize(split_backend);
+        uint64_t compute_end_us = ggml_time_us();
+        compute_us[split_backend_id] += compute_end_us - compute_start_us;
+    }
+
+#if 0
+    // per-backend timings
+    fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (copy_us[i] > 0 || compute_us[i] > 0) {
+            fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
+        }
+    }
+#endif
+}
+
+static void sched_reset(ggml_backend_sched_t sched) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_tallocr_reset(sched->tallocs[i]);
+    }
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
+    GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
+
+    struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
+    memset(sched, 0, sizeof(struct ggml_backend_sched));
+
+    sched->n_backends = n_backends;
+    for (int i = 0; i < n_backends; i++) {
+        sched->backends[i] = backends[i];
+    }
+
+    sched->galloc = ggml_gallocr_new();
+
+    // init measure allocs for each backend
+    for (int i = 0; i < n_backends; i++) {
+        sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
+    }
+
+    return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+    if (sched == NULL) {
+        return;
+    }
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_tallocr_free(sched->tallocs[i]);
+    }
+    ggml_gallocr_free(sched->galloc);
+    free(sched->hash_set.keys);
+    free(sched->node_talloc);
+    free(sched->node_copies);
+    free(sched);
+}
+
+void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    // initialize hash tables
+    size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
+    sched->hash_set.size = hash_size;
+    sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
+    sched->node_talloc   = malloc(sizeof(sched->node_talloc[0])   * hash_size);
+    sched->node_copies   = malloc(sizeof(sched->node_copies[0])   * hash_size);
+
+    sched_split_graph(sched, measure_graph);
+    sched_alloc_splits(sched);
+
+    // allocate buffers and reset allocators
+    for (int i = 0; i < sched->n_backends; i++) {
+        size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
+        ggml_tallocr_free(sched->tallocs[i]);
+        sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
+    }
+
+    sched_reset(sched);
+}
+
+void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+
+    sched_split_graph(sched, graph);
+    sched_alloc_splits(sched);
+    sched_compute_splits(sched);
+    sched_reset(sched);
+}
+
+ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    return sched->tallocs[backend_index];
+}
+
+ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
+}
+
+void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+    int backend_index = sched_backend_prio(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+    node_allocr(node) = sched->tallocs[backend_index];
+}
+
+// utils
+void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    tensor->backend = tensor->view_src->backend;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(hash_set, src);
+    if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        ggml_backend_view_init(dst->view_src->buffer, dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        graph_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = {
+        /* .size = */ graph->visited_hash_table.size,
+        /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
+    };
+    struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
+    bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_init_tensor(hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    free(hash_set.keys);
+    free(node_copies);
+    free(node_init);
+
+    return (struct ggml_backend_graph_copy) {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+}

+ 243 - 0
ggml/src/ggml-impl.h

@@ -0,0 +1,243 @@
+#pragma once
+
+#include "ggml.h"
+
+// GGML internal header
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdbool.h>
+#include <string.h> // memcpy
+#include <math.h>   // fabsf
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// static_assert should be a #define, but if it's not,
+// fall back to the _Static_assert C11 keyword.
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#ifndef __SSE3__
+#define __SSE3__
+#endif
+#endif
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#if defined(__ARM_NEON) && !defined(_MSC_VER)
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
+#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
+
+#define GGML_FP16_TO_FP32(x) ((float) (x))
+#define GGML_FP32_TO_FP16(x) (x)
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <intrin.h>
+#else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
+#include <immintrin.h>
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
+#ifdef __F16C__
+
+#ifdef _MSC_VER
+#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+#else
+#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#endif
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+/* the inline asm below is about 12% faster than the lookup method */
+#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    register float f;
+    register double d;
+    __asm__(
+        "mtfprd %0,%2\n"
+        "xscvhpdp %0,%0\n"
+        "frsp %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=f"(f):
+        /* in */   "r"(h));
+    return f;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+    register double d;
+    register ggml_fp16_t r;
+    __asm__( /* xscvdphp can work on double or single precision */
+        "xscvdphp %0,%2\n"
+        "mffprd %1,%0\n" :
+        /* temp */ "=d"(d),
+        /* out */  "=r"(r):
+        /* in */   "f"(f));
+    return r;
+}
+
+#else
+
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
+static inline float fp32_from_bits(uint32_t w) {
+    union {
+        uint32_t as_bits;
+        float as_value;
+    } fp32;
+    fp32.as_bits = w;
+    return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+    union {
+        float as_value;
+        uint32_t as_bits;
+    } fp32;
+    fp32.as_value = f;
+    return fp32.as_bits;
+}
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+    const uint32_t w = (uint32_t) h << 16;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    const uint32_t two_w = w + w;
+
+    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float exp_scale = 0x1.0p-112f;
+#else
+    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+    const uint32_t magic_mask = UINT32_C(126) << 23;
+    const float magic_bias = 0.5f;
+    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+    const uint32_t result = sign |
+        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+    return fp32_from_bits(result);
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+    const float scale_to_inf = 0x1.0p+112f;
+    const float scale_to_zero = 0x1.0p-110f;
+#else
+    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+    const uint32_t w = fp32_to_bits(f);
+    const uint32_t shl1_w = w + w;
+    const uint32_t sign = w & UINT32_C(0x80000000);
+    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+    if (bias < UINT32_C(0x71000000)) {
+        bias = UINT32_C(0x71000000);
+    }
+
+    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+    const uint32_t bits = fp32_to_bits(base);
+    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+    const uint32_t nonsign = exp_bits + mantissa_bits;
+    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // __ARM_NEON
+
+// precomputed f32 table for f16 (256 KB)
+// defined in ggml.c, initialized in ggml_init()
+extern float ggml_table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
+#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
+
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+    uint16_t s;
+    memcpy(&s, &f, sizeof(uint16_t));
+    return ggml_table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+#endif
+
+#define GGML_HASHTABLE_FULL ((size_t)-1)
+#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
+
+bool   ggml_hash_contains      (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
+size_t ggml_hash_find          (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+size_t ggml_hash_insert        (      struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// return index, asserts if table is full
+size_t ggml_hash_find_or_insert(      struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+#ifdef __cplusplus
+}
+#endif

+ 7382 - 0
ggml/src/ggml-quants.c

@@ -0,0 +1,7382 @@
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <intrin.h>
+#else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
+#include <immintrin.h>
+#endif
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = _mm_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = _mm_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = _mm_maddubs_epi16(ax, sy);
+    const __m128i ones = _mm_set1_epi16(1);
+    return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if __AVXVNNI__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON)
+#if !defined(__aarch64__)
+
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+    return
+        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+    int32x4_t res;
+
+    res[0] = roundf(vgetq_lane_f32(v, 0));
+    res[1] = roundf(vgetq_lane_f32(v, 1));
+    res[2] = roundf(vgetq_lane_f32(v, 2));
+    res[3] = roundf(vgetq_lane_f32(v, 3));
+
+    return res;
+}
+
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vld1q_s16(ptr + 0);
+    res.val[1] = vld1q_s16(ptr + 8);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+    res.val[2] = vld1q_u8(ptr + 32);
+    res.val[3] = vld1q_u8(ptr + 48);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+    int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+    ggml_int8x16x2_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+    res.val[2] = vld1q_s8(ptr + 32);
+    res.val[3] = vld1q_s8(ptr + 48);
+
+    return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t  int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t  int8x16x2_t
+#define ggml_int8x16x4_t  int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2  vld1q_u8_x2
+#define ggml_vld1q_u8_x4  vld1q_u8_x4
+#define ggml_vld1q_s8_x2  vld1q_s8_x2
+#define ggml_vld1q_s8_x4  vld1q_s8_x4
+
+#endif
+#endif
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_0_reference(x, y, k);
+}
+
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q4_1_reference(x, y, k);
+}
+
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(qh));
+    }
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_0_reference(x, y, k);
+}
+
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+    const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+    }
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q5_1_reference(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = x[i*QK8_0 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = x[i*QK8_0 + j]*id;
+
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+    assert(QK8_1 == 32);
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_1; j++) {
+            const float v = x[i*QK8_1 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int sum = 0;
+
+        for (int j = 0; j < QK8_1/2; ++j) {
+            const float v0 = x[i*QK8_1           + j]*id;
+            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+            y[i].qs[          j] = roundf(v0);
+            y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+            sum += y[i].qs[          j];
+            sum += y[i].qs[QK8_1/2 + j];
+        }
+
+        y[i].s = sum*d;
+    }
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        int32x4_t accv = vdupq_n_s32(0);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+            accv = vaddq_s32(accv, vi);
+        }
+
+        y[i].s = d * vaddvq_s32(accv);
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        v128_t accv = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+            accv = wasm_i32x4_add(accv, vi);
+        }
+
+        y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
+                      wasm_i32x4_extract_lane(accv, 1) +
+                      wasm_i32x4_extract_lane(accv, 2) +
+                      wasm_i32x4_extract_lane(accv, 3));
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = sum*d;
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_1_reference(x, y, k);
+#endif
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+    static const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+            const int x1 = (x[i].qs[j] >>   4) | xh_1;
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+    static const int qk = QK8_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk; ++j) {
+            y[i*qk + j] = x[i].qs[j]*d;
+        }
+    }
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < 1e-30f) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    int weight_type = rmse_type%2;
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = weight_type == 1 ? x[i] * x[i] : 1;
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = sumlx/suml2;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = weight_type == 1 ? x[i] * x[i] : 1;
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (!amax) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
+    for (int i = 1; i < n; ++i) {
+#endif
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+#if QK_K == 256
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+#endif
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float   weights[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+
+    const float q4scale = 15.f;
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        if (max_scale > 0) {
+            float iscale = q4scale/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*scales[j]);
+                y[i].scales[j] = l;
+            }
+            y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
+        } else {
+            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+        if (max_min > 0) {
+            float iscale = q4scale/max_min;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*mins[j]);
+                y[i].scales[j] |= (l << 4);
+            }
+            y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
+        } else {
+            y[i].dmin = GGML_FP32_TO_FP16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int((x[16*j + ii] + dm)/d);
+                l = MAX(0, MIN(3, l));
+                L[16*j + ii] = l;
+            }
+        }
+
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
+            y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
+            y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
+            y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q2_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
+        quantize_row_q2_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q2_K));
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float scales[QK_K / 16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
+            float scale = fabsf(scales[j]);
+            if (scale > amax) {
+                amax = scale; max_scale = scales[j];
+            }
+        }
+
+#if QK_K == 256
+        memset(y[i].scales, 0, 12);
+        if (max_scale) {
+            float iscale = -32.f/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int8_t l = nearest_int(iscale*scales[j]);
+                l = MAX(-32, MIN(31, l)) + 32;
+                if (j < 8) {
+                    y[i].scales[j] = l & 0xF;
+                } else {
+                    y[i].scales[j-8] |= ((l & 0xF) << 4);
+                }
+                l >>= 4;
+                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+            }
+            y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        } else {
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+
+        int8_t sc;
+        for (int j = 0; j < QK_K/16; ++j) {
+            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#else
+        if (max_scale) {
+            float iscale = -8.f/max_scale;
+            for (int j = 0; j < QK_K/16; j+=2) {
+                int l1 = nearest_int(iscale*scales[j]);
+                l1 = 8 + MAX(-8, MIN(7, l1));
+                int l2 = nearest_int(iscale*scales[j+1]);
+                l2 = 8 + MAX(-8, MIN(7, l2));
+                y[i].scales[j/2] = l1 | (l2 << 4);
+            }
+            y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        } else {
+            for (int j = 0; j < QK_K/16; j+=2) {
+                y[i].scales[j/2] = 0;
+            }
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
+            float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8);
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+#endif
+
+        memset(y[i].hmask, 0, QK_K/8);
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+        int m = 0;
+        uint8_t hm = 1;
+        for (int j = 0; j < QK_K; ++j) {
+            if (L[j] > 3) {
+                y[i].hmask[m] |= hm;
+                L[j] -= 4;
+            }
+            if (++m == QK_K/8) {
+                m = 0; hm <<= 1;
+            }
+        }
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+    }
+}
+
+#if QK_K == 256
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[4];
+    const int8_t * scales = (const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        uint8_t m = 1;
+
+        memcpy(aux, x[i].scales, 12);
+        uint32_t tmp = aux[2];
+        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+}
+#else
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    assert(QK_K == 64);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l=0; l<8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
+    }
+}
+#endif
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
+    quantize_row_q3_K_reference(x, vy, k);
+}
+
+size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
+        quantize_row_q3_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q3_K));
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    float   weights[32];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+#if QK_K == 256
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+            }
+        }
+#else
+        const float s_factor = 15.f;
+        float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? s_factor/max_min   : 0.f;
+        int d1 = nearest_int(inv_scale*scales[0]);
+        int m1 = nearest_int(inv_min*mins[0]);
+        int d2 = nearest_int(inv_scale*scales[1]);
+        int m2 = nearest_int(inv_min*mins[1]);
+        y[i].scales[0] = d1 | (m1 << 4);
+        y[i].scales[1] = d2 | (m2 << 4);
+        y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor);
+        y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor);
+
+        float sumlx = 0;
+        int   suml2 = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            const uint8_t sd = y[i].scales[j] & 0xF;
+            const uint8_t sm = y[i].scales[j] >>  4;
+            const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd;
+            if (!d) continue;
+            const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + m)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+                sumlx += (x[32*j + ii] + m)*l*sd;
+                suml2 += l*l*sd*sd;
+            }
+        }
+        if (suml2) {
+            y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2);
+        }
+#endif
+        uint8_t * q = y[i].qs;
+        for (int j = 0; j < QK_K; j += 64) {
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
+
+        const float d   = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+#else
+        const float dall = GGML_FP16_TO_FP32(x[i].d[0]);
+        const float mall = GGML_FP16_TO_FP32(x[i].d[1]);
+        const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
+        const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
+
+    }
+}
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q4_K * restrict y = vy;
+    quantize_row_q4_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
+        quantize_row_q4_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q4_K));
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+#if QK_K == 256
+    uint8_t L[QK_K];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+    float weights[32];
+    uint8_t Laux[32];
+#else
+    int8_t L[QK_K];
+    float scales[QK_K/16];
+#endif
+
+    for (int i = 0; i < nb; i++) {
+
+#if QK_K == 256
+
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(31, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        uint8_t m1 = 1, m2 = 2;
+        for (int n = 0; n < QK_K; n += 64) {
+            for (int j = 0; j < 32; ++j) {
+                int l1 = L[n + j];
+                if (l1 > 15) {
+                    l1 -= 16; qh[j] |= m1;
+                }
+                int l2 = L[n + j + 32];
+                if (l2 > 15) {
+                    l2 -= 16; qh[j] |= m2;
+                }
+                ql[j] = l1 | (l2 << 4);
+            }
+            m1 <<= 2; m2 <<= 2;
+            ql += 32;
+        }
+#else
+        float max_scale = 0, amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
+            float abs_scale = fabsf(scales[j]);
+            if (abs_scale > amax) {
+                amax = abs_scale;
+                max_scale = scales[j];
+            }
+        }
+
+        float iscale = -128.f/max_scale;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int l = nearest_int(iscale*scales[j]);
+            y[i].scales[j] = MAX(-128, MIN(127, l));
+        }
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) continue;
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-16, MIN(15, l));
+                L[16*j + ii] = l + 16;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        for (int j = 0; j < 32; ++j) {
+            int jm = j%8;
+            int is = j/8;
+            int l1 = L[j];
+            if (l1 > 15) {
+                l1 -= 16; qh[jm] |= (1 << is);
+            }
+            int l2 = L[j + 32];
+            if (l2 > 15) {
+                l2 -= 16; qh[jm] |= (1 << (4 + is));
+            }
+            ql[j] = l1 | (l2 << 4);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * ql = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+#if QK_K == 256
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+#else
+        float d = GGML_FP16_TO_FP32(x[i].d);
+        const int8_t * restrict s = x[i].scales;
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * s[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * s[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * s[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * s[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+#endif
+    }
+}
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q5_K * restrict y = vy;
+    quantize_row_q5_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
+        quantize_row_q5_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q5_K));
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float   scales[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float max_abs_scale = 0;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+
+            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
+            scales[ib] = scale;
+
+            const float abs_scale = fabsf(scale);
+            if (abs_scale > max_abs_scale) {
+                max_abs_scale = abs_scale;
+                max_scale = scale;
+            }
+
+        }
+
+        if (!max_abs_scale) {
+            memset(&y[i], 0, sizeof(block_q6_K));
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+            x += QK_K;
+            continue;
+        }
+
+        float iscale = -128.f/max_scale;
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+        }
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-32, MIN(31, l));
+                L[16*j + ii] = l + 32;
+            }
+        }
+
+        uint8_t * restrict ql = y[i].ql;
+        uint8_t * restrict qh = y[i].qh;
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                const uint8_t q1 = L[j + l +  0] & 0xF;
+                const uint8_t q2 = L[j + l + 32] & 0xF;
+                const uint8_t q3 = L[j + l + 64] & 0xF;
+                const uint8_t q4 = L[j + l + 96] & 0xF;
+                ql[l+ 0] = q1 | (q3 << 4);
+                ql[l+32] = q2 | (q4 << 4);
+                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+            }
+            ql += 64;
+            qh += 32;
+        }
+#else
+        for (int l = 0; l < 32; ++l) {
+            const uint8_t q1 = L[l +  0] & 0xF;
+            const uint8_t q2 = L[l + 32] & 0xF;
+            ql[l] = q1 | (q2 << 4);
+        }
+        for (int l = 0; l < 16; ++l) {
+            qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict ql = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict sc = x[i].scales;
+
+#if QK_K == 256
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
+            }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
+
+    }
+}
+
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK_K == 0);
+    block_q6_K * restrict y = vy;
+    quantize_row_q6_K_reference(x, y, k);
+}
+
+size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK_K == 0);
+    (void)hist; // TODO: collect histograms
+
+    for (int j = 0; j < n; j += k) {
+        block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
+        quantize_row_q6_K_reference(src + j, y, k);
+    }
+    return (n/QK_K*sizeof(block_q6_K));
+}
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        float max = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K; ++j) {
+            float ax = fabsf(x[j]);
+            if (ax > amax) {
+                amax = ax; max = x[j];
+            }
+        }
+        if (!amax) {
+            y[i].d = 0;
+            memset(y[i].qs, 0, QK_K);
+            x += QK_K;
+            continue;
+        }
+        const float iscale = -128.f/max;
+        for (int j = 0; j < QK_K; ++j) {
+            int v = nearest_int(iscale*x[j]);
+            y[i].qs[j] = MIN(127, v);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int sum = 0;
+            for (int ii = 0; ii < 16; ++ii) {
+                sum += y[i].qs[j*16 + ii];
+            }
+            y[i].bsums[j] = sum;
+        }
+        y[i].d = 1/iscale;
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        for (int j = 0; j < QK_K; ++j) {
+            *y++ = x[i].d * x[i].qs[j];
+        }
+    }
+}
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
+    quantize_row_q8_K_reference(x, y, k);
+}
+
+//===================================== Dot ptoducts =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
+#endif
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        const __m128i lowMask = _mm_set1_epi8(0xF);
+        const __m128i off = _mm_set1_epi8(8);
+
+        const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx = _mm_and_si128(lowMask, tmp);
+        __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
+
+        bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
+        by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx = _mm_sub_epi8(bx, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
+
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    // First round without accumulation
+    {
+        _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        acc_0 = _mm_mul_ps( d_0_1, p0 );
+        acc_1 = _mm_mul_ps( d_0_1, p1 );
+        acc_2 = _mm_mul_ps( d_2_3, p2 );
+        acc_3 = _mm_mul_ps( d_2_3, p3 );
+    }
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    // Main loop
+    for (int i = 2; i < nb; i+=2) {
+        _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F) - 8;
+            const int v1 = (x[i].qs[j] >>   4) - 8;
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+    // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs = 0;
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i + 0];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float d0 = GGML_FP16_TO_FP32(x[i].d);
+        const float d1 = y[i].d;
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
+
+        // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[i].qs[j] & 0x0F);
+            const int v1 = (x[i].qs[j] >>   4);
+
+            sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q5_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        // extract the 5th bit via lookup table ((!b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q8_0 * restrict y0 = &y[i];
+
+        const v128_t m4b  = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+                        wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        bx = _mm256_or_si256(bx, bxhi);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // These temporary registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
+
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q5_1 * restrict x1 = &x[i + 1];
+        const block_q8_1 * restrict y0 = &y[i];
+        const block_q8_1 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
+        summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
+
+        // extract the 5th bit via lookup table ((b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit
+        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    float summs = 0.0f;
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q8_1 * restrict y0 = &y[i];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
+
+        const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit
+        const v128_t v0lf = wasm_v128_or(v0l, qhl);
+        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
+    }
+
+    *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+         wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        bx = _mm256_or_si256(bx, bxhi);
+
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx);
+        __m128i bxh = _mm256_extractf128_si256(bx, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx = MM256_SET_M128I(bxh, bxl);
+
+        const __m256 dy = _mm256_set1_ps(y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // load qh
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        int sumi = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
+            const int32_t x1 = (x[i].qs[j] >>  4) | xh_1;
+
+            sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+        }
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+
+#else
+        const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
+        const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
+        const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
+        const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
+
+        const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
+        const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
+        const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
+        const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
+
+        const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
+        const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+        const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
+        const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+#endif
+    }
+
+    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        // Multiply q with scale and accumulate
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d, q, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
+    }
+
+    *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+    size_t vl = __riscv_vsetvl_e8m1(qk);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
+        vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    }
+
+    *s = sumf;
+#else
+    // scalar
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk; j++) {
+            sumi += x[i].qs[j]*y[i].qs[j];
+        }
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    }
+
+    *s = sumf;
+#endif
+}
+
+#if QK_K == 256
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+    const uint8x16_t m4 = vdupq_n_u8(0xF);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    ggml_int8x16x2_t q2bytes;
+    uint8_t aux[16];
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * restrict sc = x[i].scales;
+
+        const uint8x16_t mins_and_scales = vld1q_u8(sc);
+        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
+        vst1q_u8(aux, scales);
+
+        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
+                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
+        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
+                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
+        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
+
+        int isum = 0;
+        int is = 0;
+
+// We use this macro instead of a function call because for some reason
+// the code runs 2-3% slower, even if the function is declared inline
+#if defined(__ARM_FEATURE_DOTPROD)
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
+        isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
+#else
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        {\
+    const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
+                                   vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
+    const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
+                                   vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
+    isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
+        }
+#endif
+
+#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
+        MULTIPLY_ACCUM_WITH_SCALE((index));
+
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
+
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
+            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
+            MULTIPLY_ACCUM_WITH_SCALE(0);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
+
+            is += 8;
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m256i mins = _mm256_cvtepi8_epi16(mins8);
+        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
+            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = _mm256_add_epi32(p0, p1);
+            p2 = _mm256_add_epi32(p2, p3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(0x3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // load mins and scales from block_q2_K.scales[QK_K/16]
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
+        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
+
+        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
+        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
+        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
+
+        // sumf += -dmin * summs in 32bits*8
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
+
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
+            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_1 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
+            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
+            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
+            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
+            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
+            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
+            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
+            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
+            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
+
+            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
+
+            p0 = _mm_add_epi32(p0, p1);
+            p2 = _mm_add_epi32(p2, p3);
+            p4 = _mm_add_epi32(p4, p5);
+            p6 = _mm_add_epi32(p6, p7);
+
+            // isum in 32bits*4*2
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
+        }
+
+        // sumf += dall * isum - dmin * summs in 32bits
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        size_t vl = 16;
+
+        vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+        vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+
+        vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+
+        vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+        vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+        vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+        vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+
+        sumf  += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+
+        vl = 32;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+
+        uint8_t is=0;
+        int isum=0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load Q2
+            vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+
+            vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+            vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+
+            // duplicate scale elements for product
+            vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
+            vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
+            vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
+            vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+
+            vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+            vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+            vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+            vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+
+            // load Q8
+            vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
+            vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+
+            vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+            vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+            vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+            vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+
+            isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+
+            q2+=32;  q8+=128;  is=8;
+
+        }
+
+        sumf += dall * isum;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < 16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        int isum = 0;
+        int is = 0;
+        int d;
+        for (int k = 0; k < QK_K/128; ++k) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+                d = sc[is++] & 0xF;
+                int isuml = 0;
+                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                d = sc[is++] & 0xF;
+                isuml = 0;
+                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                shift += 2;
+                q8 += 32;
+            }
+            q2 += 32;
+        }
+        sumf += dall * isum - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    ggml_int8x16x4_t q2bytes;
+
+    uint32_t aux32[2];
+    const uint8_t * scales = (const uint8_t *)aux32;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const float dmin = -y[i].d * (float)x[i].dmin;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+        aux32[0] = sc[0] & 0x0f0f0f0f;
+        aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+        int isum1 = 0, isum2 = 0;
+
+        const uint8x16_t q2bits = vld1q_u8(q2);
+
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
+
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
+        q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
+        q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
+        isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
+        isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum1 += vaddvq_s16(p1) * scales[0];
+        isum2 += vaddvq_s16(p2) * scales[1];
+
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum1 += vaddvq_s16(p3) * scales[2];
+        isum2 += vaddvq_s16(p4) * scales[3];
+#endif
+        sum += d * (isum1 + isum2);
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
+        const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+        const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+
+        const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
+        const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
+        const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
+        const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t ud, um;
+    const uint8_t * restrict db = (const uint8_t *)&ud;
+    const uint8_t * restrict mb = (const uint8_t *)&um;
+
+    float summs = 0;
+
+    // TODO: optimize this
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+        ud = (sc[0] >> 0) & 0x0f0f0f0f;
+        um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+        summs += dmin * smin;
+
+        const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+        const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+        const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+        const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+        const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
+        const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
+        const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
+        const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux32[2];
+    const uint8_t * scales = (const uint8_t *)aux32;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const float dmin = -y[i].d * (float)x[i].dmin;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+        aux32[0] = sc[0] & 0x0f0f0f0f;
+        aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+        sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+        int isum1 = 0;
+        int isum2 = 0;
+
+        size_t vl = 16;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        // load Q2
+        vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl);
+
+        vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl));
+        vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl));
+        vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl));
+        vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl));
+
+        // load Q8, and take product with Q2
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl);
+        vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl);
+        vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl);
+        vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl);
+
+        isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0];
+        isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1];
+        isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2];
+        isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3];
+
+        sumf += d * (isum1 + isum2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    float sumf = 0;
+
+    int isum[4];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        isum[0] = isum[1] = isum[2] = isum[3] = 0;
+        for (int l =  0; l < 16; ++l) {
+            isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
+            isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
+            isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
+            isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
+        }
+        for (int l = 0; l < 4; ++l) {
+            isum[l] *= (sc[l] & 0xF);
+        }
+        sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t m0 = vdupq_n_u8(1);
+    const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+    const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+    const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+    const int8_t m32 = 32;
+
+    ggml_int8x16x4_t q3bytes;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q3h;
+
+        int32_t isum = 0;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
+#else
+            int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
+            int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
+            int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
+            int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
+                                     vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+            scale += 4;
+
+            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+            isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
+#else
+            p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
+                           vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
+            p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
+                           vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
+            p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
+                           vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
+            p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
+                           vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+            scale += 4;
+
+            if (j == 0) {
+                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+            }
+
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i mone = _mm256_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+
+        // integer accumulator
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+        int is  = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
+
+            // prepare low and high bits
+            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = _mm256_add_epi32(p16_0, p16_1);
+            p16_2 = _mm256_add_epi32(p16_2, p16_3);
+            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+
+        }
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i mone = _mm_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    const uint32_t *aux;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        aux = (const uint32_t *)x[i].scales;
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+
+        // integer accumulator
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+
+            // prepare low and high bits
+            const int bit = j << 2;
+
+            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
+
+            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+
+            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+
+            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+
+            // load Q8 quants from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            // multiply with scales
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+
+            // accumulate
+            p16_0 = _mm_add_epi32(p16_0, p16_1);
+            p16_2 = _mm_add_epi32(p16_2, p16_3);
+            p16_4 = _mm_add_epi32(p16_4, p16_5);
+            p16_6 = _mm_add_epi32(p16_6, p16_7);
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
+
+        }
+
+        // multiply with block scale and accumulate
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= 32;
+
+
+        size_t vl = 32;
+        uint8_t m =  1;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+
+        int sum_t = 0;
+
+        for (int j = 0; j < QK_K; j += 128) {
+
+            vl = 32;
+
+            // load Q3
+            vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+
+            vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+            vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+            vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+            vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+
+            // compute mask for subtraction
+            vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+            vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+            vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
+            m <<= 1;
+
+            // load Q8 and take product with Q3
+            vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            // retrieve lane to multiply with scale
+            vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+            vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+            vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+            vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+            vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+            vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+            vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+            vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+            sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q3 += 32;    q8 += 128;   scale += 8;
+
+        }
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        sumf += d*sum_t;
+
+    }
+
+    *s = sumf;
+
+#else
+    // scalar version
+    // This function is written like this so the compiler can manage to vectorize most of it
+    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+    // The ideal situation would be if we could just write the code once, and the compiler would
+    // automatically produce the best possible set of machine instructions, instead of us having to manually
+    // write vectorized versions for AVX, ARM_NEON, etc.
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint32_t auxs[4];
+    const int8_t * scales = (const int8_t*)auxs;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            q3 += 32;
+        }
+        a = aux8;
+
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+
+#else
+
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+    const uint8x16_t mh  = vdupq_n_u8(4);
+
+    ggml_int8x16x4_t q3bytes;
+
+    uint16_t aux16[2];
+    int8_t * scales = (int8_t *)aux16;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        ggml_uint8x16x4_t q3h;
+
+        const uint8x8_t  hbits    = vld1_u8(x[i].hmask);
+        const uint8x16_t q3bits   = vld1q_u8(x[i].qs);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        for (int j = 0; j < 4; ++j) scales[j] -= 8;
+
+        int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
+
+        const float d = y[i].d * (float)x[i].d;
+
+        const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
+        q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
+        q3h.val[1] = vandq_u8(mh, htmp);
+        q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
+        q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
+
+        q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b),                q3h.val[0]));
+        q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
+        q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
+        q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6),                q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
+        isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
+#endif
+
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i m1 = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
+        const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
+        __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
+        q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
+        q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m256i q3aux  = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
+        const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
+        const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+        const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        // multiply with scales
+        p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+        p16_0 = _mm256_add_epi32(p16_0, p16_1);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint64_t aux64;
+
+    uint16_t aux16[2];
+    const int8_t * aux8 = (const int8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
+        const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
+        const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
+        const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
+
+        memcpy(&aux64, x[i].hmask, 8);
+
+        __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+        __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
+        __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
+        __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
+        q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
+        q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
+        q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
+        q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
+
+        // load low 2 bits
+        const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+        // prepare low and high bits
+        const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
+        const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
+        const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
+        const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
+
+        // load Q8 quants
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
+        // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+        // and 2 if the high bit was set)
+        const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        // multiply with scales
+        p16_0 = _mm_madd_epi16(scale_0, p16_0);
+        p16_1 = _mm_madd_epi16(scale_1, p16_1);
+        p16_2 = _mm_madd_epi16(scale_2, p16_2);
+        p16_3 = _mm_madd_epi16(scale_3, p16_3);
+
+        p16_0 = _mm_add_epi32(p16_0, p16_2);
+        p16_1 = _mm_add_epi32(p16_1, p16_3);
+        __m256i p16 = MM256_SET_M128I(p16_1, p16_0);
+
+        // multiply with block scale and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint16_t aux16[2];
+    int8_t * scales = (int8_t *)aux16;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t a = *(const uint16_t *)x[i].scales;
+        aux16[0] = a & 0x0f0f;
+        aux16[1] = (a >> 4) & 0x0f0f;
+
+        for (int j = 0; j < 4; ++j) scales[j] -= 8;
+
+        int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
+
+        const float d = y[i].d * (float)x[i].d;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load qh
+        vuint8mf4_t qh_x1   = __riscv_vle8_v_u8mf4(x[i].hmask, 8);
+        vuint8mf2_t qh_x2   = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
+
+        size_t vl = 16;
+
+        // extend and combine both qh_x1 and qh_x2
+        vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
+
+        vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
+        vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl);
+        vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
+        vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl);
+
+        // load Q3
+        vuint8mf2_t q3_x  = __riscv_vle8_v_u8mf2(q3, vl);
+
+        vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl);
+        vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl);
+        vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl);
+        vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl);
+
+        vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0);
+        vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1);
+        vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2);
+        vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3);
+
+        // load Q8 and take product with Q3
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3];
+
+        sumf += d * isum;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    int32_t scales[4];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 8; ++l) {
+            a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
+            a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
+            a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
+            a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
+            a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
+            a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
+            a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
+            a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
+        }
+
+        scales[0] = (x[i].scales[0] & 0xF) - 8;
+        scales[1] = (x[i].scales[0] >>  4) - 8;
+        scales[2] = (x[i].scales[1] & 0xF) - 8;
+        scales[3] = (x[i].scales[1] >>  4) - 8;
+
+        memset(aux32, 0, 8*sizeof(int32_t));
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+
+#ifdef __ARM_FEATURE_DOTPROD
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+            const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+            const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+
+            sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+#else
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+            const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+            const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
+
+#endif
+        }
+
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4l = _mm256_and_si256(q4bits, m4);
+            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+            p16l = _mm256_madd_epi16(scale_l, p16l);
+
+            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+            p16h = _mm256_madd_epi16(scale_h, p16h);
+            const __m256i sumj = _mm256_add_epi32(p16l, p16h);
+
+            sumi = _mm256_add_epi32(sumi, sumj);
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+
+            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_0 = _mm_add_epi32(sumi_0, p16l);
+            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_1 = _mm_add_epi32(sumi_1, p16l);
+
+            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_0 = _mm_add_epi32(sumi_0, p16h);
+            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_1 = _mm_add_epi32(sumi_1, p16h);
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        size_t vl = 8;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vl = 32;
+
+        int32_t sum_1 = 0;
+        int32_t sum_2 = 0;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q4
+            vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+            // load Q8 and multiply it with lower Q4 nibble
+            vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+            vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+            vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+            sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+            // load Q8 and multiply it with upper Q4 nibble
+            vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+            vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+            vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+            sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+            q4 += 32;    q8 += 64;
+
+        }
+
+        sumf += d*(sum_1 + sum_2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            a += 32;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            a += 32; q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#else
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+#ifdef __ARM_FEATURE_DOTPROD
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    float sumf = 0;
+
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x4_t q8bytes;
+
+    float sum_mins = 0.f;
+
+    uint16_t aux16[2];
+    const uint8_t * restrict scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint16_t * restrict a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
+        sum_mins += y[i].d * (float)x[i].d[1] * summi;
+
+        const float d = y[i].d * (float)x[i].d[0];
+
+        const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
+
+#ifdef __ARM_FEATURE_DOTPROD
+        q8bytes = ggml_vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+        const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+        const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+        const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
+        const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
+
+#else
+        q8bytes = ggml_vld1q_s8_x4(q8);
+        q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
+
+        q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+        q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
+        int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
+
+#endif
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf - sum_mins;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
+        const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m256i q4l = _mm256_and_si256(q4bits, m4);
+        const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+        const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+        const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+
+        const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
+
+        const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    uint16_t aux16[2];
+    const uint8_t * scales = (const uint8_t *)aux16;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
+        const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
+        const __m256 vd = _mm256_set1_ps(d);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
+        const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
+        const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
+        const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
+        const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
+        const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
+        const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
+
+        const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
+        const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
+
+    }
+
+    *s = hsum_float_8(acc) - summs;
+
+#elif defined __riscv_v_intrinsic
+
+    uint16_t s16[2];
+    const uint8_t * restrict scales = (const uint8_t *)s16;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+        s16[0] = b[0] & 0x0f0f;
+        s16[1] = (b[0] >> 4) & 0x0f0f;
+
+        sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
+
+        size_t vl = 32;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        // load Q4
+        vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+        // load Q8 and multiply it with lower Q4 nibble
+        vint8m1_t  q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+        vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl);
+        vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl);
+
+        sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1);
+
+        // load Q8 and multiply it with upper Q4 nibble
+        vint8m1_t  q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+        vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+        vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl);
+
+        sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    uint8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint16_t s16[2];
+    const uint8_t * restrict scales = (const uint8_t *)s16;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        uint8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
+        for (int l = 0; l < 32; ++l) a[l+32] = q4[l]  >> 4;
+
+        const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+        s16[0] = b[0] & 0x0f0f;
+        s16[1] = (b[0] >> 4) & 0x0f0f;
+
+        sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            q8 += 16; a += 16;
+            for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
+            q8 += 16; a += 16;
+            const float dl = d * scales[j];
+            for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mone = vdupq_n_u8(1);
+    const uint8x16_t mtwo = vdupq_n_u8(2);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    ggml_int8x16x4_t q5bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        int32_t sumi_mins = vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q5h;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
+
+            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+            sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
+#else
+
+            const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
+
+            const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                           vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
+#endif
+        }
+
+        sumf += d * sumi - dmin * sumi_mins;
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+   for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+#if QK_K == 256
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+#else
+        // TODO
+        const float d = 0, dmin = 0;
+#endif
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+        __m256i hmask = mone;
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+
+            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+            p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m128i mone  = _mm_set1_epi8(1);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+        __m128i hmask = mone;
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int bit = 0;
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+
+            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_0 = _mm_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm_madd_epi16(scale_0, p16_1);
+
+            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_2 = _mm_madd_epi16(scale_1, p16_2);
+            p16_3 = _mm_madd_epi16(scale_1, p16_3);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+    float sums = 0.0;
+
+    size_t vl;
+
+    for (int i = 0; i < nb; ++i) {
+
+        vl = 8;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        vl = 32;
+        int32_t aux32 = 0;
+        int is = 0;
+
+        uint8_t m = 1;
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q5 and Q8
+            vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+            vint8m1_t  q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+
+            // compute mask for addition
+            vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
+            m <<= 1;
+
+            vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
+            m <<= 1;
+
+            vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+            vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+
+            vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+            vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+
+            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+
+            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+            q5 += 32;    q8 += 64;
+
+        }
+
+        vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+        sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+
+    }
+
+    *s = sumf+sums;
+
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mh = vdupq_n_u8(16);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+    ggml_int8x16x4_t q5bytes;
+    ggml_uint8x16x4_t q5h;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const int8_t * sc = x[i].scales;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const uint8x8_t qhbits = vld1_u8(qh);
+
+        const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+        const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
+
+        const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
+        q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
+        q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
+        q5h.val[2] = vbicq_u8(mh, htmp);
+        q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
+
+        q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
+        q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
+        q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
+        q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
+        int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
+        int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
+        int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
+
+        sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+#else
+
+        const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
+
+        const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                       vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
+
+        sumf += d*sumi;
+#endif
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
+        const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
+
+        const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
+        const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
+
+        const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+        const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
+        const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
+        const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
+        const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
+
+        const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mone  = _mm_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+        const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
+        const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
+        const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
+        const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
+
+        int64_t aux64;
+        memcpy(&aux64, x[i].qh, 8);
+        const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
+        const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
+
+        const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
+        const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
+        const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
+        const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
+
+        const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
+        const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
+        const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
+        const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
+        const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
+        const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
+        const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
+        const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
+
+        const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
+        const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * (float)x[i].d;
+        const int8_t * sc = x[i].scales;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load qh
+        vuint8mf4_t qh_x1   = __riscv_vle8_v_u8mf4(qh, 8);
+        vuint8mf2_t qh_x2   = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
+
+        size_t vl = 16;
+
+        // combine both qh_1 and qh_2
+        vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
+
+        vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
+        vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl);
+        vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl);
+        vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
+
+        vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0);
+        vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1);
+        vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2);
+        vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3);
+
+        // load q5
+        vuint8mf2_t q5_x1  = __riscv_vle8_v_u8mf2(q5, vl);
+        vuint8mf2_t q5_x2  = __riscv_vle8_v_u8mf2(q5+16, vl);
+
+        vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl));
+        vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl));
+        vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl));
+        vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl));
+
+        vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl);
+        vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl);
+        vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl);
+        vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl);
+
+        // load Q8 and multiply it with Q5
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0);
+        int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1);
+        int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2);
+        int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3);
+
+        sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t aux8[QK_K];
+    int16_t aux16[16];
+    float   sums [8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 32; ++l) {
+            a[l+ 0] = q4[l] & 0xF;
+            a[l+32] = q4[l]  >> 4;
+        }
+        for (int is = 0; is < 8; ++is) {
+            uint8_t m = 1 << is;
+            for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const int8_t * restrict sc = x[i].scales;
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float dl = d * sc[j];
+            for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l <  8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
+            q8 += 16; a += 16;
+        }
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+#endif
+
+
+#if QK_K == 256
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+    //const int8x16_t  m32s = vdupq_n_s8(32);
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const int8x16_t scales = vld1q_s8(scale);
+        const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+
+        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+        int32_t isum_mins = vaddvq_s32(prod);
+
+        int32_t isum = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 2);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+
+#else
+
+            int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+            scale += 2;
+
+            int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                     vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
+            scale += 2;
+#endif
+
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            shifted = vshrq_n_u8(qhbits.val[0], 4);
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[0], 6);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 6);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+            isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+
+            //for (int l = 0; l < 4; ++l) {
+            //    const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
+            //    isum += vaddvq_s32(p) * *scale++;
+            //}
+#else
+            p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+            p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+            isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+            scale += 2;
+
+            p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+            p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                    vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+            isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
+            scale += 2;
+#endif
+
+        }
+        //sum += isum * d_all * y[i].d;
+        sum += d_all * y[i].d * (isum - 32 * isum_mins);
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+
+            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
+
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m32s = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
+            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
+            const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
+            const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
+            const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
+            const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
+
+            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+
+            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
+            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
+            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
+            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
+            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
+            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
+            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
+            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
+
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi8(shuffle, m2);
+
+            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
+            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
+
+        }
+
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        size_t vl;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        int sum_t = 0;
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            vl = 32;
+
+            // load qh
+            vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+            // load Q6
+            vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+            vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+            vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+            vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+            vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+            vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+            vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+            vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+            vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+            vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+            vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+            vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+            vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+            vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+            vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+            vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+            // load Q8 and take product
+            vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+            vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+            vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+            vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+            vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+            vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+            vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+            vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+            sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q6 += 64;   qh += 32;   q8 += 128;   is=8;
+
+        }
+
+        sumf += d * sum_t;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a  += 128;
+            q4 += 64;
+            qh += 32;
+        }
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#else
+
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    assert(n % QK_K == 0);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+    const int8x16_t  m32s = vdupq_n_s8(32);
+#if defined(__ARM_FEATURE_DOTPROD)
+    const int32x4_t  vzero = vdupq_n_s32(0);
+#endif
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = (float)x[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        int32_t isum = 0;
+
+        uint8x16_t qhbits = vld1q_u8(qh);
+        ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+        ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
+
+        q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
+        uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
+        q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 4);
+        q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+        shifted = vshrq_n_u8(qhbits, 6);
+        q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+        q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+        q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+        q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
+        q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+        isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+#else
+
+        int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+        int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+        isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+
+        int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+        int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+                                 vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+        isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+
+        sum += isum * d_all * y[i].d;
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
+        const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
+
+        const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+        const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+        __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+
+        __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+        __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+
+        p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+        p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+
+        sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(3);
+    const __m128i m32s = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+        const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+        const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+        const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+        const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+        const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+        const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+        const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
+        const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
+        const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
+        const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
+
+        const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
+        const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
+        const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
+        const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
+
+        const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+        const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+        __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
+        __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
+        __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
+        __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
+
+        __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
+        __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
+        __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
+        __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
+
+        p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+        p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+        p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+        p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+
+        p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+        p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+        p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+        p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+
+        sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+        sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = (float)x[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        int32_t isum = 0;
+
+        size_t vl = 16;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // load Q6
+        vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl);
+        vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl);
+
+        // load qh
+        vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl);
+
+        vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+        qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
+        vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
+
+        vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl);
+        vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl);
+        vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl);
+        vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl);
+
+        vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl);
+        vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl);
+        vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl);
+        vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl);
+
+        // load Q8 and take product
+        vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
+        vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
+        vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
+        vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
+
+        vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
+        vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
+        vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
+        vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
+
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2];
+        isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3];
+
+        sumf += isum * d_all * y[i].d;
+
+    }
+
+    *s = sumf;
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int l = 0; l < 16; ++l) {
+            a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            a[l+32] = (int8_t)((q4[l+ 0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            a[l+48] = (int8_t)((q4[l+16] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+        }
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#endif

+ 224 - 0
ggml/src/ggml-quants.h

@@ -0,0 +1,224 @@
+#pragma once
+
+#include "ggml-impl.h"
+
+// GGML internal header
+
+#include <stdint.h>
+#include <stddef.h>
+
+#define QK4_0 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
+#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.5625 bits per weight
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    ggml_fp16_t d;           // super-block scale for quantized scales
+    ggml_fp16_t dmin;        // super-block scale for quantized mins
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[2];
+    ggml_fp16_t d;             // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
+#else
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[12];        // scales, quantized with 6 bits
+    ggml_fp16_t d;             // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+#endif
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d[2];          // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
+typedef struct {
+    ggml_fp16_t d;             // super-block scale for quantized scales
+    ggml_fp16_t dmin;          // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+#endif
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+    ggml_fp16_t d;               // super-block scale
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+    ggml_fp16_t d;               // super-block scale for quantized scales
+    ggml_fp16_t dmin;            // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    ggml_fp16_t d;           // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK_K];       // quants
+    int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+
+// Quantization
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
+
+void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
+void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
+void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
+void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
+void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
+void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
+
+void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
+
+// Dequantization
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
+//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
+
+// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);

文件差異過大導致無法顯示
+ 69 - 1290
ggml/src/ggml.c


+ 15 - 9
ggml/test_unity_cpp.py

@@ -6,26 +6,27 @@
 
 
 import ctypes
 import ctypes
 import functools
 import functools
+import shutil
 from ctypes import c_void_p
 from ctypes import c_void_p
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Iterator, List, Tuple
+from typing import Any, Iterator, Tuple
 
 
 import fairseq2.nn
 import fairseq2.nn
 import fairseq2.nn.transformer
 import fairseq2.nn.transformer
 import numpy as np
 import numpy as np
 import pytest
 import pytest
+import requests  # type: ignore
 import torch
 import torch
-import torchaudio
+import torchaudio  # type: ignore
+from ctypes_utils import NULLPTR, Ptr
 from fairseq2.data.audio import WaveformToFbankConverter
 from fairseq2.data.audio import WaveformToFbankConverter
-from seamless_communication.inference.generator import SequenceGeneratorOptions
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
-from seamless_communication.inference.translator import Modality, Translator
+from ggml_convert import convert_model, read_layer_config
 
 
 import ggml
 import ggml
-from ctypes_utils import NULLPTR, Ptr
 from ggml import NativeObj
 from ggml import NativeObj
-from ggml_convert import convert_model, read_layer_config
-import requests
+from seamless_communication.inference.generator import SequenceGeneratorOptions
+from seamless_communication.inference.translator import Modality, Translator
 
 
 Ctx = ggml.ggml_context_p
 Ctx = ggml.ggml_context_p
 
 
@@ -56,6 +57,10 @@ def _ctx() -> Iterator[Ctx]:
                 no_alloc=True,
                 no_alloc=True,
             )
             )
         )
         )
+
+        # Create 'dot' folder for temporary dump of ggml graphs
+        (Path(__file__).parent / "dot").mkdir(exist_ok=True)
+
         with torch.inference_mode():
         with torch.inference_mode():
             yield ctx
             yield ctx
     finally:
     finally:
@@ -87,6 +92,7 @@ def load_pt_model() -> Any:
 
 
 
 
 def download_sample_audio() -> Any:
 def download_sample_audio() -> Any:
+    Path(DATA).mkdir(exist_ok=True)
     response = requests.get(TEST_AUDIO_SAMPLE_URL, stream=True)
     response = requests.get(TEST_AUDIO_SAMPLE_URL, stream=True)
     with open(DATA / "LJ037-0171_sr16k.wav", "wb") as file:
     with open(DATA / "LJ037-0171_sr16k.wav", "wb") as file:
         for chunk in response.iter_content(chunk_size=1024):
         for chunk in response.iter_content(chunk_size=1024):
@@ -180,7 +186,7 @@ def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
     y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
     y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
     gx = ggml.from_numpy(ctx, x)
     gx = ggml.from_numpy(ctx, x)
     gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
     gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
-    gf = ggml.build_and_compute(ctx, gy, dump="dot/test_Linear_forward.dot")
+    ggml.build_and_compute(ctx, gy, dump="dot/test_Linear_forward.dot")
 
 
     y = ggml.to_numpy(gy)
     y = ggml.to_numpy(gy)
     assert np.allclose(y_exp, y, atol=1e-5)
     assert np.allclose(y_exp, y, atol=1e-5)
@@ -613,7 +619,7 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
                 "text_decoder_frontend.pos_encoder",
                 "text_decoder_frontend.pos_encoder",
                 gseq,
                 gseq,
             )
             )
-            gf = ggml.build_and_compute(ctx, gy, dump=t == 1)
+            ggml.build_and_compute(ctx, gy, dump=t == 1)
             y = ggml.to_numpy(gy)
             y = ggml.to_numpy(gy)
 
 
             y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
             y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()

文件差異過大導致無法顯示
+ 615 - 145
ggml/third_party_ggml.py


部分文件因文件數量過多而無法顯示