Guillaume Wenzek 1 an în urmă
părinte
comite
3f5912b973

+ 1 - 1
ggml/Makefile

@@ -15,4 +15,4 @@ run: build/bin/unity
 	$< --model examples/unity/models/unity-large/ggml-model.bin
 
 tests: build/src/libggml.so
-	pytest test_unity_cpp.py
+	pytest test_unity_cpp.py -s

+ 52 - 2
ggml/examples/unity/fairseq2.cpp

@@ -16,6 +16,15 @@ extern "C" void fairseq2_model_free(fairseq2_model* model) {
     delete model;
 };
 
+extern "C" std::string* std_string_alloc(char* c_str) {
+    return new std::string(c_str);
+}
+
+extern "C" void std_string_free(std::string* str) {
+    delete str;
+}
+
+
 
 // Linear
 
@@ -41,6 +50,17 @@ void Linear_init(
     }
 }
 
+extern "C" ggml_tensor* Linear_forward(
+    fairseq2_model& model,
+    const std::string &prefix,
+    ggml_tensor* input
+) {
+    ggml_tensor* weight = model.tensors[prefix + ".weight"];
+    ggml_tensor* bias = model.tensors[prefix + ".bias"];
+
+    return ggml_add(model.ctx, ggml_mul_mat(model.ctx, weight, input), bias);
+}
+
 // LayerNorm
 
 std::size_t LayerNorm_size(int32_t dim)
@@ -60,6 +80,24 @@ void LayerNorm_init(
     model.tensors[prefix + ".bias"] = self.bias;
 }
 
+extern "C" ggml_tensor* LayerNorm_forward(
+    fairseq2_model& model,
+    const std::string &prefix,
+    ggml_tensor* input) {
+    ggml_tensor* weight = model.tensors[prefix + ".weight"];
+    ggml_tensor* bias = model.tensors[prefix + ".bias"];
+
+    auto ctx = model.ctx;
+    // TODO: should `eps` be part of unity hparams ?
+    input = ggml_norm(ctx, input, /*eps*/1e-5);
+    return ggml_add(
+        ctx,
+        ggml_mul(ctx, ggml_repeat(ctx, weight, input), input),
+        ggml_repeat(ctx, bias, input)
+    );
+}
+
+
 std::size_t StandardFeedForwardNetwork_size(int32_t dim, int32_t inner_dim)
 {
     return LayerNorm_size(dim) + Linear_size(dim, inner_dim) + Linear_size(inner_dim, dim);
@@ -77,11 +115,23 @@ void StandardFeedForwardNetwork_init(
     Linear_init(self.output_proj, model, prefix + ".output_proj", inner_dim, model_dim, true);
 }
 
-ggml_tensor* StandardFeedForwardNetwork_forward(
-    StandardFeedForwardNetwork* self,
+extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
     ggml_tensor* seqs
 ) {
+    seqs = Linear_forward(model, prefix + ".inner_proj", seqs);
+    // inner_activation = ReLu // TODO: allow other activation
+    seqs = ggml_relu(model.ctx, seqs);
+
+    if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
+        seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
+    }
 
+    // TODO: inference dropout
+    // if self.inner_dropout is not None:
+    //     seqs = self.inner_dropout(seqs)
+    seqs = Linear_forward(model, prefix + ".output_proj", seqs);
     return seqs;
 }
 

+ 8 - 3
ggml/examples/unity/fairseq2.h

@@ -17,6 +17,10 @@ struct fairseq2_model {
 extern "C" fairseq2_model* fairseq2_model_alloc();
 extern "C" void fairseq2_model_free(fairseq2_model* model);
 
+extern "C" std::string* std_string_alloc(char* c_str);
+extern "C" void std_string_free(std::string* str);
+
+
 struct Linear {
     struct ggml_tensor* weight;  // out_dim * in_dim
     struct ggml_tensor* bias;  // out_dim
@@ -85,9 +89,10 @@ void StandardFeedForwardNetwork_init(
     int inner_dim
 );
 
-ggml_tensor* StandardFeedForwardNetwork_forward(
-    StandardFeedForwardNetwork* self,
-    ggml_tensor* seqs
+extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* input
 );
 
 // Transformer

+ 2 - 2
ggml/examples/unity/model_loader.cpp

@@ -1,7 +1,7 @@
 #include <string>
 #include "model_loader.h"
 
-#define DEBUG 1
+#define DEBUG_MODEL_LOAD 0
 
 std::ifstream open_ggml_file(const char* fname) {
     printf("%s: loading model from '%s'\n", __func__, fname);
@@ -31,7 +31,7 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
             break;
         auto tensor = load_tensor_value(fin, model.ctx);
         model.tensors[name] = tensor;
-        if (DEBUG) {
+        if (DEBUG_MODEL_LOAD) {
             printf("%s [%5ld, %5ld], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), tensor->ne[0], tensor->ne[1], ggml_type_name(tensor->type), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
         }
         total_size += ggml_nbytes(tensor);

+ 1 - 2
ggml/examples/unity/unity.cpp

@@ -464,7 +464,7 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
     return true;
 }
 
-extern "C" ggml_tensor* LayerNorm_forward(
+ggml_tensor* LayerNorm_forward(
     const LayerNorm& layer,
     ggml_context* ctx,
     ggml_tensor* cur,
@@ -505,7 +505,6 @@ extern "C" ggml_cgraph* unity_audio_encoder_graph(
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
-
     struct ggml_cgraph  * gf = ggml_new_graph(ctx0);
     struct ggml_tensor * ffn_scale = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, 1);
     ffn_scale->data = malloc(ggml_nbytes(ffn_scale));

+ 40 - 5
ggml/ggml.py

@@ -6,6 +6,7 @@ adding a few utilities to convert between ggml and numpy tensors for testing.
 import numpy as np
 import ctypes
 import torch
+import functools
 from pathlib import Path
 from typing import Self
 from typing import Dict
@@ -93,7 +94,9 @@ def _pad_shape(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
     return shape + padding  # type: ignore
 
 
-def from_numpy(ctx: ggml_context_p, array: np.ndarray) -> ggml_tensor_p:
+def from_numpy(ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"]) -> ggml_tensor_p:
+    if type(array).__name__ == "Tensor":
+        array = array.numpy()
     tensor_p = ggml_new_tensor(ctx, from_numpy_dtype(array.dtype), 1, GgmlShape())
     tensor_p.contents.n_dims = array.ndim
     tensor_p.contents.data = array.ctypes.data_as(ctypes.c_void_p)
@@ -139,8 +142,8 @@ class NativeObj:
             # print(f"freeing {self}")
             self.ptr = NULL
 
-    def __enter__(self) -> Self:
-        return self
+    def __enter__(self) -> ctypes.c_void_p:
+        return self.ptr
 
     def __exit__(self, *args: Any) -> None:
         self.free()
@@ -178,6 +181,18 @@ def GptVocab() -> NativeObj:
 def Fairseq2Model() -> NativeObj:
     return NativeObj("fairseq2_model")
 
+lib.std_string_alloc.argtypes = [ctypes.c_char_p]
+lib.std_string_alloc.restype = ctypes.c_void_p
+lib.std_string_free.argtypes = [ctypes.c_void_p]
+lib.std_string_free.restype = None
+NativeObj._cache["std_string"] = (lib.std_string_alloc, lib.std_string_free)
+
+@functools.lru_cache(1024)
+def CppStr(content: str) -> NativeObj:
+    c_str = ctypes.create_string_buffer(content.encode("utf-8"))
+    cpp_str = lib.std_string_alloc(c_str)
+    return NativeObj("std_string", cpp_str)
+
 
 lib.unity_model_load.argtypes = [ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p]
 
@@ -222,6 +237,26 @@ lib.unity_eval.restype = ctypes.POINTER(ggml_cgraph)
 
 
 def unity_eval(
-    allocr: NativeObj, model: NativeObj, tensor: ggml_tensor_p, n_threads: int
+    allocr: ctypes.c_void_p, model: NativeObj, tensor: ggml_tensor_p, n_threads: int
 ) -> ggml_cgraph_p:
-    return lib.unity_eval(allocr.ptr, model.ptr, tensor, n_threads)
+    return lib.unity_eval(allocr, model.ptr, tensor, n_threads)
+
+
+_FORWARD_CACHE: Dict[str, Callable[[...], ggml_tensor_p]] = {}
+
+
+def forward(
+    layer_name: str, model: NativeObj, prefix: str, *inputs: ggml_tensor_p
+) -> ggml_tensor_p:
+    fwd: Any = _FORWARD_CACHE.get(layer_name)
+    if fwd is None:
+        fwd = getattr(lib, layer_name + "_forward")
+        num_inputs = len(inputs)
+        fwd.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + [
+            ctypes.POINTER(ggml_tensor)
+        ] * num_inputs
+        fwd.restype = ctypes.POINTER(ggml_tensor)
+        _FORWARD_CACHE[layer_name] = fwd
+
+    with CppStr(prefix) as std_prefix:
+        return fwd(model.ptr, std_prefix, *inputs)  # ignore: type[no-any-return]

+ 1 - 1
ggml/ggml_convert.py

@@ -28,6 +28,7 @@ def convert_model(model_name: str, out: Optional[Path] = None) -> None:
     if "unity" in model_name or "seamlessM4T" in model_name:
         model_config = load_unity_config(model_name)
         hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
+        print(hparams)
         model = load_unity_model(model_name)
     else:
         raise ValueError(f"Unsupported model type: {model_name}")
@@ -49,7 +50,6 @@ def write_ggml_file(
         # Size of each tensor
         byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
         # + tensor overhead
-        breakpoint()
         byte_size += ggml.ggml_tensor_overhead() * len(state_dict)
         # + some slack cause I'm bad at math
         byte_size = int(byte_size * 1.2)

+ 57 - 8
ggml/test_unity_cpp.py

@@ -3,10 +3,13 @@ import ctypes
 import torch
 import pytest
 import numpy as np
+import torch
+from typing import Any
 from pathlib import Path
 from typing import Iterator
 from ggml import NativeObj
 from ggml_convert import convert_model
+from seamless_communication.models.unity import load_unity_model
 
 Ctx = ggml.ggml_context_p
 
@@ -123,7 +126,7 @@ def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
     assert np.allclose(a, ggml.to_numpy(ga))
 
 
-def test_unity_model_load(ctx: Ctx) -> None:
+def test_ning_model_load(ctx: Ctx) -> None:
     model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
     print(model, vocab)
 
@@ -134,7 +137,7 @@ def test_unity_model_load(ctx: Ctx) -> None:
     with ggml.MeasureArena() as arena:
         graph = ggml.unity_audio_encoder_graph(model, example)
         # TODO: why the extra memory ?
-        mem_size = ggml.ggml_allocr_alloc_graph(arena.ptr, graph) + ggml.GGML_MEM_ALIGN
+        mem_size = ggml.ggml_allocr_alloc_graph(arena, graph) + ggml.GGML_MEM_ALIGN
 
     with ggml.FixedSizeArena(mem_size) as allocr:
         print(
@@ -149,16 +152,62 @@ def test_unity_model_load(ctx: Ctx) -> None:
         assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
 
 
-def test_unity_model_load2(ctx: Ctx, tmp_path: Path) -> None:
-
+@pytest.fixture(scope="module")
+def g_model() -> NativeObj:
     model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
     if not model_file.exists():
         convert_model("seamlessM4T_medium", model_file)
+    return ggml.load_unity_ggml_file(model_file)
+
+
+@pytest.fixture(scope="module")
+def pt_model() -> Iterator[Any]:
+    model = load_unity_model("seamlessM4T_medium")
+    model.eval()
+    with torch.inference_mode():
+        yield model
+
+@pytest.mark.xfail(reason="TODO")
+def test_hparams_code_is_up_to_date() -> None:
+    model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
+
     hparams_header_file = model_file.with_suffix(".hparams.h")
     hparams_struct = hparams_header_file.read_text().strip()
     actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
-    # breakpoint()
-    # assert hparams_struct in actual_code
+    assert hparams_struct in actual_code
+
+
+def test_unity_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+    x = torch.empty((1024,))
+    torch.nn.init.uniform_(x, -1, 1)
+
+    # Test FFN without LayerNorm
+    y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
+    gx = ggml.from_numpy(ctx, x)
+    gy = ggml.forward(
+        "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
+    )
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
+    abs_diff = np.max(np.abs(y - y_exp))
+    assert abs_diff < 1e-2
+    assert np.allclose(y_exp, y, rtol=1e-3)
+
+
+def test_unity_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+    x = torch.empty((1024,))
+    torch.nn.init.uniform_(x, -1, 1)
+
+    y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
+    gx = ggml.from_numpy(ctx, x)
+    gy = ggml.forward(
+        "LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx
+    )
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-    model = ggml.load_unity_ggml_file(model_file)
-    print(model)
+    y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
+    abs_diff = np.max(np.abs(y - y_exp))
+    assert np.allclose(y_exp, y)