|
@@ -22,6 +22,7 @@ from seamless_communication.models.inference.translator import Translator, Modal
|
|
|
from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
import torchaudio
|
|
|
from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
|
|
|
+
|
|
|
Ctx = ggml.ggml_context_p
|
|
|
|
|
|
UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
|
|
@@ -241,6 +242,42 @@ def test_MultiheadAttention_forward(
|
|
|
assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
|
|
|
|
|
|
|
|
|
+def test_MultiheadAttention_forward_with_state_bag(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
+ pt_model = load_pt_model()
|
|
|
+ self_attn = pt_model.text_encoder.layers[0].self_attn
|
|
|
+
|
|
|
+ x = torch.empty((2, 21, 1024))
|
|
|
+ torch.random.manual_seed(0)
|
|
|
+ torch.nn.init.uniform_(x, -1, 1)
|
|
|
+
|
|
|
+ state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
+
|
|
|
+ ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
|
|
|
+ # Incremental decoding
|
|
|
+ for t in range(3):
|
|
|
+ xq, xk = x[:, t : t + 1], x[:, t : t + 1]
|
|
|
+ y_exp = self_attn(xq, None, xk, xk, state_bag=state_bag).numpy()
|
|
|
+ assert y_exp.shape == (2, 1, 1024)
|
|
|
+
|
|
|
+ gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
+ gxk = ggml.from_numpy(ctx, xk.contiguous())
|
|
|
+ ggml.ggml_set_name(gxk, b"xk")
|
|
|
+ gy = ggml.forward(
|
|
|
+ "MultiheadAttention",
|
|
|
+ g_model,
|
|
|
+ "text_encoder.layers.0.self_attn",
|
|
|
+ gxq,
|
|
|
+ gxk,
|
|
|
+ gxk,
|
|
|
+ None, # type: ignore
|
|
|
+ )
|
|
|
+ gf = ggml.ggml_build_forward(gy)
|
|
|
+ ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
+ assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
+
|
|
|
+
|
|
|
def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
padding_mask = torch.ones((2, 21))
|
|
@@ -272,11 +309,12 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
|
|
|
|
|
|
-def test_StandardConformerEncoderLayer_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+
|
|
|
+def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
- x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt")
|
|
|
+ x = torch.load(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt"
|
|
|
+ )
|
|
|
padding_mask = torch.ones((1, x.shape[1]))
|
|
|
layer = pt_model.speech_encoder.inner.layers[0]
|
|
|
gx = ggml.from_numpy(ctx, x[0])
|
|
@@ -304,7 +342,9 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
|
|
|
ctx: Ctx, g_model: c_void_p
|
|
|
) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
- x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt")
|
|
|
+ x = torch.load(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt"
|
|
|
+ )
|
|
|
layer = pt_model.speech_encoder.adaptor_layers[0]
|
|
|
gx = ggml.from_numpy(ctx, x[0])
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
@@ -356,12 +396,13 @@ def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=1e-4)
|
|
|
|
|
|
-def test_StandardConformerEncoder_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+
|
|
|
+def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
- wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
|
|
|
- gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
+ wav, _ = torchaudio.load(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
|
|
|
+ )
|
|
|
+ gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
gy = ggml.forward(
|
|
|
"StandardConformerEncoder",
|
|
@@ -381,24 +422,25 @@ def test_StandardConformerEncoder_forward(
|
|
|
)
|
|
|
converter_input = {
|
|
|
"waveform": wav.transpose(0, 1),
|
|
|
- "sample_rate": 16000.,
|
|
|
+ "sample_rate": 16000.0,
|
|
|
"format": -1,
|
|
|
}
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
|
- speech_encoder_input = pt_model.speech_encoder_frontend(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
|
|
|
+ speech_encoder_input = pt_model.speech_encoder_frontend(
|
|
|
+ converter(converter_input)["fbank"].unsqueeze(0), None
|
|
|
+ )[0]
|
|
|
|
|
|
y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
|
|
|
y_exp = y_exp.numpy() # remove batch dimension
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-2) # There are 10 elements in a 137*1024 tensor with error >1e-2
|
|
|
-
|
|
|
+ assert np.allclose(
|
|
|
+ y_exp, y, atol=1e-2
|
|
|
+ ) # There are 10 elements in a 137*1024 tensor with error >1e-2
|
|
|
|
|
|
|
|
|
-def test_WaveformToFbank_forward(
|
|
|
- ctx: Ctx, g_model: c_void_p
|
|
|
-) -> None:
|
|
|
+def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
pt_model = load_pt_model()
|
|
|
converter = WaveformToFbankConverter(
|
|
|
num_mel_bins=80,
|
|
@@ -407,30 +449,27 @@ def test_WaveformToFbank_forward(
|
|
|
standardize=True,
|
|
|
)
|
|
|
extractor = Wav2Vec2FbankFeatureExtractor(80, 2, 1)
|
|
|
- wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
|
|
|
- gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
+ wav, _ = torchaudio.load(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
|
|
|
+ )
|
|
|
+ gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
|
|
|
- gy = ggml.forward(
|
|
|
- "WaveformToFbank",
|
|
|
- g_model,
|
|
|
- "",
|
|
|
- gx
|
|
|
- )
|
|
|
+ gy = ggml.forward("WaveformToFbank", g_model, "", gx)
|
|
|
gf = ggml.ggml_build_forward(gy)
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
|
converter_input = {
|
|
|
"waveform": wav.transpose(0, 1),
|
|
|
- "sample_rate": 16000.,
|
|
|
+ "sample_rate": 16000.0,
|
|
|
"format": -1,
|
|
|
}
|
|
|
y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
|
|
|
y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
|
|
|
+ assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
|
|
|
|
|
|
|
|
|
def test_causal_attention_mask(ctx: Ctx):
|
|
@@ -600,8 +639,11 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
# The score error is big, this may negatively impact the beam search.
|
|
|
assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
|
|
|
|
|
|
+
|
|
|
def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
- src_audio_wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
|
|
|
+ src_audio_wav, _ = torchaudio.load(
|
|
|
+ "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
|
|
|
+ )
|
|
|
# translator = load_translator()
|
|
|
# token_encoder = translator.text_tokenizer.create_encoder(
|
|
|
# task="translation"
|
|
@@ -628,9 +670,23 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
# tgt_tokens = text_out.generator_output.results[0][0].seq
|
|
|
# score = text_out.generator_output.results[0][0].score.item()
|
|
|
|
|
|
- tgt_tokens = [ 3, 256200, 16991, 249346, 249725, 146, 25220, 251069, 249211,
|
|
|
- 251148, 253935, 3] # "大家好 , 世界无主题。"
|
|
|
- gx = ggml.from_numpy(ctx, src_audio_wav * 2**15) # Apply scale before sending into ggml!
|
|
|
+ tgt_tokens = [
|
|
|
+ 3,
|
|
|
+ 256200,
|
|
|
+ 16991,
|
|
|
+ 249346,
|
|
|
+ 249725,
|
|
|
+ 146,
|
|
|
+ 25220,
|
|
|
+ 251069,
|
|
|
+ 249211,
|
|
|
+ 251148,
|
|
|
+ 253935,
|
|
|
+ 3,
|
|
|
+ ] # "大家好 , 世界无主题。"
|
|
|
+ gx = ggml.from_numpy(
|
|
|
+ ctx, src_audio_wav * 2**15
|
|
|
+ ) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
gy = ggml.forward(
|
|
|
"StandardConformerEncoder",
|
|
@@ -659,8 +715,6 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
job.bos_idx = 2
|
|
|
job.eos_idx = 3
|
|
|
|
|
|
- result_ptr = ggml.generate_sequence(
|
|
|
- g_model, job, encoder_out, None, ctx
|
|
|
- )
|
|
|
+ result_ptr = ggml.generate_sequence(g_model, job, encoder_out, None, ctx)
|
|
|
g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
|
|
|
assert g_tokens == tgt_tokens
|