Browse Source

move layers to fairseq2.cpp

Guillaume Wenzek 1 year ago
parent
commit
506dee42d8

+ 2 - 2
ggml/Makefile

@@ -1,12 +1,12 @@
 build: build/src/libggml.so build/bin/unity
 
-build/src/libggml.so: examples/unity/*.cpp
+build/src/libggml.so: examples/unity/*.h examples/unity/*.cpp
 	mkdir -p build
 	cd build; cmake -DBUILD_SHARED_LIBS=On -DCMAKE_BUILD_TYPE=Debug ..
 	cd build; make -j4 ggml
 	find build/ -iname '*.so'
 
-build/bin/unity: examples/unity/*.cpp
+build/bin/unity: examples/unity/*.h examples/unity/*.cpp
 	mkdir -p build
 	cd build; cmake ..
 	cd build; make -j4 unity

+ 1 - 0
ggml/examples/unity/CMakeLists.txt

@@ -5,6 +5,7 @@ target_include_directories(unity PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
 target_link_libraries(unity PRIVATE ggml common common-ggml)
 target_sources(unity
     PRIVATE
+        fairseq2.cpp
         model_loader.cpp
         unity_model_loader.cpp
 )

+ 79 - 0
ggml/examples/unity/fairseq2.cpp

@@ -0,0 +1,79 @@
+#include "ggml.h"
+#include "fairseq2.h"
+
+fairseq2_model fairseq2_model_init(ggml_context* ctx, void* hparams) {
+    // TODO? allocate the model in the ggml_context
+    fairseq2_model model;
+    model.ctx = ctx;
+    model.hparams = hparams;
+    // TODO:
+    // init_model_tensors(model);
+    return model;
+};
+
+// Linear
+
+std::size_t Linear_size(int32_t input_dim, int32_t output_dim)
+{
+    return (input_dim * output_dim * ggml_type_size(GGML_TYPE_F32)) // weight
+        + (output_dim * ggml_type_size(GGML_TYPE_F32)); // bias
+};
+
+void Linear_init(
+    Linear* self,
+    fairseq2_model& model,
+    const std::string &prefix,
+    int input_dim,
+    int output_dim,
+    bool bias
+) {
+    self->weight = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, output_dim, input_dim);
+    model.tensors[prefix + ".weight"] = self->weight;
+    if (bias) {
+        self->bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, output_dim);
+        model.tensors[prefix + ".inner_proj.bias"] = self->bias;
+    }
+}
+
+// LayerNorm
+
+std::size_t LayerNorm_size(int32_t dim)
+{
+    return 2 * dim * ggml_type_size(GGML_TYPE_F32); // weight and bias
+};
+
+void LayerNorm_init(
+    LayerNorm* self,
+    fairseq2_model& model,
+    const std::string &prefix,
+    int dim
+) {
+    self->weight = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
+    model.tensors[prefix + ".weight"] = self->weight;
+    self->bias = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, dim);
+    model.tensors[prefix + ".bias"] = self->bias;
+}
+
+std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim)
+{
+    return LayerNorm_size(dim) + Linear_size(dim, inner_dim) + Linear_size(inner_dim, dim);
+};
+
+void StandardFeedForwardNetwork_init(
+    StandardFeedForwardNetwork* self,
+    fairseq2_model& model,
+    const std::string &prefix,
+    int model_dim,
+    int inner_dim
+) {
+    Linear_init(&self->inner_proj, model, prefix + ".inner_proj", model_dim, inner_dim, true);
+    LayerNorm_init(&self->inner_layer_norm, model, prefix + ".inner_layer_norm", inner_dim);
+    Linear_init(&self->output_proj, model, prefix + ".output_proj", inner_dim, model_dim, true);
+}
+
+ggml_tensor* StandardFeedForwardNetwork_forward(
+    StandardFeedForwardNetwork* self,
+    ggml_tensor* seqs
+) {
+    return seqs;
+}

+ 82 - 0
ggml/examples/unity/fairseq2.h

@@ -0,0 +1,82 @@
+#include <map>
+#include <string>
+#include "ggml.h"
+
+
+struct fairseq2_model {
+    ggml_context* ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+    void* hparams;
+};
+
+fairseq2_model fairseq2_model_alloc(ggml_context* ctx, void* hparams);
+
+struct Linear {
+    struct ggml_tensor* weight;  // out_dim * in_dim
+    struct ggml_tensor* bias;  // out_dim
+};
+
+std::size_t Linear_size(int32_t input_dim, int32_t output_dim);
+void Linear_init(Linear* self,fairseq2_model& model, const std::string &prefix, int input_dim, int output_dim, bool bias);
+
+// LayerNorm
+
+struct LayerNorm {
+    struct ggml_tensor* weight;  // model_dim
+    struct ggml_tensor* bias;  // model_dim
+};
+
+std::size_t LayerNorm_size(int32_t dim);
+
+void LayerNorm_init(LayerNorm* self, fairseq2_model& model, const std::string &prefix, int dim);
+
+struct MultiheadAttention {
+    // num_key_value_heads: int
+    struct Linear q_proj;
+    struct Linear k_proj;
+    struct Linear v_proj;
+    // pos_encoder: Optional[PositionEncoder]
+    struct ggml_tensor* bias_k;
+    struct ggml_tensor* bias_v;
+    // add_zero_attn: bool
+    // head_scale_weight: Optional[Parameter]
+    struct Linear output_proj;
+};
+
+struct StandardFeedForwardNetwork {
+    struct Linear inner_proj; // ffn_inner_dim x model_dim
+    // inner_activation -> Relu for unity
+    // struct Dropout inner_dropout;
+    struct LayerNorm inner_layer_norm; // ffn_inner_dim
+    struct Linear output_proj; // model_dim x ffn_inner_dim
+};
+
+std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim);
+
+void StandardFeedForwardNetwork_init(
+    StandardFeedForwardNetwork* self,
+    fairseq2_model& model,
+    const std::string &prefix,
+    int model_dim,
+    int inner_dim
+);
+
+ggml_tensor* StandardFeedForwardNetwork_forward(
+    StandardFeedForwardNetwork* self,
+    ggml_tensor* seqs
+);
+
+struct TransformerDecoderLayer {
+    struct MultiheadAttention self_attn;
+    struct LayerNorm self_attn_norm;
+    // self_attn_dropout: Optional[Dropout]
+    struct LayerNorm self_attn_layer_norm;
+    struct MultiheadAttention encoder_decoder_attn;
+    // encoder_decoder_dropout: Optional[Dropout]
+    struct LayerNorm encoder_decoder_attn_layer_norm;
+    struct StandardFeedForwardNetwork ffn;
+    // ffn_dropout: Optional[Dropout]
+    // residual_scale: Optional[Parameter]
+    struct LayerNorm ffn_layer_norm;
+    // norm_order: TransformerNormOrder
+};

+ 17 - 69
ggml/examples/unity/model_loader.cpp

@@ -1,74 +1,26 @@
-// Copyright (c) Meta Platforms, Inc. and affiliates.
-// All rights reserved.
-//
-// This source code is licensed under the license found in the
-// LICENSE file in the root directory of this source tree.
-
-
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-
-#include "common.h"
-#include "common-ggml.h"
-
-#include <iostream>
-#include <stdexcept>
-
+#include <string>
 #include "model_loader.h"
 
+std::ifstream open_ggml_file(const char* fname) {
+    printf("%s: loading model from '%s'\n", __func__, fname);
 
-template<typename T>
-void
-model_loader<T>::load_ggml_file(const std::string &fname, fairseq2_model<T> &model)
-{
-    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
-
-    auto fin = std::ifstream(fname, std::ios::binary);
+    auto fin = std::ifstream(std::string(fname), std::ios::binary);
     if (!fin) {
-        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname);
         throw std::invalid_argument("failed to open file."); // TODO Merge error message.
     }
 
-    if (!verify_magic(fin)) {
-        fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+    std::uint32_t magic;
+    fin.read((char*)&magic, 4);
+    if (magic != GGML_FILE_MAGIC) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad header %d)\n", __func__, fname, magic);
         throw std::invalid_argument("failed to open file."); // TODO Merge error message.
     }
+    return fin;
+}
 
-    load_hparams(fin, model.hparams);
-    init_model(model);
-    load_model_weights(fin, model);
-};
-
-template<typename T>
-bool 
-model_loader<T>::verify_magic(std::ifstream &fin)
-{
-    uint32_t magic;
-    fin.read((char *) &magic, sizeof(magic));
-
-    return magic == GGML_FILE_MAGIC;
-};
-
-template<typename T>
 void
-model_loader<T>::init_model(fairseq2_model<T> &model)
-{
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ compute_context_size(model.hparams),
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ false,
-    };
-
-    model.ctx = ggml_init(params);
-    if (!model.ctx)
-        throw std::runtime_error("ggml_init() failed.");
-
-    init_model_tensors(model);
-};
-
-template<typename T>
-void
-model_loader<T>::load_model_weights(std::ifstream &fin, fairseq2_model<T> &model)
+model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
 {
     size_t total_size = 0;
     while (!fin.eof()) {
@@ -80,13 +32,12 @@ model_loader<T>::load_model_weights(std::ifstream &fin, fairseq2_model<T> &model
     printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
 };
 
-template<typename T>
 ggml_tensor *
-model_loader<T>::next_tensor(std::ifstream &fin, fairseq2_model<T> &model)
+model_loader::next_tensor(std::ifstream &fin, fairseq2_model &model)
 {
     auto name = get_name(fin);
     std::cout << "loading tensor: " << name << std::endl;
-   
+
     if (model.tensors.find(name) == model.tensors.end()) {
         fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
         throw std::invalid_argument("failed to open file."); // TODO Merge error message.
@@ -95,9 +46,8 @@ model_loader<T>::next_tensor(std::ifstream &fin, fairseq2_model<T> &model)
     return model.tensors[name];
 };
 
-template<typename T>
 void
-model_loader<T>::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
+model_loader::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
 {
     int32_t n_dims;
     int32_t ttype;
@@ -140,15 +90,13 @@ model_loader<T>::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
     fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
 };
 
-
-template<typename T>
 std::string
-model_loader<T>::get_name(std::ifstream& fin)
+model_loader::get_name(std::ifstream& fin)
 {
     int32_t length;
     fin.read(reinterpret_cast<char *>(&length), sizeof(length));
     std::string name(length, 0);
     fin.read(&name[0], length);
- 
+
     return name;
 };

+ 34 - 26
ggml/examples/unity/model_loader.h

@@ -12,46 +12,54 @@
 
 #include "common.h"
 #include "common-ggml.h"
+#include "fairseq2.h"
 
 #include <iostream>
 #include <stdexcept>
 
-
-template <typename T>
-struct fairseq2_model {
-    struct ggml_context *ctx;
-    std::map<std::string, struct ggml_tensor *> tensors;
-
-    T hparams;
-};
-
-template <typename T>
 class model_loader {
 public:
-    void
-    load_ggml_file(const std::string &fname, fairseq2_model<T> &model);
+    virtual ~model_loader() {};
 
-protected:
-    virtual void
-    load_hparams(std::ifstream &fin, T &hparams) = 0;
+    virtual fairseq2_model& alloc_model(ggml_context* ctx) = 0;
+
+    virtual void load_hparams(fairseq2_model& model, std::ifstream &fin) = 0;
+
+    virtual void load_model_weights(fairseq2_model &model, std::ifstream &fin);
 
     virtual std::size_t
-    compute_context_size(T &hparams) = 0;
+    compute_context_size(void *raw_hparams) = 0;
 
     virtual void
-    init_model_tensors(fairseq2_model<T> &model);
+    init_model_tensors(fairseq2_model &model) = 0;
 
 private:
-    bool verify_magic(std::ifstream &fin);
-
-    void
-    init_model(fairseq2_model<T> &model);
-
-    void load_model_weights(std::ifstream &fin, fairseq2_model<T> &model);
-    
-    ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model<T> &model);
+    ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model &model);
 
     // TODO Move these two to helpers
     void load_tensor_value(std::ifstream &fin, ggml_tensor *tensor);
     std::string get_name(std::ifstream &fin);
-};
+};
+
+/// allocate the fairseq2 model and hyperparameters into the ggml context
+template<typename T>
+fairseq2_model& alloc_fairseq2_model(ggml_context* ctx) {
+    auto hparams = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(T))->data;
+    auto& model = (fairseq2_model&)ggml_new_tensor_1d(ctx, GGML_TYPE_I8, sizeof(fairseq2_model))->data;
+
+    model.ctx = ctx;
+    model.hparams = hparams;
+    return model;
+};
+
+std::ifstream open_ggml_file(const char* fname);
+
+template<typename T>
+fairseq2_model& load_fairseq2_ggml_file(ggml_context* ctx, const char* fname) {
+    T loader;
+    fairseq2_model& model = loader.alloc_model(ctx);
+    auto fin = open_ggml_file(fname);
+    loader.load_hparams(model, fin);
+    loader.load_model_weights(model, fin);
+    return model;
+}

+ 55 - 61
ggml/examples/unity/unity.cpp

@@ -3,6 +3,7 @@
 
 #include "common.h"
 #include "common-ggml.h"
+#include "fairseq2.h"
 
 #include <cassert>
 #include <cmath>
@@ -26,14 +27,9 @@ struct unity_hparams {
     float   eps     = 1e-5f;
 };
 
-// layer def
-struct layer_norm_layer {
-    struct ggml_tensor * w;
-    struct ggml_tensor * b;
-};
 
 struct audio_enc_layer {
-    struct layer_norm_layer self_attn_layer_norm;
+    struct LayerNorm self_attn_layer_norm;
 
     struct ggml_tensor * self_attn_linear_k_w;
     struct ggml_tensor * self_attn_linear_k_b;
@@ -48,7 +44,7 @@ struct audio_enc_layer {
     struct ggml_tensor * self_attn_pos_bias_u;
     struct ggml_tensor * self_attn_pos_bias_v;
 
-    struct layer_norm_layer conv_layer_norm;
+    struct LayerNorm conv_layer_norm;
 
     struct ggml_tensor * conv_pointwise_conv1_w;
     struct ggml_tensor * conv_depthwise_conv_w;
@@ -59,21 +55,23 @@ struct audio_enc_layer {
     struct ggml_tensor * conv_batch_norm_num_batches_tracked;
     struct ggml_tensor * conv_pointwise_conv2_w;
 
-    struct layer_norm_layer ffn1_layer_norm;
+    struct LayerNorm ffn1_layer_norm;
     struct ggml_tensor * ffn1_w1;
     struct ggml_tensor * ffn1_b1;
     struct ggml_tensor * ffn1_w2;
     struct ggml_tensor * ffn1_b2;
 
-    struct layer_norm_layer ffn2_layer_norm;
+    struct LayerNorm ffn2_layer_norm;
     struct ggml_tensor * ffn2_w1;
     struct ggml_tensor * ffn2_b1;
     struct ggml_tensor * ffn2_w2;
     struct ggml_tensor * ffn2_b2;
 
-    struct layer_norm_layer final_layer_norm;
+    struct LayerNorm final_layer_norm;
 };
 
+
+
 // struct ggml_tensor * conv_ln;
 // struct ggml_tensor * conv_pool_1d;
 
@@ -86,9 +84,9 @@ struct unity_model {
     struct ggml_tensor * audio_enc_pos_conv_wg;
     struct ggml_tensor * audio_enc_pos_conv_wv;
     struct ggml_tensor * audio_enc_pos_conv_b;
-    struct layer_norm_layer audio_enc_layer_norm;
+    struct LayerNorm audio_enc_layer_norm;
     struct ggml_tensor * audio_enc_pos_enc_w;
-    struct layer_norm_layer layer_norm;
+    struct LayerNorm layer_norm;
     struct ggml_tensor * memory_k;
     struct ggml_tensor * memory_v;
     std::vector<audio_enc_layer> audio_enc_layers;
@@ -100,7 +98,7 @@ struct unity_model {
     // std::vector<adapter_layer> adapter_layers;
 
     // text decoder
-    // std::vector<text_dec_layer> text_dec_layers;
+    std::vector<TransformerDecoderLayer> text_dec_layers;
 
     // unit decoder
     // std::vector<unit_dec_layer> unit_dec_layers;
@@ -196,14 +194,14 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
         // const int n_text_vocab = hparams.n_text_vocab;
         const int kernel_size = 31;
 
-        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm.w
-        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm.b
+        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm.weight
+        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // self_attn_layer_norm.bias
 
         ctx_size += n_audio_enc_layer*(5*n_audio_enc_dim*n_audio_enc_dim*ggml_type_sizef(wtype));         // self_attn_w
         ctx_size += n_audio_enc_layer*(4*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // self_attn_b
 
-        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm.w
-        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm.b
+        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm.weight
+        ctx_size += n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_layer_norm.bias
 
         ctx_size += n_audio_enc_layer*(n_audio_enc_dim*n_audio_enc_dim*2*ggml_type_sizef(wtype));           // conv_pointwise_conv1_w
         ctx_size += n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // conv_batch_norm_w
@@ -212,12 +210,12 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
         ctx_size += n_audio_enc_layer*(n_audio_enc_dim*n_audio_enc_dim*ggml_type_sizef(wtype));           // conv_pointwise_conv2_w
 
         ctx_size += 2 * n_audio_enc_layer * (n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_layer_norm.w
-        ctx_size += 2 * n_audio_enc_layer * (n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_layer_norm.b
+        ctx_size += 2 * n_audio_enc_layer * (n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // ffn{1,2}_layer_norm.bias
         ctx_size += 2 * n_audio_enc_layer * (2 * n_audio_enc_dim * n_audio_enc_ffn_dim * ggml_type_sizef(wtype));  // ffn{1,2}_w{1,2}
         ctx_size += 2 * n_audio_enc_layer * (2 * n_audio_enc_dim * ggml_type_sizef(GGML_TYPE_F32));  // ffn{1,2}_b{1,2}
 
         ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // final_layer_norm.w
-        ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // final_layer_norm.b
+        ctx_size += n_audio_enc_layer*(n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32)); // final_layer_norm.bias
 
         ctx_size += n_ctx*n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // memory_k
         ctx_size += n_ctx*n_audio_enc_layer*n_audio_enc_dim*ggml_type_sizef(GGML_TYPE_F32); // memory_v
@@ -277,23 +275,23 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
         model.tensors["model/enc/pos_conv/w_v"] = model.audio_enc_pos_conv_wv;
         model.tensors["model/enc/pos_conv/b"] = model.audio_enc_pos_conv_b;
 
-        model.audio_enc_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-        model.audio_enc_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-        model.tensors["model/enc/layer_norm/w"] = model.audio_enc_layer_norm.w;
-        model.tensors["model/enc/layer_norm/b"] = model.audio_enc_layer_norm.b;
+        model.audio_enc_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+        model.audio_enc_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+        model.tensors["model/enc/layer_norm/w"] = model.audio_enc_layer_norm.weight;
+        model.tensors["model/enc/layer_norm/b"] = model.audio_enc_layer_norm.bias;
 
-        model.layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
-        model.layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
-        model.tensors["model/layer_norm/w"] = model.layer_norm.w;
-        model.tensors["model/layer_norm/b"] = model.layer_norm.b;
+        model.layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
+        model.layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_feat_dim);
+        model.tensors["model/layer_norm/w"] = model.layer_norm.weight;
+        model.tensors["model/layer_norm/b"] = model.layer_norm.bias;
 
         
 
         for (int i = 0; i < n_audio_enc_layer; ++i) {
             auto & layer = model.audio_enc_layers[i];
 
-            layer.self_attn_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-            layer.self_attn_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.self_attn_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.self_attn_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
             layer.self_attn_linear_k_w   = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
             layer.self_attn_linear_k_b   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
@@ -308,8 +306,8 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             layer.self_attn_pos_bias_u = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim / n_audio_enc_head, n_audio_enc_head);
             layer.self_attn_pos_bias_v = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim / n_audio_enc_head, n_audio_enc_head);
 
-            layer.conv_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-            layer.conv_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.conv_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.conv_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
             layer.conv_pointwise_conv1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, 2*n_audio_enc_dim);
             layer.conv_depthwise_conv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 31, n_audio_enc_dim);
@@ -322,8 +320,8 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
 
             layer.conv_pointwise_conv2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_dim);
 
-            layer.ffn1_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-            layer.ffn1_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.ffn1_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.ffn1_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
             layer.ffn1_w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_ffn_dim);
             layer.ffn1_b1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim);
@@ -331,8 +329,8 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             layer.ffn1_w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim, n_audio_enc_dim);
             layer.ffn1_b2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
-            layer.ffn2_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-            layer.ffn2_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.ffn2_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.ffn2_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
             layer.ffn2_w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_dim, n_audio_enc_ffn_dim);
             layer.ffn2_b1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim);
@@ -340,13 +338,13 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             layer.ffn2_w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_enc_ffn_dim, n_audio_enc_dim);
             layer.ffn2_b2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
-            layer.final_layer_norm.w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
-            layer.final_layer_norm.b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.final_layer_norm.weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
+            layer.final_layer_norm.bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_enc_dim);
 
             // map by name
 
-            model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/w"] = layer.self_attn_layer_norm.w;
-            model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/b"] = layer.self_attn_layer_norm.b;
+            model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/w"] = layer.self_attn_layer_norm.weight;
+            model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_layer_norm/b"] = layer.self_attn_layer_norm.bias;
 
             model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_k/w"] = layer.self_attn_linear_k_w;
             model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_linear_k/b"] = layer.self_attn_linear_k_b;
@@ -360,8 +358,8 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_pos_bias/u"] = layer.self_attn_pos_bias_u;
             model.tensors["model/enc/h" + std::to_string(i) + "/self_attn_pos_bias/v"] = layer.self_attn_pos_bias_v;
 
-            model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/w"]        = layer.conv_layer_norm.w;
-            model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/b"]        = layer.conv_layer_norm.b;
+            model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/w"]        = layer.conv_layer_norm.weight;
+            model.tensors["model/enc/h" + std::to_string(i) + "/conv_layer_norm/b"]        = layer.conv_layer_norm.bias;
 
             model.tensors["model/enc/h" + std::to_string(i) + "/conv_pointwise_conv1/w"] = layer.conv_pointwise_conv1_w;
             model.tensors["model/enc/h" + std::to_string(i) + "/conv_depthwise_conv/w"] = layer.conv_depthwise_conv_w;
@@ -372,22 +370,22 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
             model.tensors["model/enc/h" + std::to_string(i) + "/conv_batch_norm/n"] = layer.conv_batch_norm_num_batches_tracked;
             model.tensors["model/enc/h" + std::to_string(i) + "/conv_pointwise_conv2/w"] = layer.conv_pointwise_conv2_w;
 
-            model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/w"] = layer.ffn1_layer_norm.w;
-            model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/b"] = layer.ffn1_layer_norm.b;
+            model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/w"] = layer.ffn1_layer_norm.weight;
+            model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_layer_norm/b"] = layer.ffn1_layer_norm.bias;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_1/w"] = layer.ffn1_w1;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_1/b"] = layer.ffn1_b1;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_2/w"] = layer.ffn1_w2;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn1_w_2/b"] = layer.ffn1_b2;
 
-            model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/w"] = layer.ffn2_layer_norm.w;
-            model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/b"] = layer.ffn2_layer_norm.b;
+            model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/w"] = layer.ffn2_layer_norm.weight;
+            model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_layer_norm/b"] = layer.ffn2_layer_norm.bias;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_1/w"] = layer.ffn2_w1;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_1/b"] = layer.ffn2_b1;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_2/w"] = layer.ffn2_w2;
             model.tensors["model/enc/h" + std::to_string(i) + "/ffn2_w_2/b"] = layer.ffn2_b2;
 
-            model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/w"] = layer.final_layer_norm.w;
-            model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/b"] = layer.final_layer_norm.b;
+            model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/w"] = layer.final_layer_norm.weight;
+            model.tensors["model/enc/h" + std::to_string(i) + "/final_layer_norm/b"] = layer.final_layer_norm.bias;
         }
     }
 
@@ -467,17 +465,17 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
     return true;
 }
 
-extern "C" ggml_tensor* unity_layer_norm(
+extern "C" ggml_tensor* LayerNorm_forward(
+    const LayerNorm& layer,
     ggml_context* ctx,
     ggml_tensor* cur,
-    const layer_norm_layer& layer,
-    const unity_hparams& hparams
+    float eps
 ) {
-    cur = ggml_norm(ctx, cur, hparams.eps);
+    cur = ggml_norm(ctx, cur, eps);
     return ggml_add(
         ctx,
-        ggml_mul(ctx, ggml_repeat(ctx, layer.w, cur), cur),
-        ggml_repeat(ctx, layer.b, cur)
+        ggml_mul(ctx, ggml_repeat(ctx, layer.weight, cur), cur),
+        ggml_repeat(ctx, layer.bias, cur)
     );
 }
 
@@ -519,12 +517,8 @@ extern "C" ggml_cgraph* unity_audio_encoder_graph(
         struct ggml_tensor * residual = cur;
         const audio_enc_layer layer = model.audio_enc_layers[il];
         // FFN1: layernorm
-        cur = ggml_norm(ctx0, cur, hparams.eps);
-        cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, layer.ffn1_layer_norm.w, cur),
-                    cur),
-                ggml_repeat(ctx0, layer.ffn1_layer_norm.b, cur));
+        cur = LayerNorm_forward(layer.ffn1_layer_norm, ctx0, cur, hparams.eps);
+
         // FFN1: proj
         cur = ggml_mul_mat(ctx0, layer.ffn1_w1, cur);
         cur = ggml_add(ctx0, ggml_repeat(ctx0, layer.ffn1_b1, cur), cur);
@@ -543,9 +537,9 @@ extern "C" ggml_cgraph* unity_audio_encoder_graph(
         cur = ggml_norm(ctx0, cur, hparams.eps);
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
-                    ggml_repeat(ctx0, layer.self_attn_layer_norm.w, cur),
+                    ggml_repeat(ctx0, layer.self_attn_layer_norm.weight, cur),
                     cur),
-                ggml_repeat(ctx0, layer.self_attn_layer_norm.b, cur));
+                ggml_repeat(ctx0, layer.self_attn_layer_norm.bias, cur));
         
         // self_attn: qkv
         struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,

+ 68 - 4
ggml/examples/unity/unity_model_loader.cpp

@@ -12,10 +12,69 @@
 
 #include "unity_model_loader.h"
 
+struct audio_enc_layer {
+    struct LayerNorm self_attn_layer_norm;
 
-void
-unity_model_loader::load_hparams(std::ifstream& fin, unity_hparams& hparams)
+    struct ggml_tensor * self_attn_linear_k_w;
+    struct ggml_tensor * self_attn_linear_k_b;
+    struct ggml_tensor * self_attn_linear_q_w;
+    struct ggml_tensor * self_attn_linear_q_b;
+    struct ggml_tensor * self_attn_linear_v_w;
+    struct ggml_tensor * self_attn_linear_v_b;
+    struct ggml_tensor * self_attn_linear_out_w;
+    struct ggml_tensor * self_attn_linear_out_b;
+    struct ggml_tensor * self_attn_linear_pos_w;
+
+    struct ggml_tensor * self_attn_pos_bias_u;
+    struct ggml_tensor * self_attn_pos_bias_v;
+
+    struct LayerNorm conv_layer_norm;
+
+    struct ggml_tensor * conv_pointwise_conv1_w;
+    struct ggml_tensor * conv_depthwise_conv_w;
+    struct ggml_tensor * conv_batch_norm_w;
+    struct ggml_tensor * conv_batch_norm_b;
+    struct ggml_tensor * conv_batch_norm_running_mean;
+    struct ggml_tensor * conv_batch_norm_running_var;
+    struct ggml_tensor * conv_batch_norm_num_batches_tracked;
+    struct ggml_tensor * conv_pointwise_conv2_w;
+
+    struct LayerNorm ffn1_layer_norm;
+    struct ggml_tensor * ffn1_w1;
+    struct ggml_tensor * ffn1_b1;
+    struct ggml_tensor * ffn1_w2;
+    struct ggml_tensor * ffn1_b2;
+
+    struct LayerNorm ffn2_layer_norm;
+    struct ggml_tensor * ffn2_w1;
+    struct ggml_tensor * ffn2_b1;
+    struct ggml_tensor * ffn2_w2;
+    struct ggml_tensor * ffn2_b2;
+
+    struct LayerNorm final_layer_norm;
+};
+
+
+struct unity_model {
+    unity_hparams* hparams;
+    // audio encoder
+    struct ggml_tensor * post_extract_proj_w;
+    struct ggml_tensor * post_extract_proj_b;
+    struct ggml_tensor * audio_enc_pos_conv_wg;
+    struct ggml_tensor * audio_enc_pos_conv_wv;
+    struct ggml_tensor * audio_enc_pos_conv_b;
+    struct LayerNorm audio_enc_layer_norm;
+    struct ggml_tensor * audio_enc_pos_enc_w;
+    struct LayerNorm layer_norm;
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+    std::vector<audio_enc_layer> audio_enc_layers;
+};
+
+void unity_model_loader::load_hparams(fairseq2_model& model, std::ifstream &fin)
 {
+    auto hparams = (unity_hparams&)model.hparams;
+
     fin.read((char*) &hparams.model_dim, sizeof(hparams.model_dim));
     fin.read((char*) &hparams.w2v2_encoder_config__model_dim, sizeof(hparams.w2v2_encoder_config__model_dim));
     fin.read((char*) &hparams.w2v2_encoder_config__max_seq_len, sizeof(hparams.w2v2_encoder_config__max_seq_len));
@@ -71,13 +130,18 @@ unity_model_loader::load_hparams(std::ifstream& fin, unity_hparams& hparams)
 };
 
 std::size_t
-unity_model_loader::compute_context_size(unity_hparams &hparams)
+unity_model_loader::compute_context_size(void* raw_hparams)
 {
     // TODO
+    auto hparams = (unity_hparams&)raw_hparams;
 };
 
 void
-unity_model_loader::init_model_tensors(fairseq2_model<unity_hparams> &model)
+unity_model_loader::init_model_tensors(fairseq2_model &model)
 {
     // TODO
 };
+
+// extern "C" fairseq2_model<unity_hparams>* unity_model_load2(const char* fname, ggml_context* ctx) {
+//     return nullptr;
+// }

+ 28 - 82
ggml/examples/unity/unity_model_loader.h

@@ -66,7 +66,7 @@ struct unity_hparams {
     float adaptor_dropout_p;
 };
 
-// Methods
+
 
 // Embedding
 std::size_t compute_embed_size(int32_t vocab_size, int32_t dim)
@@ -74,65 +74,6 @@ std::size_t compute_embed_size(int32_t vocab_size, int32_t dim)
     return vocab_size * dim * ggml_type_size(GGML_TYPE_F32);
 };
 
-// Projection
-std::size_t compute_projection_size(int32_t in_dim, int32_t out_dim)
-{
-    return (in_dim * out_dim * ggml_type_size(GGML_TYPE_F32)) // weight
-        + (out_dim * ggml_type_size(GGML_TYPE_F32)); // bias
-};
-
-// LayerNorm
-std::size_t compute_layer_norm_size(int32_t dim)
-{
-    return 2 * dim * ggml_type_size(GGML_TYPE_F32); // weight and bias
-};
-
-// FFN Layer
-
-struct ffn_layer {
-    struct ggml_tensor* layer_norm_w; // model_dim
-    struct ggml_tensor* layer_norm_b; // model_dim
-
-    struct ggml_tensor* inner_proj_w; // ffn_inner_dim x model_dim
-    struct ggml_tensor* inner_proj_b; // ffn_inner_dim
-
-    struct ggml_tensor* output_proj_w; // model_dim x ffn_inner_dim
-    struct ggml_tensor* output_proj_b; // model_dim
-};
-
-std::size_t compute_ffn_layer_size(int32_t dim, int32_t inner_dim)
-{
-    return compute_layer_norm_size(dim)
-        + compute_projection_size(dim, inner_dim)
-        + compute_projection_size(inner_dim, dim);
-};
-
-void init_ffn_layer(
-    ffn_layer *layer,
-    fairseq2_model<unity_hparams> &model_ctx,
-    const std::string &prefix)
-{
-    const auto dim = model_ctx.hparams.nllb_config__model_dim;
-    const auto inner_dim = model_ctx.hparams.nllb_config__ffn_inner_dim;
-    auto ctx = model_ctx.ctx;
-    auto &tensor_map = model_ctx.tensors;
-
-    layer->layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
-    tensor_map[prefix + "_layer_norm.weight"] = layer->layer_norm_w;
-    layer->layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
-    tensor_map[prefix + "_layer_norm.bias"] = layer->layer_norm_b;
-
-    layer->inner_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, inner_dim, dim);
-    tensor_map[prefix + ".inner_proj.weight"] = layer->inner_proj_w;
-    layer->inner_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, inner_dim);
-    tensor_map[prefix + ".inner_proj.bias"] = layer->inner_proj_b;
-
-    layer->output_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, inner_dim);
-    tensor_map[prefix + ".output_proj.weight"] = layer->output_proj_w;
-    layer->output_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
-    tensor_map[prefix + ".output_proj.bias"] = layer->output_proj_b;
-}
-
 // Attention Layer
 
 struct attention_layer {
@@ -152,16 +93,17 @@ struct attention_layer {
 
 std::size_t compute_attention_layer_size(int32_t dim)
 {
-    return compute_layer_norm_size(dim)
-        + 4 * compute_projection_size(dim, dim); // q, k, v, and out
+    return LayerNorm_size(dim)
+        + 4 * Linear_size(dim, dim); // q, k, v, and out
 };
 
 void init_attention_layer(
     attention_layer *layer,
-    fairseq2_model<unity_hparams> &model_ctx,
+    fairseq2_model &model_ctx,
     const std::string &prefix)
 {
-    const auto dim = model_ctx.hparams.nllb_config__model_dim;
+    auto hparams = (unity_hparams&)model_ctx.hparams;
+    const auto dim = hparams.nllb_config__model_dim;
     auto ctx = model_ctx.ctx;
     auto &tensor_map = model_ctx.tensors;
 
@@ -197,23 +139,23 @@ void init_attention_layer(
 struct attention_head {
     struct attention_layer* self_attn; // model_dim
     struct attention_layer* encoder_decoder_attn; // model_dim
-    struct ffn_layer* ffn;
+    struct StandardFeedForwardNetwork* ffn;
 };
 
 std::size_t compute_attention_head_size(int32_t dim, int32_t inner_dim)
 {
-    return 2 * compute_attention_layer_size(dim)
-        + compute_ffn_layer_size(dim, inner_dim);
+    return 2 * compute_attention_layer_size(dim) + StandardFeedForwardNetwork_size(dim, inner_dim);
 };
 
 void init_attention_head(
     attention_head *head,
-    fairseq2_model<unity_hparams> &model_ctx,
+    fairseq2_model &model_ctx,
     const std::string &prefix)
 {
+    auto hparams = (unity_hparams&)model_ctx.hparams;
     init_attention_layer(head->self_attn, model_ctx, prefix + ".self_attn");
     init_attention_layer(head->encoder_decoder_attn, model_ctx, prefix + ".encoder_decoder_attn");
-    init_ffn_layer(head->ffn, model_ctx, prefix + ".ffn");
+    StandardFeedForwardNetwork_init(head->ffn, model_ctx, prefix + ".ffn", hparams.nllb_config__model_dim, hparams.nllb_config__ffn_inner_dim);
 }
 
 // TODO: attention_head_compute_graph
@@ -227,8 +169,9 @@ struct text_decoder {
     struct ggml_tensor* layer_norm_b;
 };
 
-std::size_t compute_context_size(unity_hparams &hparams)
+std::size_t compute_context_size(void* raw_hparams)
 {
+    auto hparams = (unity_hparams&)raw_hparams;
     const auto vocab_size = hparams.nllb_config__vocabulary_size;
     const auto dim = hparams.nllb_config__model_dim;
     const auto inner_dim = hparams.nllb_config__ffn_inner_dim;
@@ -238,17 +181,17 @@ std::size_t compute_context_size(unity_hparams &hparams)
 
     return compute_embed_size(vocab_size, dim)
         + n_layers * compute_attention_head_size(dim, inner_dim)
-        + compute_layer_norm_size(dim)
+        + LayerNorm_size(dim)
         + overhead;
 };
 
 void init_model_tensors(
     text_decoder &model,
-    fairseq2_model<unity_hparams> &model_ctx,
+    fairseq2_model &model_ctx,
     const std::string &prefix)
 {
     const auto ctx = model_ctx.ctx;
-    const auto hparams = model_ctx.hparams;
+    auto hparams = (unity_hparams&)model_ctx.hparams;
     auto tensor_map = model_ctx.tensors;
 
     const auto vocab_size = hparams.nllb_config__vocabulary_size;
@@ -277,16 +220,19 @@ void init_model_tensors(
 };
 
 
+class unity_model_loader: public model_loader {
+    public:
+    fairseq2_model& alloc_model(ggml_context* ctx) {
+        return alloc_fairseq2_model<unity_hparams>(ctx);
+    };
 
-// Model
-class unity_model_loader: public model_loader<unity_hparams> {
-protected:
-    void
-    load_hparams(std::ifstream &fin, unity_hparams &hparams);
+    void load_hparams(fairseq2_model& model, std::ifstream &fin);
 
-    std::size_t
-    compute_context_size(unity_hparams &hparams) = 0;
+    std::size_t compute_context_size(void* raw_hparams);
 
-    void
-    init_model_tensors(fairseq2_model<unity_hparams> &model);
+    void init_model_tensors(fairseq2_model &model);
 };
+
+extern "C" fairseq2_model& load_unity_ggml_file(ggml_context* ctx, const char* fname) {
+    return load_fairseq2_ggml_file<unity_model_loader>(ctx, fname);
+}

+ 12 - 3
ggml/ggml.py

@@ -94,9 +94,7 @@ def _pad_shape(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
 
 
 def from_numpy(ctx: ggml_context_p, array: np.ndarray) -> ggml_tensor_p:
-    tensor_p = ggml_new_tensor(
-        ctx, from_numpy_dtype(array.dtype), 1, GgmlShape()
-    )
+    tensor_p = ggml_new_tensor(ctx, from_numpy_dtype(array.dtype), 1, GgmlShape())
     tensor_p.contents.n_dims = array.ndim
     tensor_p.contents.data = array.ctypes.data_as(ctypes.c_void_p)
     tensor_p.contents.ne = GgmlShape(*_pad_shape(array.shape))
@@ -191,6 +189,17 @@ def unity_model_load(model_file: Path) -> Tuple[NativeObj, NativeObj]:
     return model, vocab
 
 
+lib.load_unity_ggml_file.argtypes = [ctypes.c_char_p, ggml_context_p]
+lib.load_unity_ggml_file.restype = ctypes.c_void_p
+
+
+def load_unity_ggml_file(model_file: Path) -> ctypes.c_void_p:
+    model = UnityModel()
+    return lib.load_unity_ggml_file(
+        ctypes.create_string_buffer(str(model_file).encode("utf-8")), ctx
+    )
+
+
 lib.unity_audio_encoder_graph.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
 lib.unity_audio_encoder_graph.restype = ctypes.POINTER(ggml_cgraph)
 

+ 1 - 0
ggml/src/CMakeLists.txt

@@ -249,6 +249,7 @@ endif()
 add_library(${TARGET}
     ggml.c
     ggml-alloc.c
+    ../examples/unity/fairseq2.cpp
     ../examples/unity/model_loader.cpp
     ../examples/unity/unity_model_loader.cpp
     ../examples/unity/unity.cpp

+ 8 - 0
ggml/test_unity_cpp.py

@@ -145,3 +145,11 @@ def test_unity_model_load(ctx: Ctx) -> None:
         expected = map(float, expected_raw.split(","))
         assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
 
+
+# def test_unity_model_load2(ctx: Ctx) -> None:
+#     model = ggml.unity_model_load(
+#         UNITY_MODELS / "unity-large/ggml-model.bin"
+#     )
+#     print(model, vocab)
+#
+#