Эх сурвалжийг харах

start working out the loader api

Naji El Hachem 1 жил өмнө
parent
commit
853b53bad0

+ 10 - 0
ggml/examples/unity/CMakeLists.txt

@@ -3,6 +3,16 @@
 
 
 set(TEST_TARGET unity)
 set(TEST_TARGET unity)
 add_executable(${TEST_TARGET} unity.cpp)
 add_executable(${TEST_TARGET} unity.cpp)
+
+target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
+
+target_sources(${TEST_TARGET}
+    PRIVATE
+        model_loader.cpp
+        unity_model_loader.cpp
+        unity_text_decoder.cpp
+)
+
 target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
 
 
 #
 #

+ 82 - 0
ggml/examples/unity/buffered_ggml_writer.py

@@ -0,0 +1,82 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import struct
+from io import BufferedWriter
+
+import torch
+
+from ggml.examples.unity.type_utils import to_ctype
+
+
+class BufferedGGMLWriter:
+    buffer: BufferedWriter
+
+    def __init__(self, buffer: BufferedWriter) -> None:
+        self.buffer = buffer
+
+    def write_magic_hex(self) -> None:
+        """Write GGML Magic Number to internal buffer.
+        This should be called at the start of your convert process.
+        """
+        self.buffer.write(struct.pack("i", 0x67676d6c))
+
+    def write_hparams(self, hparams: dict) -> None:
+        """Write hyper parameters to internal buffer.
+
+        :params hparams:
+            flattened dict containing model's hyper parameters.
+        """
+        for key in hparams.keys():
+            try:
+                value = hparams[key]
+                ctype, cvalue = to_ctype(value)
+                self.buffer.write(struct.pack(ctype, cvalue))
+            except ValueError as e:
+                # TODO use logger
+                print(f"[Warning] {e}. Skipping config for key {key}")
+                continue
+
+    def write_state_dict(self, state_dict: dict) -> None:
+        """Write pytorch state dict to internal buffer.
+
+        :paras state_dict:
+            state dict returned by pytorch model
+        """
+        for key, value in state_dict.items():
+            self.write_string(key)
+            self.write_tensor(value)
+
+    def write_string(self, value: str) -> None:
+        """Write string in utf-8 format to internal buffer.
+
+        :params value:
+            string value to dump.
+        """
+        str_ = value.encode("utf-8")
+        self.buffer.write(struct.pack("i", len(str_)))
+        self.buffer.write(str_)
+
+    def write_tensor(self, value: torch.Tensor) -> None:
+        """Write torch tensor in ggml format to internal buffer.
+
+        First we save the number of dimensions and the dtype.
+        Then we save the data as numpy array.
+
+        :params value:
+            Tensor to dump.
+        """
+        data = value.squeeze().numpy()
+        n_dims = len(data.shape)
+
+        # TODO: Convert to fp16 when necessary!
+        ftype = 0
+
+        self.buffer.write(struct.pack("ii", n_dims, ftype))
+        for i in range(n_dims):
+            self.buffer.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+
+        data.tofile(self.buffer)

+ 157 - 0
ggml/examples/unity/fairseq2_to_ggml_converter.py

@@ -0,0 +1,157 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import dataclasses
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from fairseq2.assets import AssetCard
+
+from ggml.examples.unity.buffered_ggml_writer import BufferedGGMLWriter
+from ggml.examples.unity.type_utils import get_cpp_type
+from seamless_communication.models.unity import (
+    load_unity_config,
+    load_unity_model
+)
+
+Preprocessor = Callable[[Any], Any]
+
+
+class Fairseq2ToGGMLConverter:
+    """Converter from fairseq2 format to GGML format"""
+
+    config_preprocessor: Preprocessor
+    nested_params_separtor: str
+
+    def __init__(
+        self,
+        nested_params_separtor: str = ".",
+        config_preprocessor: Optional[Preprocessor] = None,
+    ) -> None:
+        """
+        :param nested_params_separtor:
+            string separator used when flattening nested hparams
+        :param config_preprocessor:
+            Preprocessor used for config/hparams values
+        """
+        self.config_preprocessor = config_preprocessor or (lambda v: v)
+        self.nested_params_separtor = nested_params_separtor
+
+    def convert_to_ggml(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        output_file: Path
+    ) -> None:
+        """Load model from card, convert to ggml format and save result.
+
+        :param model_name_or_card:
+            The name or asset card of the model to load.
+        :param output_file:
+            File path to store binary output.
+        """
+        hparams = self._load_config(model_name_or_card)
+        state_dict = self._load_state_dict(model_name_or_card)
+
+        buffer = output_file.open("wb")
+
+        ggml_writer = BufferedGGMLWriter(buffer)
+
+        ggml_writer.write_magic_hex()
+        ggml_writer.write_hparams(hparams)
+        ggml_writer.write_state_dict(state_dict)
+
+        buffer.close()
+
+    def generate_hparams_struct(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        struct_name: str,
+    ) -> str:
+        """Transform config to c++ struct
+
+        :param model_name_or_card:
+            The name or asset card of the model to load.
+        :param output_file:
+            File path to store binary output.
+        """
+        hparams = self._load_config(model_name_or_card)
+        result = f"struct {struct_name} {{\n"
+        for key, value in hparams.items():
+            result = f"{result}\t{get_cpp_type(value)} {key};\n"
+
+        result = f"{result}}};"
+
+        return result
+
+    def _load_config(
+        self,
+        model_name_or_card: Union[str, AssetCard]
+    ) -> dict:
+        """Load model config and transform it to flattened dict.
+
+        :param model_name_or_card:
+            The name or asset card of the model to load.
+
+        :returns:
+            Flat dictionnary containing all hyper parameters.
+        """
+        model_config = load_unity_config(model_name_or_card)
+        model_config_dict = dataclasses.asdict(model_config)
+        flattened = self.__flatten(model_config_dict)
+
+        return flattened
+
+    def _load_state_dict(
+        self,
+        model_name_or_card: Union[str, AssetCard]
+    ) -> dict:
+        """Load model and return state dict.
+
+        :param model_name_or_card:
+            The name or asset card of the model to load.
+
+        :returns:
+            State dict returned by pytorch model.
+        """
+        model = load_unity_model(model_name_or_card)
+
+        return model.state_dict()
+
+    def __flatten(
+        self,
+        config: dict
+    ) -> dict:
+        """Flatten nested dictionnary
+
+        :param config:
+            nested dictionnary containing model config.
+
+        :returns:
+            flat dictionnary
+        """
+        return self.__flatten_recursive(config, '')
+
+    def __flatten_recursive(
+        self,
+        config: dict,
+        prefix: str
+    ) -> dict:
+        """Recursive method used to flatten nested dictionnary"""
+        result = {}
+        for key in config:
+            new_key = f"{prefix}{key}"
+            if isinstance(config[key], dict):
+                nested_result = self.__flatten_recursive(
+                    config[key],
+                    f"{new_key}{self.nested_params_separtor}"
+                )
+                result.update(nested_result)
+            else:
+                new_config = self.config_preprocessor(config[key])
+                if new_config is not None:
+                    result[new_key] = config[key]
+
+        return result

+ 154 - 0
ggml/examples/unity/model_loader.cpp

@@ -0,0 +1,154 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <iostream>
+#include <stdexcept>
+
+#include "ggml/examples/unity/model_loader.h"
+
+
+template<typename T>
+void
+model_loader<T>::load_ggml_file(const std::string &fname, fairseq2_model<T> &model)
+{
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        throw std::invalid_argument("failed to open file."); // TODO Merge error message.
+    }
+
+    if (!verify_magic(fin)) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+        throw std::invalid_argument("failed to open file."); // TODO Merge error message.
+    }
+
+    load_hparams(fin, model.hparams);
+    init_model(model);
+    load_model_weights(fin, model);
+};
+
+template<typename T>
+bool 
+model_loader<T>::verify_magic(std::ifstream &fin)
+{
+    uint32_t magic;
+    fin.read((char *) &magic, sizeof(magic));
+
+    return magic == GGML_FILE_MAGIC;
+};
+
+template<typename T>
+void
+model_loader<T>::init_model(fairseq2_model<T> &model)
+{
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ compute_context_size(model.hparams),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ false,
+    };
+
+    model.ctx = ggml_init(params);
+    if (!model.ctx)
+        throw std::runtime_error("ggml_init() failed.");
+
+    init_model_tensors(model);
+};
+
+template<typename T>
+void
+model_loader<T>::load_model_weights(std::ifstream &fin, fairseq2_model<T> &model)
+{
+    size_t total_size = 0;
+    while (!fin.eof()) {
+        auto tensor = next_tensor(fin, model);
+        load_tensor_value(fin, tensor);
+        total_size += ggml_nbytes(tensor);
+    }
+
+    printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+};
+
+template<typename T>
+ggml_tensor *
+model_loader<T>::next_tensor(std::ifstream &fin, fairseq2_model<T> &model)
+{
+    auto name = get_name(fin);
+    std::cout << "loading tensor: " << name << std::endl;
+   
+    if (model.tensors.find(name) == model.tensors.end()) {
+        fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
+        throw std::invalid_argument("failed to open file."); // TODO Merge error message.
+    }
+
+    return model.tensors[name];
+};
+
+template<typename T>
+void
+model_loader<T>::load_tensor_value(std::ifstream &fin, ggml_tensor *tensor)
+{
+    int32_t n_dims;
+    int32_t ttype;
+
+    fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+    fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));
+
+    int32_t nelements = 1;
+    int32_t ne[3] = {1, 1, 1};
+    for (int i = 0; i < n_dims; ++i) {
+        fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+        nelements *= ne[i];
+    }
+
+    if (ggml_nelements(tensor) != nelements) {
+        std::cout << ggml_nelements(tensor) << std::endl;
+        std::cout << nelements << std::endl;
+        throw std::runtime_error("tensor has wrong size in model file.");
+    }
+
+    if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+        fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                __func__, (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+        throw std::runtime_error("tensor has wrong shape in file."); // TODO Merge error message.
+    }
+
+    // for debugging
+    if (0) {
+        printf("%[%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+    }
+
+    const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+    if ((nelements * bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+        fprintf(stderr, "%s: tensor has wrong size in model file: got %zu, expected %zu\n",
+                __func__, ggml_nbytes(tensor), nelements * bpe);
+        throw std::runtime_error("tensor has wrong size in file."); // TODO Merge error message.
+    }
+
+    fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+};
+
+
+template<typename T>
+std::string
+model_loader<T>::get_name(std::ifstream& fin)
+{
+    int32_t length;
+    fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+    std::string name(length, 0);
+    fin.read(&name[0], length);
+ 
+    return name;
+};

+ 57 - 0
ggml/examples/unity/model_loader.h

@@ -0,0 +1,57 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <iostream>
+#include <stdexcept>
+
+
+template <typename T>
+struct fairseq2_model {
+    struct ggml_context *ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+
+    T hparams;
+};
+
+template <typename T>
+class model_loader {
+public:
+    void
+    load_ggml_file(const std::string &fname, fairseq2_model<T> &model);
+
+protected:
+    virtual void
+    load_hparams(std::ifstream &fin, T &hparams) = 0;
+
+    virtual std::size_t
+    compute_context_size(T &hparams) = 0;
+
+    virtual void
+    init_model_tensors(fairseq2_model<T> &model);
+
+private:
+    bool verify_magic(std::ifstream &fin);
+
+    void
+    init_model(fairseq2_model<T> &model);
+
+    void load_model_weights(std::ifstream &fin, fairseq2_model<T> &model);
+    
+    ggml_tensor * next_tensor(std::ifstream &fin, fairseq2_model<T> &model);
+
+    // TODO Move these two to helpers
+    void load_tensor_value(std::ifstream &fin, ggml_tensor *tensor);
+    std::string get_name(std::ifstream &fin);
+};

+ 58 - 0
ggml/examples/unity/type_utils.py

@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from enum import Enum
+from typing import Any, Tuple
+
+
+def to_ctype(value: Any) -> Tuple[str, Any]:
+    """Transform python type to ctype.
+
+    :params value:
+        value to cast into ctype
+
+    :returns:
+        A tuple of ctype and cvalue.
+    """
+    if isinstance(value, int):
+        return ("i", value)
+    if isinstance(value, float):
+        return ("f", value)
+    if isinstance(value, bool):
+        return ('?', value)
+    if isinstance(value, Enum):
+        return ('i', value.value)
+
+    raise ValueError(f"Unsupported type {type(value)}")
+
+
+def get_cpp_type(value) -> str:
+    """Return equivalent cpp type in string format
+
+    :params value:
+        value to cast into ctype
+
+    :returns:
+        str containing cpp type
+    """
+    # used to have compatibility between types
+    try:
+        ctype, _ = to_ctype(value)
+    except ValueError as e:
+        return f"Error[{e}]"
+
+    if ctype == "i":
+        return "int32_t"
+    if ctype == "f":
+        return "float"
+    if ctype == "?":
+        return "bool"
+
+    raise RuntimeError(
+        f"Should not have reached this part."
+        f"Missing cpp translation for {ctype}"
+    )

+ 83 - 0
ggml/examples/unity/unity_model_loader.cpp

@@ -0,0 +1,83 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include "ggml/examples/unity/unity_model_loader.h"
+
+
+void
+unity_model_loader::load_hparams(std::ifstream& fin, unity_hparams& hparams)
+{
+    fin.read((char*) &hparams.model_dim, sizeof(hparams.model_dim));
+    fin.read((char*) &hparams.w2v2_encoder_config__model_dim, sizeof(hparams.w2v2_encoder_config__model_dim));
+    fin.read((char*) &hparams.w2v2_encoder_config__max_seq_len, sizeof(hparams.w2v2_encoder_config__max_seq_len));
+    fin.read((char*) &hparams.w2v2_encoder_config__feature_dim, sizeof(hparams.w2v2_encoder_config__feature_dim));
+    fin.read((char*) &hparams.w2v2_encoder_config__use_fbank, sizeof(hparams.w2v2_encoder_config__use_fbank));
+    fin.read((char*) &hparams.w2v2_encoder_config__first_pass_dropout_p, sizeof(hparams.w2v2_encoder_config__first_pass_dropout_p));
+    fin.read((char*) &hparams.w2v2_encoder_config__layer_norm_features, sizeof(hparams.w2v2_encoder_config__layer_norm_features));
+    fin.read((char*) &hparams.w2v2_encoder_config__feature_extractor_bias, sizeof(hparams.w2v2_encoder_config__feature_extractor_bias));
+    fin.read((char*) &hparams.w2v2_encoder_config__feature_extractor_layer_norm_convs, sizeof(hparams.w2v2_encoder_config__feature_extractor_layer_norm_convs));
+    fin.read((char*) &hparams.w2v2_encoder_config__feature_grad_scale, sizeof(hparams.w2v2_encoder_config__feature_grad_scale));
+    fin.read((char*) &hparams.w2v2_encoder_config__num_fbank_channels, sizeof(hparams.w2v2_encoder_config__num_fbank_channels));
+    fin.read((char*) &hparams.w2v2_encoder_config__fbank_stride, sizeof(hparams.w2v2_encoder_config__fbank_stride));
+    fin.read((char*) &hparams.w2v2_encoder_config__sample_fbank_every_k, sizeof(hparams.w2v2_encoder_config__sample_fbank_every_k));
+    fin.read((char*) &hparams.w2v2_encoder_config__pos_encoder_depth, sizeof(hparams.w2v2_encoder_config__pos_encoder_depth));
+    fin.read((char*) &hparams.w2v2_encoder_config__pos_conv_kernel_size, sizeof(hparams.w2v2_encoder_config__pos_conv_kernel_size));
+    fin.read((char*) &hparams.w2v2_encoder_config__num_pos_conv_groups, sizeof(hparams.w2v2_encoder_config__num_pos_conv_groups));
+    fin.read((char*) &hparams.w2v2_encoder_config__use_conformer, sizeof(hparams.w2v2_encoder_config__use_conformer));
+    fin.read((char*) &hparams.w2v2_encoder_config__num_encoder_layers, sizeof(hparams.w2v2_encoder_config__num_encoder_layers));
+    fin.read((char*) &hparams.w2v2_encoder_config__num_encoder_attn_heads, sizeof(hparams.w2v2_encoder_config__num_encoder_attn_heads));
+    fin.read((char*) &hparams.w2v2_encoder_config__ffn_inner_dim, sizeof(hparams.w2v2_encoder_config__ffn_inner_dim));
+    fin.read((char*) &hparams.w2v2_encoder_config__dropout_p, sizeof(hparams.w2v2_encoder_config__dropout_p));
+    fin.read((char*) &hparams.w2v2_encoder_config__attn_dropout_p, sizeof(hparams.w2v2_encoder_config__attn_dropout_p));
+    fin.read((char*) &hparams.w2v2_encoder_config__layer_drop_p, sizeof(hparams.w2v2_encoder_config__layer_drop_p));
+    fin.read((char*) &hparams.w2v2_encoder_config__norm_order, sizeof(hparams.w2v2_encoder_config__norm_order));
+    fin.read((char*) &hparams.w2v2_encoder_config__depthwise_conv_kernel_size, sizeof(hparams.w2v2_encoder_config__depthwise_conv_kernel_size));
+    fin.read((char*) &hparams.nllb_config__model_dim, sizeof(hparams.nllb_config__model_dim));
+    fin.read((char*) &hparams.nllb_config__max_seq_len, sizeof(hparams.nllb_config__max_seq_len));
+    fin.read((char*) &hparams.nllb_config__vocabulary_size, sizeof(hparams.nllb_config__vocabulary_size));
+    fin.read((char*) &hparams.nllb_config__pad_idx, sizeof(hparams.nllb_config__pad_idx));
+    fin.read((char*) &hparams.nllb_config__num_encoder_layers, sizeof(hparams.nllb_config__num_encoder_layers));
+    fin.read((char*) &hparams.nllb_config__num_decoder_layers, sizeof(hparams.nllb_config__num_decoder_layers));
+    fin.read((char*) &hparams.nllb_config__num_encoder_attn_heads, sizeof(hparams.nllb_config__num_encoder_attn_heads));
+    fin.read((char*) &hparams.nllb_config__num_decoder_attn_heads, sizeof(hparams.nllb_config__num_decoder_attn_heads));
+    fin.read((char*) &hparams.nllb_config__ffn_inner_dim, sizeof(hparams.nllb_config__ffn_inner_dim));
+    fin.read((char*) &hparams.nllb_config__dropout_p, sizeof(hparams.nllb_config__dropout_p));
+    fin.read((char*) &hparams.t2u_config__model_dim, sizeof(hparams.t2u_config__model_dim));
+    fin.read((char*) &hparams.t2u_config__unit_max_seq_len, sizeof(hparams.t2u_config__unit_max_seq_len));
+    fin.read((char*) &hparams.t2u_config__unit_vocabulary_size, sizeof(hparams.t2u_config__unit_vocabulary_size));
+    fin.read((char*) &hparams.t2u_config__unit_pad_idx, sizeof(hparams.t2u_config__unit_pad_idx));
+    fin.read((char*) &hparams.t2u_config__num_encoder_layers, sizeof(hparams.t2u_config__num_encoder_layers));
+    fin.read((char*) &hparams.t2u_config__num_decoder_layers, sizeof(hparams.t2u_config__num_decoder_layers));
+    fin.read((char*) &hparams.t2u_config__num_encoder_attn_heads, sizeof(hparams.t2u_config__num_encoder_attn_heads));
+    fin.read((char*) &hparams.t2u_config__num_decoder_attn_heads, sizeof(hparams.t2u_config__num_decoder_attn_heads));
+    fin.read((char*) &hparams.t2u_config__ffn_inner_dim, sizeof(hparams.t2u_config__ffn_inner_dim));
+    fin.read((char*) &hparams.t2u_config__dropout_p, sizeof(hparams.t2u_config__dropout_p));
+    fin.read((char*) &hparams.use_text_encoder, sizeof(hparams.use_text_encoder));
+    fin.read((char*) &hparams.use_conformer_adaptor, sizeof(hparams.use_conformer_adaptor));
+    fin.read((char*) &hparams.num_adaptor_layers, sizeof(hparams.num_adaptor_layers));
+    fin.read((char*) &hparams.adaptor_kernel_size, sizeof(hparams.adaptor_kernel_size));
+    fin.read((char*) &hparams.adaptor_stride, sizeof(hparams.adaptor_stride));
+    fin.read((char*) &hparams.adaptor_layer_norm, sizeof(hparams.adaptor_layer_norm));
+    fin.read((char*) &hparams.adaptor_dropout_p, sizeof(hparams.adaptor_dropout_p));
+};
+
+std::size_t
+unity_model_loader::compute_context_size(unity_hparams &hparams)
+{
+    // TODO
+};
+
+void
+unity_model_loader::init_model_tensors(fairseq2_model<unity_hparams> &model)
+{
+    // TODO
+};

+ 290 - 0
ggml/examples/unity/unity_model_loader.h

@@ -0,0 +1,290 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <vector>
+#include "ggml/examples/unity/model_loader.h"
+
+
+// TODO Merge with Ning implementation
+struct unity_hparams {
+    int32_t model_dim;
+    int32_t w2v2_encoder_config__model_dim;
+    int32_t w2v2_encoder_config__max_seq_len;
+    int32_t w2v2_encoder_config__feature_dim;
+    int32_t w2v2_encoder_config__use_fbank;
+    float w2v2_encoder_config__first_pass_dropout_p;
+    int32_t w2v2_encoder_config__layer_norm_features;
+    int32_t w2v2_encoder_config__feature_extractor_bias;
+    int32_t w2v2_encoder_config__feature_extractor_layer_norm_convs;
+    int32_t w2v2_encoder_config__feature_grad_scale;
+    int32_t w2v2_encoder_config__num_fbank_channels;
+    int32_t w2v2_encoder_config__fbank_stride;
+    int32_t w2v2_encoder_config__sample_fbank_every_k;
+    int32_t w2v2_encoder_config__pos_encoder_depth;
+    int32_t w2v2_encoder_config__pos_conv_kernel_size;
+    int32_t w2v2_encoder_config__num_pos_conv_groups;
+    int32_t w2v2_encoder_config__use_conformer;
+    int32_t w2v2_encoder_config__num_encoder_layers;
+    int32_t w2v2_encoder_config__num_encoder_attn_heads;
+    int32_t w2v2_encoder_config__ffn_inner_dim;
+    float w2v2_encoder_config__dropout_p;
+    float w2v2_encoder_config__attn_dropout_p;
+    float w2v2_encoder_config__layer_drop_p;
+    int32_t w2v2_encoder_config__norm_order;
+    int32_t w2v2_encoder_config__depthwise_conv_kernel_size;
+    int32_t nllb_config__model_dim;
+    int32_t nllb_config__max_seq_len;
+    int32_t nllb_config__vocabulary_size;
+    int32_t nllb_config__pad_idx;
+    int32_t nllb_config__num_encoder_layers;
+    int32_t nllb_config__num_decoder_layers;
+    int32_t nllb_config__num_encoder_attn_heads;
+    int32_t nllb_config__num_decoder_attn_heads;
+    int32_t nllb_config__ffn_inner_dim;
+    float nllb_config__dropout_p;
+    int32_t t2u_config__model_dim;
+    int32_t t2u_config__unit_max_seq_len;
+    int32_t t2u_config__unit_vocabulary_size;
+    int32_t t2u_config__unit_pad_idx;
+    int32_t t2u_config__num_encoder_layers;
+    int32_t t2u_config__num_decoder_layers;
+    int32_t t2u_config__num_encoder_attn_heads;
+    int32_t t2u_config__num_decoder_attn_heads;
+    int32_t t2u_config__ffn_inner_dim;
+    float t2u_config__dropout_p;
+    int32_t use_text_encoder;
+    int32_t use_conformer_adaptor;
+    int32_t num_adaptor_layers;
+    int32_t adaptor_kernel_size;
+    int32_t adaptor_stride;
+    int32_t adaptor_layer_norm;
+    float adaptor_dropout_p;
+};
+
+// Methods
+
+// Embedding
+std::size_t compute_embed_size(int32_t vocab_size, int32_t dim)
+{
+    return vocab_size * dim * ggml_type_size(GGML_TYPE_F32);
+};
+
+// Projection
+std::size_t compute_projection_size(int32_t in_dim, int32_t out_dim)
+{
+    return (in_dim * out_dim * ggml_type_size(GGML_TYPE_F32)) // weight
+        + (out_dim * ggml_type_size(GGML_TYPE_F32)); // bias
+};
+
+// LayerNorm
+std::size_t compute_layer_norm_size(int32_t dim)
+{
+    return 2 * dim * ggml_type_size(GGML_TYPE_F32); // weight and bias
+};
+
+// FFN Layer
+
+struct ffn_layer {
+    struct ggml_tensor* layer_norm_w; // model_dim
+    struct ggml_tensor* layer_norm_b; // model_dim
+
+    struct ggml_tensor* inner_proj_w; // ffn_inner_dim x model_dim
+    struct ggml_tensor* inner_proj_b; // ffn_inner_dim
+
+    struct ggml_tensor* output_proj_w; // model_dim x ffn_inner_dim
+    struct ggml_tensor* output_proj_b; // model_dim
+};
+
+std::size_t compute_ffn_layer_size(int32_t dim, int32_t inner_dim)
+{
+    return compute_layer_norm_size(dim)
+        + compute_projection_size(dim, inner_dim)
+        + compute_projection_size(inner_dim, dim);
+};
+
+void init_ffn_layer(
+    ffn_layer *layer,
+    fairseq2_model<unity_hparams> &model_ctx,
+    const std::string &prefix)
+{
+    const auto dim = model_ctx.hparams.nllb_config__model_dim;
+    const auto inner_dim = model_ctx.hparams.nllb_config__ffn_inner_dim;
+    auto ctx = model_ctx.ctx;
+    auto &tensor_map = model_ctx.tensors;
+
+    layer->layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + "_layer_norm.weight"] = layer->layer_norm_w;
+    layer->layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + "_layer_norm.bias"] = layer->layer_norm_b;
+
+    layer->inner_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, inner_dim, dim);
+    tensor_map[prefix + ".inner_proj.weight"] = layer->inner_proj_w;
+    layer->inner_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, inner_dim);
+    tensor_map[prefix + ".inner_proj.bias"] = layer->inner_proj_b;
+
+    layer->output_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, inner_dim);
+    tensor_map[prefix + ".output_proj.weight"] = layer->output_proj_w;
+    layer->output_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + ".output_proj.bias"] = layer->output_proj_b;
+}
+
+// Attention Layer
+
+struct attention_layer {
+    struct ggml_tensor* layer_norm_w; // model_dim
+    struct ggml_tensor* layer_norm_b; // model_dim
+
+    struct ggml_tensor* q_proj_w; // model_dim x model_dim
+    struct ggml_tensor* q_proj_b; // model_dim
+    struct ggml_tensor* k_proj_w; // model_dim x model_dim
+    struct ggml_tensor* k_proj_b; // model_dim
+    struct ggml_tensor* v_proj_w; // model_dim x model_dim
+    struct ggml_tensor* v_proj_b; // model_dim
+
+    struct ggml_tensor* output_proj_w; // model_dim x model_dim
+    struct ggml_tensor* output_proj_b; // model_dim
+};
+
+std::size_t compute_attention_layer_size(int32_t dim)
+{
+    return compute_layer_norm_size(dim)
+        + 4 * compute_projection_size(dim, dim); // q, k, v, and out
+};
+
+void init_attention_layer(
+    attention_layer *layer,
+    fairseq2_model<unity_hparams> &model_ctx,
+    const std::string &prefix)
+{
+    const auto dim = model_ctx.hparams.nllb_config__model_dim;
+    auto ctx = model_ctx.ctx;
+    auto &tensor_map = model_ctx.tensors;
+
+    layer->layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + "_layer_norm.weight"] = layer->layer_norm_w;
+    layer->layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + "_layer_norm.bias"] = layer->layer_norm_b;
+
+    layer->q_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, dim);
+    tensor_map[prefix + ".q_proj.weight"] = layer->q_proj_w;
+    layer->q_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + ".q_proj.bias"] = layer->q_proj_b;
+
+    layer->k_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, dim);
+    tensor_map[prefix + ".k_proj.weight"] = layer->k_proj_w;
+    layer->k_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + ".k_proj.bias"] = layer->k_proj_b;
+
+    layer->v_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, dim);
+    tensor_map[prefix + ".v_proj.weight"] = layer->v_proj_w;
+    layer->v_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + ".v_proj.bias"] = layer->v_proj_b;
+
+    layer->output_proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, dim);
+    tensor_map[prefix + ".output_proj.weight"] = layer->output_proj_w;
+    layer->output_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map[prefix + ".output_proj.bias"] = layer->output_proj_b;
+}
+
+
+// Attention Head
+
+struct attention_head {
+    struct attention_layer* self_attn; // model_dim
+    struct attention_layer* encoder_decoder_attn; // model_dim
+    struct ffn_layer* ffn;
+};
+
+std::size_t compute_attention_head_size(int32_t dim, int32_t inner_dim)
+{
+    return 2 * compute_attention_layer_size(dim)
+        + compute_ffn_layer_size(dim, inner_dim);
+};
+
+void init_attention_head(
+    attention_head *head,
+    fairseq2_model<unity_hparams> &model_ctx,
+    const std::string &prefix)
+{
+    init_attention_layer(head->self_attn, model_ctx, prefix + ".self_attn");
+    init_attention_layer(head->encoder_decoder_attn, model_ctx, prefix + ".encoder_decoder_attn");
+    init_ffn_layer(head->ffn, model_ctx, prefix + ".ffn");
+}
+
+// Text Decoder
+
+struct text_decoder {
+    struct ggml_tensor* frontend_embed_w; // vocab_size x model_dim
+    std::vector<attention_head*> multi_head;
+    struct ggml_tensor* layer_norm_w;
+    struct ggml_tensor* layer_norm_b;
+};
+
+std::size_t compute_context_size(unity_hparams &hparams)
+{
+    const auto vocab_size = hparams.nllb_config__vocabulary_size;
+    const auto dim = hparams.nllb_config__model_dim;
+    const auto inner_dim = hparams.nllb_config__ffn_inner_dim;
+    const auto n_layers = hparams.nllb_config__num_decoder_layers;
+
+    const auto overhead = (6 + 12 * n_layers) * 512; // TODO Find out what this is.
+
+    return compute_embed_size(vocab_size, dim)
+        + n_layers * compute_attention_head_size(dim, inner_dim)
+        + compute_layer_norm_size(dim)
+        + overhead;
+};
+
+void init_model_tensors(
+    text_decoder &model,
+    fairseq2_model<unity_hparams> &model_ctx,
+    const std::string &prefix)
+{
+    const auto ctx = model_ctx.ctx;
+    const auto hparams = model_ctx.hparams;
+    auto tensor_map = model_ctx.tensors;
+
+    const auto vocab_size = hparams.nllb_config__vocabulary_size;
+    const auto dim = hparams.nllb_config__model_dim;
+    const auto n_layers = hparams.nllb_config__num_decoder_layers;
+
+    // This can be simplified by adding syntax sugar
+
+    // frontend
+    model.frontend_embed_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, vocab_size, dim);
+    tensor_map["text_decoder_frontend.embed.weight"] = model.frontend_embed_w;
+
+    // layers
+    model.multi_head.resize(n_layers);
+    for (int i = 0; i < n_layers; ++i) {
+        auto head = model.multi_head[i];
+        auto prefix = "text_decoder.layers." + std::to_string(i);
+        init_attention_head(head, model_ctx, prefix);
+    }
+
+    // layer_norm
+    model.layer_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map["text_decoder.layer_norm.weight"] = model.layer_norm_w;
+    model.layer_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim);
+    tensor_map["text_decoder.layer_norm.bias"] = model.layer_norm_b;
+};
+
+
+
+// Model
+class unity_model_loader: public model_loader<unity_hparams> {
+protected:
+    void
+    load_hparams(std::ifstream &fin, unity_hparams &hparams);
+
+    std::size_t
+    compute_context_size(unity_hparams &hparams) = 0;
+
+    void
+    init_model_tensors(fairseq2_model<unity_hparams> &model);
+};

+ 3 - 0
src/seamless_communication/models/unity/__init__.py

@@ -51,6 +51,9 @@ from seamless_communication.models.unity.loader import (
 from seamless_communication.models.unity.loader import (
 from seamless_communication.models.unity.loader import (
     load_unity_unit_tokenizer as load_unity_unit_tokenizer,
     load_unity_unit_tokenizer as load_unity_unit_tokenizer,
 )
 )
+from seamless_communication.models.unity.loader import (
+    load_unity_config as load_unity_config
+)
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
 from seamless_communication.models.unity.model import UnitYModel as UnitYModel
 from seamless_communication.models.unity.model import (
 from seamless_communication.models.unity.model import (
     UnitYNART2UModel as UnitYNART2UModel,
     UnitYNART2UModel as UnitYNART2UModel,