Explorar o código

add failing test: test_PositionalEmbedding_forward_with_cache

Guillaume Wenzek hai 1 ano
pai
achega
ddff1a0644
Modificáronse 1 ficheiros con 34 adicións e 2 borrados
  1. 34 2
      ggml/test_unity_cpp.py

+ 34 - 2
ggml/test_unity_cpp.py

@@ -270,6 +270,7 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
         state = state_bag.get_state(
             attn, fairseq2.nn.transformer.MultiheadAttentionState
         )
+        state_bag.increment_step()
         assert state is not None
         assert np.allclose(
             state.prev_k.numpy(),
@@ -334,6 +335,7 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
                 atol=1e-3,
             )
 
+        state_bag.increment_step()
         y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
         assert y_exp.shape == (2, 1, 1024)
         assert np.allclose(y, y_exp, atol=1e-2)
@@ -566,6 +568,30 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
     assert np.allclose(y_exp, y, atol=1e-6)
 
 
+def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
+    seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
+    pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
+    pos_encoder.eval()
+    state_bag = fairseq2.nn.IncrementalStateBag()
+
+    ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
+    # Incremental decoding
+    for t in range(20):
+        gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
+        ggml.ggml_set_name(gseq, b"seq")
+        gy = ggml.forward(
+            "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
+        )
+        gf = ggml.ggml_build_forward(gy)
+        ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+        y = ggml.to_numpy(gy)
+
+        y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
+        state_bag.increment_step()
+        assert y.shape == y_exp.shape
+        assert np.allclose(y_exp, y, atol=1e-6)
+
+
 def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
     seq = torch.arange(2 * 20).reshape(2, 20)
     seq[1, 15:] = 0  # padding for second sentence
@@ -774,11 +800,17 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
     )
     result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
     results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
-    assert_hypotheses(text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1)
+    assert_hypotheses(
+        text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
+    )
 
 
 def assert_hypotheses(
-    expected: List[Any], results: List[Any], *, score_rtol: float, step_scores_rtol: float
+    expected: List[Any],
+    results: List[Any],
+    *,
+    score_rtol: float,
+    step_scores_rtol: float
 ) -> None:
     assert len(results) == len(expected)
     for g_hyp, exp in zip(results, expected):