Guillaume Wenzek 1 год назад
Родитель
Сommit
cc23e2b1c7
2 измененных файлов с 35 добавлено и 36 удалено
  1. 0 1
      ggml/ggml.py
  2. 35 35
      ggml/test_unity_cpp.py

+ 0 - 1
ggml/ggml.py

@@ -300,7 +300,6 @@ lib.std_string_free.restype = None
 NativeObj._cache["std_string"] = (lib.std_string_alloc, lib.std_string_free)
 
 
-@functools.lru_cache(1024)
 def CppStr(content: str) -> NativeObj:
     c_str = ctypes.create_string_buffer(content.encode("utf-8"))
     cpp_str = lib.std_string_alloc(c_str)

+ 35 - 35
ggml/test_unity_cpp.py

@@ -9,6 +9,7 @@ import fairseq2.nn.transformer
 import logging
 import sys
 import functools
+from typing import Tuple
 from pathlib import Path
 from ctypes_utils import Ptr
 from ctypes import c_void_p
@@ -48,6 +49,7 @@ def _load_g_model_once() -> NativeObj:
         convert_model("seamlessM4T_medium", model_file)
     return ggml.load_fairseq2_ggml_file(model_file)
 
+
 @pytest.fixture()
 def g_model(ctx: Ctx) -> c_void_p:
     model = _load_g_model_once()
@@ -61,6 +63,7 @@ def load_translator() -> Translator:
         "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
     )
 
+
 def load_pt_model() -> Any:
     return load_translator().model
 
@@ -76,7 +79,9 @@ def test_convert_linear(tmp_path: Path) -> None:
     g_module = ggml.load_fairseq2_ggml_file(module_file)
 
     for k, v in layer_config.items():
-        assert ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
+        assert (
+            ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
+        )
 
 
 def test_causal_attention_mask(ctx: Ctx):
@@ -161,27 +166,26 @@ def _name(tensor: ggml.ggml_tensor_p) -> bytes:
         return b"???"
 
 
-@pytest.mark.parametrize("flash_attn", [False, True])
-def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: bool) -> None:
+@pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
+def test_MultiheadAttention_forward(
+    ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
+) -> None:
     x = torch.empty((2, 21, 1024))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
-    pt_model = load_pt_model()
-    self_attn = pt_model.text_encoder.layers[0].self_attn
-
     # Note: we use different lengths for queries and keys,
     # this tests the implementation in decoding context too.
     # Note2: ggml_flash_attn requires that we have more keys than queries
-    if flash_attn:
-        xq = x[:, :11, :]
-        xk = x
-    else:
-        xq = x
-        xk = x[:, :13, :]
+    # qlen, klen = (11, 21) if flash_attn else (21, 13)
+    qlen, klen = lengths
+    xq = x[:, :qlen]
+    xk = x[:, :klen]
+    if qlen > klen and UNITY_FLASH_ATTN:
+        pytest.skip(reason="flash_attn requires qlen > klen")
 
     gxq = ggml.from_numpy(ctx, xq.contiguous())
-    gxk = ggml.from_numpy(ctx, xk)
+    gxk = ggml.from_numpy(ctx, xk.contiguous())
     ggml.ggml_set_name(gxk, b"xk")
     gy = ggml.forward(
         "MultiheadAttention",
@@ -195,6 +199,8 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
+    pt_model = load_pt_model()
+    self_attn = pt_model.text_encoder.layers[0].self_attn
     q_exp = self_attn.q_proj(xq).numpy()
 
     y = ggml.to_numpy(gy)
@@ -216,7 +222,8 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
     assert np.allclose(q_exp, q, atol=1e-5)
 
     # with flash_attn we don't have attn_weights
-    if not flash_attn:
+    naive_attn = b"attn_weights" in nodes
+    if naive_attn:
         attn_weights = nodes[b"attn_weights"]
         [attn_weights_exp] = attn_weights_hook._storage
         attn_weights_exp = attn_weights_exp.numpy()
@@ -225,17 +232,16 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
         # so the error isn't that small
         assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
         # But the sums should be close to 1
-        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 1)))
+        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, qlen)))
         # And the maximum index should match the original ones.
-        assert np.allclose(np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
+        assert np.allclose(
+            np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
         )
     assert y.shape == y_exp.shape
-    assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
+    assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
 
 
-def test_StandardTransformerEncoderLayer_forward(
-    ctx: Ctx, g_model: c_void_p
-) -> None:
+def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
     padding_mask = torch.ones((2, 21))
     torch.random.manual_seed(0)
@@ -290,7 +296,7 @@ def test_StandardConformerEncoderLayer_forward(
     y = ggml.to_numpy(gy)
 
     y_exp, _ = layer(x, padding_mask)
-    y_exp = y_exp.numpy()  
+    y_exp = y_exp.numpy()
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=2e-3)
 
@@ -315,15 +321,13 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
     y = ggml.to_numpy(gy)
 
     y_exp, _ = layer(x, None)
-    y_exp = y_exp.numpy() 
+    y_exp = y_exp.numpy()
 
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=2e-3)
 
 
-def test_StandardTransformerEncoder_forward(
-    ctx: Ctx, g_model: c_void_p
-) -> None:
+def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
     padding_mask = torch.ones((2, 21))
     torch.random.manual_seed(0)
@@ -390,7 +394,7 @@ def test_StandardConformerEncoder_forward(
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=1e-2) # There are 10 elements in a 137*1024 tensor with error >1e-2
 
-    
+
 
 def test_WaveformToFbank_forward(
     ctx: Ctx, g_model: c_void_p
@@ -406,7 +410,7 @@ def test_WaveformToFbank_forward(
     wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
     gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
     ggml.ggml_set_name(gx, b"x")
-    
+
     gy = ggml.forward(
         "WaveformToFbank",
         g_model,
@@ -423,7 +427,7 @@ def test_WaveformToFbank_forward(
         "format": -1,
     }
     y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
-    y_exp = y_exp.numpy() 
+    y_exp = y_exp.numpy()
 
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
@@ -463,9 +467,7 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
     assert np.allclose(y_exp, y, atol=1e-6)
 
 
-def test_TransformerEmbeddingFrontend_forward(
-    ctx: Ctx, g_model: c_void_p
-) -> None:
+def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
     seq = torch.arange(2 * 20).reshape(2, 20)
     seq[1, 15:] = 0  # padding for second sentence
     seq_len = torch.tensor([20, 15])
@@ -486,9 +488,7 @@ def test_TransformerEmbeddingFrontend_forward(
     assert np.allclose(y_exp, y, atol=1e-6)
 
 
-def test_StandardTransformerDecoder_forward(
-    ctx: Ctx, g_model: c_void_p
-) -> None:
+def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 13, 1024))
     encoder_out = torch.empty((2, 21, 1024))
     padding_mask = torch.ones((2, 13))
@@ -658,7 +658,7 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
     job.unk_idx = 1
     job.bos_idx = 2
     job.eos_idx = 3
-    
+
     result_ptr = ggml.generate_sequence(
         g_model, job, encoder_out, None, ctx
     )