|
@@ -9,6 +9,7 @@ import fairseq2.nn.transformer
|
|
|
import logging
|
|
|
import sys
|
|
|
import functools
|
|
|
+from typing import Tuple
|
|
|
from pathlib import Path
|
|
|
from ctypes_utils import Ptr
|
|
|
from ctypes import c_void_p
|
|
@@ -48,6 +49,7 @@ def _load_g_model_once() -> NativeObj:
|
|
|
convert_model("seamlessM4T_medium", model_file)
|
|
|
return ggml.load_fairseq2_ggml_file(model_file)
|
|
|
|
|
|
+
|
|
|
@pytest.fixture()
|
|
|
def g_model(ctx: Ctx) -> c_void_p:
|
|
|
model = _load_g_model_once()
|
|
@@ -61,6 +63,7 @@ def load_translator() -> Translator:
|
|
|
"seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
|
|
|
)
|
|
|
|
|
|
+
|
|
|
def load_pt_model() -> Any:
|
|
|
return load_translator().model
|
|
|
|
|
@@ -76,7 +79,9 @@ def test_convert_linear(tmp_path: Path) -> None:
|
|
|
g_module = ggml.load_fairseq2_ggml_file(module_file)
|
|
|
|
|
|
for k, v in layer_config.items():
|
|
|
- assert ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
|
|
|
+ assert (
|
|
|
+ ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def test_causal_attention_mask(ctx: Ctx):
|
|
@@ -161,27 +166,26 @@ def _name(tensor: ggml.ggml_tensor_p) -> bytes:
|
|
|
return b"???"
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize("flash_attn", [False, True])
|
|
|
-def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: bool) -> None:
|
|
|
+@pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
|
|
|
+def test_MultiheadAttention_forward(
|
|
|
+ ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
|
|
|
+) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
|
|
|
- pt_model = load_pt_model()
|
|
|
- self_attn = pt_model.text_encoder.layers[0].self_attn
|
|
|
-
|
|
|
# Note: we use different lengths for queries and keys,
|
|
|
# this tests the implementation in decoding context too.
|
|
|
# Note2: ggml_flash_attn requires that we have more keys than queries
|
|
|
- if flash_attn:
|
|
|
- xq = x[:, :11, :]
|
|
|
- xk = x
|
|
|
- else:
|
|
|
- xq = x
|
|
|
- xk = x[:, :13, :]
|
|
|
+ # qlen, klen = (11, 21) if flash_attn else (21, 13)
|
|
|
+ qlen, klen = lengths
|
|
|
+ xq = x[:, :qlen]
|
|
|
+ xk = x[:, :klen]
|
|
|
+ if qlen > klen and UNITY_FLASH_ATTN:
|
|
|
+ pytest.skip(reason="flash_attn requires qlen > klen")
|
|
|
|
|
|
gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
- gxk = ggml.from_numpy(ctx, xk)
|
|
|
+ gxk = ggml.from_numpy(ctx, xk.contiguous())
|
|
|
ggml.ggml_set_name(gxk, b"xk")
|
|
|
gy = ggml.forward(
|
|
|
"MultiheadAttention",
|
|
@@ -195,6 +199,8 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
|
|
|
gf = ggml.ggml_build_forward(gy)
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
+ pt_model = load_pt_model()
|
|
|
+ self_attn = pt_model.text_encoder.layers[0].self_attn
|
|
|
q_exp = self_attn.q_proj(xq).numpy()
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
@@ -216,7 +222,8 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
|
|
|
assert np.allclose(q_exp, q, atol=1e-5)
|
|
|
|
|
|
# with flash_attn we don't have attn_weights
|
|
|
- if not flash_attn:
|
|
|
+ naive_attn = b"attn_weights" in nodes
|
|
|
+ if naive_attn:
|
|
|
attn_weights = nodes[b"attn_weights"]
|
|
|
[attn_weights_exp] = attn_weights_hook._storage
|
|
|
attn_weights_exp = attn_weights_exp.numpy()
|
|
@@ -225,17 +232,16 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: boo
|
|
|
# so the error isn't that small
|
|
|
assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
|
|
|
# But the sums should be close to 1
|
|
|
- assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 1)))
|
|
|
+ assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, qlen)))
|
|
|
# And the maximum index should match the original ones.
|
|
|
- assert np.allclose(np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
|
|
|
+ assert np.allclose(
|
|
|
+ np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
|
|
|
)
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
|
|
|
|
|
|
|
|
|
-def test_StandardTransformerEncoderLayer_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
padding_mask = torch.ones((2, 21))
|
|
|
torch.random.manual_seed(0)
|
|
@@ -290,7 +296,7 @@ def test_StandardConformerEncoderLayer_forward(
|
|
|
y = ggml.to_numpy(gy)
|
|
|
|
|
|
y_exp, _ = layer(x, padding_mask)
|
|
|
- y_exp = y_exp.numpy()
|
|
|
+ y_exp = y_exp.numpy()
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=2e-3)
|
|
|
|
|
@@ -315,15 +321,13 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
|
|
|
y = ggml.to_numpy(gy)
|
|
|
|
|
|
y_exp, _ = layer(x, None)
|
|
|
- y_exp = y_exp.numpy()
|
|
|
+ y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=2e-3)
|
|
|
|
|
|
|
|
|
-def test_StandardTransformerEncoder_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
padding_mask = torch.ones((2, 21))
|
|
|
torch.random.manual_seed(0)
|
|
@@ -390,7 +394,7 @@ def test_StandardConformerEncoder_forward(
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=1e-2) # There are 10 elements in a 137*1024 tensor with error >1e-2
|
|
|
|
|
|
-
|
|
|
+
|
|
|
|
|
|
def test_WaveformToFbank_forward(
|
|
|
ctx: Ctx, g_model: c_void_p
|
|
@@ -406,7 +410,7 @@ def test_WaveformToFbank_forward(
|
|
|
wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
|
|
|
gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
-
|
|
|
+
|
|
|
gy = ggml.forward(
|
|
|
"WaveformToFbank",
|
|
|
g_model,
|
|
@@ -423,7 +427,7 @@ def test_WaveformToFbank_forward(
|
|
|
"format": -1,
|
|
|
}
|
|
|
y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
|
|
|
- y_exp = y_exp.numpy()
|
|
|
+ y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
|
|
@@ -463,9 +467,7 @@ def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
|
|
|
|
|
|
-def test_TransformerEmbeddingFrontend_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
seq = torch.arange(2 * 20).reshape(2, 20)
|
|
|
seq[1, 15:] = 0 # padding for second sentence
|
|
|
seq_len = torch.tensor([20, 15])
|
|
@@ -486,9 +488,7 @@ def test_TransformerEmbeddingFrontend_forward(
|
|
|
assert np.allclose(y_exp, y, atol=1e-6)
|
|
|
|
|
|
|
|
|
-def test_StandardTransformerDecoder_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 13, 1024))
|
|
|
encoder_out = torch.empty((2, 21, 1024))
|
|
|
padding_mask = torch.ones((2, 13))
|
|
@@ -658,7 +658,7 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
job.unk_idx = 1
|
|
|
job.bos_idx = 2
|
|
|
job.eos_idx = 3
|
|
|
-
|
|
|
+
|
|
|
result_ptr = ggml.generate_sequence(
|
|
|
g_model, job, encoder_out, None, ctx
|
|
|
)
|