|
@@ -6,41 +6,51 @@
|
|
|
|
|
|
import dataclasses
|
|
|
import logging
|
|
|
-import math
|
|
|
import struct
|
|
|
from enum import Enum
|
|
|
from io import BufferedWriter
|
|
|
from pathlib import Path
|
|
|
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set, final
|
|
|
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence, Set, final
|
|
|
+import re
|
|
|
|
|
|
import torch
|
|
|
from fairseq2.assets import AssetCard
|
|
|
from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
|
|
|
from fairseq2.nn import SinusoidalPositionEncoder
|
|
|
from fairseq2.nn.transformer import RelativePositionalEncoding
|
|
|
-from seamless_communication.models import unity
|
|
|
-from fairseq2.data.text import SentencePieceTokenizerBase
|
|
|
-from fairseq2.data.typing import PathLike
|
|
|
-from typing import Sequence
|
|
|
from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
|
|
|
+from fairseq2.data.typing import PathLike
|
|
|
from fairseq2.typing import Device, finaloverride
|
|
|
-from fairseq2.models.utils import TokenizerLoaderBase
|
|
|
+from fairseq2.models.utils import TokenizerLoaderBase, ModelLoader
|
|
|
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
|
|
|
from fairseq2.assets import asset_store, download_manager
|
|
|
-from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
|
|
|
-from fairseq2.models.utils import ModelLoader
|
|
|
-from seamless_communication.models.unity.model import UnitYModel
|
|
|
|
|
|
import ggml
|
|
|
-import re
|
|
|
|
|
|
Preprocessor = Callable[[Any], Any]
|
|
|
log = logging.getLogger("ggml_convert")
|
|
|
-SMALLER_MODELS = [
|
|
|
+
|
|
|
+
|
|
|
+class ModelType(str, Enum):
|
|
|
+ AUTO = "auto" # inferred from the model name
|
|
|
+ UNITY = "unity"
|
|
|
+ NLLB = "nllb"
|
|
|
+
|
|
|
+
|
|
|
+UNITY_SMALLER_MODELS = [
|
|
|
"unity_nano",
|
|
|
"unity_micro",
|
|
|
] # Trained with fairseq2, with custom dict (not original NLLB ones)
|
|
|
|
|
|
|
|
|
+NLLB_2_UNITY_KEYMAP = {
|
|
|
+ r"^encoder_frontend\.": r"text_encoder_frontend.",
|
|
|
+ r"^encoder\." : r"text_encoder.",
|
|
|
+ r"^decoder\." : r"text_decoder.",
|
|
|
+ r"^decoder_frontend\.": r"text_decoder_frontend.",
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
@final
|
|
|
class NllbLikeTokenizer(SentencePieceTokenizerBase):
|
|
|
"""The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
|
|
@@ -141,16 +151,6 @@ class NllbLikeTokenizer(SentencePieceTokenizerBase):
|
|
|
)
|
|
|
|
|
|
|
|
|
-load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
|
|
|
- asset_store,
|
|
|
- download_manager,
|
|
|
- unity.load_unity_config,
|
|
|
- create_unity_model,
|
|
|
- None,
|
|
|
- restrict_checkpoints=False,
|
|
|
-)
|
|
|
-
|
|
|
-
|
|
|
@final
|
|
|
class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
|
|
|
"""Loads tokenizers used by NLLB models."""
|
|
@@ -164,44 +164,110 @@ class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
|
|
|
return NllbLikeTokenizer(pathname, langs, default_lang)
|
|
|
|
|
|
|
|
|
+def convert_unity_model(
|
|
|
+ model_name: str,
|
|
|
+ hparams: Optional[Dict[str, Any]] = None,
|
|
|
+):
|
|
|
+ from seamless_communication.models import unity
|
|
|
+ from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
|
|
|
+ from seamless_communication.models.unity.model import UnitYModel
|
|
|
+
|
|
|
+ load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
|
|
|
+ asset_store,
|
|
|
+ download_manager,
|
|
|
+ unity.load_unity_config,
|
|
|
+ create_unity_model,
|
|
|
+ None,
|
|
|
+ restrict_checkpoints=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ model_config = unity.load_unity_config(model_name)
|
|
|
+ hparams = flatten_config(
|
|
|
+ dataclasses.asdict(model_config), separator="__", overrides=hparams
|
|
|
+ )
|
|
|
+ log.info(hparams)
|
|
|
+ # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
|
|
|
+ if model_name in UNITY_SMALLER_MODELS:
|
|
|
+ model = load_unity_model_without_conversion(model_name)
|
|
|
+ tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(model_name)
|
|
|
+ else:
|
|
|
+ model = unity.load_unity_model(model_name)
|
|
|
+ tokenizer = unity.load_unity_text_tokenizer(model_name)
|
|
|
+
|
|
|
+ vocab = read_vocab(tokenizer)
|
|
|
+
|
|
|
+ return model, hparams, vocab
|
|
|
+
|
|
|
+
|
|
|
+def convert_nllb_model(
|
|
|
+ model_name: str,
|
|
|
+ hparams: Optional[Dict[str, Any]] = None,
|
|
|
+):
|
|
|
+ from fairseq2.models.nllb.loader import load_nllb_tokenizer, load_nllb_model, load_nllb_config
|
|
|
+
|
|
|
+ model_config = load_nllb_config(model_name)
|
|
|
+ hparams = flatten_config(
|
|
|
+ dataclasses.asdict(model_config), separator="__", overrides=hparams,
|
|
|
+ )
|
|
|
+
|
|
|
+ model = load_nllb_model(model_name)
|
|
|
+ tokenizer = load_nllb_tokenizer(model_name)
|
|
|
+ vocab = read_vocab(tokenizer)
|
|
|
+
|
|
|
+ return model, hparams, vocab
|
|
|
+
|
|
|
+
|
|
|
def convert_model(
|
|
|
model_name: Union[str, torch.nn.Module],
|
|
|
out: Optional[Path] = None,
|
|
|
+ model_type: ModelType = ModelType.AUTO,
|
|
|
layers: str = "",
|
|
|
hparams: Optional[Dict[str, Any]] = None,
|
|
|
vocab: Optional[List[Tuple[str, float]]] = None,
|
|
|
fp16: bool = False,
|
|
|
) -> None:
|
|
|
+ """
|
|
|
+ Entry function for converting different kinds of model into GGML file. Supported model checkpoints:
|
|
|
+ - unity models
|
|
|
+ - nllb models
|
|
|
+ Args:
|
|
|
+ model_name: name of a registered model (discoverable in a fairseq2 asset), path to a checkpoint,\
|
|
|
+ or the model object passed directly
|
|
|
+ out: path to store the converted .ggml model. If None, the ggml model is stored in the same place\
|
|
|
+ as input model
|
|
|
+ model_type: type of the model (or inferred from the name, only applied to nllb, unity and seamless)
|
|
|
+ layers: wildcard patterns to filter the layers from the model. Does not applied to scripted models
|
|
|
+ hparams: override the hparams in the model with the user-defined values
|
|
|
+ vocab: list of tokens, or aPath to vocabulary files (in case not bundled with the model checkpoint)
|
|
|
+ fp16: Save to .GGML float16 tensors instead of float32
|
|
|
+ """
|
|
|
+ key_map: Optional[Dict[str, str]] = None
|
|
|
if isinstance(model_name, str):
|
|
|
# Load the corresponding fairseq2 model
|
|
|
if out is None:
|
|
|
out = Path(model_name).with_suffix(".ggml")
|
|
|
|
|
|
- # The type of model depends on the name
|
|
|
- if "unity" in model_name or "seamlessM4T" in model_name:
|
|
|
- if hparams is None:
|
|
|
- model_config = unity.load_unity_config(model_name)
|
|
|
- hparams = flatten_config(
|
|
|
- dataclasses.asdict(model_config), separator="__"
|
|
|
- )
|
|
|
- log.info(hparams)
|
|
|
- # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
|
|
|
- if model_name in SMALLER_MODELS:
|
|
|
- model = load_unity_model_without_conversion(model_name)
|
|
|
+ # Reason the model architecture from the model name or user input
|
|
|
+ try:
|
|
|
+ if model_type == ModelType.AUTO:
|
|
|
+ if "unity" in model_name or "seamlessM4T" in model_name:
|
|
|
+ model_type = ModelType.UNITY
|
|
|
+ elif "nllb" in model_name:
|
|
|
+ model_type = ModelType.NLLB
|
|
|
+
|
|
|
+ assert (
|
|
|
+ model_type != ModelType.AUTO
|
|
|
+ ), "Cannot infer model type from the `model_name`. Please specify `model_type`"
|
|
|
+
|
|
|
+ if model_type == ModelType.UNITY:
|
|
|
+ model, hparams, vocab = convert_unity_model(model_name, hparams=hparams)
|
|
|
+ elif model_type == ModelType.NLLB:
|
|
|
+ model, hparams, vocab = convert_nllb_model(model_name, hparams=hparams)
|
|
|
+ key_map = NLLB_2_UNITY_KEYMAP
|
|
|
else:
|
|
|
- model = unity.load_unity_model(model_name)
|
|
|
- if vocab is None:
|
|
|
- # Need the diverge here because current default in SC is to add a separate <pad>
|
|
|
- # as control symbol in NllbTokenizer
|
|
|
- if model_name in SMALLER_MODELS:
|
|
|
- tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(
|
|
|
- model_name
|
|
|
- )
|
|
|
- else:
|
|
|
- tokenizer = unity.load_unity_text_tokenizer(model_name)
|
|
|
- vocab = read_vocab(tokenizer)
|
|
|
- else:
|
|
|
- raise ValueError(f"Unsupported model type: {model_name}")
|
|
|
+ raise ValueError(f"Unsupported model type: {model_name} (type: {model_type})")
|
|
|
+ except Exception as exc:
|
|
|
+ raise ValueError(f"Error in loading model: {model_name}") from exc
|
|
|
else:
|
|
|
# Use the model passed explicitly
|
|
|
assert (
|
|
@@ -214,21 +280,14 @@ def convert_model(
|
|
|
if layers:
|
|
|
state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
|
|
|
fixup_model(model, state_dict, layer_filter=layers)
|
|
|
- layer_config = read_layer_config(model, layer_filter=layers)
|
|
|
+ if key_map:
|
|
|
+ state_dict = convert_model_state_dict(state_dict, key_map=key_map)
|
|
|
+ layer_config = read_layer_config(model, layer_filter=layers, key_map=key_map)
|
|
|
+
|
|
|
vocab = vocab or []
|
|
|
write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
|
|
|
|
|
|
|
|
|
-def _nested_getattr(model: Any, name: str) -> Any:
|
|
|
- parts = name.split(".")
|
|
|
- node = model
|
|
|
- for part in parts:
|
|
|
- node = getattr(node, part)
|
|
|
- if node is None:
|
|
|
- return None
|
|
|
- return node
|
|
|
-
|
|
|
-
|
|
|
def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
|
|
|
queue = list(model._modules.items())
|
|
|
modules = []
|
|
@@ -385,10 +444,12 @@ def write_state_dict(
|
|
|
# Compressed size
|
|
|
compressed_byte_size = sum(_fp16_byte_size(x) for x in state_dict.values())
|
|
|
log.warning(
|
|
|
- f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb compressed to {compressed_byte_size / GB:.3f}"
|
|
|
+ f"Saving a ggml file with {len(state_dict)} tensors, totalling {true_byte_size / GB:.3f}Gb"
|
|
|
+ f". Compressed to {compressed_byte_size / GB:.3f}Gb"
|
|
|
)
|
|
|
|
|
|
for key, value in state_dict.items():
|
|
|
+ # Rename the layers to make it look like "unity-arch"
|
|
|
write_string(out, key)
|
|
|
if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
|
|
|
# GGML broadcasting isn't as strong as numpy
|
|
@@ -463,7 +524,7 @@ def torch_to_ggml_type(dtype: torch.dtype) -> int:
|
|
|
def flatten_config(
|
|
|
config: Dict[str, Any],
|
|
|
separator: str,
|
|
|
- config_preprocessor: Optional[Preprocessor] = None,
|
|
|
+ overrides: Optional[Dict[str, Any]] = None,
|
|
|
) -> Dict[str, Any]:
|
|
|
"""Flatten nested dictionnary
|
|
|
|
|
@@ -478,9 +539,6 @@ def flatten_config(
|
|
|
flat dictionnary
|
|
|
"""
|
|
|
|
|
|
- if config_preprocessor is None:
|
|
|
- config_preprocessor = lambda x: x
|
|
|
-
|
|
|
def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
|
|
result = {}
|
|
|
for key in config:
|
|
@@ -489,16 +547,22 @@ def flatten_config(
|
|
|
nested_result = __flatten(config[key], f"{new_key}{separator}")
|
|
|
result.update(nested_result)
|
|
|
else:
|
|
|
- new_config = config_preprocessor(config[key])
|
|
|
+ new_config = config[key]
|
|
|
if new_config is not None:
|
|
|
result[new_key] = config[key]
|
|
|
|
|
|
return result
|
|
|
|
|
|
- return __flatten(config)
|
|
|
+ res_config = __flatten(config)
|
|
|
+ if overrides:
|
|
|
+ return {**res_config, **overrides}
|
|
|
+ else:
|
|
|
+ return res_config
|
|
|
|
|
|
|
|
|
-def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
|
|
|
+def read_layer_config(
|
|
|
+ model: torch.nn.Module, layer_filter: str, key_map: Optional[Dict[str, str]] = None
|
|
|
+) -> Dict[str, Any]:
|
|
|
layer_config = {}
|
|
|
|
|
|
def _append_node_config(node: Any, prefix: str) -> None:
|
|
@@ -523,6 +587,15 @@ def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, An
|
|
|
_append_node_config(model, "")
|
|
|
for name, node in find_children(model, torch.nn.Module, layer_filter):
|
|
|
_append_node_config(node, name + ".")
|
|
|
+
|
|
|
+ key_map = key_map or {}
|
|
|
+ keys_to_replace = []
|
|
|
+ for k, v in layer_config.items():
|
|
|
+ for old_pattern, replacement in key_map.items():
|
|
|
+ if (new_key := re.sub(old_pattern, replacement, k)) != k:
|
|
|
+ keys_to_replace.append((k, new_key))
|
|
|
+ for old_key, new_key in keys_to_replace:
|
|
|
+ layer_config[new_key] = layer_config.pop(old_key)
|
|
|
return layer_config
|
|
|
|
|
|
|