Pārlūkot izejas kodu

Unity.cpp nllb support (#258)

* extend fairseq2 and unity_lib data structure

* update read and write nllb and unity-arch  model

* address Ning's and Guil's comments about the nllb layer naming

* convert on write instead on read

* fix bugs when combining nllb ggml_convert with layer_fitlers; fix typos

* sync local commit with upstream

* increase default graph size to test bigger nllb model

* fix docstrings

* fix docstrings

* revert nit docstring

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Tuan Tran 1 gadu atpakaļ
vecāks
revīzija
34575dc9b3

+ 6 - 6
ggml/examples/unity/fairseq2.cpp

@@ -1580,7 +1580,7 @@ extern "C" Hypothesis* generate_sequence(
             ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
         }
 
-        printf_mem_usage(step_ctx, "  step_ctx");
+        printf_mem_usage(step_ctx, "step_ctx");
         ggml_free(prev_step_ctx);
         prev_step_ctx = step_ctx;
 #if DEBUG_MEM_USAGE
@@ -1656,7 +1656,7 @@ struct llm_bigram_spm {
 struct llm_tokenizer_spm {
     llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
 
-    void tokenize(const std::string& input_text, ggml_tensor& output) {
+    void tokenize(const std::string& input_text, ggml_tensor* output) {
         llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
 
         // split string into utf8 chars
@@ -1724,8 +1724,8 @@ struct llm_tokenizer_spm {
             try_add_bigram(bigram.left, left_sym.next);
         }
 
-        llama_vocab::id* out = (llama_vocab::id*)output.data;
-        int out_step = sizeof(llama_vocab::id) / output.nb[0];
+        llama_vocab::id* out = (llama_vocab::id*)output->data;
+        int out_step = sizeof(llama_vocab::id) / output->nb[0];
         int num_tokens = 0;
         for (int i = 0; i > -1; i = symbols[i].next) {
             llm_symbol& symbol = symbols[i];
@@ -1734,7 +1734,7 @@ struct llm_tokenizer_spm {
         }
         *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
         num_tokens += 1;
-        output.ne[0] = num_tokens;
+        output->ne[0] = num_tokens;
     }
 
 private:
@@ -1773,7 +1773,7 @@ private:
 };
 
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
     llm_tokenizer_spm spm = {model->vocab};
     spm.tokenize(std::string(text), out);
 }

+ 8 - 2
ggml/examples/unity/fairseq2.h

@@ -199,6 +199,13 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
     ggml_tensor* padding_mask
 );
 
+extern "C" ggml_tensor* StandardTransformerEncoder_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask
+);
+
 extern "C" ggml_tensor* RelativePositionMHA_forward(
     fairseq2_model& model,
     const std::string& prefix,
@@ -317,8 +324,7 @@ extern "C" Hypothesis* generate_sequence(
     int threads
 );
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out);
-
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out);
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out);
 
 std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out);

+ 77 - 1
ggml/examples/unity/lib/unity_lib.cpp

@@ -2,6 +2,23 @@
 #include <algorithm>
 
 
+struct ggml_cgraph * unity_text_encoder(
+        fairseq2_model & model,
+        struct ggml_tensor * text_input) {
+    ggml_context* ctx0 = model.ctx;
+    ggml_cgraph* gf = ggml_new_graph(ctx0);
+    ggml_tensor* seqs = TransformerEmbeddingFrontend_forward(model, "text_encoder_frontend", text_input);
+    ggml_tensor* encoder_output = StandardTransformerEncoder_forward(
+        model,
+        "text_encoder",
+        seqs,
+        nullptr  // TODO: handle padding mask
+    );
+    encoder_output = ggml_dup(model.ctx, encoder_output);
+    ggml_build_forward_expand(gf, encoder_output);
+    return gf;
+}
+
 struct ggml_cgraph * unity_speech_encoder(
         fairseq2_model& model,
         struct ggml_tensor * speech_input) {
@@ -43,7 +60,7 @@ extern "C" fairseq2_model unity_init_model(const char* model_path) {
 }
 
 //  struct as return - transcription, CE score, LID 
-extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads, int memory_mb) {
+extern "C" Result unity_eval_speech(fairseq2_model& model, std::vector<float>& data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
     Result result;
     // The ctx_size_mb mostly depends of input length and model dim.
     int ctx_size_mb = opts.mem_mb;
@@ -101,10 +118,69 @@ extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, Sequ
         lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
     }
     
+    
     result.transcription = result_tokens;
     result.word_confidence_scores = word_scores;
     result.lid_scores = lid_scores;
     result.err = 0;
+    ggml_free(model.ctx);
+    ggml_allocr_reset(fwd_alloc);
+    return result;
+}
+
+
+extern "C" Result unity_eval_text(fairseq2_model& model, const std::string& text, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
+    Result result;
+    // The ctx_size_mb mostly depends of input length and model dim.
+    int ctx_size_mb = opts.mem_mb;
+    auto encoder_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
+    auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
+    ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
+    int tgt_lang_idx;
+    auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__"); 
+    if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
+        std::cerr << "Unknown language " << tgt_lang << "\n";
+        result.err = 1;
+        return result;
+    }
+    tgt_lang_idx = tgt_lang_ptr->second;
+
+    // tokenize the input text
+    model.ctx = ctx_from_buffer(encoder_buf);
+    ggml_set_no_alloc(model.ctx, false);
+    ggml_tensor* tokens = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 64);
+    ggml_set_no_alloc(model.ctx, true);
+    fairseq2_spm_tokenize(&model, text.c_str(), tokens);
+    
+    // Text encoder
+    ggml_cgraph* gf = unity_text_encoder(model, tokens);
+    ggml_allocr_alloc_graph(fwd_alloc, gf);
+    ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);
+    ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
+    
+    // Beam search decoding
+    const Hypothesis* hypo = unity_decode(model, opts, tgt_lang_idx, encoder_output, n_threads);
+    
+    // Drop language and bos token.
+    ggml_tensor* tgt_tokens = ggml_slice(model.ctx, hypo[0].seq, 0, 2, 0);
+    // Collect result string
+    char result_str[4096];
+
+    std::pair<std::vector<std::string>, std::vector<float>> p = fairseq2_spm_detokenize(&model, tgt_tokens, hypo[0].step_scores, (char*)&result_str);
+    std::vector<std::string> result_tokens = p.first;
+    std::vector<float> word_scores = p.second;
+
+    std::unordered_map<std::string, float> lid_scores;
+    std::vector<int> lang_ids;
+    for (const auto& kv : model.vocab.token_to_id) {
+        if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
+            lang_ids.push_back(kv.second);
+        }
+    }
+    std::sort(lang_ids.begin(), lang_ids.end());
+    for (size_t i = 0; i < lang_ids.size(); ++i) {
+        lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
+    }
     
     result.transcription = result_tokens;
     result.word_confidence_scores = word_scores;

+ 22 - 2
ggml/examples/unity/lib/unity_lib.h

@@ -26,7 +26,13 @@ struct Result {
 
 struct ggml_cgraph * unity_speech_encoder(
     fairseq2_model& model,
-    struct ggml_tensor * speech_input);
+    struct ggml_tensor * speech_input
+);
+
+struct ggml_cgraph * unity_text_encoder(
+    fairseq2_model& model,
+    struct ggml_tensor * text_input
+);
 
 Hypothesis* unity_decode(
         fairseq2_model& model,
@@ -38,4 +44,18 @@ Hypothesis* unity_decode(
 
 extern "C" fairseq2_model unity_init_model(const char* model_path);
 
-extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads, int memory_gb);
+extern "C" Result unity_eval_speech(
+    fairseq2_model& model, 
+    std::vector<float>& data, 
+    SequenceGeneratorOptions opts, 
+    std::string tgt_lang, 
+    int n_threads
+);
+
+extern "C" Result unity_eval_text(
+    fairseq2_model& model,  
+    const std::string& text, 
+    SequenceGeneratorOptions opts, 
+    std::string tgt_lang, 
+    int n_threads
+);

+ 31 - 5
ggml/examples/unity/unity.cpp

@@ -13,7 +13,8 @@
 
 struct unity_params {
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    std::string model      = "seamlessM4T_medium.ggml"; // model path
+    std::string model = "seamlessM4T_medium.ggml"; // model path
+    std::string input_text = "";
     std::string tgt_lang = "eng";
     std::vector<std::string> files = {};
     bool text = false;
@@ -26,7 +27,7 @@ struct unity_params {
         /*len_penalty*/ 1.0,
         /*unk_penalty*/ 0.0,
         /*normalize_scores*/ true,
-        /*mem_mb*/ 512,
+        /*mem_mb*/ 512
     };
     bool verbose = false;
 };
@@ -37,6 +38,9 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -i, --input           Input text for the text-2-text translation\n");
+    fprintf(stderr, "  -l, --tgt-lang        Target translation lang (default: %s\n", params.tgt_lang);
+
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -v, --verbose         Print out word level confidence score and LID score (default: off)");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
@@ -67,6 +71,8 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
             params.model = get_next_arg(i, argc, argv, arg, params);
+        } else if (arg == "-i" || arg == "--input") {
+            params.input_text = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-l" || arg == "--tgt-lang") {
             params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "--text") {
@@ -108,8 +114,13 @@ int main(int argc, char ** argv) {
     char result_str[4096];
 
     std::string input;
-    bool interactive = params.files.size() == 0;
+    bool interactive = (params.files.size() == 0 && params.input_text.length() == 0);
     auto next_file = params.files.begin();
+
+    // Flag for the input case: true --> s2st, false --> t2tt
+    bool s2st_or_t2tt = true;
+
+    // S2ST
     while (true) {
         if (interactive) {
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
@@ -118,7 +129,10 @@ int main(int argc, char ** argv) {
                 break;
             }
         } else {
-            if (next_file == params.files.end()) break;
+            if (params.input_text.length() > 0) {
+                break;
+            }
+            if (next_file == params.files.end() && s2st_or_t2tt) break;
             input = *(next_file++);
         }
         std::istringstream iss(input);
@@ -144,7 +158,7 @@ int main(int argc, char ** argv) {
         std::vector<float> data(n_frames * info.channels);
         sf_readf_float(sndfile, data.data(), n_frames);
 
-        Result result = unity_eval(model, data, params.opts, tgt_lang, params.n_threads, ctx_size_mb);
+        Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
         std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
             [](const std::string& a, const std::string& b) {
                 return a + " " + b;
@@ -167,5 +181,17 @@ int main(int argc, char ** argv) {
         }
     }
 
+    // T2TT
+    if (params.input_text.length() > 0) {
+        // tokenize the input text
+        Result result = unity_eval_text(model, params.input_text, params.opts, params.tgt_lang, params.n_threads);
+        std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+            [](const std::string& a, const std::string& b) {
+                return a + " " + b;
+            }
+        );
+        std::cout << "Translation: " << concat_translation << std::endl;
+    }
+
     return 0;
 }

+ 138 - 65
ggml/ggml_convert.py

@@ -6,41 +6,51 @@
 
 import dataclasses
 import logging
-import math
 import struct
 from enum import Enum
 from io import BufferedWriter
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set, final
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence, Set, final
+import re
 
 import torch
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn.transformer import RelativePositionalEncoding
-from seamless_communication.models import unity
-from fairseq2.data.text import SentencePieceTokenizerBase
-from fairseq2.data.typing import PathLike
-from typing import Sequence
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
+from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
-from fairseq2.models.utils import TokenizerLoaderBase
+from fairseq2.models.utils import TokenizerLoaderBase, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
 from fairseq2.assets import asset_store, download_manager
-from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
-from fairseq2.models.utils import ModelLoader
-from seamless_communication.models.unity.model import UnitYModel
 
 import ggml
-import re
 
 Preprocessor = Callable[[Any], Any]
 log = logging.getLogger("ggml_convert")
-SMALLER_MODELS = [
+
+
+class ModelType(str, Enum):
+    AUTO = "auto"  # inferred from the model name
+    UNITY = "unity"
+    NLLB = "nllb"
+
+
+UNITY_SMALLER_MODELS = [
     "unity_nano",
     "unity_micro",
 ]  # Trained with fairseq2, with custom dict (not original NLLB ones)
 
 
+NLLB_2_UNITY_KEYMAP = {
+    r"^encoder_frontend\.": r"text_encoder_frontend.",
+    r"^encoder\."         : r"text_encoder.",
+    r"^decoder\."         : r"text_decoder.",
+    r"^decoder_frontend\.": r"text_decoder_frontend.",
+}
+
+
 @final
 class NllbLikeTokenizer(SentencePieceTokenizerBase):
     """The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
@@ -141,16 +151,6 @@ class NllbLikeTokenizer(SentencePieceTokenizerBase):
         )
 
 
-load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
-    asset_store,
-    download_manager,
-    unity.load_unity_config,
-    create_unity_model,
-    None,
-    restrict_checkpoints=False,
-)
-
-
 @final
 class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
     """Loads tokenizers used by NLLB models."""
@@ -164,44 +164,110 @@ class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
         return NllbLikeTokenizer(pathname, langs, default_lang)
 
 
+def convert_unity_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from seamless_communication.models import unity
+    from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
+    from seamless_communication.models.unity.model import UnitYModel
+
+    load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
+        asset_store,
+        download_manager,
+        unity.load_unity_config,
+        create_unity_model,
+        None,
+        restrict_checkpoints=False,
+    )
+
+    model_config = unity.load_unity_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams
+    )
+    log.info(hparams)
+    # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
+    if model_name in UNITY_SMALLER_MODELS:
+        model = load_unity_model_without_conversion(model_name)
+        tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(model_name)
+    else:
+        model = unity.load_unity_model(model_name)
+        tokenizer = unity.load_unity_text_tokenizer(model_name)
+
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
+def convert_nllb_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from fairseq2.models.nllb.loader import load_nllb_tokenizer, load_nllb_model, load_nllb_config
+
+    model_config = load_nllb_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams,
+    )
+
+    model = load_nllb_model(model_name)
+    tokenizer = load_nllb_tokenizer(model_name)
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
+    model_type: ModelType = ModelType.AUTO,
     layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
     fp16: bool = False,
 ) -> None:
+    """
+    Entry function for converting different kinds of model into GGML file. Supported model checkpoints:
+        - unity models
+        - nllb models
+    Args:
+        model_name: name of a registered model (discoverable in a fairseq2 asset), path to a checkpoint,\
+            or the model object passed directly
+        out: path to store the converted .ggml model. If None, the ggml model is stored in the same place\
+            as input model
+        model_type: type of the model (or inferred from the name, only applied to nllb, unity and seamless)
+        layers: wildcard patterns to filter the layers from the model. Does not applied to scripted models
+        hparams: override the hparams in the model with the user-defined values
+        vocab: list of tokens, or aPath to  vocabulary files (in case not bundled with the model checkpoint)
+        fp16: Save to .GGML float16 tensors instead of float32
+    """
+    key_map: Optional[Dict[str, str]] = None
     if isinstance(model_name, str):
         # Load the corresponding fairseq2 model
         if out is None:
             out = Path(model_name).with_suffix(".ggml")
 
-        # The type of model depends on the name
-        if "unity" in model_name or "seamlessM4T" in model_name:
-            if hparams is None:
-                model_config = unity.load_unity_config(model_name)
-                hparams = flatten_config(
-                    dataclasses.asdict(model_config), separator="__"
-                )
-                log.info(hparams)
-            # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
-            if model_name in SMALLER_MODELS:
-                model = load_unity_model_without_conversion(model_name)
+        # Reason the model architecture from the model name or user input
+        try:
+            if model_type == ModelType.AUTO:
+                if "unity" in model_name or "seamlessM4T" in model_name:
+                    model_type = ModelType.UNITY
+                elif "nllb" in model_name:
+                    model_type = ModelType.NLLB
+
+            assert (
+                model_type != ModelType.AUTO
+            ), "Cannot infer model type from the `model_name`. Please specify `model_type`"
+
+            if model_type == ModelType.UNITY:
+                model, hparams, vocab = convert_unity_model(model_name, hparams=hparams)
+            elif model_type == ModelType.NLLB:
+                model, hparams, vocab = convert_nllb_model(model_name, hparams=hparams)
+                key_map = NLLB_2_UNITY_KEYMAP
             else:
-                model = unity.load_unity_model(model_name)
-            if vocab is None:
-                # Need the diverge here because current default in SC is to add a separate <pad>
-                # as control symbol in NllbTokenizer
-                if model_name in SMALLER_MODELS:
-                    tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(
-                        model_name
-                    )
-                else:
-                    tokenizer = unity.load_unity_text_tokenizer(model_name)
-                vocab = read_vocab(tokenizer)
-        else:
-            raise ValueError(f"Unsupported model type: {model_name}")
+                raise ValueError(f"Unsupported model type: {model_name} (type: {model_type})")
+        except Exception as exc:
+            raise ValueError(f"Error in loading model: {model_name}") from exc
     else:
         # Use the model passed explicitly
         assert (
@@ -214,21 +280,14 @@ def convert_model(
     if layers:
         state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
     fixup_model(model, state_dict, layer_filter=layers)
-    layer_config = read_layer_config(model, layer_filter=layers)
+    if key_map:
+        state_dict = convert_model_state_dict(state_dict, key_map=key_map)
+    layer_config = read_layer_config(model, layer_filter=layers, key_map=key_map)
+
     vocab = vocab or []
     write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
 
 
-def _nested_getattr(model: Any, name: str) -> Any:
-    parts = name.split(".")
-    node = model
-    for part in parts:
-        node = getattr(node, part)
-        if node is None:
-            return None
-    return node
-
-
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
     queue = list(model._modules.items())
     modules = []
@@ -385,10 +444,12 @@ def write_state_dict(
         # Compressed size
         compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
         log.warning(
-            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb compressed to {compressed_byte_size / GB:.3f}"
+            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb"
+            f". Compressed to {compressed_byte_size / GB:.3f}Gb"
         )
 
     for key, value in state_dict.items():
+        # Rename the layers to make it look like "unity-arch"
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
             # GGML broadcasting isn't as strong as numpy
@@ -463,7 +524,7 @@ def torch_to_ggml_type(dtype: torch.dtype) -> int:
 def flatten_config(
     config: Dict[str, Any],
     separator: str,
-    config_preprocessor: Optional[Preprocessor] = None,
+    overrides: Optional[Dict[str, Any]] = None,
 ) -> Dict[str, Any]:
     """Flatten nested dictionnary
 
@@ -478,9 +539,6 @@ def flatten_config(
         flat dictionnary
     """
 
-    if config_preprocessor is None:
-        config_preprocessor = lambda x: x
-
     def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
         result = {}
         for key in config:
@@ -489,16 +547,22 @@ def flatten_config(
                 nested_result = __flatten(config[key], f"{new_key}{separator}")
                 result.update(nested_result)
             else:
-                new_config = config_preprocessor(config[key])
+                new_config = config[key]
                 if new_config is not None:
                     result[new_key] = config[key]
 
         return result
 
-    return __flatten(config)
+    res_config = __flatten(config)
+    if overrides:
+        return {**res_config, **overrides}
+    else:
+        return res_config
 
 
-def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
+def read_layer_config(
+    model: torch.nn.Module, layer_filter: str, key_map: Optional[Dict[str, str]] = None
+) -> Dict[str, Any]:
     layer_config = {}
 
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -523,6 +587,15 @@ def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, An
     _append_node_config(model, "")
     for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
+
+    key_map = key_map or {}
+    keys_to_replace = []
+    for k, v in layer_config.items():
+        for old_pattern, replacement in key_map.items():
+            if (new_key := re.sub(old_pattern, replacement, k)) != k:
+                keys_to_replace.append((k, new_key))
+    for old_key, new_key in keys_to_replace:
+        layer_config[new_key] = layer_config.pop(old_key)
     return layer_config
 
 

+ 2 - 2
ggml/include/ggml/ggml.h

@@ -215,13 +215,13 @@
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS           4
-#define GGML_MAX_PARAMS         2048
+#define GGML_MAX_PARAMS         4096
 #define GGML_MAX_CONTEXTS       64
 #define GGML_MAX_SRC            10
 #define GGML_MAX_NAME           64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
-#define GGML_DEFAULT_GRAPH_SIZE 2048
+#define GGML_DEFAULT_GRAPH_SIZE 4096
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
 #else