Эх сурвалжийг харах

batching -> StandardTransformerDecoder

Guillaume Wenzek 1 жил өмнө
parent
commit
3d69d4975f
1 өөрчлөгдсөн 9 нэмэгдсэн , 11 устгасан
  1. 9 11
      ggml/test_unity_cpp.py

+ 9 - 11
ggml/test_unity_cpp.py

@@ -330,18 +330,17 @@ def test_TransformerEmbeddingFrontend_forward(
 def test_StandardTransformerDecoder_forward(
     ctx: Ctx, g_model: c_void_p, pt_model: Any
 ) -> None:
-    pytest.skip("foo")
-    x = torch.empty((1, 13, 1024))
-    encoder_out = torch.empty((1, 21, 1024))
-    padding_mask = torch.ones((1, 13))
+    x = torch.empty((2, 13, 1024))
+    encoder_out = torch.empty((2, 21, 1024))
+    padding_mask = torch.ones((2, 13))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
     torch.nn.init.uniform_(encoder_out, -1, 1)
-    gx = ggml.from_numpy(ctx, x[0])
+    gx = ggml.from_numpy(ctx, x)
     ggml.ggml_set_name(gx, b"x")
-    gpad = ggml.from_numpy(ctx, padding_mask[0])
+    gpad = ggml.from_numpy(ctx, padding_mask)
     ggml.ggml_set_name(gpad, b"padding_mask")
-    genc = ggml.from_numpy(ctx, encoder_out[0])
+    genc = ggml.from_numpy(ctx, encoder_out)
     gy = ggml.forward(
         "StandardTransformerDecoder",
         g_model,
@@ -351,15 +350,14 @@ def test_StandardTransformerDecoder_forward(
         genc,
         None,
     )
-    gf = ggml.ggml_build_forward(gy)
-    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    ggml.build_and_compute(ctx, gy)
     y = ggml.to_numpy(gy)
 
     y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
-    y_exp = y_exp.squeeze(0).numpy()  # remove batch dimension
+    y_exp = y_exp.numpy()
 
     assert y.shape == y_exp.shape
-    assert np.allclose(y_exp, y, atol=1e-4)
+    assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
 
 
 def test_t2tt(ctx: Ctx, g_model: c_void_p):