Browse Source

use ggml_diag_mask_inf

Guillaume Wenzek 1 year ago
parent
commit
eb80195345
2 changed files with 34 additions and 88 deletions
  1. 2 23
      ggml/examples/unity/fairseq2.cpp
  2. 32 65
      ggml/test_unity_cpp.py

+ 2 - 23
ggml/examples/unity/fairseq2.cpp

@@ -396,32 +396,11 @@ extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
     return seqs;
 }
 
-ggml_tensor* causal_mask_cache = nullptr;
-
 extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
     auto seq_len = seqs->ne[1];
-    auto mask = causal_mask_cache;
-    // TODO: this cache only works as long as we don't change the size/device too often
     // TODO: allow other ggml_type
-    if (mask == nullptr || mask->backend != seqs->backend || mask->ne[0] < seq_len) {
-        printf("new causal_mask (%ld, %ld) created\n", seq_len, seq_len);
-        mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
-        char* data = (char*)mask->data;
-
-        // tensor([[0., -inf, -inf, -inf],
-        //         [0.,   0., -inf, -inf],
-        //         [0.,   0.,   0., -inf],
-        //         [0.,   0.,   0.,   0.]])
-        for (int i = 0; i < seq_len; ++i) {
-            char* row = data + i * mask->nb[1];
-            for (int j = 0; j <= i; ++j) {*(float*)(row + j * mask->nb[0]) = 0;}
-            for (int j = i + 1; j < seq_len; ++j) {*(float*)(row + j * mask->nb[0]) = -INFINITY;}
-        }
-
-        causal_mask_cache = mask;
-    }
-
-    return ggml_view_2d(ctx, mask, seq_len, seq_len, mask->nb[1], 0);
+    ggml_tensor* mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
+    return ggml_diag_mask_inf(ctx, mask, 0);
 }
 
 extern "C" ggml_tensor* StandardTransformerDecoder_forward(

+ 32 - 65
ggml/test_unity_cpp.py

@@ -26,6 +26,7 @@ CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
 FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
 UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
 
+
 @pytest.fixture(name="ctx")
 def _ctx() -> Iterator[Ctx]:
     """Allocate a new context with 1024 MB of memory"""
@@ -36,7 +37,6 @@ def _ctx() -> Iterator[Ctx]:
         ggml.ggml_free(ctx)
 
 
-
 @pytest.fixture(scope="module")
 def g_model_once() -> Iterator[c_void_p]:
     model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
@@ -78,6 +78,36 @@ def test_hparams_code_is_up_to_date() -> None:
     assert hparams_struct in actual_code
 
 
+def test_causal_attention_mask(ctx: Ctx):
+    x = torch.zeros((1, 10, 32))
+    generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
+    mask_exp = generator(x).numpy()
+
+    gx = ggml.from_numpy(ctx, x)
+    gmask = ggml.causal_attention_mask(ctx, gx)
+    mask = ggml.to_numpy(gmask)
+
+    gf = ggml.ggml_build_forward(gmask)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    assert mask_exp.shape == (10, 10)
+    assert mask.shape == (10, 10)
+    assert np.all(mask == mask_exp)
+
+    x = x[:, :8, :]
+    mask_exp = generator(x).numpy()
+    gx = ggml.from_numpy(ctx, x)
+    gmask = ggml.causal_attention_mask(ctx, gx)
+    mask = ggml.to_numpy(gmask)
+
+    gf = ggml.ggml_build_forward(gmask)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    assert mask_exp.shape == (8, 8)
+    assert mask.shape == (8, 8)
+    assert np.all(mask == mask_exp)
+
+
 def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     x = torch.empty((21, 1024))  # (seq_len, model_dim)
     torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
@@ -162,7 +192,6 @@ def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     # assert q.shape == q_exp.shape
     # assert np.allclose(q_exp, q, atol=1e-5)
 
-
     # with flash_attn we don't have attn_weights
     if not UNITY_FLASH_ATTN:
         attn_weights = nodes[b"attn_weights"]
@@ -241,30 +270,6 @@ def test_StandardTransformerEncoder_forward(
     assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
 
 
-def test_causal_attention_mask(ctx: Ctx):
-    x = torch.zeros((1, 10, 32))
-    generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
-    mask_exp = generator(x).numpy()
-
-    gx = ggml.from_numpy(ctx, x)
-    gmask = ggml.causal_attention_mask(ctx, gx)
-    mask = ggml.to_numpy(gmask)
-
-    assert mask_exp.shape == (10, 10)
-    assert mask.shape == (10, 10)
-    assert np.all(mask == mask_exp)
-
-    x = x[:, :8, :]
-    mask_exp = generator(x).numpy()
-    gx = ggml.from_numpy(ctx, x)
-    gmask = ggml.causal_attention_mask(ctx, gx)
-    mask = ggml.to_numpy(gmask)
-    assert mask_exp.shape == (8, 8)
-    assert mask.shape == (8, 8)
-    assert np.all(mask == mask_exp)
-
-
-
 def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
     seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
     # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
@@ -399,44 +404,6 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p):
     g_score = ggml.generate_sequence(
         g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
     )
-    tokens = list(ggml.to_numpy(result))
+    tokens = list(ggml.to_numpy(ctypes.pointer(result)))
     assert tokens == tgt_tokens
     assert g_score == pytest.approx(score)
-
-
-def test_in_loop(ctx: Ctx, g_model: c_void_p, pt_model: Any):
-    resources = locals()
-
-    import importlib
-    import time
-
-    testcase = test_TransformerEmbeddingFrontend_forward.__name__
-    name, script = __name__, __file__
-    root = Path(__file__).parent
-    watched_files = [Path(__file__), root / "ggml.py", root / "build/src/libggml.so"]
-    last_try = 0.0
-
-    while True:
-        last_save = max(f.stat().st_mtime for f in watched_files)
-        if last_save <= last_try:
-            time.sleep(0.1)
-            continue
-
-        last_try = last_save
-        spec = importlib.util.spec_from_file_location(name, script)
-        module = importlib.util.module_from_spec(spec)
-        spec.loader.exec_module(module)
-        sys.modules[name] = module
-        f = getattr(module, testcase)
-        f_args = [k for k in f.__annotations__ if k != "return"]
-        try:
-            f(**{k: resources[k] for k in f_args})
-            print(f"Testcase {testcase} success")
-        except AssertionError as e:
-            print(f"Testcase {testcase} failed: {e}")
-
-        except Exception as e:
-            import pdb
-
-            logging.exception(f"Testcase {testcase} crashed !")
-            pdb.post_mortem()