Pārlūkot izejas kodu

layer_filter

# Conflicts:
#	ggml/ggml_convert.py
Guillaume Wenzek 1 gadu atpakaļ
vecāks
revīzija
ac6b874c43
1 mainītis faili ar 15 papildinājumiem un 8 dzēšanām
  1. 15 8
      ggml/ggml_convert.py

+ 15 - 8
ggml/ggml_convert.py

@@ -21,6 +21,7 @@ from fairseq2.nn.transformer import RelativePositionalEncoding
 from seamless_communication.models import unity
 
 import ggml
+import re
 
 Preprocessor = Callable[[Any], Any]
 log = logging.getLogger("ggml_convert")
@@ -29,6 +30,7 @@ log = logging.getLogger("ggml_convert")
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
+    layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
     fp16: bool = False,
@@ -61,7 +63,10 @@ def convert_model(
         model = model_name
 
     state_dict = model.state_dict()
-    layer_config = read_layer_config(model)
+    if layers:
+        state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
+    fixup_model(model, state_dict, layer_filter=layers)
+    layer_config = read_layer_config(model, layer_filter=layers)
     vocab = vocab or []
     write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
 
@@ -76,13 +81,15 @@ def _nested_getattr(model: Any, name: str) -> Any:
     return node
 
 
-def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.Module]]:
+def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
     queue = list(model._modules.items())
     modules = []
     while queue:
         name, node = queue.pop()
         if node is None:
             continue
+        if layer_filter and not re.match(layer_filter, name):
+            continue
         if isinstance(node, t):
             modules.append((name, node))
         for child_name, child_node in node._modules.items():
@@ -91,9 +98,9 @@ def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.M
     return modules
 
 
-def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
+def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], layer_filter: str) -> None:
     # Bake the embedding scaling into the weights
-    frontends = find_children(model, TransformerEmbeddingFrontend)
+    frontends = find_children(model, TransformerEmbeddingFrontend, layer_filter)
     if frontends:
         log.info(
             "Upgrading the following TransformerEmbeddingFrontend: {}",
@@ -105,7 +112,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
 
     # Sinusoidal embeddings are typically not saved since they are easily recomputed,
     # but this allows to avoid porting the sinusoidal logic to GGML
-    pos_encoders = find_children(model, SinusoidalPositionEncoder)
+    pos_encoders = find_children(model, SinusoidalPositionEncoder, layer_filter)
     if pos_encoders:
         log.info(
             "Upgrading the following SinusoidalPositionEncoder: {}",
@@ -116,7 +123,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
         assert name not in state_dict
         state_dict[name] = pos_encoder.freqs
 
-    relative_pos_encs = find_children(model, RelativePositionalEncoding)
+    relative_pos_encs = find_children(model, RelativePositionalEncoding, layer_filter)
     # speech_encoder has several copies of the relative_pos_enc module.
     # For efficiency reasons we only make one copy of it to GGML.
     if relative_pos_encs:
@@ -352,7 +359,7 @@ def flatten_config(
     return __flatten(config)
 
 
-def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
+def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
     layer_config = {}
 
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -375,7 +382,7 @@ def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
             layer_config[prefix + k] = v
 
     _append_node_config(model, "")
-    for name, node in find_children(model, torch.nn.Module):
+    for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
     return layer_config