test_unity_cpp.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. import functools
  12. from pathlib import Path
  13. from ctypes_utils import Ptr
  14. from ctypes import c_void_p
  15. from typing import Any
  16. from pathlib import Path
  17. from typing import Iterator
  18. from ggml import NativeObj
  19. from ggml_convert import convert_model
  20. from seamless_communication.models.inference.translator import Translator, Modality
  21. from fairseq2.data.audio import WaveformToFbankConverter
  22. import torchaudio
  23. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  24. Ctx = ggml.ggml_context_p
  25. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  26. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  27. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  28. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  29. @pytest.fixture(name="ctx")
  30. def _ctx() -> Iterator[Ctx]:
  31. """Allocate a new context with 1024 MB of memory"""
  32. try:
  33. ctx = ggml.ggml_init(params=CTX_PARAMS)
  34. with torch.inference_mode():
  35. yield ctx
  36. finally:
  37. ggml.ggml_free(ctx)
  38. @functools.lru_cache()
  39. def _load_g_model_once() -> NativeObj:
  40. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  41. if not model_file.exists():
  42. convert_model("seamlessM4T_medium", model_file)
  43. return ggml.load_unity_ggml_file(model_file)
  44. @pytest.fixture()
  45. def g_model(ctx: Ctx) -> c_void_p:
  46. model = _load_g_model_once()
  47. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  48. return model.ptr
  49. @functools.lru_cache(maxsize=1)
  50. def load_translator() -> Translator:
  51. return Translator(
  52. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  53. )
  54. def load_pt_model() -> Any:
  55. return load_translator().model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_causal_attention_mask(ctx: Ctx):
  64. x = torch.zeros((1, 10, 32))
  65. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  66. mask_exp = generator(x).numpy()
  67. gx = ggml.from_numpy(ctx, x)
  68. gmask = ggml.causal_attention_mask(ctx, gx)
  69. mask = ggml.to_numpy(gmask)
  70. gf = ggml.ggml_build_forward(gmask)
  71. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  72. assert mask_exp.shape == (10, 10)
  73. assert mask.shape == (10, 10)
  74. assert np.all(mask == mask_exp)
  75. x = x[:, :8, :]
  76. mask_exp = generator(x).numpy()
  77. gx = ggml.from_numpy(ctx, x)
  78. gmask = ggml.causal_attention_mask(ctx, gx)
  79. mask = ggml.to_numpy(gmask)
  80. gf = ggml.ggml_build_forward(gmask)
  81. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  82. assert mask_exp.shape == (8, 8)
  83. assert mask.shape == (8, 8)
  84. assert np.all(mask == mask_exp)
  85. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  86. x = torch.empty((2, 21, 1024))
  87. torch.nn.init.uniform_(x, -1, 1)
  88. pt_model = load_pt_model()
  89. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  90. gx = ggml.from_numpy(ctx, x)
  91. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  92. ggml.build_and_compute(ctx, gy)
  93. y = ggml.to_numpy(gy)
  94. assert np.allclose(y_exp, y, atol=1e-5)
  95. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  96. x = torch.empty((2, 21, 1024))
  97. torch.nn.init.uniform_(x, -1, 1)
  98. pt_model = load_pt_model()
  99. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  100. gx = ggml.from_numpy(ctx, x)
  101. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  102. ggml.build_and_compute(ctx, gy)
  103. y = ggml.to_numpy(gy)
  104. assert np.allclose(y_exp, y, atol=1e-5)
  105. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  106. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  107. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  108. # Test FFN without LayerNorm
  109. pt_model = load_pt_model()
  110. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  111. gx = ggml.from_numpy(ctx, x)
  112. gy = ggml.forward(
  113. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  114. )
  115. ggml.build_and_compute(ctx, gy)
  116. y = ggml.to_numpy(gy)
  117. assert np.allclose(y_exp, y, atol=1e-5)
  118. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  119. try:
  120. return tensor.contents.name # type: ignore[no-any-return]
  121. except ValueError:
  122. return b"???"
  123. def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
  124. x = torch.empty((2, 21, 1024))
  125. torch.random.manual_seed(0)
  126. torch.nn.init.uniform_(x, -1, 1)
  127. pt_model = load_pt_model()
  128. self_attn = pt_model.text_encoder.layers[0].self_attn
  129. # Note: we use different lengths for queries and keys,
  130. # this tests the implementation in decoding context too.
  131. # Note2: ggml_flash_attn requires that we have more keys than queries
  132. gxq = ggml.from_numpy(ctx, x[:, :11, :].contiguous())
  133. gx = ggml.from_numpy(ctx, x)
  134. ggml.ggml_set_name(gx, b"x")
  135. gy = ggml.forward(
  136. "MultiheadAttention",
  137. g_model,
  138. "text_encoder.layers.0.self_attn",
  139. gxq,
  140. gx,
  141. gx,
  142. None, # TODO: tests with causal attention masks
  143. )
  144. gf = ggml.ggml_build_forward(gy)
  145. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  146. q_exp = self_attn.q_proj(x[:, :11, :]).numpy()
  147. y = ggml.to_numpy(gy)
  148. nodes = {}
  149. for i in range(gf.n_nodes):
  150. name = _name(gf.nodes[i])
  151. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  152. print(name, f"op({gf.nodes[i].contents.op})", children)
  153. nodes[name] = ggml.to_numpy(gf.nodes[i])
  154. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  155. self_attn.register_attn_weight_hook(attn_weights_hook)
  156. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  157. q = nodes[b"q"]
  158. assert q.shape == q_exp.shape
  159. assert np.allclose(q_exp, q, atol=1e-5)
  160. # with flash_attn we don't have attn_weights
  161. if not UNITY_FLASH_ATTN:
  162. attn_weights = nodes[b"attn_weights"]
  163. [attn_weights_exp] = attn_weights_hook._storage
  164. attn_weights_exp = attn_weights_exp.numpy()
  165. assert attn_weights_exp.shape == attn_weights.shape
  166. # GGML is very agressively reducing small softmax weights to 0,
  167. # so the error isn't that small
  168. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  169. # But the sums should be close to 1
  170. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 11)))
  171. # And the maximum index should match the original ones.
  172. assert np.allclose(
  173. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  174. )
  175. assert y.shape == y_exp.shape
  176. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  177. def test_StandardTransformerEncoderLayer_forward(
  178. ctx: Ctx, g_model: c_void_p
  179. ) -> None:
  180. x = torch.empty((2, 21, 1024))
  181. padding_mask = torch.ones((2, 21))
  182. torch.random.manual_seed(0)
  183. torch.nn.init.uniform_(x, -1, 1)
  184. pt_model = load_pt_model()
  185. layer = pt_model.text_encoder.layers[0]
  186. gx = ggml.from_numpy(ctx, x)
  187. ggml.ggml_set_name(gx, b"x")
  188. gpad = ggml.from_numpy(ctx, padding_mask)
  189. ggml.ggml_set_name(gpad, b"padding_mask")
  190. gy = ggml.forward(
  191. "StandardTransformerEncoderLayer",
  192. g_model,
  193. "text_encoder.layers.0",
  194. gx,
  195. None, # TODO support padding mask
  196. )
  197. gf = ggml.ggml_build_forward(gy)
  198. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  199. y = ggml.to_numpy(gy)
  200. y_exp, _ = layer(x, padding_mask)
  201. y_exp = y_exp.numpy()
  202. assert y.shape == y_exp.shape
  203. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  204. def test_StandardConformerEncoderLayer_forward(
  205. ctx: Ctx, g_model: c_void_p
  206. ) -> None:
  207. pt_model = load_pt_model()
  208. x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt")
  209. padding_mask = torch.ones((1, x.shape[1]))
  210. layer = pt_model.speech_encoder.inner.layers[0]
  211. gx = ggml.from_numpy(ctx, x[0])
  212. ggml.ggml_set_name(gx, b"x")
  213. gpad = ggml.from_numpy(ctx, padding_mask[0])
  214. ggml.ggml_set_name(gpad, b"padding_mask")
  215. gy = ggml.forward(
  216. "StandardConformerEncoderLayer",
  217. g_model,
  218. "speech_encoder.inner.layers.0",
  219. gx,
  220. None, # TODO support padding mask
  221. )
  222. gf = ggml.ggml_build_forward(gy)
  223. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  224. y = ggml.to_numpy(gy)
  225. y_exp, _ = layer(x, padding_mask)
  226. y_exp = y_exp.numpy()
  227. assert y.shape == y_exp.shape
  228. assert np.allclose(y_exp, y, atol=2e-3)
  229. def test_StandardConformerEncoderAdaptorLayer_forward(
  230. ctx: Ctx, g_model: c_void_p
  231. ) -> None:
  232. pt_model = load_pt_model()
  233. x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt")
  234. layer = pt_model.speech_encoder.adaptor_layers[0]
  235. gx = ggml.from_numpy(ctx, x[0])
  236. ggml.ggml_set_name(gx, b"x")
  237. gy = ggml.forward(
  238. "StandardConformerEncoderAdaptorLayer",
  239. g_model,
  240. "speech_encoder.adaptor_layers.0",
  241. gx,
  242. None, # TODO support padding mask
  243. )
  244. gf = ggml.ggml_build_forward(gy)
  245. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  246. y = ggml.to_numpy(gy)
  247. y_exp, _ = layer(x, None)
  248. y_exp = y_exp.numpy()
  249. assert y.shape == y_exp.shape
  250. assert np.allclose(y_exp, y, atol=2e-3)
  251. def test_StandardTransformerEncoder_forward(
  252. ctx: Ctx, g_model: c_void_p
  253. ) -> None:
  254. x = torch.empty((2, 21, 1024))
  255. padding_mask = torch.ones((2, 21))
  256. torch.random.manual_seed(0)
  257. torch.nn.init.uniform_(x, -1, 1)
  258. gx = ggml.from_numpy(ctx, x)
  259. ggml.ggml_set_name(gx, b"x")
  260. gpad = ggml.from_numpy(ctx, padding_mask)
  261. ggml.ggml_set_name(gpad, b"padding_mask")
  262. gy = ggml.forward(
  263. "StandardTransformerEncoder",
  264. g_model,
  265. "text_encoder",
  266. gx,
  267. None, # TODO support padding mask
  268. )
  269. gf = ggml.ggml_build_forward(gy)
  270. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  271. y = ggml.to_numpy(gy)
  272. pt_model = load_pt_model()
  273. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  274. y_exp = y_exp.numpy()
  275. assert y.shape == y_exp.shape
  276. assert np.allclose(y_exp, y, atol=1e-4)
  277. def test_StandardConformerEncoder_forward(
  278. ctx: Ctx, g_model: c_void_p
  279. ) -> None:
  280. pt_model = load_pt_model()
  281. wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  282. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  283. ggml.ggml_set_name(gx, b"x")
  284. gy = ggml.forward(
  285. "StandardConformerEncoder",
  286. g_model,
  287. "speech_encoder",
  288. gx,
  289. None, # TODO support padding mask
  290. )
  291. gf = ggml.ggml_build_forward(gy)
  292. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  293. converter = WaveformToFbankConverter(
  294. num_mel_bins=80,
  295. waveform_scale=2**15,
  296. channel_last=True,
  297. standardize=True,
  298. )
  299. converter_input = {
  300. "waveform": wav.transpose(0, 1),
  301. "sample_rate": 16000.,
  302. "format": -1,
  303. }
  304. y = ggml.to_numpy(gy)
  305. speech_encoder_input = pt_model.speech_encoder_frontend(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
  306. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  307. y_exp = y_exp.numpy() # remove batch dimension
  308. assert y.shape == y_exp.shape
  309. assert np.allclose(y_exp, y, atol=1e-2) # There are 10 elements in a 137*1024 tensor with error >1e-2
  310. def test_WaveformToFbank_forward(
  311. ctx: Ctx, g_model: c_void_p
  312. ) -> None:
  313. pt_model = load_pt_model()
  314. converter = WaveformToFbankConverter(
  315. num_mel_bins=80,
  316. waveform_scale=2**15,
  317. channel_last=True,
  318. standardize=True,
  319. )
  320. extractor = Wav2Vec2FbankFeatureExtractor(80, 2, 1)
  321. wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  322. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  323. ggml.ggml_set_name(gx, b"x")
  324. gy = ggml.forward(
  325. "WaveformToFbank",
  326. g_model,
  327. "",
  328. gx
  329. )
  330. gf = ggml.ggml_build_forward(gy)
  331. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  332. y = ggml.to_numpy(gy)
  333. converter_input = {
  334. "waveform": wav.transpose(0, 1),
  335. "sample_rate": 16000.,
  336. "format": -1,
  337. }
  338. y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
  339. y_exp = y_exp.numpy()
  340. assert y.shape == y_exp.shape
  341. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  342. def test_causal_attention_mask(ctx: Ctx):
  343. x = torch.zeros((5, 10))
  344. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  345. mask_exp = generator(x)
  346. gx = ggml.from_numpy(ctx, x)
  347. gmask = ggml.causal_attention_mask(ctx, gx)
  348. mask = ggml.to_numpy(gmask)
  349. assert mask_exp.shape == (10, 10)
  350. assert mask.shape == (10, 10)
  351. assert np.allclose(mask, mask_exp)
  352. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  353. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  354. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  355. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  356. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  357. y_exp = pos_encoder(seq, None)[0].numpy()
  358. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  359. ggml.ggml_set_name(gseq, b"seq")
  360. gy = ggml.forward(
  361. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  362. )
  363. gf = ggml.ggml_build_forward(gy)
  364. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  365. y = ggml.to_numpy(gy)
  366. assert y.shape == y_exp.shape
  367. assert np.allclose(y_exp, y, atol=1e-6)
  368. def test_TransformerEmbeddingFrontend_forward(
  369. ctx: Ctx, g_model: c_void_p
  370. ) -> None:
  371. seq = torch.arange(2 * 20).reshape(2, 20)
  372. seq[1, 15:] = 0 # padding for second sentence
  373. seq_len = torch.tensor([20, 15])
  374. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  375. ggml.ggml_set_name(gseq, b"seq")
  376. gy = ggml.forward(
  377. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  378. )
  379. ggml.build_and_compute(ctx, gy)
  380. y = ggml.to_numpy(gy)
  381. pt_model = load_pt_model()
  382. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  383. y_exp = y_exp.numpy()
  384. assert y.shape == y_exp.shape
  385. assert np.allclose(y_exp, y, atol=1e-6)
  386. def test_StandardTransformerDecoder_forward(
  387. ctx: Ctx, g_model: c_void_p
  388. ) -> None:
  389. x = torch.empty((2, 13, 1024))
  390. encoder_out = torch.empty((2, 21, 1024))
  391. padding_mask = torch.ones((2, 13))
  392. torch.random.manual_seed(0)
  393. torch.nn.init.uniform_(x, -1, 1)
  394. torch.nn.init.uniform_(encoder_out, -1, 1)
  395. gx = ggml.from_numpy(ctx, x)
  396. ggml.ggml_set_name(gx, b"x")
  397. gpad = ggml.from_numpy(ctx, padding_mask)
  398. ggml.ggml_set_name(gpad, b"padding_mask")
  399. genc = ggml.from_numpy(ctx, encoder_out)
  400. gy = ggml.forward(
  401. "StandardTransformerDecoder",
  402. g_model,
  403. "text_decoder",
  404. gx,
  405. None, # TODO support padding mask,
  406. genc,
  407. None,
  408. )
  409. ggml.build_and_compute(ctx, gy)
  410. y = ggml.to_numpy(gy)
  411. pt_model = load_pt_model()
  412. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  413. y_exp = y_exp.numpy()
  414. assert y.shape == y_exp.shape
  415. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
  416. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  417. src_lang = "eng"
  418. src_text = "We are all in a yellow submarine."
  419. tgt_lang = "fra"
  420. sample_file = Path(__file__).parent / "sample_input.npz"
  421. beam_size = 2
  422. if not sample_file.exists():
  423. translator = load_translator()
  424. device = translator.device
  425. token_encoder = translator.text_tokenizer.create_encoder(
  426. task="translation", lang=src_lang, mode="source", device=device
  427. )
  428. src = translator.collate(token_encoder(src_text))
  429. text_out, _ = translator.get_prediction(
  430. translator.model,
  431. translator.text_tokenizer,
  432. translator.unit_tokenizer,
  433. src,
  434. input_modality=Modality.TEXT,
  435. output_modality=Modality.TEXT,
  436. tgt_lang=tgt_lang,
  437. beam_size=beam_size,
  438. )
  439. tgt_text = str(text_out.sentences[0])
  440. assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  441. hypotheses = [
  442. {
  443. "seq": h.seq.tolist(),
  444. "score": h.score.item(),
  445. "step_scores": h.step_scores.numpy(),
  446. }
  447. for h in text_out.generator_output.results[0]
  448. ]
  449. np.savez(
  450. sample_file,
  451. encoder_output=text_out.encoder_output.numpy(),
  452. encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
  453. hypotheses=hypotheses,
  454. )
  455. # allow_pickle to load the hyp dicts
  456. text_out = np.load(sample_file, allow_pickle=True)
  457. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  458. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  459. prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
  460. max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
  461. job = ggml.SequenceGeneratorJob()
  462. job.opts.beam_size = beam_size
  463. job.opts.min_seq_len = 1
  464. job.opts.soft_max_seq_len_a = 1
  465. job.opts.soft_max_seq_len_b = 200
  466. job.opts.hard_max_seq_len = int(max_seq_len * 1.5)
  467. job.opts.len_penalty = 1.0
  468. job.opts.unk_penalty = 0.0
  469. job.opts.normalize_scores = True
  470. job.prefix_seq = ggml.from_numpy(ctx, prefix_seq)
  471. job.pad_idx = 0
  472. job.unk_idx = 1
  473. job.bos_idx = 2
  474. job.eos_idx = 3
  475. result_ptr = ggml.generate_sequence(
  476. g_model, job, encoder_out, encoder_padding_mask, ctx
  477. )
  478. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  479. assert len(results) == len(text_out["hypotheses"])
  480. for g_hyp, exp in zip(results, text_out["hypotheses"]):
  481. g_tokens = list(ggml.to_numpy(g_hyp.seq))
  482. g_step_scores = ggml.to_numpy(g_hyp.step_scores)
  483. assert g_tokens == exp["seq"]
  484. assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
  485. # The score error is big, this may negatively impact the beam search.
  486. assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
  487. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  488. src_audio_wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  489. # translator = load_translator()
  490. # token_encoder = translator.text_tokenizer.create_encoder(
  491. # task="translation"
  492. # )
  493. # decoded_audio = {
  494. # "waveform": src_audio_wav.t(),
  495. # "sample_rate": 16000.,
  496. # "format": -1,
  497. # }
  498. # src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  499. # text_out, _ = translator.get_prediction(
  500. # translator.model,
  501. # translator.text_tokenizer,
  502. # translator.unit_tokenizer,
  503. # src,
  504. # input_modality=Modality.SPEECH,
  505. # output_modality=Modality.TEXT,
  506. # tgt_lang="cmn",
  507. # )
  508. # tgt_text = str(text_out.sentences[0])
  509. # assert tgt_text == "大家好 , 世界无主题。"
  510. # tgt_tokens = text_out.generator_output.results[0][0].seq
  511. # score = text_out.generator_output.results[0][0].score.item()
  512. tgt_tokens = [ 3, 256200, 16991, 249346, 249725, 146, 25220, 251069, 249211,
  513. 251148, 253935, 3]
  514. score = -1.606838583946228
  515. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15) # Apply scale before sending into ggml!
  516. ggml.ggml_set_name(gx, b"x")
  517. gy = ggml.forward(
  518. "StandardConformerEncoder",
  519. g_model,
  520. "speech_encoder",
  521. gx,
  522. None, # TODO support padding mask
  523. )
  524. gf = ggml.ggml_build_forward(gy)
  525. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  526. encoder_out = gy
  527. job = ggml.SequenceGeneratorJob()
  528. job.opts.beam_size = 1
  529. job.opts.min_seq_len = 1
  530. job.opts.soft_max_seq_len_a = 1
  531. job.opts.soft_max_seq_len_b = 200
  532. job.opts.hard_max_seq_len = 20
  533. job.opts.len_penalty = 1.0
  534. job.opts.unk_penalty = 0.0
  535. job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
  536. job.opts.normalize_scores = True
  537. job.pad_idx = 0
  538. job.unk_idx = 1
  539. job.bos_idx = 2
  540. job.eos_idx = 3
  541. result = ggml.ggml_tensor()
  542. g_score = ggml.generate_sequence(
  543. g_model, job, encoder_out, None, ctypes.byref(result)
  544. )
  545. tokens = list(ggml.to_numpy(result))
  546. assert tokens == tgt_tokens
  547. assert g_score == pytest.approx(score)