Bladeren bron

merge with wa_models

Tuan Tran 1 jaar geleden
bovenliggende
commit
747002a537

+ 11 - 5
ggml/examples/unity/fairseq2.cpp

@@ -1781,15 +1781,16 @@ extern "C" void fairseq2_spm_tokenize(fairseq2_model* model, const char* text, g
 
 
 extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, char* out) {
-    int eos_idx = model->vocab.token_to_id["</s>"];
+    bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
+    int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
     int sent_len = tokens->ne[0];
     std::size_t written = 0;
+std::vector<llama_vocab::token_data> id_to_token = no_tgt_vocab ? model->vocab.id_to_token : model->tgt_vocab.id_to_token;
     for (int i = 0; i < sent_len; ++i) {
         int id = ggml_get_i32_1d(tokens, i);
         // Don't print the EOS token but only if it appear at the end.
         if (i == sent_len - 1 && eos_idx == id) break;
-
-        std::string token = model->vocab.id_to_token.at(id).text;
+        std::string token = no_tgt_vocab ? model->vocab.id_to_token.at(id).text : model->tgt_vocab.id_to_token.at(id).text;
         // Skip the first space outputted.
         auto begin = token.begin();
         if (i == 0 && token.size() > 0 && token[0] == ' ') begin += 1;
@@ -1804,8 +1805,13 @@ extern "C" std::size_t fairseq2_spm_detokenize(fairseq2_model* model, ggml_tenso
 
 
 // TODO: Unify with the above?
-std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(fairseq2_model* model, ggml_tensor* tokens, ggml_tensor* scores, char* out) {
-    int eos_idx = model->vocab.token_to_id["</s>"];
+std::pair<std::vector<std::string>, std::vector<float>> fairseq2_spm_detokenize(
+        fairseq2_model* model,
+        ggml_tensor* tokens,
+        ggml_tensor* scores,
+        char* out) {
+    bool no_tgt_vocab = model->tgt_vocab.id_to_token.empty();
+    int eos_idx = no_tgt_vocab ? model->vocab.token_to_id["</s>"] : model->tgt_vocab.token_to_id["</s>"];
     int sent_len = tokens->ne[0];
     std::size_t written = 0;
     std::vector<float> word_scores;

+ 4 - 0
ggml/examples/unity/fairseq2.h

@@ -97,8 +97,12 @@ struct fairseq2_model {
     // Normally those can be inferred from hparams, but it avoids doing this logic in GGML
     std::unordered_map<std::string, std::int64_t> layer_config = {};
 
+// Vocabulary for text transcription and translation APIs
     llama_vocab vocab;
 
+    // Optional target vocabulary for bilingual models
+    llama_vocab tgt_vocab;
+
     // KV cache for attention layers
     mutable std::unordered_map<std::string, KeyValueTensor> kv_cache = {};
 

+ 10 - 6
ggml/examples/unity/lib/unity_lib.cpp

@@ -48,7 +48,9 @@ Hypothesis* unity_decode(
     };
     FORCE_ALLOC(prefix_seq, model.ctx, ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, 2));
     ((int *)prefix_seq->data)[0]  = job.eos_idx;
+if (model.hparams["multilingual"] != 0) {
     ((int *)prefix_seq->data)[1]  = tgt_lang_idx;
+}
     job.prefix_seq = prefix_seq;
     return generate_sequence(model, job, encoder_output, nullptr, model.ctx, n_threads);
 }
@@ -137,13 +139,15 @@ extern "C" Result unity_eval_text(fairseq2_model& model, const std::string& text
     auto encoder_fwd_buf = std::vector<uint8_t>(ctx_size_mb * 1024 * 1024);
     ggml_allocr* fwd_alloc = ggml_allocr_new(encoder_fwd_buf.data(), encoder_fwd_buf.capacity(), 8);
     int tgt_lang_idx;
-    auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__"); 
-    if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
-        std::cerr << "Unknown language " << tgt_lang << "\n";
-        result.err = 1;
-        return result;
+    if (model.hparams["multilingual"] != 0) {
+        auto tgt_lang_ptr = model.vocab.token_to_id.find("__" + tgt_lang + "__"); 
+        if (tgt_lang_ptr == model.vocab.token_to_id.end()) {
+            std::cerr << "Unknown language " << tgt_lang << "\n";
+            result.err = 1;
+            return result;
+        }
+        tgt_lang_idx = tgt_lang_ptr->second;
     }
-    tgt_lang_idx = tgt_lang_ptr->second;
 
     // tokenize the input text
     model.ctx = ctx_from_buffer(encoder_buf);

+ 1 - 0
ggml/examples/unity/model_loader.cpp

@@ -219,5 +219,6 @@ extern "C" int load_fairseq2_ggml_file(fairseq2_model& model, const char* fname)
     loader.load_hparams(model.layer_config, fin);
     loader.load_vocab(model.vocab, fin);
     loader.load_model_weights(model, fin);
+    loader.load_vocab(model.tgt_vocab, fin);
     return 0;
 }

+ 78 - 10
ggml/ggml_convert.py

@@ -35,6 +35,8 @@ class ModelType(str, Enum):
     AUTO = "auto"  # inferred from the model name
     UNITY = "unity"
     NLLB = "nllb"
+    BITEXT = "bitext"
+    BITEXT_SCRIPTED = "bitext_scripted"
 
 
 UNITY_SMALLER_MODELS = [
@@ -185,6 +187,7 @@ def convert_unity_model(
     hparams = flatten_config(
         dataclasses.asdict(model_config), separator="__", overrides=hparams
     )
+    hparams["multilingual"] = True
     log.info(hparams)
     # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
     if model_name in UNITY_SMALLER_MODELS:
@@ -194,7 +197,7 @@ def convert_unity_model(
         model = unity.load_unity_model(model_name)
         tokenizer = unity.load_unity_text_tokenizer(model_name)
 
-    vocab = read_vocab(tokenizer)
+    vocab = read_vocab_from_tokenizer(tokenizer)
 
     return model, hparams, vocab
 
@@ -209,27 +212,72 @@ def convert_nllb_model(
     hparams = flatten_config(
         dataclasses.asdict(model_config), separator="__", overrides=hparams,
     )
+    hparams["multilingual"] = True
 
     model = load_nllb_model(model_name)
     tokenizer = load_nllb_tokenizer(model_name)
-    vocab = read_vocab(tokenizer)
+    vocab = read_vocab_from_tokenizer(tokenizer)
 
     return model, hparams, vocab
 
 
+def convert_bitext_model(
+    model_name: str,
+    src_vocab: str,
+    tgt_vocab: str,
+    hparams: Optional[Dict[str, Any]] = None,
+
+):
+    from fairseq2.models.nllb.loader import load_nllb_model, load_nllb_config
+    import sentencepiece as spm
+    from torch.ao.quantization.qconfig import default_dynamic_qconfig, float_qparams_weight_only_qconfig
+
+    model_config = load_nllb_config(model_name)
+    hparams = flatten_config(
+        dataclasses.asdict(model_config), separator="__", overrides=hparams,
+    )
+    hparams["multilingual"] = False
+    
+    model = load_nllb_model(model_name)
+    # quantize the non-scripted model to optimize the output size
+    torch.ao.quantization.quantize_dynamic(
+        model,
+        {
+            torch.nn.Linear: default_dynamic_qconfig,
+            torch.nn.Embedding: float_qparams_weight_only_qconfig,
+        },
+        dtype=torch.qint8,
+        inplace=True,
+    )
+
+    def _read_vocab(vocab_file: str) -> List[Tuple[str, float]]:
+        sp = spm.SentencePieceProcessor(vocab_file)
+        return [
+            (sp.id_to_piece(id), sp.get_score(id)) for id in range(sp.get_piece_size())  # type: ignore[no-member]
+        ]
+
+    src_vocab = _read_vocab(src_vocab)
+    tgt_vocab = _read_vocab(tgt_vocab)
+
+    return model, hparams, src_vocab, tgt_vocab
+
+
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
     model_type: ModelType = ModelType.AUTO,
     layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
-    vocab: Optional[List[Tuple[str, float]]] = None,
+    vocab: Optional[str] = None,  # optional vocabulary files if stored separately
+    extra_vocab: Optional[str] = None,  # additional vocabulary, e.g. for target languages in bilingual models
     fp16: bool = False,
 ) -> None:
     """
     Entry function for converting different kinds of model into GGML file. Supported model checkpoints:
         - unity models
         - nllb models
+        - Bilingual encoder-decoder model (Pytorch) with separate vocabulary for src and tgt languages
+        - Bilingual encoder-decoder model (torchscript)
     Args:
         model_name: name of a registered model (discoverable in a fairseq2 asset), path to a checkpoint,\
             or the model object passed directly
@@ -238,10 +286,13 @@ def convert_model(
         model_type: type of the model (or inferred from the name, only applied to nllb, unity and seamless)
         layers: wildcard patterns to filter the layers from the model. Does not applied to scripted models
         hparams: override the hparams in the model with the user-defined values
-        vocab: list of tokens, or aPath to  vocabulary files (in case not bundled with the model checkpoint)
+        vocab: Path to  vocabulary files (in case not bundled with the model checkpoint)
+        extra_vocab: Path to additional vocabulary files (used in bilingual models with explicit tgt languages)
         fp16: Save to .GGML float16 tensors instead of float32
     """
+    
     key_map: Optional[Dict[str, str]] = None
+    tgt_vocab: Optional[List[Tuple[str, float]]] = None
     if isinstance(model_name, str):
         # Load the corresponding fairseq2 model
         if out is None:
@@ -264,8 +315,20 @@ def convert_model(
             elif model_type == ModelType.NLLB:
                 model, hparams, vocab = convert_nllb_model(model_name, hparams=hparams)
                 key_map = NLLB_2_UNITY_KEYMAP
+            elif model_type == ModelType.BITEXT_SCRIPTED:
+                # TODO: implement the EdgeML model conversion here
+                raise NotImplementedError("Scripted model conversion not implemented yet")
+            
+            # Bilingual non-scripted model
             else:
-                raise ValueError(f"Unsupported model type: {model_name} (type: {model_type})")
+                assert (
+                    vocab and extra_vocab
+                ), "non-scripted model requires vocbulary files (SPM Protobuf format)"
+
+                model, hparams, vocab, tgt_vocab = convert_bitext_model(
+                    model_name, hparams=hparams, src_vocab=vocab, tgt_vocab=extra_vocab
+                )
+                key_map = NLLB_2_UNITY_KEYMAP
         except Exception as exc:
             raise ValueError(f"Error in loading model: {model_name}") from exc
     else:
@@ -285,7 +348,8 @@ def convert_model(
     layer_config = read_layer_config(model, layer_filter=layers, key_map=key_map)
 
     vocab = vocab or []
-    write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
+    tgt_vocab = tgt_vocab or []
+    write_ggml_file(out, hparams, layer_config, state_dict=state_dict, vocab=vocab, tgt_vocab=tgt_vocab, fp16=fp16)
 
 
 def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
@@ -340,7 +404,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], lay
         state_dict["speech_encoder.pos_enc"] = rel_pos_enc.freqs
 
 
-def read_vocab(tokenizer: Any) -> List[Tuple[str, float]]:
+def read_vocab_from_tokenizer(tokenizer: Any) -> List[Tuple[str, float]]:
     vocab_info = tokenizer.vocab_info
     vocab = [
         (tokenizer.model.index_to_token(i).replace("▁", " "), -i)
@@ -353,9 +417,10 @@ def write_ggml_file(
     out: Path,
     hparams: Dict[str, Any],
     layer_config: Dict[str, Any],
-    vocab: List[Tuple[str, float]],
     state_dict: Dict[str, torch.Tensor],
-    fp16: bool,
+    vocab: List[Tuple[str, float]],
+    tgt_vocab: Optional[List[Tuple[str, float]]] = None,  # tgt_vocab for bilingual models
+    fp16: bool = False,
 ) -> None:
     with out.open("wb") as o:
         write_ggml_header(o)
@@ -363,6 +428,7 @@ def write_ggml_file(
         write_hparams(o, layer_config)
         write_vocab(o, vocab)
         write_state_dict(o, state_dict, fp16)
+        write_vocab(o, tgt_vocab)
 
 
 def write_ggml_header(out: BufferedWriter) -> None:
@@ -398,6 +464,9 @@ def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
 def write_vocab(out: BufferedWriter, vocab: List[Tuple[str, float]]) -> None:
     out.write(struct.pack("<q", len(vocab)))
 
+    if len(vocab) == 0:
+        return
+
     # Write all words concatenated in a buffer
     words = [bytes(w, "utf8") for w, score in vocab]
     packed_words = b"\0".join(words)
@@ -449,7 +518,6 @@ def write_state_dict(
         )
 
     for key, value in state_dict.items():
-        # Rename the layers to make it look like "unity-arch"
         write_string(out, key)
         if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
             # GGML broadcasting isn't as strong as numpy