|
@@ -24,7 +24,7 @@ from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtra
|
|
|
Ctx = ggml.ggml_context_p
|
|
|
|
|
|
UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
|
|
|
-CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
|
|
|
+CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=None)
|
|
|
|
|
|
FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
|
|
|
UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
|
|
@@ -618,8 +618,7 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
# score = text_out.generator_output.results[0][0].score.item()
|
|
|
|
|
|
tgt_tokens = [ 3, 256200, 16991, 249346, 249725, 146, 25220, 251069, 249211,
|
|
|
- 251148, 253935, 3]
|
|
|
- score = -1.606838583946228
|
|
|
+ 251148, 253935, 3] # "大家好 , 世界无主题。"
|
|
|
gx = ggml.from_numpy(ctx, src_audio_wav * 2**15) # Apply scale before sending into ggml!
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
gy = ggml.forward(
|
|
@@ -635,11 +634,11 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
encoder_out = gy
|
|
|
|
|
|
job = ggml.SequenceGeneratorJob()
|
|
|
- job.opts.beam_size = 1
|
|
|
+ job.opts.beam_size = 5
|
|
|
job.opts.min_seq_len = 1
|
|
|
job.opts.soft_max_seq_len_a = 1
|
|
|
job.opts.soft_max_seq_len_b = 200
|
|
|
- job.opts.hard_max_seq_len = 20
|
|
|
+ job.opts.hard_max_seq_len = 1000
|
|
|
job.opts.len_penalty = 1.0
|
|
|
job.opts.unk_penalty = 0.0
|
|
|
job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
|
|
@@ -648,12 +647,9 @@ def test_s2tt(ctx: Ctx, g_model: c_void_p):
|
|
|
job.unk_idx = 1
|
|
|
job.bos_idx = 2
|
|
|
job.eos_idx = 3
|
|
|
-
|
|
|
- result = ggml.ggml_tensor()
|
|
|
|
|
|
- g_score = ggml.generate_sequence(
|
|
|
- g_model, job, encoder_out, None, ctypes.byref(result)
|
|
|
+ result_ptr = ggml.generate_sequence(
|
|
|
+ g_model, job, encoder_out, None, ctx
|
|
|
)
|
|
|
- tokens = list(ggml.to_numpy(result))
|
|
|
- assert tokens == tgt_tokens
|
|
|
- assert g_score == pytest.approx(score)
|
|
|
+ g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
|
|
|
+ assert g_tokens == tgt_tokens
|