|
@@ -270,6 +270,7 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
state = state_bag.get_state(
|
|
|
attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
)
|
|
|
+ state_bag.increment_step()
|
|
|
assert state is not None
|
|
|
assert np.allclose(
|
|
|
state.prev_k.numpy(),
|
|
@@ -334,6 +335,7 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
atol=1e-3,
|
|
|
)
|
|
|
|
|
|
+ state_bag.increment_step()
|
|
|
y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
|
|
|
assert y_exp.shape == (2, 1, 1024)
|
|
|
assert np.allclose(y, y_exp, atol=1e-2)
|
|
@@ -566,6 +568,30 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
|
|
|
|
|
|
+def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
+ seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
|
|
|
+ pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
|
|
|
+ pos_encoder.eval()
|
|
|
+ state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
+
|
|
|
+ ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
|
|
|
+ # Incremental decoding
|
|
|
+ for t in range(20):
|
|
|
+ gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
|
|
|
+ ggml.ggml_set_name(gseq, b"seq")
|
|
|
+ gy = ggml.forward(
|
|
|
+ "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
+
|
|
|
+ y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
|
|
|
+ state_bag.increment_step()
|
|
|
+ assert y.shape == y_exp.shape
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
+
|
|
|
+
|
|
|
def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
seq = torch.arange(2 * 20).reshape(2, 20)
|
|
|
seq[1, 15:] = 0 # padding for second sentence
|
|
@@ -774,11 +800,17 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
)
|
|
|
result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
|
|
|
results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
|
|
|
- assert_hypotheses(text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1)
|
|
|
+ assert_hypotheses(
|
|
|
+ text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def assert_hypotheses(
|
|
|
- expected: List[Any], results: List[Any], *, score_rtol: float, step_scores_rtol: float
|
|
|
+ expected: List[Any],
|
|
|
+ results: List[Any],
|
|
|
+ *,
|
|
|
+ score_rtol: float,
|
|
|
+ step_scores_rtol: float
|
|
|
) -> None:
|
|
|
assert len(results) == len(expected)
|
|
|
for g_hyp, exp in zip(results, expected):
|