|
@@ -21,6 +21,7 @@ from ggml_convert import convert_model, read_layer_config
|
|
|
from seamless_communication.models.inference.translator import Translator, Modality
|
|
|
from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
import torchaudio
|
|
|
+from typing import List
|
|
|
from ctypes_utils import NULLPTR
|
|
|
from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
|
|
|
|
|
@@ -272,7 +273,9 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
assert state is not None
|
|
|
assert np.allclose(
|
|
|
state.prev_k.numpy(),
|
|
|
- ggml.to_numpy(nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]),
|
|
|
+ ggml.to_numpy(
|
|
|
+ nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
|
|
|
+ ),
|
|
|
atol=1e-3,
|
|
|
)
|
|
|
|
|
@@ -325,7 +328,9 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
assert state is not None
|
|
|
assert np.allclose(
|
|
|
state.prev_k.numpy(),
|
|
|
- ggml.to_numpy(nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]),
|
|
|
+ ggml.to_numpy(
|
|
|
+ nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
|
|
|
+ ),
|
|
|
atol=1e-3,
|
|
|
)
|
|
|
|
|
@@ -688,80 +693,76 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
)
|
|
|
results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
|
|
|
|
|
|
- assert len(results) == len(text_out["hypotheses"])
|
|
|
- for g_hyp, exp in zip(results, text_out["hypotheses"]):
|
|
|
- g_tokens = list(ggml.to_numpy(g_hyp.seq))
|
|
|
- g_step_scores = ggml.to_numpy(g_hyp.step_scores)
|
|
|
- assert g_tokens == exp["seq"]
|
|
|
- assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
|
|
|
- # The score error is big, this may negatively impact the beam search.
|
|
|
- assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
|
|
|
+ # The step score error is big, this may negatively impact the beam search.
|
|
|
+ assert_hypotheses(
|
|
|
+ text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
src_audio_wav, _ = torchaudio.load(DATA / "test.wav")
|
|
|
- # translator = load_translator()
|
|
|
- # token_encoder = translator.text_tokenizer.create_encoder(
|
|
|
- # task="translation"
|
|
|
- # )
|
|
|
- # decoded_audio = {
|
|
|
- # "waveform": src_audio_wav.t(),
|
|
|
- # "sample_rate": 16000.,
|
|
|
- # "format": -1,
|
|
|
- # }
|
|
|
- # src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
|
|
|
-
|
|
|
- # text_out, _ = translator.get_prediction(
|
|
|
- # translator.model,
|
|
|
- # translator.text_tokenizer,
|
|
|
- # translator.unit_tokenizer,
|
|
|
- # src,
|
|
|
- # input_modality=Modality.SPEECH,
|
|
|
- # output_modality=Modality.TEXT,
|
|
|
- # tgt_lang="cmn",
|
|
|
- # )
|
|
|
-
|
|
|
- # tgt_text = str(text_out.sentences[0])
|
|
|
- # assert tgt_text == "大家好 , 世界无主题。"
|
|
|
- # tgt_tokens = text_out.generator_output.results[0][0].seq
|
|
|
- # score = text_out.generator_output.results[0][0].score.item()
|
|
|
-
|
|
|
- tgt_tokens = [
|
|
|
- 3,
|
|
|
- 256200,
|
|
|
- 16991,
|
|
|
- 249346,
|
|
|
- 249725,
|
|
|
- 146,
|
|
|
- 25220,
|
|
|
- 251069,
|
|
|
- 249211,
|
|
|
- 251148,
|
|
|
- 253935,
|
|
|
- 3,
|
|
|
- ] # "大家好 , 世界无主题。"
|
|
|
- score = -1.606838583946228
|
|
|
- gx = ggml.from_numpy(
|
|
|
- ctx, src_audio_wav * 2**15
|
|
|
- ) # Apply scale before sending into ggml!
|
|
|
+ sample_file = DATA / "test.wav.npz"
|
|
|
+ if not sample_file.exists():
|
|
|
+ translator = load_translator()
|
|
|
+ token_encoder = translator.text_tokenizer.create_encoder(task="translation")
|
|
|
+ decoded_audio = {
|
|
|
+ "waveform": src_audio_wav.t(),
|
|
|
+ "sample_rate": 16000.0,
|
|
|
+ "format": -1,
|
|
|
+ }
|
|
|
+ src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
|
|
|
+
|
|
|
+ text_out, _ = translator.get_prediction(
|
|
|
+ translator.model,
|
|
|
+ translator.text_tokenizer,
|
|
|
+ translator.unit_tokenizer,
|
|
|
+ src,
|
|
|
+ input_modality=Modality.SPEECH,
|
|
|
+ output_modality=Modality.TEXT,
|
|
|
+ tgt_lang="cmn",
|
|
|
+ )
|
|
|
+
|
|
|
+ tgt_text = str(text_out.sentences[0])
|
|
|
+ assert tgt_text == "大家好 , 世界无主题。"
|
|
|
+ hypotheses = [
|
|
|
+ {
|
|
|
+ "seq": h.seq.tolist(),
|
|
|
+ "score": h.score.item(),
|
|
|
+ "step_scores": h.step_scores.numpy(),
|
|
|
+ }
|
|
|
+ for h in text_out.generator_output.results[0]
|
|
|
+ ]
|
|
|
+ np.savez(
|
|
|
+ sample_file,
|
|
|
+ encoder_output=text_out.encoder_output.numpy(),
|
|
|
+ hypotheses=hypotheses,
|
|
|
+ )
|
|
|
+
|
|
|
+ text_out = np.load(sample_file, allow_pickle=True)
|
|
|
+ encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
|
|
|
+ tgt_tokens = text_out["hypotheses"][0]["seq"]
|
|
|
+ max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
|
|
|
+ max_seq_len = int(max_seq_len * 1.5)
|
|
|
+
|
|
|
+ # Apply scale before sending into ggml!
|
|
|
+ gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
- gy = ggml.forward(
|
|
|
+ encoder_out = ggml.forward(
|
|
|
"StandardConformerEncoder",
|
|
|
g_model,
|
|
|
"speech_encoder",
|
|
|
gx,
|
|
|
None, # TODO support padding mask
|
|
|
)
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
+ gf = ggml.ggml_build_forward(encoder_out)
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
- encoder_out = gy
|
|
|
-
|
|
|
+ beam_size = 5
|
|
|
opts = ggml.SequenceGeneratorOptions(
|
|
|
- beam_size=5,
|
|
|
+ beam_size=beam_size,
|
|
|
soft_max_seq_len_a=1,
|
|
|
soft_max_seq_len_b=200,
|
|
|
- hard_max_seq_len=1000,
|
|
|
+ hard_max_seq_len=max_seq_len,
|
|
|
)
|
|
|
job = ggml.SequenceGeneratorJob(
|
|
|
opts=opts,
|
|
@@ -771,6 +772,18 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
bos_idx=2,
|
|
|
eos_idx=3,
|
|
|
)
|
|
|
- result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
|
|
|
- g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
|
|
|
- assert g_tokens == tgt_tokens
|
|
|
+ result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
|
|
|
+ results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
|
|
|
+ assert_hypotheses(text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1)
|
|
|
+
|
|
|
+
|
|
|
+def assert_hypotheses(
|
|
|
+ expected: List[Any], results: List[Any], *, score_rtol: float, step_scores_rtol: float
|
|
|
+) -> None:
|
|
|
+ assert len(results) == len(expected)
|
|
|
+ for g_hyp, exp in zip(results, expected):
|
|
|
+ g_tokens = list(ggml.to_numpy(g_hyp.seq))
|
|
|
+ g_step_scores = ggml.to_numpy(g_hyp.step_scores)
|
|
|
+ assert g_tokens == exp["seq"]
|
|
|
+ assert g_hyp.score == pytest.approx(exp["score"], rel=score_rtol)
|
|
|
+ assert np.allclose(g_step_scores, exp["step_scores"], rtol=step_scores_rtol)
|