|
@@ -21,6 +21,7 @@ from fairseq2.nn.transformer import RelativePositionalEncoding
|
|
|
from seamless_communication.models import unity
|
|
|
|
|
|
import ggml
|
|
|
+import re
|
|
|
|
|
|
Preprocessor = Callable[[Any], Any]
|
|
|
log = logging.getLogger("ggml_convert")
|
|
@@ -29,6 +30,7 @@ log = logging.getLogger("ggml_convert")
|
|
|
def convert_model(
|
|
|
model_name: Union[str, torch.nn.Module],
|
|
|
out: Optional[Path] = None,
|
|
|
+ layers: str = "",
|
|
|
hparams: Optional[Dict[str, Any]] = None,
|
|
|
vocab: Optional[List[Tuple[str, float]]] = None,
|
|
|
fp16: bool = False,
|
|
@@ -61,7 +63,10 @@ def convert_model(
|
|
|
model = model_name
|
|
|
|
|
|
state_dict = model.state_dict()
|
|
|
- layer_config = read_layer_config(model)
|
|
|
+ if layers:
|
|
|
+ state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
|
|
|
+ fixup_model(model, state_dict, layer_filter=layers)
|
|
|
+ layer_config = read_layer_config(model, layer_filter=layers)
|
|
|
vocab = vocab or []
|
|
|
write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
|
|
|
|
|
@@ -76,13 +81,15 @@ def _nested_getattr(model: Any, name: str) -> Any:
|
|
|
return node
|
|
|
|
|
|
|
|
|
-def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.Module]]:
|
|
|
+def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
|
|
|
queue = list(model._modules.items())
|
|
|
modules = []
|
|
|
while queue:
|
|
|
name, node = queue.pop()
|
|
|
if node is None:
|
|
|
continue
|
|
|
+ if layer_filter and not re.match(layer_filter, name):
|
|
|
+ continue
|
|
|
if isinstance(node, t):
|
|
|
modules.append((name, node))
|
|
|
for child_name, child_node in node._modules.items():
|
|
@@ -91,9 +98,9 @@ def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.M
|
|
|
return modules
|
|
|
|
|
|
|
|
|
-def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
|
|
|
+def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], layer_filter: str) -> None:
|
|
|
# Bake the embedding scaling into the weights
|
|
|
- frontends = find_children(model, TransformerEmbeddingFrontend)
|
|
|
+ frontends = find_children(model, TransformerEmbeddingFrontend, layer_filter)
|
|
|
if frontends:
|
|
|
log.info(
|
|
|
"Upgrading the following TransformerEmbeddingFrontend: {}",
|
|
@@ -105,7 +112,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
|
|
|
|
|
|
# Sinusoidal embeddings are typically not saved since they are easily recomputed,
|
|
|
# but this allows to avoid porting the sinusoidal logic to GGML
|
|
|
- pos_encoders = find_children(model, SinusoidalPositionEncoder)
|
|
|
+ pos_encoders = find_children(model, SinusoidalPositionEncoder, layer_filter)
|
|
|
if pos_encoders:
|
|
|
log.info(
|
|
|
"Upgrading the following SinusoidalPositionEncoder: {}",
|
|
@@ -116,7 +123,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
|
|
|
assert name not in state_dict
|
|
|
state_dict[name] = pos_encoder.freqs
|
|
|
|
|
|
- relative_pos_encs = find_children(model, RelativePositionalEncoding)
|
|
|
+ relative_pos_encs = find_children(model, RelativePositionalEncoding, layer_filter)
|
|
|
# speech_encoder has several copies of the relative_pos_enc module.
|
|
|
# For efficiency reasons we only make one copy of it to GGML.
|
|
|
if relative_pos_encs:
|
|
@@ -352,7 +359,7 @@ def flatten_config(
|
|
|
return __flatten(config)
|
|
|
|
|
|
|
|
|
-def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
|
|
|
+def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
|
|
|
layer_config = {}
|
|
|
|
|
|
def _append_node_config(node: Any, prefix: str) -> None:
|
|
@@ -375,7 +382,7 @@ def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
|
|
|
layer_config[prefix + k] = v
|
|
|
|
|
|
_append_node_config(model, "")
|
|
|
- for name, node in find_children(model, torch.nn.Module):
|
|
|
+ for name, node in find_children(model, torch.nn.Module, layer_filter):
|
|
|
_append_node_config(node, name + ".")
|
|
|
return layer_config
|
|
|
|