Prechádzať zdrojové kódy

s2tt generate the sample if needed

Guillaume Wenzek 1 rok pred
rodič
commit
8c074387a9
2 zmenil súbory, kde vykonal 76 pridanie a 64 odobranie
  1. 0 1
      ggml/examples/unity/fairseq2.cpp
  2. 76 63
      ggml/test_unity_cpp.py

+ 0 - 1
ggml/examples/unity/fairseq2.cpp

@@ -1207,7 +1207,6 @@ void _finalize_hypothesis(
 
 
 /// Generates a translation for a single sequence
-// TODO: add IncrementalStateBag support to avoid a O(N^3) generation.
 // TODO: clean ups
 // * replace manual tensor tweaking with ggml_set_*d (a ggml_set_slice could be useful)
 extern "C" Hypothesis* generate_sequence(

+ 76 - 63
ggml/test_unity_cpp.py

@@ -21,6 +21,7 @@ from ggml_convert import convert_model, read_layer_config
 from seamless_communication.models.inference.translator import Translator, Modality
 from fairseq2.data.audio import WaveformToFbankConverter
 import torchaudio
+from typing import List
 from ctypes_utils import NULLPTR
 from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
 
@@ -272,7 +273,9 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
         assert state is not None
         assert np.allclose(
             state.prev_k.numpy(),
-            ggml.to_numpy(nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]),
+            ggml.to_numpy(
+                nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
+            ),
             atol=1e-3,
         )
 
@@ -325,7 +328,9 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
             assert state is not None
             assert np.allclose(
                 state.prev_k.numpy(),
-                ggml.to_numpy(nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]),
+                ggml.to_numpy(
+                    nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
+                ),
                 atol=1e-3,
             )
 
@@ -688,80 +693,76 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
     )
     results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
 
-    assert len(results) == len(text_out["hypotheses"])
-    for g_hyp, exp in zip(results, text_out["hypotheses"]):
-        g_tokens = list(ggml.to_numpy(g_hyp.seq))
-        g_step_scores = ggml.to_numpy(g_hyp.step_scores)
-        assert g_tokens == exp["seq"]
-        assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
-        # The score error is big, this may negatively impact the beam search.
-        assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
+    # The step score error is big, this may negatively impact the beam search.
+    assert_hypotheses(
+        text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
+    )
 
 
 def test_s2tt(ctx: Ctx, g_model: c_void_p):
     src_audio_wav, _ = torchaudio.load(DATA / "test.wav")
-    # translator = load_translator()
-    # token_encoder = translator.text_tokenizer.create_encoder(
-    #     task="translation"
-    # )
-    # decoded_audio = {
-    #     "waveform": src_audio_wav.t(),
-    #     "sample_rate": 16000.,
-    #     "format": -1,
-    # }
-    # src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
-
-    # text_out, _ = translator.get_prediction(
-    #     translator.model,
-    #     translator.text_tokenizer,
-    #     translator.unit_tokenizer,
-    #     src,
-    #     input_modality=Modality.SPEECH,
-    #     output_modality=Modality.TEXT,
-    #     tgt_lang="cmn",
-    # )
-
-    # tgt_text = str(text_out.sentences[0])
-    # assert tgt_text == "大家好 , 世界无主题。"
-    # tgt_tokens = text_out.generator_output.results[0][0].seq
-    # score = text_out.generator_output.results[0][0].score.item()
-
-    tgt_tokens = [
-        3,
-        256200,
-        16991,
-        249346,
-        249725,
-        146,
-        25220,
-        251069,
-        249211,
-        251148,
-        253935,
-        3,
-    ]  # "大家好 , 世界无主题。"
-    score = -1.606838583946228
-    gx = ggml.from_numpy(
-        ctx, src_audio_wav * 2**15
-    )  # Apply scale before sending into ggml!
+    sample_file = DATA / "test.wav.npz"
+    if not sample_file.exists():
+        translator = load_translator()
+        token_encoder = translator.text_tokenizer.create_encoder(task="translation")
+        decoded_audio = {
+            "waveform": src_audio_wav.t(),
+            "sample_rate": 16000.0,
+            "format": -1,
+        }
+        src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
+
+        text_out, _ = translator.get_prediction(
+            translator.model,
+            translator.text_tokenizer,
+            translator.unit_tokenizer,
+            src,
+            input_modality=Modality.SPEECH,
+            output_modality=Modality.TEXT,
+            tgt_lang="cmn",
+        )
+
+        tgt_text = str(text_out.sentences[0])
+        assert tgt_text == "大家好 , 世界无主题。"
+        hypotheses = [
+            {
+                "seq": h.seq.tolist(),
+                "score": h.score.item(),
+                "step_scores": h.step_scores.numpy(),
+            }
+            for h in text_out.generator_output.results[0]
+        ]
+        np.savez(
+            sample_file,
+            encoder_output=text_out.encoder_output.numpy(),
+            hypotheses=hypotheses,
+        )
+
+    text_out = np.load(sample_file, allow_pickle=True)
+    encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
+    tgt_tokens = text_out["hypotheses"][0]["seq"]
+    max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
+    max_seq_len = int(max_seq_len * 1.5)
+
+    # Apply scale before sending into ggml!
+    gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
     ggml.ggml_set_name(gx, b"x")
-    gy = ggml.forward(
+    encoder_out = ggml.forward(
         "StandardConformerEncoder",
         g_model,
         "speech_encoder",
         gx,
         None,  # TODO support padding mask
     )
-    gf = ggml.ggml_build_forward(gy)
+    gf = ggml.ggml_build_forward(encoder_out)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-    encoder_out = gy
-
+    beam_size = 5
     opts = ggml.SequenceGeneratorOptions(
-        beam_size=5,
+        beam_size=beam_size,
         soft_max_seq_len_a=1,
         soft_max_seq_len_b=200,
-        hard_max_seq_len=1000,
+        hard_max_seq_len=max_seq_len,
     )
     job = ggml.SequenceGeneratorJob(
         opts=opts,
@@ -771,6 +772,18 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
         bos_idx=2,
         eos_idx=3,
     )
-    result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
-    g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
-    assert g_tokens == tgt_tokens
+    result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
+    results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
+    assert_hypotheses(text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1)
+
+
+def assert_hypotheses(
+    expected: List[Any], results: List[Any], *, score_rtol: float, step_scores_rtol: float
+) -> None:
+    assert len(results) == len(expected)
+    for g_hyp, exp in zip(results, expected):
+        g_tokens = list(ggml.to_numpy(g_hyp.seq))
+        g_step_scores = ggml.to_numpy(g_hyp.step_scores)
+        assert g_tokens == exp["seq"]
+        assert g_hyp.score == pytest.approx(exp["score"], rel=score_rtol)
+        assert np.allclose(g_step_scores, exp["step_scores"], rtol=step_scores_rtol)