Sfoglia il codice sorgente

export model size in hparams

Guillaume Wenzek 1 anno fa
parent
commit
c0bec21155
3 ha cambiato i file con 144 aggiunte e 85 eliminazioni
  1. 2 3
      ggml/ggml.py
  2. 129 80
      ggml/ggml_convert.py
  3. 13 2
      ggml/test_unity_cpp.py

+ 2 - 3
ggml/ggml.py

@@ -199,9 +199,8 @@ lib.load_unity_ggml_file.restype = None
 
 def load_unity_ggml_file(model_file: Path) -> NativeObj:
     model = Fairseq2Model()
-    lib.load_unity_ggml_file(
-        model.ptr, ctypes.create_string_buffer(str(model_file).encode("utf-8"))
-    )
+    bytes_file = ctypes.create_string_buffer(str(model_file).encode("utf-8"))
+    lib.load_unity_ggml_file(model.ptr, bytes_file)
     return model
 
 

+ 129 - 80
ggml/examples/unity/ggml_convert.py → ggml/ggml_convert.py

@@ -13,63 +13,58 @@ from pathlib import Path
 from typing import Any, Callable, Dict, Optional, Tuple, Union
 
 import torch
+import ggml
 from fairseq2.assets import AssetCard
 from seamless_communication.models.unity import load_unity_config, load_unity_model
 
 Preprocessor = Callable[[Any], Any]
 
 
-def to_ctype(value: Any) -> Tuple[str, Any]:
-    """Transform python type to ctype.
-
-    :params value:
-        value to cast into ctype
-
-    :returns:
-        A tuple of ctype and cvalue.
-    """
-    if isinstance(value, int):
-        return ("i", value)
-    if isinstance(value, float):
-        return ("f", value)
-    if isinstance(value, bool):
-        return ("?", value)
-    if isinstance(value, Enum):
-        return ("i", value.value)
+def convert_model(model_name: str, out: Optional[Path] = None) -> None:
+    if out is None:
+        out = Path(model_name).with_suffix(".ggml")
 
-    raise ValueError(f"Unsupported type {type(value)}")
+    # The type of model depends on the name
+    if "unity" in model_name or "seamlessM4T" in model_name:
+        model_config = load_unity_config(model_name)
+        hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
+        model = load_unity_model(model_name)
+    else:
+        raise ValueError(f"Unsupported model type: {model_name}")
 
+    with out.open("wb") as o:
+        write_ggml_file(o, hparams, model.state_dict())
 
-def get_cpp_type(value: Any) -> str:
-    """Return equivalent cpp type in string format
+    with out.with_suffix(".hparams.h").open("w") as h:
+        h.write(generate_hparams_struct(hparams, "unity_hparams"))
 
-    :params value:
-        value to cast into ctype
 
-    :returns:
-        str containing cpp type
-    """
-    # used to have compatibility between types
-    try:
-        ctype, _ = to_ctype(value)
-    except ValueError as e:
-        return f"// Error: {e}"
+def write_ggml_file(
+    out: BufferedWriter, hparams: Dict[str, Any], state_dict: Dict[str, torch.Tensor]
+) -> None:
+    write_ggml_header(out)
 
-    if ctype == "i":
-        return "std::int32_t"
-    if ctype == "f":
-        return "std::float32"
-    if ctype == "?":
-        return "bool"
+    # Apppend the byte size to the hparams.
+    if "model_byte_size" not in hparams:
+        # Size of each tensor
+        byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
+        # + tensor overhead
+        breakpoint()
+        byte_size += ggml.ggml_tensor_overhead() * len(state_dict)
+        # + some slack cause I'm bad at math
+        byte_size = int(byte_size * 1.2)
+        hparams["model_byte_size"] = byte_size
+        logging.warning(f"Saving a ggml file with {len(state_dict)} tensors, for an estimated amount of {byte_size / (1024**3)} GGML Gb")
+    # 6877961321223123048
+    hparams["__end_of_hparams__"] = struct.unpack("l", b"hparams_")[0]
 
-    raise RuntimeError(
-        f"Should not have reached this part." f"Missing cpp translation for {ctype}"
-    )
+    write_hparams(out, hparams)
+    write_state_dict(out, state_dict)
 
 
 def write_ggml_header(out: BufferedWriter) -> None:
-    """Write GGML header"""
-    out.write(b"ggml")
+    """Write GGML header (in reverse cause why not)"""
+    out.write(b"ggml"[::-1])
 
 
 def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
@@ -79,6 +74,8 @@ def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
         flattened dict containing model's hyper parameters.
 
     """
+    # TODO: should we preprend the size of the hparams struct ?
+    # this would help catch out of sync writer/loader code
     for key, value in hparams.items():
         try:
             # TODO: this is not cross platform, what's the standard way of writing hparams in GGML ?
@@ -120,25 +117,37 @@ def write_tensor(out: BufferedWriter, value: torch.Tensor) -> None:
     :params value:
         Tensor to dump.
     """
-    data = value.squeeze().numpy()
-    n_dims = len(data.shape)
-
-    # TODO: Convert to fp16 when necessary!
-    ftype = 0
-
-    out.write(struct.pack("ii", n_dims, ftype))
+    if value.dtype is torch.int64:
+        # GGML doesn't ahve int64, downcast it
+        value = value.to(dtype=torch.int32)
+
+    if value.ndim == 0:
+        # GGML doesn't support scalar as tensors.
+        value = value.reshape(1)
+
+    data = value.numpy()
+    n_dims = data.ndim
+    assert n_dims < 5, "ggml doesn't support 5 dims tensors"
+    assert n_dims >= 1, "ggml doesn't support 0 dim tensors"
+
+    ftype = torch_to_ggml_type(value.dtype)
+    out.write(struct.pack("i", n_dims))
+    out.write(struct.pack("i", ftype))
     for i in range(n_dims):
-        out.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+        # ggml uses long for shape
+        out.write(struct.pack("l", data.shape[n_dims - 1 - i]))
 
     data.tofile(out)
 
-
-def write_ggml_file(
-    out: BufferedWriter, hparams: Dict[str, Any], state_dict: Dict[str, torch.Tensor]
-) -> None:
-    write_ggml_header(out)
-    write_hparams(out, hparams)
-    write_state_dict(out, state_dict)
+def torch_to_ggml_type(dtype: type) -> int:
+    if dtype is torch.float32:
+        return ggml.GGML_TYPE_F32
+    elif dtype is torch.float16:
+        return ggml.GGML_TYPE_F16
+    elif dtype is torch.int32:
+        return ggml.GGML_TYPE_I32
+    else:
+        raise NotImplementedError(f"{dtype} is not mapped to a GGML_TYPE")
 
 
 def flatten_config(
@@ -179,6 +188,56 @@ def flatten_config(
     return __flatten(config)
 
 
+def to_ctype(value: Any) -> Tuple[str, Any]:
+    """Transform python type to ctype.
+
+    :params value:
+        value to cast into ctype
+
+    :returns:
+        A tuple of ctype and cvalue.
+    """
+    if isinstance(value, int):
+        return ("l", value)
+    if isinstance(value, float):
+        return ("f", value)
+    if isinstance(value, bool):
+        return ("?", value)
+    if isinstance(value, Enum):
+        return ("i", value.value)
+
+    raise ValueError(f"Unsupported type {type(value)}")
+
+
+def get_cpp_type(value: Any) -> str:
+    """Return equivalent cpp type in string format
+
+    :params value:
+        value to cast into ctype
+
+    :returns:
+        str containing cpp type
+    """
+    # used to have compatibility between types
+    try:
+        ctype, _ = to_ctype(value)
+    except ValueError as e:
+        return f"// Error: {e}"
+
+    if ctype == "i":
+        return "std::int32_t"
+    if ctype == "l":
+        return "std::int64_t"
+    if ctype == "f":
+        return "float"
+    if ctype == "?":
+        return "bool"
+
+    raise RuntimeError(
+        f"Should not have reached this part." f"Missing cpp translation for {ctype}"
+    )
+
+
 def generate_hparams_struct(
     hparams: Dict[str, Any],
     struct_name: str,
@@ -190,34 +249,24 @@ def generate_hparams_struct(
     :param struct_name:
         Name of the generated struct.
     """
-    struct = f"struct {struct_name} {{\n"
-    fields = "\n".join(
-        [f"    {get_cpp_type(value)} {key};" for key, value in hparams.items()]
-    )
+    struct = f"struct {struct_name} {{"
+    fields = [f"    {get_cpp_type(value)} {key};" for key, value in hparams.items()]
+    struct = "\n".join([struct] + fields + ["};\n"])
 
-    return struct + fields + "\n};\n"
+    valid_fields = [
+        key for key, value in hparams.items() if "Error" not in get_cpp_type(value)
+    ]
+    read_struct = f"void read_{struct_name}({struct_name}& out, std::ifstream &fin) {{"
+    read_fields = [
+        f"    fin.read((char*) &out.{field}, sizeof(out.{field}));"
+        for field in valid_fields
+    ]
+    read_struct = "\n".join([read_struct] + read_fields + ["};\n"])
 
-
-def main(model_name: str, out: Optional[Path] = None) -> None:
-    if out is None:
-        out = Path(model_name).with_suffix(".ggml")
-
-    # The type of model depends on the name
-    if "unity" in model_name or "seamlessM4T" in model_name:
-        model_config = load_unity_config(model_name)
-        hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
-        model = load_unity_model(model_name)
-    else:
-        raise ValueError(f"Unsupported model type: {model_name}")
-
-    with out.open("wb") as o:
-        write_ggml_file(o, hparams, model.state_dict())
-
-    with out.with_suffix(".hparams.h").open("w") as h:
-        h.write(generate_hparams_struct(hparams, model_name + "_hparams"))
+    return "\n".join([struct, read_struct])
 
 
 if __name__ == "__main__":
     import func_argparse
 
-    func_argparse.single_main(main)
+    func_argparse.single_main(convert_model)

+ 13 - 2
ggml/test_unity_cpp.py

@@ -6,6 +6,7 @@ import numpy as np
 from pathlib import Path
 from typing import Iterator
 from ggml import NativeObj
+from ggml_convert import convert_model
 
 Ctx = ggml.ggml_context_p
 
@@ -148,6 +149,16 @@ def test_unity_model_load(ctx: Ctx) -> None:
         assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
 
 
-def test_unity_model_load2(ctx: Ctx) -> None:
-    model = ggml.load_unity_ggml_file(UNITY_MODELS / "unity-large/ggml-model.bin")
+def test_unity_model_load2(ctx: Ctx, tmp_path: Path) -> None:
+
+    model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
+    if not model_file.exists():
+        convert_model("seamlessM4T_medium", model_file)
+    hparams_header_file = model_file.with_suffix(".hparams.h")
+    hparams_struct = hparams_header_file.read_text().strip()
+    actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
+    # breakpoint()
+    # assert hparams_struct in actual_code
+
+    model = ggml.load_unity_ggml_file(model_file)
     print(model)