|
@@ -34,6 +34,11 @@ FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
|
|
|
UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
|
|
|
|
|
|
DATA = Path(__file__).parent
|
|
|
+DATA_DEV = DATA / "dev"
|
|
|
+if not DATA_DEV.exists():
|
|
|
+ DATA_DEV = Path(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@pytest.fixture(name="ctx")
|
|
@@ -245,7 +250,7 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding
|
|
|
for t in range(3):
|
|
|
xq = x[:, t : t + 1]
|
|
@@ -296,7 +301,7 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding, the keys come from the encoder, and don't change during decoding
|
|
|
xk = x[:, :11]
|
|
|
gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
|
|
@@ -375,9 +380,10 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
|
|
|
|
|
|
def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
- x = torch.load(
|
|
|
- "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt"
|
|
|
- )
|
|
|
+ if not DATA_DEV.exists():
|
|
|
+ pytest.skip(reason=f"Folder {DATA_DEV} not found !")
|
|
|
+
|
|
|
+ x = torch.load(DATA_DEV / "seqs_before_conformer_block.pt")
|
|
|
padding_mask = torch.ones((1, x.shape[1]))
|
|
|
layer = pt_model.speech_encoder.inner.layers[0]
|
|
|
gx = ggml.from_numpy(ctx, x[0])
|
|
@@ -406,9 +412,10 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
|
|
|
ctx: Ctx, g_model: c_void_p
|
|
|
) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
- x = torch.load(
|
|
|
- "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt"
|
|
|
- )
|
|
|
+ if not DATA_DEV.exists():
|
|
|
+ pytest.skip(reason=f"Folder {DATA_DEV} not found !")
|
|
|
+
|
|
|
+ x = torch.load(DATA_DEV / "seqs_before_adaptor.pt")
|
|
|
layer = pt_model.speech_encoder.adaptor_layers[0]
|
|
|
gx = ggml.from_numpy(ctx, x[0])
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
@@ -458,7 +465,7 @@ def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None
|
|
|
y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-4)
|
|
|
+ assert np.allclose(y_exp, y, atol=5e-3)
|
|
|
|
|
|
|
|
|
def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
@@ -510,10 +517,8 @@ def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
channel_last=True,
|
|
|
standardize=True,
|
|
|
)
|
|
|
- extractor = Wav2Vec2FbankFeatureExtractor(80, 2, 1)
|
|
|
- wav, _ = torchaudio.load(
|
|
|
- "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
|
|
|
- )
|
|
|
+ extractor = Wav2Vec2FbankFeatureExtractor(80, stride=2, sample_every_k=1)
|
|
|
+ wav, _ = torchaudio.load(DATA / "test.wav")
|
|
|
gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
|
|
@@ -560,7 +565,7 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
|
|
|
pos_encoder.eval()
|
|
|
state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
|
|
|
- with ggml.model_kv_cache_alloc(g_model, 2, 21):
|
|
|
+ with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding
|
|
|
for t in range(20):
|
|
|
gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
|
|
@@ -796,7 +801,7 @@ def assert_hypotheses(
|
|
|
results: List[Any],
|
|
|
*,
|
|
|
score_rtol: float,
|
|
|
- step_scores_rtol: float
|
|
|
+ step_scores_rtol: float,
|
|
|
) -> None:
|
|
|
assert len(results) == len(expected)
|
|
|
for g_hyp, exp in zip(results, expected):
|