瀏覽代碼

fairseq2 churn ebf637dc52b42850e086eebe2d60a45a89881aad

Guillaume Wenzek 1 年之前
父節點
當前提交
20fea08676
共有 1 個文件被更改,包括 32 次插入36 次删除
  1. 32 36
      ggml/test_unity_cpp.py

+ 32 - 36
ggml/test_unity_cpp.py

@@ -82,7 +82,7 @@ def test_convert_linear(tmp_path: Path) -> None:
     module = fairseq2.nn.Linear(16, 24, True)
 
     layer_config = read_layer_config(module)
-    assert layer_config == {"input_dim": 16, "output_dim": 24, "skip_init": False}
+    assert layer_config == {"input_dim": 16, "output_dim": 24}
 
     module_file = Path("module.ggml")
     convert_model(module, module_file)
@@ -96,8 +96,8 @@ def test_convert_linear(tmp_path: Path) -> None:
 
 def test_causal_attention_mask(ctx: Ctx):
     x = torch.zeros((1, 10, 32))
-    generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
-    mask_exp = generator(x).numpy()
+    generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
+    mask_exp = generator(x, x).materialize().numpy()
 
     gx = ggml.from_numpy(ctx, x)
     gmask = ggml.causal_attention_mask(ctx, gx)
@@ -111,7 +111,7 @@ def test_causal_attention_mask(ctx: Ctx):
     assert np.all(mask == mask_exp)
 
     x = x[:, :8, :]
-    mask_exp = generator(x).numpy()
+    mask_exp = generator(x, x).materialize().numpy()
     gx = ggml.from_numpy(ctx, x)
     gmask = ggml.causal_attention_mask(ctx, gx)
     mask = ggml.to_numpy(gmask)
@@ -209,10 +209,10 @@ def test_MultiheadAttention_forward(
     y = ggml.to_numpy(gy)
     nodes = ggml.nodes(gf)
 
-    attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
+    attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
     self_attn.register_attn_weight_hook(attn_weights_hook)
 
-    y_exp = self_attn(xq, None, xk, xk).numpy()
+    y_exp = self_attn(xq, None, xk, None, xk).numpy()
 
     q = ggml.to_numpy(nodes[b"q"])
     assert q.shape == q_exp.shape
@@ -221,15 +221,15 @@ def test_MultiheadAttention_forward(
     # with flash_attn we don't have attn_weights
     naive_attn = b"attn_weights" in nodes
     if naive_attn:
-        attn_weights = ggml.to_numpy(nodes[b"attn_weights"])
-        [attn_weights_exp] = attn_weights_hook._storage
+        attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
+        [(_, attn_weights_exp)] = attn_weights_hook._storage
         attn_weights_exp = attn_weights_exp.numpy()
         assert attn_weights_exp.shape == attn_weights.shape
         # GGML is very agressively reducing small softmax weights to 0,
         # so the error isn't that small
         assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
         # But the sums should be close to 1
-        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, qlen)))
+        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
         # And the maximum index should match the original ones.
         assert np.allclose(
             np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
@@ -248,13 +248,13 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
-    state_bag = fairseq2.nn.IncrementalStateBag()
+    state_bag = fairseq2.nn.IncrementalStateBag(100)
 
     with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
         # Incremental decoding
         for t in range(3):
             xq = x[:, t : t + 1]
-            y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
+            y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
             assert y_exp.shape == (2, 1, 1024)
 
             gxq = ggml.from_numpy(ctx, xq.contiguous())
@@ -272,13 +272,11 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
             ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
             nodes = ggml.nodes(gf)
-            state = state_bag.get_state(
-                attn, fairseq2.nn.transformer.MultiheadAttentionState
-            )
+            state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
             state_bag.increment_step()
             assert state is not None
             assert np.allclose(
-                state.prev_k.numpy(),
+                state.get()[0].transpose(1, 2).reshape(2, t + 1, -1).numpy(),
                 ggml.to_numpy(
                     nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
                 ),
@@ -299,7 +297,7 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
-    state_bag = fairseq2.nn.IncrementalStateBag()
+    state_bag = fairseq2.nn.IncrementalStateBag(100)
 
     with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
         # Incremental decoding, the keys come from the encoder, and don't change during decoding
@@ -329,11 +327,11 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
             if t > 0:
                 # the cache only appear in the graph during the second call
                 state = state_bag.get_state(
-                    attn, fairseq2.nn.transformer.MultiheadAttentionState
+                    attn, fairseq2.nn.transformer.AttentionState
                 )
                 assert state is not None
                 assert np.allclose(
-                    state.prev_k.numpy(),
+                    state.get()[0].transpose(1, 2).reshape(2, 11, -1).numpy(),
                     ggml.to_numpy(
                         nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
                     ),
@@ -341,14 +339,13 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
                 )
 
             state_bag.increment_step()
-            y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
+            y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
             assert y_exp.shape == (2, 1, 1024)
             assert np.allclose(y, y_exp, atol=1e-2)
 
 
 def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
-    padding_mask = torch.ones((2, 21))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
@@ -357,7 +354,8 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
 
     gx = ggml.from_numpy(ctx, x)
     ggml.ggml_set_name(gx, b"x")
-    gpad = ggml.from_numpy(ctx, padding_mask)
+    padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
+    gpad = ggml.from_numpy(ctx, padding_mask.materialize())
     ggml.ggml_set_name(gpad, b"padding_mask")
     gy = ggml.forward(
         "StandardTransformerEncoderLayer",
@@ -371,7 +369,7 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
 
     y = ggml.to_numpy(gy)
 
-    y_exp, _ = layer(x, padding_mask)
+    y_exp, _ = layer(x, padding_mask=None)
     y_exp = y_exp.numpy()
 
     assert y.shape == y_exp.shape
@@ -440,13 +438,13 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
 
 def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 21, 1024))
-    padding_mask = torch.ones((2, 21))
+    padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
 
     gx = ggml.from_numpy(ctx, x)
     ggml.ggml_set_name(gx, b"x")
-    gpad = ggml.from_numpy(ctx, padding_mask)
+    gpad = ggml.from_numpy(ctx, padding_mask.materialize())
     ggml.ggml_set_name(gpad, b"padding_mask")
     gy = ggml.forward(
         "StandardTransformerEncoder",
@@ -532,8 +530,8 @@ def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
         "sample_rate": 16000.0,
         "format": -1,
     }
-    y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
-    y_exp = y_exp.numpy()
+    y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
+    y_exp = y_exp.squeeze(0).numpy()
 
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=4e-3)  # reduce? error is from standardization
@@ -563,7 +561,7 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
     seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
     pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
     pos_encoder.eval()
-    state_bag = fairseq2.nn.IncrementalStateBag()
+    state_bag = fairseq2.nn.IncrementalStateBag(100)
 
     with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
         # Incremental decoding
@@ -607,13 +605,13 @@ def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> No
 def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
     x = torch.empty((2, 13, 1024))
     encoder_out = torch.empty((2, 21, 1024))
-    padding_mask = torch.ones((2, 13))
+    padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
     torch.nn.init.uniform_(encoder_out, -1, 1)
     gx = ggml.from_numpy(ctx, x)
     ggml.ggml_set_name(gx, b"x")
-    gpad = ggml.from_numpy(ctx, padding_mask)
+    gpad = ggml.from_numpy(ctx, padding_mask.materialize())
     ggml.ggml_set_name(gpad, b"padding_mask")
     genc = ggml.from_numpy(ctx, encoder_out)
     gy = ggml.forward(
@@ -655,11 +653,13 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
             translator.model,
             translator.text_tokenizer,
             translator.unit_tokenizer,
-            src,
+            src["seqs"],
+            None,
             input_modality=Modality.TEXT,
             output_modality=Modality.TEXT,
             tgt_lang=tgt_lang,
-            beam_size=beam_size,
+            text_generation_opts=SequenceGeneratorOptions(beam_size=beam_size),
+            unit_generation_opts=None,
         )
 
         tgt_text = str(text_out.sentences[0])
@@ -675,14 +675,12 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
         np.savez(
             sample_file,
             encoder_output=text_out.encoder_output.numpy(),
-            encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
             hypotheses=hypotheses,
         )
 
     # allow_pickle to load the hyp dicts
     text_out = np.load(sample_file, allow_pickle=True)
     encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
-    encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
     prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
     max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
 
@@ -705,9 +703,7 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
         eos_idx=3,
     )
 
-    result_ptr = ggml.generate_sequence(
-        g_model, job, encoder_out, encoder_padding_mask, ctx
-    )
+    result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
     results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
 
     # The step score error is big, this may negatively impact the beam search.