|
@@ -245,43 +245,43 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
|
|
|
- # Incremental decoding
|
|
|
- for t in range(3):
|
|
|
- xq = x[:, t : t + 1]
|
|
|
- y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
|
|
|
- assert y_exp.shape == (2, 1, 1024)
|
|
|
-
|
|
|
- gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
- ggml.ggml_set_name(gxq, b"xq")
|
|
|
- gy = ggml.forward(
|
|
|
- "MultiheadAttention",
|
|
|
- g_model,
|
|
|
- "text_decoder.layers.0.self_attn",
|
|
|
- gxq,
|
|
|
- gxq,
|
|
|
- gxq,
|
|
|
- None, # type: ignore
|
|
|
- )
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
- ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ # Incremental decoding
|
|
|
+ for t in range(3):
|
|
|
+ xq = x[:, t : t + 1]
|
|
|
+ y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
|
|
|
+ assert y_exp.shape == (2, 1, 1024)
|
|
|
+
|
|
|
+ gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
+ ggml.ggml_set_name(gxq, b"xq")
|
|
|
+ gy = ggml.forward(
|
|
|
+ "MultiheadAttention",
|
|
|
+ g_model,
|
|
|
+ "text_decoder.layers.0.self_attn",
|
|
|
+ gxq,
|
|
|
+ gxq,
|
|
|
+ gxq,
|
|
|
+ None, # type: ignore
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
- nodes = ggml.nodes(gf)
|
|
|
- state = state_bag.get_state(
|
|
|
- attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
- )
|
|
|
- state_bag.increment_step()
|
|
|
- assert state is not None
|
|
|
- assert np.allclose(
|
|
|
- state.prev_k.numpy(),
|
|
|
- ggml.to_numpy(
|
|
|
- nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
|
|
|
- ),
|
|
|
- atol=1e-3,
|
|
|
- )
|
|
|
+ nodes = ggml.nodes(gf)
|
|
|
+ state = state_bag.get_state(
|
|
|
+ attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
+ )
|
|
|
+ state_bag.increment_step()
|
|
|
+ assert state is not None
|
|
|
+ assert np.allclose(
|
|
|
+ state.prev_k.numpy(),
|
|
|
+ ggml.to_numpy(
|
|
|
+ nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
|
|
|
+ ),
|
|
|
+ atol=1e-3,
|
|
|
+ )
|
|
|
|
|
|
- y = ggml.to_numpy(gy)
|
|
|
- assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
+ assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
|
|
|
|
|
|
def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
@@ -296,49 +296,49 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
|
|
|
- # Incremental decoding, the keys come from the encoder, and don't change during decoding
|
|
|
- xk = x[:, :11]
|
|
|
- gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
|
|
|
-
|
|
|
- for t in range(3):
|
|
|
- xq = x[:, t : t + 1]
|
|
|
-
|
|
|
- gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
- ggml.ggml_set_name(gxq, b"xq")
|
|
|
- gy = ggml.forward(
|
|
|
- "MultiheadAttention",
|
|
|
- g_model,
|
|
|
- "text_decoder.layers.0.encoder_decoder_attn",
|
|
|
- gxq,
|
|
|
- gxk,
|
|
|
- gxk,
|
|
|
- None, # type: ignore
|
|
|
- )
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
- ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
- y = ggml.to_numpy(gy)
|
|
|
- nodes = ggml.nodes(gf)
|
|
|
- leaves = ggml.leafs(gf)
|
|
|
-
|
|
|
- if t > 0:
|
|
|
- # the cache only appear in the graph during the second call
|
|
|
- state = state_bag.get_state(
|
|
|
- attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
- )
|
|
|
- assert state is not None
|
|
|
- assert np.allclose(
|
|
|
- state.prev_k.numpy(),
|
|
|
- ggml.to_numpy(
|
|
|
- nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
|
|
|
- ),
|
|
|
- atol=1e-3,
|
|
|
+ with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ # Incremental decoding, the keys come from the encoder, and don't change during decoding
|
|
|
+ xk = x[:, :11]
|
|
|
+ gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
|
|
|
+
|
|
|
+ for t in range(3):
|
|
|
+ xq = x[:, t : t + 1]
|
|
|
+
|
|
|
+ gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
+ ggml.ggml_set_name(gxq, b"xq")
|
|
|
+ gy = ggml.forward(
|
|
|
+ "MultiheadAttention",
|
|
|
+ g_model,
|
|
|
+ "text_decoder.layers.0.encoder_decoder_attn",
|
|
|
+ gxq,
|
|
|
+ gxk,
|
|
|
+ gxk,
|
|
|
+ None, # type: ignore
|
|
|
)
|
|
|
-
|
|
|
- state_bag.increment_step()
|
|
|
- y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
|
|
|
- assert y_exp.shape == (2, 1, 1024)
|
|
|
- assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
+ nodes = ggml.nodes(gf)
|
|
|
+ leaves = ggml.leafs(gf)
|
|
|
+
|
|
|
+ if t > 0:
|
|
|
+ # the cache only appear in the graph during the second call
|
|
|
+ state = state_bag.get_state(
|
|
|
+ attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
+ )
|
|
|
+ assert state is not None
|
|
|
+ assert np.allclose(
|
|
|
+ state.prev_k.numpy(),
|
|
|
+ ggml.to_numpy(
|
|
|
+ nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
|
|
|
+ ),
|
|
|
+ atol=1e-3,
|
|
|
+ )
|
|
|
+
|
|
|
+ state_bag.increment_step()
|
|
|
+ y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
|
|
|
+ assert y_exp.shape == (2, 1, 1024)
|
|
|
+ assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
|
|
|
|
|
|
def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
@@ -534,20 +534,6 @@ def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
|
|
|
|
|
|
|
|
|
-def test_causal_attention_mask(ctx: Ctx):
|
|
|
- x = torch.zeros((5, 10))
|
|
|
- generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
|
|
|
- mask_exp = generator(x)
|
|
|
-
|
|
|
- gx = ggml.from_numpy(ctx, x)
|
|
|
- gmask = ggml.causal_attention_mask(ctx, gx)
|
|
|
- mask = ggml.to_numpy(gmask)
|
|
|
-
|
|
|
- assert mask_exp.shape == (10, 10)
|
|
|
- assert mask.shape == (10, 10)
|
|
|
- assert np.allclose(mask, mask_exp)
|
|
|
-
|
|
|
-
|
|
|
def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
|
|
|
# this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
|
|
@@ -574,22 +560,22 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
|
|
|
pos_encoder.eval()
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
|
|
|
- # Incremental decoding
|
|
|
- for t in range(20):
|
|
|
- gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
|
|
|
- ggml.ggml_set_name(gseq, b"seq")
|
|
|
- gy = ggml.forward(
|
|
|
- "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
|
|
|
- )
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
- ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
- y = ggml.to_numpy(gy)
|
|
|
-
|
|
|
- y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
|
|
|
- state_bag.increment_step()
|
|
|
- assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
+ with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ # Incremental decoding
|
|
|
+ for t in range(20):
|
|
|
+ gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
|
|
|
+ ggml.ggml_set_name(gseq, b"seq")
|
|
|
+ gy = ggml.forward(
|
|
|
+ "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
+
|
|
|
+ y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
|
|
|
+ state_bag.increment_step()
|
|
|
+ assert y.shape == y_exp.shape
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
|
|
|
|
|
|
def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
|