|
@@ -3,10 +3,13 @@ import ctypes
|
|
|
import torch
|
|
|
import pytest
|
|
|
import numpy as np
|
|
|
+import torch
|
|
|
+from typing import Any
|
|
|
from pathlib import Path
|
|
|
from typing import Iterator
|
|
|
from ggml import NativeObj
|
|
|
from ggml_convert import convert_model
|
|
|
+from seamless_communication.models.unity import load_unity_model
|
|
|
|
|
|
Ctx = ggml.ggml_context_p
|
|
|
|
|
@@ -123,7 +126,7 @@ def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
|
|
|
assert np.allclose(a, ggml.to_numpy(ga))
|
|
|
|
|
|
|
|
|
-def test_unity_model_load(ctx: Ctx) -> None:
|
|
|
+def test_ning_model_load(ctx: Ctx) -> None:
|
|
|
model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
|
|
|
print(model, vocab)
|
|
|
|
|
@@ -134,7 +137,7 @@ def test_unity_model_load(ctx: Ctx) -> None:
|
|
|
with ggml.MeasureArena() as arena:
|
|
|
graph = ggml.unity_audio_encoder_graph(model, example)
|
|
|
# TODO: why the extra memory ?
|
|
|
- mem_size = ggml.ggml_allocr_alloc_graph(arena.ptr, graph) + ggml.GGML_MEM_ALIGN
|
|
|
+ mem_size = ggml.ggml_allocr_alloc_graph(arena, graph) + ggml.GGML_MEM_ALIGN
|
|
|
|
|
|
with ggml.FixedSizeArena(mem_size) as allocr:
|
|
|
print(
|
|
@@ -149,16 +152,62 @@ def test_unity_model_load(ctx: Ctx) -> None:
|
|
|
assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
|
|
|
|
|
|
|
|
|
-def test_unity_model_load2(ctx: Ctx, tmp_path: Path) -> None:
|
|
|
-
|
|
|
+@pytest.fixture(scope="module")
|
|
|
+def g_model() -> NativeObj:
|
|
|
model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
|
|
|
if not model_file.exists():
|
|
|
convert_model("seamlessM4T_medium", model_file)
|
|
|
+ return ggml.load_unity_ggml_file(model_file)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture(scope="module")
|
|
|
+def pt_model() -> Iterator[Any]:
|
|
|
+ model = load_unity_model("seamlessM4T_medium")
|
|
|
+ model.eval()
|
|
|
+ with torch.inference_mode():
|
|
|
+ yield model
|
|
|
+
|
|
|
+@pytest.mark.xfail(reason="TODO")
|
|
|
+def test_hparams_code_is_up_to_date() -> None:
|
|
|
+ model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
|
|
|
+
|
|
|
hparams_header_file = model_file.with_suffix(".hparams.h")
|
|
|
hparams_struct = hparams_header_file.read_text().strip()
|
|
|
actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
|
|
|
- # breakpoint()
|
|
|
- # assert hparams_struct in actual_code
|
|
|
+ assert hparams_struct in actual_code
|
|
|
+
|
|
|
+
|
|
|
+def test_unity_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
|
|
|
+ x = torch.empty((1024,))
|
|
|
+ torch.nn.init.uniform_(x, -1, 1)
|
|
|
+
|
|
|
+ # Test FFN without LayerNorm
|
|
|
+ y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
|
|
|
+ gx = ggml.from_numpy(ctx, x)
|
|
|
+ gy = ggml.forward(
|
|
|
+ "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+
|
|
|
+ y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
|
|
|
+ abs_diff = np.max(np.abs(y - y_exp))
|
|
|
+ assert abs_diff < 1e-2
|
|
|
+ assert np.allclose(y_exp, y, rtol=1e-3)
|
|
|
+
|
|
|
+
|
|
|
+def test_unity_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
|
|
|
+ x = torch.empty((1024,))
|
|
|
+ torch.nn.init.uniform_(x, -1, 1)
|
|
|
+
|
|
|
+ y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
|
|
|
+ gx = ggml.from_numpy(ctx, x)
|
|
|
+ gy = ggml.forward(
|
|
|
+ "LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
- model = ggml.load_unity_ggml_file(model_file)
|
|
|
- print(model)
|
|
|
+ y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
|
|
|
+ abs_diff = np.max(np.abs(y - y_exp))
|
|
|
+ assert np.allclose(y_exp, y)
|