Эх сурвалжийг харах

layer_filter

# Conflicts:
#	ggml/ggml_convert.py
Guillaume Wenzek 1 жил өмнө
parent
commit
f4e33e9b24

+ 1 - 1
ggml/examples/unity/model_loader.cpp

@@ -47,7 +47,7 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
     // Note this require changing the on disk format
     bool as_float32 = true;
     struct ggml_init_params params = {
-        /*.mem_size   =*/ f32_tensor_size + num_tensor * (int64_t)ggml_tensor_overhead(),
+        /*.mem_size   =*/ f32_tensor_size + (num_tensor + 1) * (int64_t)ggml_tensor_overhead(),
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ false,
     };

+ 15 - 8
ggml/ggml_convert.py

@@ -21,6 +21,7 @@ from fairseq2.nn.transformer import RelativePositionalEncoding
 from seamless_communication.models import unity
 
 import ggml
+import re
 
 Preprocessor = Callable[[Any], Any]
 log = logging.getLogger("ggml_convert")
@@ -29,6 +30,7 @@ log = logging.getLogger("ggml_convert")
 def convert_model(
     model_name: Union[str, torch.nn.Module],
     out: Optional[Path] = None,
+    layers: str = "",
     hparams: Optional[Dict[str, Any]] = None,
     vocab: Optional[List[Tuple[str, float]]] = None,
     fp16: bool = False,
@@ -61,7 +63,10 @@ def convert_model(
         model = model_name
 
     state_dict = model.state_dict()
-    layer_config = read_layer_config(model)
+    if layers:
+        state_dict = {k: v for k, v in state_dict.items() if re.match(layers, k)}
+    fixup_model(model, state_dict, layer_filter=layers)
+    layer_config = read_layer_config(model, layer_filter=layers)
     vocab = vocab or []
     write_ggml_file(out, hparams, layer_config, vocab, state_dict, fp16)
 
@@ -76,13 +81,15 @@ def _nested_getattr(model: Any, name: str) -> Any:
     return node
 
 
-def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.Module]]:
+def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> List[Tuple[str, torch.nn.Module]]:
     queue = list(model._modules.items())
     modules = []
     while queue:
         name, node = queue.pop()
         if node is None:
             continue
+        if layer_filter and not re.match(layer_filter, name):
+            continue
         if isinstance(node, t):
             modules.append((name, node))
         for child_name, child_node in node._modules.items():
@@ -91,9 +98,9 @@ def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.M
     return modules
 
 
-def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
+def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor], layer_filter: str) -> None:
     # Bake the embedding scaling into the weights
-    frontends = find_children(model, TransformerEmbeddingFrontend)
+    frontends = find_children(model, TransformerEmbeddingFrontend, layer_filter)
     if frontends:
         log.info(
             "Upgrading the following TransformerEmbeddingFrontend: {}",
@@ -105,7 +112,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
 
     # Sinusoidal embeddings are typically not saved since they are easily recomputed,
     # but this allows to avoid porting the sinusoidal logic to GGML
-    pos_encoders = find_children(model, SinusoidalPositionEncoder)
+    pos_encoders = find_children(model, SinusoidalPositionEncoder, layer_filter)
     if pos_encoders:
         log.info(
             "Upgrading the following SinusoidalPositionEncoder: {}",
@@ -116,7 +123,7 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
         assert name not in state_dict
         state_dict[name] = pos_encoder.freqs
 
-    relative_pos_encs = find_children(model, RelativePositionalEncoding)
+    relative_pos_encs = find_children(model, RelativePositionalEncoding, layer_filter)
     # speech_encoder has several copies of the relative_pos_enc module.
     # For efficiency reasons we only make one copy of it to GGML.
     if relative_pos_encs:
@@ -352,7 +359,7 @@ def flatten_config(
     return __flatten(config)
 
 
-def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
+def read_layer_config(model: torch.nn.Module, layer_filter: str) -> Dict[str, Any]:
     layer_config = {}
 
     def _append_node_config(node: Any, prefix: str) -> None:
@@ -375,7 +382,7 @@ def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
             layer_config[prefix + k] = v
 
     _append_node_config(model, "")
-    for name, node in find_children(model, torch.nn.Module):
+    for name, node in find_children(model, torch.nn.Module, layer_filter):
         _append_node_config(node, name + ".")
     return layer_config
 

+ 2 - 2
ggml/test_unity_cpp.py

@@ -97,7 +97,7 @@ def download_sample_audio() -> Any:
 def test_convert_linear(tmp_path: Path) -> None:
     module = fairseq2.nn.Linear(16, 24, True)
 
-    layer_config = read_layer_config(module)
+    layer_config = read_layer_config(module, "")
     assert layer_config == {"input_dim": 16, "output_dim": 24}
 
     module_file = tmp_path / "module.ggml"
@@ -112,7 +112,7 @@ def test_convert_linear(tmp_path: Path) -> None:
 def test_convert_linear_fp16(tmp_path: Path, ctx: Ctx) -> None:
     pt_model = torch.nn.ModuleDict({"linear": fairseq2.nn.Linear(16, 24, True)})
 
-    layer_config = read_layer_config(pt_model)
+    layer_config = read_layer_config(pt_model, "")
     assert layer_config == {"linear.input_dim": 16, "linear.output_dim": 24}
 
     ggml_file = tmp_path / "linear.ggml"