Browse Source

Unity.cpp nllb support (#258)

* extend fairseq2 and unity_lib data structure

* update read and write nllb and unity-arch  model

* address Ning's and Guil's comments about the nllb layer naming

* convert on write instead on read

* fix bugs when combining nllb ggml_convert with layer_fitlers; fix typos

* sync local commit with upstream

* increase default graph size to test bigger nllb model

* fix docstrings

* fix docstrings

* revert nit docstring

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Tuan Tran 1 year ago
parent
commit
34575dc9b3

+ 6 - 6
ggml/examples/unity/fairseq2.cpp

@@ -1580,7 +1580,7 @@ extern "C" Hypothesis* generate_sequence(
             ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
             ((float*)scores->data)[step_nr + 1 + i * max_seq_len] = ggml_get_f32_1d(next_scores, i);
         }
         }
 
 
-        printf_mem_usage(step_ctx, "  step_ctx");
+        printf_mem_usage(step_ctx, "step_ctx");
         ggml_free(prev_step_ctx);
         ggml_free(prev_step_ctx);
         prev_step_ctx = step_ctx;
         prev_step_ctx = step_ctx;
 #if DEBUG_MEM_USAGE
 #if DEBUG_MEM_USAGE
@@ -1656,7 +1656,7 @@ struct llm_bigram_spm {
 struct llm_tokenizer_spm {
 struct llm_tokenizer_spm {
     llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
     llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
 
 
-    void tokenize(const std::string& input_text, ggml_tensor& output) {
+    void tokenize(const std::string& input_text, ggml_tensor* output) {
         llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
         llama_vocab::id unk_idx = vocab.token_to_id.at("<unk>");
 
 
         // split string into utf8 chars
         // split string into utf8 chars
@@ -1724,8 +1724,8 @@ struct llm_tokenizer_spm {
             try_add_bigram(bigram.left, left_sym.next);
             try_add_bigram(bigram.left, left_sym.next);
         }
         }
 
 
-        llama_vocab::id* out = (llama_vocab::id*)output.data;
-        int out_step = sizeof(llama_vocab::id) / output.nb[0];
+        llama_vocab::id* out = (llama_vocab::id*)output->data;
+        int out_step = sizeof(llama_vocab::id) / output->nb[0];
         int num_tokens = 0;
         int num_tokens = 0;
         for (int i = 0; i > -1; i = symbols[i].next) {
         for (int i = 0; i > -1; i = symbols[i].next) {
             llm_symbol& symbol = symbols[i];
             llm_symbol& symbol = symbols[i];
@@ -1734,7 +1734,7 @@ struct llm_tokenizer_spm {
         }
         }
         *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
         *(out + num_tokens * out_step) = vocab.token_to_id.at("</s>");
         num_tokens += 1;
         num_tokens += 1;
-        output.ne[0] = num_tokens;
+        output->ne[0] = num_tokens;
     }
     }
 
 
 private:
 private:
@@ -1773,7 +1773,7 @@ private:
 };
 };
 
 
 
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out) {
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out) {
     llm_tokenizer_spm spm = {model->vocab};
     llm_tokenizer_spm spm = {model->vocab};
     spm.tokenize(std::string(text), out);
     spm.tokenize(std::string(text), out);
 }
 }

+ 8 - 2
ggml/examples/unity/fairseq2.h

@@ -199,6 +199,13 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
     ggml_tensor* padding_mask
     ggml_tensor* padding_mask
 );
 );
 
 
+extern "C" ggml_tensor* StandardTransformerEncoder_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask
+);
+
 extern "C" ggml_tensor* RelativePositionMHA_forward(
 extern "C" ggml_tensor* RelativePositionMHA_forward(
     fairseq2_model& model,
     fairseq2_model& model,
     const std::string& prefix,
     const std::string& prefix,
@@ -317,8 +324,7 @@ extern "C" Hypothesis* generate_sequence(
     int threads
     int threads
 );
 );
 
 
-extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor& out);
-
+extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, ggml_tensor* out);
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out);
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out);
 
 
 std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out);
 std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out);

+ 77 - 1
ggml/examples/unity/lib/unity_lib.cpp

@@ -2,6 +2,23 @@
 #include <algorithm>
 #include <algorithm>
 
 
 
 
+struct ggml_cgraph * unity_text_encoder(
+        fairseq2_model & model,
+        struct ggml_tensor * text_input) {
+    ggml_context* ctx0 = model.ctx;
+    ggml_cgraph* gf = ggml_new_graph(ctx0);
+    ggml_tensor* seqs = TransformerEmbeddingFrontend_forward(model, "text_encoder_frontend", text_input);
+    ggml_tensor* encoder_output = StandardTransformerEncoder_forward(
+        model,
+        "text_encoder",
+        seqs,
+        nullptr  // TODO: handle padding mask
+    );
+    encoder_output = ggml_dup(model.ctx, encoder_output);
+    ggml_build_forward_expand(gf, encoder_output);
+    return gf;
+}
+
 struct ggml_cgraph * unity_speech_encoder(
 struct ggml_cgraph * unity_speech_encoder(
         fairseq2_model& model,
         fairseq2_model& model,
         struct ggml_tensor * speech_input) {
         struct ggml_tensor * speech_input) {
@@ -43,7 +60,7 @@ extern "C" fairseq2_model unity_init_model(const char* model_path) {
 }
 }
 
 
 //  struct as return - transcription, CE score, LID 
 //  struct as return - transcription, CE score, LID 
-extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads, int memory_mb) {
+extern "C" Result unity_eval_speech(fairseq2_model& model, std::vector<float>& data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
     Result result;
     Result result;
     // The ctx_size_mb mostly depends of input length and model dim.
     // The ctx_size_mb mostly depends of input length and model dim.
     int ctx_size_mb = opts.mem_mb;
     int ctx_size_mb = opts.mem_mb;
@@ -101,10 +118,69 @@ extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, Sequ
         lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
         lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
     }
     }
     
     
+    
     result.transcription = result_tokens;
     result.transcription = result_tokens;
     result.word_confidence_scores = word_scores;
     result.word_confidence_scores = word_scores;
     result.lid_scores = lid_scores;
     result.lid_scores = lid_scores;
     result.err = 0;
     result.err = 0;
+    ggml_free(model.ctx);
+    ggml_allocr_reset(fwd_alloc);
+    return result;
+}
+
+
+extern "C" Result unity_eval_text(fairseq2_model& model, const std::string& text, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads) {
+    Result result;
+    // The ctx_size_mb mostly depends of input length and model dim.
+    int ctx_size_mb = opts.mem_mb;
+    auto encoder_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
+    auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
+    ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
+    int tgt_lang_idx;
+    auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__"); 
+    if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
+        std::cerr << "Unknown language " << tgt_lang << "\n";
+        result.err = 1;
+        return result;
+    }
+    tgt_lang_idx = tgt_lang_ptr->second;
+
+    // tokenize the input text
+    model.ctx = ctx_from_buffer(encoder_buf);
+    ggml_set_no_alloc(model.ctx, false);
+    ggml_tensor* tokens = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 64);
+    ggml_set_no_alloc(model.ctx, true);
+    fairseq2_spm_tokenize(&model, text.c_str(), tokens);
+    
+    // Text encoder
+    ggml_cgraph* gf = unity_text_encoder(model, tokens);
+    ggml_allocr_alloc_graph(fwd_alloc, gf);
+    ggml_graph_compute_with_ctx(model.ctx, gf, n_threads);
+    ggml_tensor* encoder_output = gf->nodes[gf->n_nodes - 1];
+    
+    // Beam search decoding
+    const Hypothesis* hypo = unity_decode(model, opts, tgt_lang_idx, encoder_output, n_threads);
+    
+    // Drop language and bos token.
+    ggml_tensor* tgt_tokens = ggml_slice(model.ctx, hypo[0].seq, 0, 2, 0);
+    // Collect result string
+    char result_str[4096];
+
+    std::pair<std::vector<std::string>, std::vector<float>> p = fairseq2_spm_detokenize(&model, tgt_tokens, hypo[0].step_scores, (char*)&result_str);
+    std::vector<std::string> result_tokens = p.first;
+    std::vector<float> word_scores = p.second;
+
+    std::unordered_map<std::string, float> lid_scores;
+    std::vector<int> lang_ids;
+    for (const auto& kv : model.vocab.token_to_id) {
+        if (kv.first.substr(0, 2) == "__" && kv.first.substr(kv.first.size() - 2) == "__") {
+            lang_ids.push_back(kv.second);
+        }
+    }
+    std::sort(lang_ids.begin(), lang_ids.end());
+    for (size_t i = 0; i < lang_ids.size(); ++i) {
+        lid_scores[model.vocab.id_to_token[lang_ids[i]].text] = ggml_get_f32_1d(hypo[0].lid_scores, i); 
+    }
     
     
     result.transcription = result_tokens;
     result.transcription = result_tokens;
     result.word_confidence_scores = word_scores;
     result.word_confidence_scores = word_scores;

+ 22 - 2
ggml/examples/unity/lib/unity_lib.h

@@ -26,7 +26,13 @@ struct Result {
 
 
 struct ggml_cgraph * unity_speech_encoder(
 struct ggml_cgraph * unity_speech_encoder(
     fairseq2_model& model,
     fairseq2_model& model,
-    struct ggml_tensor * speech_input);
+    struct ggml_tensor * speech_input
+);
+
+struct ggml_cgraph * unity_text_encoder(
+    fairseq2_model& model,
+    struct ggml_tensor * text_input
+);
 
 
 Hypothesis* unity_decode(
 Hypothesis* unity_decode(
         fairseq2_model& model,
         fairseq2_model& model,
@@ -38,4 +44,18 @@ Hypothesis* unity_decode(
 
 
 extern "C" fairseq2_model unity_init_model(const char* model_path);
 extern "C" fairseq2_model unity_init_model(const char* model_path);
 
 
-extern "C" Result unity_eval(fairseq2_model model, std::vector<float> data, SequenceGeneratorOptions opts, std::string tgt_lang, int n_threads, int memory_gb);
+extern "C" Result unity_eval_speech(
+    fairseq2_model& model, 
+    std::vector<float>& data, 
+    SequenceGeneratorOptions opts, 
+    std::string tgt_lang, 
+    int n_threads
+);
+
+extern "C" Result unity_eval_text(
+    fairseq2_model& model,  
+    const std::string& text, 
+    SequenceGeneratorOptions opts, 
+    std::string tgt_lang, 
+    int n_threads
+);

+ 31 - 5
ggml/examples/unity/unity.cpp

@@ -13,7 +13,8 @@
 
 
 struct unity_params {
 struct unity_params {
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    std::string model      = "seamlessM4T_medium.ggml"; // model path
+    std::string model = "seamlessM4T_medium.ggml"; // model path
+    std::string input_text = "";
     std::string tgt_lang = "eng";
     std::string tgt_lang = "eng";
     std::vector<std::string> files = {};
     std::vector<std::string> files = {};
     bool text = false;
     bool text = false;
@@ -26,7 +27,7 @@ struct unity_params {
         /*len_penalty*/ 1.0,
         /*len_penalty*/ 1.0,
         /*unk_penalty*/ 0.0,
         /*unk_penalty*/ 0.0,
         /*normalize_scores*/ true,
         /*normalize_scores*/ true,
-        /*mem_mb*/ 512,
+        /*mem_mb*/ 512
     };
     };
     bool verbose = false;
     bool verbose = false;
 };
 };
@@ -37,6 +38,9 @@ void unity_print_usage(int /*argc*/, char ** argv, const unity_params & params)
     fprintf(stderr, "\n");
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
     fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -i, --input           Input text for the text-2-text translation\n");
+    fprintf(stderr, "  -l, --tgt-lang        Target translation lang (default: %s\n", params.tgt_lang);
+
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -v, --verbose         Print out word level confidence score and LID score (default: off)");
     fprintf(stderr, "  -v, --verbose         Print out word level confidence score and LID score (default: off)");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
@@ -67,6 +71,8 @@ bool unity_params_parse(int argc, char ** argv, unity_params & params) {
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
             params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
         } else if (arg == "-m" || arg == "--model") {
         } else if (arg == "-m" || arg == "--model") {
             params.model = get_next_arg(i, argc, argv, arg, params);
             params.model = get_next_arg(i, argc, argv, arg, params);
+        } else if (arg == "-i" || arg == "--input") {
+            params.input_text = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "-l" || arg == "--tgt-lang") {
         } else if (arg == "-l" || arg == "--tgt-lang") {
             params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
             params.tgt_lang = get_next_arg(i, argc, argv, arg, params);
         } else if (arg == "--text") {
         } else if (arg == "--text") {
@@ -108,8 +114,13 @@ int main(int argc, char ** argv) {
     char result_str[4096];
     char result_str[4096];
 
 
     std::string input;
     std::string input;
-    bool interactive = params.files.size() == 0;
+    bool interactive = (params.files.size() == 0 && params.input_text.length() == 0);
     auto next_file = params.files.begin();
     auto next_file = params.files.begin();
+
+    // Flag for the input case: true --> s2st, false --> t2tt
+    bool s2st_or_t2tt = true;
+
+    // S2ST
     while (true) {
     while (true) {
         if (interactive) {
         if (interactive) {
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
             std::cout << "\nEnter audio_path and tgt_lang, separated by space (or 'exit' to quit):\n";
@@ -118,7 +129,10 @@ int main(int argc, char ** argv) {
                 break;
                 break;
             }
             }
         } else {
         } else {
-            if (next_file == params.files.end()) break;
+            if (params.input_text.length() > 0) {
+                break;
+            }
+            if (next_file == params.files.end() && s2st_or_t2tt) break;
             input = *(next_file++);
             input = *(next_file++);
         }
         }
         std::istringstream iss(input);
         std::istringstream iss(input);
@@ -144,7 +158,7 @@ int main(int argc, char ** argv) {
         std::vector<float> data(n_frames * info.channels);
         std::vector<float> data(n_frames * info.channels);
         sf_readf_float(sndfile, data.data(), n_frames);
         sf_readf_float(sndfile, data.data(), n_frames);
 
 
-        Result result = unity_eval(model, data, params.opts, tgt_lang, params.n_threads, ctx_size_mb);
+        Result result = unity_eval_speech(model, data, params.opts, tgt_lang, params.n_threads);
         std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
         std::string concat_transcription = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
             [](const std::string& a, const std::string& b) {
             [](const std::string& a, const std::string& b) {
                 return a + " " + b;
                 return a + " " + b;
@@ -167,5 +181,17 @@ int main(int argc, char ** argv) {
         }
         }
     }
     }
 
 
+    // T2TT
+    if (params.input_text.length() > 0) {
+        // tokenize the input text
+        Result result = unity_eval_text(model, params.input_text, params.opts, params.tgt_lang, params.n_threads);
+        std::string concat_translation = std::accumulate(std::next(result.transcription.begin()), result.transcription.end(), result.transcription[0],
+            [](const std::string& a, const std::string& b) {
+                return a + " " + b;
+            }
+        );
+        std::cout << "Translation: " << concat_translation << std::endl;
+    }
+
     return 0;
     return 0;
 }
 }

+ 138 - 65
ggml/ggml_convert.py

@@ -6,41 +6,51 @@
 
 
 import dataclasses
 import dataclasses
 import logging
 import logging
-import math
 import struct
 import struct
 from enum import Enum
 from enum import Enum
 from io import BufferedWriter
 from io import BufferedWriter
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set, final
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence, Set, final
+import re
 
 
 import torch
 import torch
 from fairseq2.assets import AssetCard
 from fairseq2.assets import AssetCard
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn import SinusoidalPositionEncoder
 from fairseq2.nn.transformer import RelativePositionalEncoding
 from fairseq2.nn.transformer import RelativePositionalEncoding
-from seamless_communication.models import unity
-from fairseq2.data.text import SentencePieceTokenizerBase
-from fairseq2.data.typing import PathLike
-from typing import Sequence
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
 from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
+from fairseq2.data.typing import PathLike
 from fairseq2.typing import Device, finaloverride
 from fairseq2.typing import Device, finaloverride
-from fairseq2.models.utils import TokenizerLoaderBase
+from fairseq2.models.utils import TokenizerLoaderBase, ModelLoader
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets import asset_store, download_manager
-from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
-from fairseq2.models.utils import ModelLoader
-from seamless_communication.models.unity.model import UnitYModel
 
 
 import ggml
 import ggml
-import re
 
 
 Preprocessor = Callable[[Any], Any]
 Preprocessor = Callable[[Any], Any]
 log = logging.getLogger("ggml_convert")
 log = logging.getLogger("ggml_convert")
-SMALLER_MODELS = [
+
+
+class ModelType(str, Enum):
+    AUTO = "auto"  # inferred from the model name
+    UNITY = "unity"
+    NLLB = "nllb"
+
+
+UNITY_SMALLER_MODELS = [
     "unity_nano",
     "unity_nano",
     "unity_micro",
     "unity_micro",
 ]  # Trained with fairseq2, with custom dict (not original NLLB ones)
 ]  # Trained with fairseq2, with custom dict (not original NLLB ones)
 
 
 
 
+NLLB_2_UNITY_KEYMAP = {
+    r"^encoder_frontend\.": r"text_encoder_frontend.",
+    r"^encoder\."         : r"text_encoder.",
+    r"^decoder\."         : r"text_decoder.",
+    r"^decoder_frontend\.": r"text_decoder_frontend.",
+}
+
+
 @final
 @final
 class NllbLikeTokenizer(SentencePieceTokenizerBase):
 class NllbLikeTokenizer(SentencePieceTokenizerBase):
     """The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
     """The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
@@ -141,16 +151,6 @@ class NllbLikeTokenizer(SentencePieceTokenizerBase):
         )
         )
 
 
 
 
-load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
-    asset_store,
-    download_manager,
-    unity.load_unity_config,
-    create_unity_model,
-    None,
-    restrict_checkpoints=False,
-)
-
-
 @final
 @final
 class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
 class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
     """Loads tokenizers used by NLLB models."""
     """Loads tokenizers used by NLLB models."""
@@ -164,44 +164,110 @@ class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
         return NllbLikeTokenizer(pathname, langs, default_lang)
         return NllbLikeTokenizer(pathname, langs, default_lang)
 
 
 
 
+def convert_unity_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from seamless_communication.models import unity
+    from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
+    from seamless_communication.models.unity.model import UnitYModel
+
+    load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
+        asset_store,
+        download_manager,
+        unity.load_unity_config,
+        create_unity_model,
+        None,
+        restrict_checkpoints=False,
+    )
+
+    model_config = unity.load_unity_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams
+    )
+    log.info(hparams)
+    # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
+    if model_name in UNITY_SMALLER_MODELS:
+        model = load_unity_model_without_conversion(model_name)
+        tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(model_name)
+    else:
+        model = unity.load_unity_model(model_name)
+        tokenizer = unity.load_unity_text_tokenizer(model_name)
+
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
+def convert_nllb_model(
+    model_name: str,
+    hparams: Optional[Dict[str, Any]] = None,
+):
+    from fairseq2.models.nllb.loader import load_nllb_tokenizer, load_nllb_model, load_nllb_config
+
+    model_config = load_nllb_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams,
+    )
+
+    model = load_nllb_model(model_name)
+    tokenizer = load_nllb_tokenizer(model_name)
+    vocab = read_vocab(tokenizer)
+
+    return model, hparams, vocab
+
+
 def convert_model(
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
     out: Optional[Path] = None,
+    model_type: ModelType = ModelType.AUTO,
     layers: str = "",
     layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     hparams: Optional[Dict[str, Any]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
     fp16: bool = False,
     fp16: bool = False,
 ) -> None:
 ) -> None:
+    """
+    Entry function for converting different kinds of model into GGML file. Supported model checkpoints:
+        - unity models
+        - nllb models
+    Args:
+        model_name: name of a registered model (discoverable in a fairseq2 asset), path to a checkpoint,\
+            or the model object passed directly
+        out: path to store the converted .ggml model. If None, the ggml model is stored in the same place\
+            as input model
+        model_type: type of the model (or inferred from the name, only applied to nllb, unity and seamless)
+        layers: wildcard patterns to filter the layers from the model. Does not applied to scripted models
+        hparams: override the hparams in the model with the user-defined values
+        vocab: list of tokens, or aPath to  vocabulary files (in case not bundled with the model checkpoint)
+        fp16: Save to .GGML float16 tensors instead of float32
+    """
+    key_map: Optional[Dict[str, str]] = None
     if isinstance(model_name, str):
     if isinstance(model_name, str):
         # Load the corresponding fairseq2 model
         # Load the corresponding fairseq2 model
         if out is None:
         if out is None:
             out = Path(model_name).with_suffix(".ggml")
             out = Path(model_name).with_suffix(".ggml")
 
 
-        # The type of model depends on the name
-        if "unity" in model_name or "seamlessM4T" in model_name:
-            if hparams is None:
-                model_config = unity.load_unity_config(model_name)
-                hparams = flatten_config(
-                    dataclasses.asdict(model_config), separator="__"
-                )
-                log.info(hparams)
-            # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
-            if model_name in SMALLER_MODELS:
-                model = load_unity_model_without_conversion(model_name)
+        # Reason the model architecture from the model name or user input
+        try:
+            if model_type == ModelType.AUTO:
+                if "unity" in model_name or "seamlessM4T" in model_name:
+                    model_type = ModelType.UNITY
+                elif "nllb" in model_name:
+                    model_type = ModelType.NLLB
+
+            assert (
+                model_type != ModelType.AUTO
+            ), "Cannot infer model type from the `model_name`. Please specify `model_type`"
+
+            if model_type == ModelType.UNITY:
+                model, hparams, vocab = convert_unity_model(model_name, hparams=hparams)
+            elif model_type == ModelType.NLLB:
+                model, hparams, vocab = convert_nllb_model(model_name, hparams=hparams)
+                key_map = NLLB_2_UNITY_KEYMAP
             else:
             else:
-                model = unity.load_unity_model(model_name)
-            if vocab is None:
-                # Need the diverge here because current default in SC is to add a separate <pad>
-                # as control symbol in NllbTokenizer
-                if model_name in SMALLER_MODELS:
-                    tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(
-                        model_name
-                    )
-                else:
-                    tokenizer = unity.load_unity_text_tokenizer(model_name)
-                vocab = read_vocab(tokenizer)
-        else:
-            raise ValueError(f"Unsupported model type: {model_name}")
+                raise ValueError(f"Unsupported model type: {model_name} (type: {model_type})")
+        except Exception as exc:
+            raise ValueError(f"Error in loading model: {model_name}") from exc
     else:
     else:
         # Use the model passed explicitly
         # Use the model passed explicitly
         assert (
         assert (
@@ -214,21 +280,14 @@ def convert_model(
     if layers:
     if layers:
         state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
         state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
     fixup_model(model, state_dict, layer_filter=layers)
     fixup_model(model, state_dict, layer_filter=layers)
-    layer_config = read_layer_config(model, layer_filter=layers)
+    if key_map:
+        state_dict = convert_model_state_dict(state_dict, key_map=key_map)
+    layer_config = read_layer_config(model, layer_filter=layers, key_map=key_map)
+
     vocab = vocab or []
     vocab = vocab or []
     write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
     write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
 
 
 
 
-def _nested_getattr(model: Any, name: str) -> Any:
-    parts = name.split(".")
-    node = model
-    for part in parts:
-        node = getattr(node, part)
-        if node is None:
-            return None
-    return node
-
-
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
     queue = list(model._modules.items())
     queue = list(model._modules.items())
     modules = []
     modules = []
@@ -385,10 +444,12 @@ def write_state_dict(
         # Compressed size
         # Compressed size
         compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
         compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
         log.warning(
         log.warning(
-            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb compressed to {compressed_byte_size / GB:.3f}"
+            f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb"
+            f". Compressed to {compressed_byte_size / GB:.3f}Gb"
         )
         )
 
 
     for key, value in state_dict.items():
     for key, value in state_dict.items():
+        # Rename the layers to make it look like "unity-arch"
         write_string(out, key)
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
             # GGML broadcasting isn't as strong as numpy
             # GGML broadcasting isn't as strong as numpy
@@ -463,7 +524,7 @@ def torch_to_ggml_type(dtype: torch.dtype) -> int:
 def flatten_config(
 def flatten_config(
     config: Dict[str, Any],
     config: Dict[str, Any],
     separator: str,
     separator: str,
-    config_preprocessor: Optional[Preprocessor] = None,
+    overrides: Optional[Dict[str, Any]] = None,
 ) -> Dict[str, Any]:
 ) -> Dict[str, Any]:
     """Flatten nested dictionnary
     """Flatten nested dictionnary
 
 
@@ -478,9 +539,6 @@ def flatten_config(
         flat dictionnary
         flat dictionnary
     """
     """
 
 
-    if config_preprocessor is None:
-        config_preprocessor = lambda x: x
-
     def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
     def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
         result = {}
         result = {}
         for key in config:
         for key in config:
@@ -489,16 +547,22 @@ def flatten_config(
                 nested_result = __flatten(config[key], f"{new_key}{separator}")
                 nested_result = __flatten(config[key], f"{new_key}{separator}")
                 result.update(nested_result)
                 result.update(nested_result)
             else:
             else:
-                new_config = config_preprocessor(config[key])
+                new_config = config[key]
                 if new_config is not None:
                 if new_config is not None:
                     result[new_key] = config[key]
                     result[new_key] = config[key]
 
 
         return result
         return result
 
 
-    return __flatten(config)
+    res_config = __flatten(config)
+    if overrides:
+        return {**res_config, **overrides}
+    else:
+        return res_config
 
 
 
 
-def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
+def read_layer_config(
+    model: torch.nn.Module, layer_filter: str, key_map: Optional[Dict[str, str]] = None
+) -> Dict[str, Any]:
     layer_config = {}
     layer_config = {}
 
 
     def _append_node_config(node: Any, prefix: str) -> None:
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -523,6 +587,15 @@ def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, An
     _append_node_config(model, "")
     _append_node_config(model, "")
     for name, node in find_children(model, torch.nn.Module, layer_filter):
     for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
         _append_node_config(node, name + ".")
+
+    key_map = key_map or {}
+    keys_to_replace = []
+    for k, v in layer_config.items():
+        for old_pattern, replacement in key_map.items():
+            if (new_key := re.sub(old_pattern, replacement, k)) != k:
+                keys_to_replace.append((k, new_key))
+    for old_key, new_key in keys_to_replace:
+        layer_config[new_key] = layer_config.pop(old_key)
     return layer_config
     return layer_config
 
 
 
 

+ 2 - 2
ggml/include/ggml/ggml.h

@@ -215,13 +215,13 @@
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 
 #define GGML_MAX_DIMS           4
 #define GGML_MAX_DIMS           4
-#define GGML_MAX_PARAMS         2048
+#define GGML_MAX_PARAMS         4096
 #define GGML_MAX_CONTEXTS       64
 #define GGML_MAX_CONTEXTS       64
 #define GGML_MAX_SRC            10
 #define GGML_MAX_SRC            10
 #define GGML_MAX_NAME           64
 #define GGML_MAX_NAME           64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
 #define GGML_DEFAULT_N_THREADS  4
-#define GGML_DEFAULT_GRAPH_SIZE 2048
+#define GGML_DEFAULT_GRAPH_SIZE 4096
 #if UINTPTR_MAX == 0xFFFFFFFF
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
     #define GGML_MEM_ALIGN 4
 #else
 #else