瀏覽代碼

load_fairseq2_ggml_file

Guillaume Wenzek 1 年之前
父節點
當前提交
772f90dfdc

+ 52 - 21
ggml/examples/unity/fairseq2.cpp

@@ -1,16 +1,22 @@
 #include "ggml.h"
 #include "fairseq2.h"
 
-fairseq2_model fairseq2_model_init(ggml_context* ctx, void* hparams) {
-    // TODO? allocate the model in the ggml_context
-    fairseq2_model model;
-    model.ctx = ctx;
-    model.hparams = hparams;
-    // TODO:
-    // init_model_tensors(model);
+/// allocate the fairseq2 model and hyperparameters
+extern "C" fairseq2_model* fairseq2_model_alloc() {
+    // pre-allocate some memory to write hyperparameters and tensors pointers
+    auto* model = new fairseq2_model;
+    model->hparams = new std::uint8_t[8 * 1024];
+    model->arch = new std::uint64_t[16 * 1024];  // max tensors allowed
     return model;
 };
 
+extern "C" void fairseq2_model_free(fairseq2_model* model) {
+    delete (std::uint64_t*)(model->arch);
+    delete (std::uint8_t*)model->hparams;
+    delete model;
+};
+
+
 // Linear
 
 std::size_t Linear_size(int32_t input_dim, int32_t output_dim)
@@ -20,18 +26,18 @@ std::size_t Linear_size(int32_t input_dim, int32_t output_dim)
 };
 
 void Linear_init(
-    Linear* self,
+    Linear& self,
     fairseq2_model& model,
     const std::string &prefix,
     int input_dim,
     int output_dim,
     bool bias
 ) {
-    self->weight = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, output_dim, input_dim);
-    model.tensors[prefix + ".weight"] = self->weight;
+    self.weight = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, output_dim, input_dim);
+    model.tensors[prefix + ".weight"] = self.weight;
     if (bias) {
-        self->bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, output_dim);
-        model.tensors[prefix + ".inner_proj.bias"] = self->bias;
+        self.bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, output_dim);
+        model.tensors[prefix + ".inner_proj.bias"] = self.bias;
     }
 }
 
@@ -43,15 +49,15 @@ std::size_t LayerNorm_size(int32_t dim)
 };
 
 void LayerNorm_init(
-    LayerNorm* self,
+    LayerNorm& self,
     fairseq2_model& model,
     const std::string &prefix,
     int dim
 ) {
-    self->weight = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
-    model.tensors[prefix + ".weight"] = self->weight;
-    self->bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
-    model.tensors[prefix + ".bias"] = self->bias;
+    self.weight = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
+    model.tensors[prefix + ".weight"] = self.weight;
+    self.bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
+    model.tensors[prefix + ".bias"] = self.bias;
 }
 
 std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim)
@@ -60,20 +66,45 @@ std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim)
 };
 
 void StandardFeedForwardNetwork_init(
-    StandardFeedForwardNetwork* self,
+    StandardFeedForwardNetwork& self,
     fairseq2_model& model,
     const std::string &prefix,
     int model_dim,
     int inner_dim
 ) {
-    Linear_init(&self->inner_proj, model, prefix + ".inner_proj", model_dim, inner_dim, true);
-    LayerNorm_init(&self->inner_layer_norm, model, prefix + ".inner_layer_norm", inner_dim);
-    Linear_init(&self->output_proj, model, prefix + ".output_proj", inner_dim, model_dim, true);
+    Linear_init(self.inner_proj, model, prefix + ".inner_proj", model_dim, inner_dim, true);
+    LayerNorm_init(self.inner_layer_norm, model, prefix + ".inner_layer_norm", inner_dim);
+    Linear_init(self.output_proj, model, prefix + ".output_proj", inner_dim, model_dim, true);
 }
 
 ggml_tensor* StandardFeedForwardNetwork_forward(
     StandardFeedForwardNetwork* self,
     ggml_tensor* seqs
 ) {
+
     return seqs;
 }
+
+void MultiheadAttention_init(
+    MultiheadAttention& self,
+    fairseq2_model& model,
+    const std::string &prefix,
+    int model_dim,
+    int num_heads
+) {
+    int bias = true;
+    int num_key_value_heads = num_heads;
+    int head_dim = model_dim / num_heads;
+
+    Linear_init(self.q_proj, model, prefix + ".q_proj", model_dim, model_dim, bias);
+    Linear_init(self.k_proj, model, prefix + ".k_proj", model_dim, head_dim * num_key_value_heads, bias);
+    Linear_init(self.v_proj, model, prefix + ".v_proj", model_dim, model_dim, bias);
+
+    // (H, 1, K_h)
+    self.bias_k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, num_heads, 1, head_dim * num_key_value_heads/ num_heads);
+    // (H, 1, V_h)
+    self.bias_v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, num_heads, 1, model_dim / num_heads);
+}
+
+
+// void TransformerDecoderLayer_init(TransformerDecoderLayer& self);

+ 54 - 4
ggml/examples/unity/fairseq2.h

@@ -1,15 +1,21 @@
+#pragma once
+
 #include <map>
 #include <string>
+#include <vector>
 #include "ggml.h"
 
 
 struct fairseq2_model {
     ggml_context* ctx;
     std::map<std::string, struct ggml_tensor *> tensors;
+    void* arch;
     void* hparams;
 };
 
-fairseq2_model fairseq2_model_alloc(ggml_context* ctx, void* hparams);
+/// allocate the fairseq2 model and hyperparameters
+extern "C" fairseq2_model* fairseq2_model_alloc();
+extern "C" void fairseq2_model_free(fairseq2_model* model);
 
 struct Linear {
     struct ggml_tensor* weight;  // out_dim * in_dim
@@ -17,7 +23,7 @@ struct Linear {
 };
 
 std::size_t Linear_size(int32_t input_dim, int32_t output_dim);
-void Linear_init(Linear* self,fairseq2_model& model, const std::string &prefix, int input_dim, int output_dim, bool bias);
+void Linear_init(Linear& self,fairseq2_model& model, const std::string &prefix, int input_dim, int output_dim, bool bias);
 
 // LayerNorm
 
@@ -28,7 +34,23 @@ struct LayerNorm {
 
 std::size_t LayerNorm_size(int32_t dim);
 
-void LayerNorm_init(LayerNorm* self, fairseq2_model& model, const std::string &prefix, int dim);
+void LayerNorm_init(LayerNorm& self, fairseq2_model& model, const std::string &prefix, int dim);
+
+// ConformerConvolution
+// struct ConformerConvolution {
+//     // pointwise_conv1: Conv1d
+//     // pointwise_conv1_activation: GLU
+//     // depthwise_conv: Conv1d
+//     // batch_norm: BatchNorm1d
+//     // depthwise_activation: Module
+//     // pointwise_conv2: Conv1d
+// };
+
+// std::size_t ConformerConvolution_size(int32_t dim);
+
+// void ConformerConvolution_init(ConformerConvolution* self, fairseq2_model& model, const std::string &prefix, int dim);
+
+
 
 struct MultiheadAttention {
     // num_key_value_heads: int
@@ -43,6 +65,8 @@ struct MultiheadAttention {
     struct Linear output_proj;
 };
 
+void MultiheadAttention_init(MultiheadAttention& self, fairseq2_model& model, const std::string &prefix, int model_dim, int num_heads);
+
 struct StandardFeedForwardNetwork {
     struct Linear inner_proj; // ffn_inner_dim x model_dim
     // inner_activation -> Relu for unity
@@ -54,7 +78,7 @@ struct StandardFeedForwardNetwork {
 std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim);
 
 void StandardFeedForwardNetwork_init(
-    StandardFeedForwardNetwork* self,
+    StandardFeedForwardNetwork& self,
     fairseq2_model& model,
     const std::string &prefix,
     int model_dim,
@@ -66,6 +90,15 @@ ggml_tensor* StandardFeedForwardNetwork_forward(
     ggml_tensor* seqs
 );
 
+// Transformer
+
+enum TransformerNormOrder {
+    TRANSFORMER_NORM_ORDER_POST = 0,
+    TRANSFORMER_NORM_ORDER_PRE = 1,
+    TRANSFORMER_NORM_ORDER_PRE_WITH_NORMFORMER = 2
+};
+
+
 struct TransformerDecoderLayer {
     struct MultiheadAttention self_attn;
     struct LayerNorm self_attn_norm;
@@ -80,3 +113,20 @@ struct TransformerDecoderLayer {
     struct LayerNorm ffn_layer_norm;
     // norm_order: TransformerNormOrder
 };
+
+void TransformerDecoderLayer_init();
+
+
+struct TransformerDecoder {
+    std::vector<TransformerDecoderLayer> layers;
+    struct LayerNorm layer_norm;
+};
+
+// std::size_t TransformerDecoder_size(int32_t input_dim, int32_t output_dim);
+// void TransformerDecoder_init(TransformerEncoder* self, fairseq2_model& model, const std::string &prefix, TransformerNormOrder norm_order);
+
+
+// std::size_t TransformerEncoder_size(int32_t input_dim, int32_t output_dim);
+// void TransformerEncoder_init(TransformerEncoder* self, fairseq2_model& model, const std::string &prefix, TransformerNormOrder norm_order);
+
+//

+ 1 - 1
ggml/examples/unity/model_loader.cpp

@@ -93,7 +93,7 @@ model_loader::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
 std::string
 model_loader::get_name(std::ifstream& fin)
 {
-    int32_t length;
+    std::uint32_t length;
     fin.read(reinterpret_cast<char *>(&length), sizeof(length));
     std::string name(length, 0);
     fin.read(&name[0], length);

+ 17 - 22
ggml/examples/unity/model_loader.h

@@ -21,17 +21,13 @@ class model_loader {
 public:
     virtual ~model_loader() {};
 
-    virtual fairseq2_model& alloc_model(ggml_context* ctx) = 0;
-
     virtual void load_hparams(fairseq2_model& model, std::ifstream &fin) = 0;
 
-    virtual void load_model_weights(fairseq2_model &model, std::ifstream &fin);
+    virtual std::size_t compute_context_size(void *raw_hparams) = 0;
 
-    virtual std::size_t
-    compute_context_size(void *raw_hparams) = 0;
+    virtual void tensors_alloc(fairseq2_model& model) = 0;
 
-    virtual void
-    init_model_tensors(fairseq2_model &model) = 0;
+    void load_model_weights(fairseq2_model &model, std::ifstream &fin);
 
 private:
     ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
@@ -41,25 +37,24 @@ private:
     std::string get_name(std::ifstream &fin);
 };
 
-/// allocate the fairseq2 model and hyperparameters into the ggml context
-template<typename T>
-fairseq2_model& alloc_fairseq2_model(ggml_context* ctx) {
-    auto hparams = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(T))->data;
-    auto& model = (fairseq2_model&)ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(fairseq2_model))->data;
-
-    model.ctx = ctx;
-    model.hparams = hparams;
-    return model;
-};
-
 std::ifstream open_ggml_file(const char* fname);
 
 template<typename T>
-fairseq2_model& load_fairseq2_ggml_file(ggml_context* ctx, const char* fname) {
+void load_fairseq2_ggml_file(fairseq2_model& model, const char* fname) {
     T loader;
-    fairseq2_model& model = loader.alloc_model(ctx);
     auto fin = open_ggml_file(fname);
     loader.load_hparams(model, fin);
-    loader.load_model_weights(model, fin);
-    return model;
+
+    std::size_t ctx_size = loader.compute_context_size(model.hparams);
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx_size,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ false,
+    };
+    model.ctx = ggml_init(params);
+
+    // TODO: should we delay weights loading/allocating ?
+    loader.tensors_alloc(model);
+    loader.load_model_weights(model, fin);;
 }
+

+ 3 - 3
ggml/examples/unity/unity.cpp

@@ -180,8 +180,6 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
         return false;
     }
 
-    auto & ctx = model.ctx;
-
     size_t ctx_size = 0;
 
     {
@@ -245,6 +243,7 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             return false;
         }
     }
+    auto & ctx = model.ctx;
 
     // prepare memory for the weights
     {
@@ -494,7 +493,8 @@ extern "C" ggml_cgraph* unity_audio_encoder_graph(
     // const int n_text_vocab = hparams.n_text_vocab;
     const int kernel_size = 31;
 
-    // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
+    // since we are using ggml-alloc, this buffer only needs enough space to hold
+    // the ggml_tensor and ggml_cgraph structs, but not the tensor data
     static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
     static std::vector<uint8_t> buf(buf_size);
 

+ 45 - 66
ggml/examples/unity/unity_model_loader.cpp

@@ -12,68 +12,9 @@
 
 #include "unity_model_loader.h"
 
-struct audio_enc_layer {
-    struct LayerNorm self_attn_layer_norm;
-
-    struct ggml_tensor * self_attn_linear_k_w;
-    struct ggml_tensor * self_attn_linear_k_b;
-    struct ggml_tensor * self_attn_linear_q_w;
-    struct ggml_tensor * self_attn_linear_q_b;
-    struct ggml_tensor * self_attn_linear_v_w;
-    struct ggml_tensor * self_attn_linear_v_b;
-    struct ggml_tensor * self_attn_linear_out_w;
-    struct ggml_tensor * self_attn_linear_out_b;
-    struct ggml_tensor * self_attn_linear_pos_w;
-
-    struct ggml_tensor * self_attn_pos_bias_u;
-    struct ggml_tensor * self_attn_pos_bias_v;
-
-    struct LayerNorm conv_layer_norm;
-
-    struct ggml_tensor * conv_pointwise_conv1_w;
-    struct ggml_tensor * conv_depthwise_conv_w;
-    struct ggml_tensor * conv_batch_norm_w;
-    struct ggml_tensor * conv_batch_norm_b;
-    struct ggml_tensor * conv_batch_norm_running_mean;
-    struct ggml_tensor * conv_batch_norm_running_var;
-    struct ggml_tensor * conv_batch_norm_num_batches_tracked;
-    struct ggml_tensor * conv_pointwise_conv2_w;
-
-    struct LayerNorm ffn1_layer_norm;
-    struct ggml_tensor * ffn1_w1;
-    struct ggml_tensor * ffn1_b1;
-    struct ggml_tensor * ffn1_w2;
-    struct ggml_tensor * ffn1_b2;
-
-    struct LayerNorm ffn2_layer_norm;
-    struct ggml_tensor * ffn2_w1;
-    struct ggml_tensor * ffn2_b1;
-    struct ggml_tensor * ffn2_w2;
-    struct ggml_tensor * ffn2_b2;
-
-    struct LayerNorm final_layer_norm;
-};
-
-
-struct unity_model {
-    unity_hparams* hparams;
-    // audio encoder
-    struct ggml_tensor * post_extract_proj_w;
-    struct ggml_tensor * post_extract_proj_b;
-    struct ggml_tensor * audio_enc_pos_conv_wg;
-    struct ggml_tensor * audio_enc_pos_conv_wv;
-    struct ggml_tensor * audio_enc_pos_conv_b;
-    struct LayerNorm audio_enc_layer_norm;
-    struct ggml_tensor * audio_enc_pos_enc_w;
-    struct LayerNorm layer_norm;
-    struct ggml_tensor * memory_k;
-    struct ggml_tensor * memory_v;
-    std::vector<audio_enc_layer> audio_enc_layers;
-};
-
 void unity_model_loader::load_hparams(fairseq2_model& model, std::ifstream &fin)
 {
-    auto hparams = (unity_hparams&)model.hparams;
+    auto& hparams = (unity_hparams&)model.hparams;
 
     fin.read((char*) &hparams.model_dim, sizeof(hparams.model_dim));
     fin.read((char*) &hparams.w2v2_encoder_config__model_dim, sizeof(hparams.w2v2_encoder_config__model_dim));
@@ -134,14 +75,52 @@ unity_model_loader::compute_context_size(void* raw_hparams)
 {
     // TODO
     auto hparams = (unity_hparams&)raw_hparams;
+    return hparams.model_dim * 1024 * 100;
 };
 
-void
-unity_model_loader::init_model_tensors(fairseq2_model &model)
+struct UnityArch {
+    struct TransformerDecoder text_decoder;
+};
+
+void unity_model_loader::tensors_alloc(fairseq2_model &model)
 {
-    // TODO
+    auto hparams = (unity_hparams&)model.hparams;
+    auto& arch = (UnityArch&)model.arch;
+    const auto ctx = model.ctx;
+    auto tensors = model.tensors;
+
+    const auto vocab_size = hparams.nllb_config__vocabulary_size;
+    const auto model_dim = hparams.nllb_config__model_dim;
+
+    // This can be simplified by adding syntax sugar
+
+    // frontend
+    // arch.frontend_embed_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, vocab_size, dim);
+    // tensor_map["text_decoder_frontend.embed.weight"] = arch.frontend_embed_w;
+
+    // layers
+    {
+        const auto n_layers = hparams.nllb_config__num_decoder_layers;
+        arch.text_decoder.layers = std::vector<TransformerDecoderLayer>(n_layers);
+        auto layers = arch.text_decoder.layers;
+        auto num_heads = hparams.nllb_config__num_decoder_attn_heads;
+        for (int i = 0; i < n_layers; ++i) {
+            auto prefix = "text_decoder.layers." + std::to_string(i);
+            MultiheadAttention_init(layers[i].self_attn, model, prefix + "self_attn", model_dim, num_heads);
+            LayerNorm_init(layers[i].self_attn_norm, model, prefix + "self_attn_norm", model_dim);
+        }
+    }
+
+    // // layer_norm
+    // arch.layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    // tensor_map["text_decoder.layer_norm.weight"] = arch.layer_norm_w;
+    // arch.layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    // tensor_map["text_decoder.layer_norm.bias"] = arch.layer_norm_b;
 };
 
-// extern "C" fairseq2_model<unity_hparams>* unity_model_load2(const char* fname, ggml_context* ctx) {
-//     return nullptr;
-// }
+extern "C" void load_unity_ggml_file(fairseq2_model& model, const char* fname) {
+    return load_fairseq2_ggml_file<unity_model_loader>(model, fname);
+}
+
+
+

+ 2 - 45
ggml/examples/unity/unity_model_loader.h

@@ -155,7 +155,7 @@ void init_attention_head(
     auto hparams = (unity_hparams&)model_ctx.hparams;
     init_attention_layer(head->self_attn, model_ctx, prefix + ".self_attn");
     init_attention_layer(head->encoder_decoder_attn, model_ctx, prefix + ".encoder_decoder_attn");
-    StandardFeedForwardNetwork_init(head->ffn, model_ctx, prefix + ".ffn", hparams.nllb_config__model_dim, hparams.nllb_config__ffn_inner_dim);
+    StandardFeedForwardNetwork_init((StandardFeedForwardNetwork&)(head->ffn), model_ctx, prefix + ".ffn", hparams.nllb_config__model_dim, hparams.nllb_config__ffn_inner_dim);
 }
 
 // TODO: attention_head_compute_graph
@@ -185,54 +185,11 @@ std::size_t compute_context_size(void* raw_hparams)
         + overhead;
 };
 
-void init_model_tensors(
-    text_decoder &model,
-    fairseq2_model &model_ctx,
-    const std::string &prefix)
-{
-    const auto ctx = model_ctx.ctx;
-    auto hparams = (unity_hparams&)model_ctx.hparams;
-    auto tensor_map = model_ctx.tensors;
-
-    const auto vocab_size = hparams.nllb_config__vocabulary_size;
-    const auto dim = hparams.nllb_config__model_dim;
-    const auto n_layers = hparams.nllb_config__num_decoder_layers;
-
-    // This can be simplified by adding syntax sugar
-
-    // frontend
-    model.frontend_embed_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, vocab_size, dim);
-    tensor_map["text_decoder_frontend.embed.weight"] = model.frontend_embed_w;
-
-    // layers
-    model.multi_head.resize(n_layers);
-    for (int i = 0; i < n_layers; ++i) {
-        auto head = model.multi_head[i];
-        auto prefix = "text_decoder.layers." + std::to_string(i);
-        init_attention_head(head, model_ctx, prefix);
-    }
-
-    // layer_norm
-    model.layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
-    tensor_map["text_decoder.layer_norm.weight"] = model.layer_norm_w;
-    model.layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
-    tensor_map["text_decoder.layer_norm.bias"] = model.layer_norm_b;
-};
-
-
 class unity_model_loader: public model_loader {
     public:
-    fairseq2_model& alloc_model(ggml_context* ctx) {
-        return alloc_fairseq2_model<unity_hparams>(ctx);
-    };
-
     void load_hparams(fairseq2_model& model, std::ifstream &fin);
 
     std::size_t compute_context_size(void* raw_hparams);
 
-    void init_model_tensors(fairseq2_model &model);
+    void tensors_alloc(fairseq2_model &model);
 };
-
-extern "C" fairseq2_model& load_unity_ggml_file(ggml_context* ctx, const char* fname) {
-    return load_fairseq2_ggml_file<unity_model_loader>(ctx, fname);
-}

+ 11 - 6
ggml/ggml.py

@@ -175,6 +175,10 @@ def GptVocab() -> NativeObj:
     return NativeObj("gpt_vocab")
 
 
+def Fairseq2Model() -> NativeObj:
+    return NativeObj("fairseq2_model")
+
+
 lib.unity_model_load.argtypes = [ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p]
 
 
@@ -189,15 +193,16 @@ def unity_model_load(model_file: Path) -> Tuple[NativeObj, NativeObj]:
     return model, vocab
 
 
-lib.load_unity_ggml_file.argtypes = [ctypes.c_char_p, ggml_context_p]
-lib.load_unity_ggml_file.restype = ctypes.c_void_p
+lib.load_unity_ggml_file.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
+lib.load_unity_ggml_file.restype = None
 
 
-def load_unity_ggml_file(model_file: Path) -> ctypes.c_void_p:
-    model = UnityModel()
-    return lib.load_unity_ggml_file(
-        ctypes.create_string_buffer(str(model_file).encode("utf-8")), ctx
+def load_unity_ggml_file(model_file: Path) -> NativeObj:
+    model = Fairseq2Model()
+    lib.load_unity_ggml_file(
+        model.ptr, ctypes.create_string_buffer(str(model_file).encode("utf-8"))
     )
+    return model
 
 
 lib.unity_audio_encoder_graph.argtypes = [ctypes.c_void_p, ctypes.c_void_p]

+ 10 - 12
ggml/test_unity_cpp.py

@@ -123,12 +123,12 @@ def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
 
 
 def test_unity_model_load(ctx: Ctx) -> None:
-    model, vocab = ggml.unity_model_load(
-        UNITY_MODELS / "unity-large/ggml-model.bin"
-    )
+    model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
     print(model, vocab)
 
-    example = ggml.from_file(ctx, UNITY_MODELS / "unity-large/seqs_before_conformer_block.bin", (1024, 137))
+    example = ggml.from_file(
+        ctx, UNITY_MODELS / "unity-large/seqs_before_conformer_block.bin", (1024, 137)
+    )
 
     with ggml.MeasureArena() as arena:
         graph = ggml.unity_audio_encoder_graph(model, example)
@@ -136,7 +136,9 @@ def test_unity_model_load(ctx: Ctx) -> None:
         mem_size = ggml.ggml_allocr_alloc_graph(arena.ptr, graph) + ggml.GGML_MEM_ALIGN
 
     with ggml.FixedSizeArena(mem_size) as allocr:
-        print(f"unity_audio_encoder_graph: compute buffer size: {mem_size/1024/1024} MB")
+        print(
+            f"unity_audio_encoder_graph: compute buffer size: {mem_size/1024/1024} MB"
+        )
 
         eval_res_ptr = ggml.unity_eval(allocr, model, example, 1)
         eval_res = eval_res_ptr.contents
@@ -146,10 +148,6 @@ def test_unity_model_load(ctx: Ctx) -> None:
         assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
 
 
-# def test_unity_model_load2(ctx: Ctx) -> None:
-#     model = ggml.unity_model_load(
-#         UNITY_MODELS / "unity-large/ggml-model.bin"
-#     )
-#     print(model, vocab)
-#
-#
+def test_unity_model_load2(ctx: Ctx) -> None:
+    model = ggml.load_unity_ggml_file(UNITY_MODELS / "unity-large/ggml-model.bin")
+    print(model)