Ver código fonte

Update test_unity_cpp.py

Ning 1 ano atrás
pai
commit
d9061b89b3
1 arquivos alterados com 8 adições e 12 exclusões
  1. 8 12
      ggml/test_unity_cpp.py

+ 8 - 12
ggml/test_unity_cpp.py

@@ -24,7 +24,7 @@ from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtra
 Ctx = ggml.ggml_context_p
 
 UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
-CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
+CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=None)
 
 FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
 UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
@@ -618,8 +618,7 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
     # score = text_out.generator_output.results[0][0].score.item()
 
     tgt_tokens = [     3, 256200,  16991, 249346, 249725,    146,  25220, 251069, 249211,
-        251148, 253935,      3]
-    score = -1.606838583946228
+        251148, 253935,      3] # "大家好 , 世界无主题。"
     gx = ggml.from_numpy(ctx, src_audio_wav * 2**15) # Apply scale before sending into ggml!
     ggml.ggml_set_name(gx, b"x")
     gy = ggml.forward(
@@ -635,11 +634,11 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
     encoder_out = gy
 
     job = ggml.SequenceGeneratorJob()
-    job.opts.beam_size = 1
+    job.opts.beam_size = 5
     job.opts.min_seq_len = 1
     job.opts.soft_max_seq_len_a = 1
     job.opts.soft_max_seq_len_b = 200
-    job.opts.hard_max_seq_len = 20
+    job.opts.hard_max_seq_len = 1000
     job.opts.len_penalty = 1.0
     job.opts.unk_penalty = 0.0
     job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
@@ -648,12 +647,9 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
     job.unk_idx = 1
     job.bos_idx = 2
     job.eos_idx = 3
-
-    result = ggml.ggml_tensor()
     
-    g_score = ggml.generate_sequence(
-        g_model, job, encoder_out, None, ctypes.byref(result)
+    result_ptr = ggml.generate_sequence(
+        g_model, job, encoder_out, None, ctx
     )
-    tokens = list(ggml.to_numpy(result))
-    assert tokens == tgt_tokens
-    assert g_score == pytest.approx(score)
+    g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
+    assert g_tokens == tgt_tokens