|
@@ -330,18 +330,17 @@ def test_TransformerEmbeddingFrontend_forward(
|
|
|
def test_StandardTransformerDecoder_forward(
|
|
|
ctx: Ctx, g_model: c_void_p, pt_model: Any
|
|
|
) -> None:
|
|
|
- pytest.skip("foo")
|
|
|
- x = torch.empty((1, 13, 1024))
|
|
|
- encoder_out = torch.empty((1, 21, 1024))
|
|
|
- padding_mask = torch.ones((1, 13))
|
|
|
+ x = torch.empty((2, 13, 1024))
|
|
|
+ encoder_out = torch.empty((2, 21, 1024))
|
|
|
+ padding_mask = torch.ones((2, 13))
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
torch.nn.init.uniform_(encoder_out, -1, 1)
|
|
|
- gx = ggml.from_numpy(ctx, x[0])
|
|
|
+ gx = ggml.from_numpy(ctx, x)
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
- gpad = ggml.from_numpy(ctx, padding_mask[0])
|
|
|
+ gpad = ggml.from_numpy(ctx, padding_mask)
|
|
|
ggml.ggml_set_name(gpad, b"padding_mask")
|
|
|
- genc = ggml.from_numpy(ctx, encoder_out[0])
|
|
|
+ genc = ggml.from_numpy(ctx, encoder_out)
|
|
|
gy = ggml.forward(
|
|
|
"StandardTransformerDecoder",
|
|
|
g_model,
|
|
@@ -351,15 +350,14 @@ def test_StandardTransformerDecoder_forward(
|
|
|
genc,
|
|
|
None,
|
|
|
)
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
- ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ ggml.build_and_compute(ctx, gy)
|
|
|
y = ggml.to_numpy(gy)
|
|
|
|
|
|
y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
|
|
|
- y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
|
|
|
+ y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-4)
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
|
|
|
|
|
|
|
|
|
def test_t2tt(ctx: Ctx, g_model: c_void_p):
|