test_unity_cpp.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. import functools
  12. from pathlib import Path
  13. from ctypes_utils import Ptr
  14. from ctypes import c_void_p
  15. from typing import Any
  16. from pathlib import Path
  17. from typing import Iterator
  18. from ggml import NativeObj
  19. from ggml_convert import convert_model, read_layer_config
  20. from seamless_communication.models.inference.translator import Translator, Modality
  21. from fairseq2.data.audio import WaveformToFbankConverter
  22. import torchaudio
  23. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  24. Ctx = ggml.ggml_context_p
  25. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  26. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=None)
  27. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  28. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  29. @pytest.fixture(name="ctx")
  30. def _ctx() -> Iterator[Ctx]:
  31. """Allocate a new context with 1024 MB of memory"""
  32. try:
  33. ctx = ggml.ggml_init(params=CTX_PARAMS)
  34. with torch.inference_mode():
  35. yield ctx
  36. finally:
  37. ggml.ggml_free(ctx)
  38. @functools.lru_cache()
  39. def _load_g_model_once() -> NativeObj:
  40. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  41. if not model_file.exists():
  42. convert_model("seamlessM4T_medium", model_file)
  43. return ggml.load_fairseq2_ggml_file(model_file)
  44. @pytest.fixture()
  45. def g_model(ctx: Ctx) -> c_void_p:
  46. model = _load_g_model_once()
  47. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  48. return model.ptr
  49. @functools.lru_cache(maxsize=1)
  50. def load_translator() -> Translator:
  51. return Translator(
  52. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  53. )
  54. def load_pt_model() -> Any:
  55. return load_translator().model
  56. def test_convert_linear(tmp_path: Path) -> None:
  57. module = fairseq2.nn.Linear(16, 24, True)
  58. layer_config = read_layer_config(module)
  59. assert layer_config == {"input_dim": 16, "output_dim": 24, "skip_init": False}
  60. module_file = Path("module.ggml")
  61. convert_model(module, module_file)
  62. g_module = ggml.load_fairseq2_ggml_file(module_file)
  63. for k, v in layer_config.items():
  64. assert ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  65. def test_causal_attention_mask(ctx: Ctx):
  66. x = torch.zeros((1, 10, 32))
  67. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  68. mask_exp = generator(x).numpy()
  69. gx = ggml.from_numpy(ctx, x)
  70. gmask = ggml.causal_attention_mask(ctx, gx)
  71. mask = ggml.to_numpy(gmask)
  72. gf = ggml.ggml_build_forward(gmask)
  73. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  74. assert mask_exp.shape == (10, 10)
  75. assert mask.shape == (10, 10)
  76. assert np.all(mask == mask_exp)
  77. x = x[:, :8, :]
  78. mask_exp = generator(x).numpy()
  79. gx = ggml.from_numpy(ctx, x)
  80. gmask = ggml.causal_attention_mask(ctx, gx)
  81. mask = ggml.to_numpy(gmask)
  82. gf = ggml.ggml_build_forward(gmask)
  83. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  84. assert mask_exp.shape == (8, 8)
  85. assert mask.shape == (8, 8)
  86. assert np.all(mask == mask_exp)
  87. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  88. x = torch.empty((2, 21, 1024))
  89. torch.nn.init.uniform_(x, -1, 1)
  90. pt_model = load_pt_model()
  91. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  92. gx = ggml.from_numpy(ctx, x)
  93. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  94. ggml.build_and_compute(ctx, gy)
  95. y = ggml.to_numpy(gy)
  96. assert np.allclose(y_exp, y, atol=1e-5)
  97. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  98. x = torch.empty((2, 21, 1024))
  99. torch.nn.init.uniform_(x, -1, 1)
  100. pt_model = load_pt_model()
  101. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  102. gx = ggml.from_numpy(ctx, x)
  103. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  104. ggml.build_and_compute(ctx, gy)
  105. y = ggml.to_numpy(gy)
  106. assert np.allclose(y_exp, y, atol=1e-5)
  107. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  108. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  109. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  110. # Test FFN without LayerNorm
  111. pt_model = load_pt_model()
  112. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  113. gx = ggml.from_numpy(ctx, x)
  114. gy = ggml.forward(
  115. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  116. )
  117. ggml.build_and_compute(ctx, gy)
  118. y = ggml.to_numpy(gy)
  119. assert np.allclose(y_exp, y, atol=1e-5)
  120. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  121. try:
  122. return tensor.contents.name # type: ignore[no-any-return]
  123. except ValueError:
  124. return b"???"
  125. @pytest.mark.parametrize("flash_attn", [False, True])
  126. def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: bool) -> None:
  127. x = torch.empty((2, 21, 1024))
  128. torch.random.manual_seed(0)
  129. torch.nn.init.uniform_(x, -1, 1)
  130. pt_model = load_pt_model()
  131. self_attn = pt_model.text_encoder.layers[0].self_attn
  132. # Note: we use different lengths for queries and keys,
  133. # this tests the implementation in decoding context too.
  134. # Note2: ggml_flash_attn requires that we have more keys than queries
  135. if flash_attn:
  136. xq = x[:, :11, :]
  137. xk = x
  138. else:
  139. xq = x
  140. xk = x[:, :13, :]
  141. gxq = ggml.from_numpy(ctx, xq.contiguous())
  142. gxk = ggml.from_numpy(ctx, xk)
  143. ggml.ggml_set_name(gxk, b"xk")
  144. gy = ggml.forward(
  145. "MultiheadAttention",
  146. g_model,
  147. "text_encoder.layers.0.self_attn",
  148. gxq,
  149. gxk,
  150. gxk,
  151. None, # TODO: tests with causal attention masks
  152. )
  153. gf = ggml.ggml_build_forward(gy)
  154. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  155. q_exp = self_attn.q_proj(xq).numpy()
  156. y = ggml.to_numpy(gy)
  157. nodes = {}
  158. for i in range(gf.n_nodes):
  159. name = _name(gf.nodes[i])
  160. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  161. print(name, f"op({gf.nodes[i].contents.op})", children)
  162. nodes[name] = ggml.to_numpy(gf.nodes[i])
  163. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  164. self_attn.register_attn_weight_hook(attn_weights_hook)
  165. y_exp = self_attn(xq, None, xk, xk).numpy()
  166. q = nodes[b"q"]
  167. assert q.shape == q_exp.shape
  168. assert np.allclose(q_exp, q, atol=1e-5)
  169. # with flash_attn we don't have attn_weights
  170. if not flash_attn:
  171. attn_weights = nodes[b"attn_weights"]
  172. [attn_weights_exp] = attn_weights_hook._storage
  173. attn_weights_exp = attn_weights_exp.numpy()
  174. assert attn_weights_exp.shape == attn_weights.shape
  175. # GGML is very agressively reducing small softmax weights to 0,
  176. # so the error isn't that small
  177. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  178. # But the sums should be close to 1
  179. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 1)))
  180. # And the maximum index should match the original ones.
  181. assert np.allclose(np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  182. )
  183. assert y.shape == y_exp.shape
  184. assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
  185. def test_StandardTransformerEncoderLayer_forward(
  186. ctx: Ctx, g_model: c_void_p
  187. ) -> None:
  188. x = torch.empty((2, 21, 1024))
  189. padding_mask = torch.ones((2, 21))
  190. torch.random.manual_seed(0)
  191. torch.nn.init.uniform_(x, -1, 1)
  192. pt_model = load_pt_model()
  193. layer = pt_model.text_encoder.layers[0]
  194. gx = ggml.from_numpy(ctx, x)
  195. ggml.ggml_set_name(gx, b"x")
  196. gpad = ggml.from_numpy(ctx, padding_mask)
  197. ggml.ggml_set_name(gpad, b"padding_mask")
  198. gy = ggml.forward(
  199. "StandardTransformerEncoderLayer",
  200. g_model,
  201. "text_encoder.layers.0",
  202. gx,
  203. None, # TODO support padding mask
  204. )
  205. gf = ggml.ggml_build_forward(gy)
  206. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  207. y = ggml.to_numpy(gy)
  208. y_exp, _ = layer(x, padding_mask)
  209. y_exp = y_exp.numpy()
  210. assert y.shape == y_exp.shape
  211. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  212. def test_StandardConformerEncoderLayer_forward(
  213. ctx: Ctx, g_model: c_void_p
  214. ) -> None:
  215. pt_model = load_pt_model()
  216. x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt")
  217. padding_mask = torch.ones((1, x.shape[1]))
  218. layer = pt_model.speech_encoder.inner.layers[0]
  219. gx = ggml.from_numpy(ctx, x[0])
  220. ggml.ggml_set_name(gx, b"x")
  221. gpad = ggml.from_numpy(ctx, padding_mask[0])
  222. ggml.ggml_set_name(gpad, b"padding_mask")
  223. gy = ggml.forward(
  224. "StandardConformerEncoderLayer",
  225. g_model,
  226. "speech_encoder.inner.layers.0",
  227. gx,
  228. None, # TODO support padding mask
  229. )
  230. gf = ggml.ggml_build_forward(gy)
  231. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  232. y = ggml.to_numpy(gy)
  233. y_exp, _ = layer(x, padding_mask)
  234. y_exp = y_exp.numpy()
  235. assert y.shape == y_exp.shape
  236. assert np.allclose(y_exp, y, atol=2e-3)
  237. def test_StandardConformerEncoderAdaptorLayer_forward(
  238. ctx: Ctx, g_model: c_void_p
  239. ) -> None:
  240. pt_model = load_pt_model()
  241. x = torch.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt")
  242. layer = pt_model.speech_encoder.adaptor_layers[0]
  243. gx = ggml.from_numpy(ctx, x[0])
  244. ggml.ggml_set_name(gx, b"x")
  245. gy = ggml.forward(
  246. "StandardConformerEncoderAdaptorLayer",
  247. g_model,
  248. "speech_encoder.adaptor_layers.0",
  249. gx,
  250. None, # TODO support padding mask
  251. )
  252. gf = ggml.ggml_build_forward(gy)
  253. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  254. y = ggml.to_numpy(gy)
  255. y_exp, _ = layer(x, None)
  256. y_exp = y_exp.numpy()
  257. assert y.shape == y_exp.shape
  258. assert np.allclose(y_exp, y, atol=2e-3)
  259. def test_StandardTransformerEncoder_forward(
  260. ctx: Ctx, g_model: c_void_p
  261. ) -> None:
  262. x = torch.empty((2, 21, 1024))
  263. padding_mask = torch.ones((2, 21))
  264. torch.random.manual_seed(0)
  265. torch.nn.init.uniform_(x, -1, 1)
  266. gx = ggml.from_numpy(ctx, x)
  267. ggml.ggml_set_name(gx, b"x")
  268. gpad = ggml.from_numpy(ctx, padding_mask)
  269. ggml.ggml_set_name(gpad, b"padding_mask")
  270. gy = ggml.forward(
  271. "StandardTransformerEncoder",
  272. g_model,
  273. "text_encoder",
  274. gx,
  275. None, # TODO support padding mask
  276. )
  277. gf = ggml.ggml_build_forward(gy)
  278. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  279. y = ggml.to_numpy(gy)
  280. pt_model = load_pt_model()
  281. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  282. y_exp = y_exp.numpy()
  283. assert y.shape == y_exp.shape
  284. assert np.allclose(y_exp, y, atol=1e-4)
  285. def test_StandardConformerEncoder_forward(
  286. ctx: Ctx, g_model: c_void_p
  287. ) -> None:
  288. pt_model = load_pt_model()
  289. wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  290. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  291. ggml.ggml_set_name(gx, b"x")
  292. gy = ggml.forward(
  293. "StandardConformerEncoder",
  294. g_model,
  295. "speech_encoder",
  296. gx,
  297. None, # TODO support padding mask
  298. )
  299. gf = ggml.ggml_build_forward(gy)
  300. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  301. converter = WaveformToFbankConverter(
  302. num_mel_bins=80,
  303. waveform_scale=2**15,
  304. channel_last=True,
  305. standardize=True,
  306. )
  307. converter_input = {
  308. "waveform": wav.transpose(0, 1),
  309. "sample_rate": 16000.,
  310. "format": -1,
  311. }
  312. y = ggml.to_numpy(gy)
  313. speech_encoder_input = pt_model.speech_encoder_frontend(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
  314. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  315. y_exp = y_exp.numpy() # remove batch dimension
  316. assert y.shape == y_exp.shape
  317. assert np.allclose(y_exp, y, atol=1e-2) # There are 10 elements in a 137*1024 tensor with error >1e-2
  318. def test_WaveformToFbank_forward(
  319. ctx: Ctx, g_model: c_void_p
  320. ) -> None:
  321. pt_model = load_pt_model()
  322. converter = WaveformToFbankConverter(
  323. num_mel_bins=80,
  324. waveform_scale=2**15,
  325. channel_last=True,
  326. standardize=True,
  327. )
  328. extractor = Wav2Vec2FbankFeatureExtractor(80, 2, 1)
  329. wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  330. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  331. ggml.ggml_set_name(gx, b"x")
  332. gy = ggml.forward(
  333. "WaveformToFbank",
  334. g_model,
  335. "",
  336. gx
  337. )
  338. gf = ggml.ggml_build_forward(gy)
  339. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  340. y = ggml.to_numpy(gy)
  341. converter_input = {
  342. "waveform": wav.transpose(0, 1),
  343. "sample_rate": 16000.,
  344. "format": -1,
  345. }
  346. y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
  347. y_exp = y_exp.numpy()
  348. assert y.shape == y_exp.shape
  349. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  350. def test_causal_attention_mask(ctx: Ctx):
  351. x = torch.zeros((5, 10))
  352. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  353. mask_exp = generator(x)
  354. gx = ggml.from_numpy(ctx, x)
  355. gmask = ggml.causal_attention_mask(ctx, gx)
  356. mask = ggml.to_numpy(gmask)
  357. assert mask_exp.shape == (10, 10)
  358. assert mask.shape == (10, 10)
  359. assert np.allclose(mask, mask_exp)
  360. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  361. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  362. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  363. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  364. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  365. y_exp = pos_encoder(seq, None)[0].numpy()
  366. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  367. ggml.ggml_set_name(gseq, b"seq")
  368. gy = ggml.forward(
  369. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  370. )
  371. gf = ggml.ggml_build_forward(gy)
  372. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  373. y = ggml.to_numpy(gy)
  374. assert y.shape == y_exp.shape
  375. assert np.allclose(y_exp, y, atol=1e-6)
  376. def test_TransformerEmbeddingFrontend_forward(
  377. ctx: Ctx, g_model: c_void_p
  378. ) -> None:
  379. seq = torch.arange(2 * 20).reshape(2, 20)
  380. seq[1, 15:] = 0 # padding for second sentence
  381. seq_len = torch.tensor([20, 15])
  382. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  383. ggml.ggml_set_name(gseq, b"seq")
  384. gy = ggml.forward(
  385. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  386. )
  387. ggml.build_and_compute(ctx, gy)
  388. y = ggml.to_numpy(gy)
  389. pt_model = load_pt_model()
  390. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  391. y_exp = y_exp.numpy()
  392. assert y.shape == y_exp.shape
  393. assert np.allclose(y_exp, y, atol=1e-6)
  394. def test_StandardTransformerDecoder_forward(
  395. ctx: Ctx, g_model: c_void_p
  396. ) -> None:
  397. x = torch.empty((2, 13, 1024))
  398. encoder_out = torch.empty((2, 21, 1024))
  399. padding_mask = torch.ones((2, 13))
  400. torch.random.manual_seed(0)
  401. torch.nn.init.uniform_(x, -1, 1)
  402. torch.nn.init.uniform_(encoder_out, -1, 1)
  403. gx = ggml.from_numpy(ctx, x)
  404. ggml.ggml_set_name(gx, b"x")
  405. gpad = ggml.from_numpy(ctx, padding_mask)
  406. ggml.ggml_set_name(gpad, b"padding_mask")
  407. genc = ggml.from_numpy(ctx, encoder_out)
  408. gy = ggml.forward(
  409. "StandardTransformerDecoder",
  410. g_model,
  411. "text_decoder",
  412. gx,
  413. None, # TODO support padding mask,
  414. genc,
  415. None,
  416. )
  417. ggml.build_and_compute(ctx, gy)
  418. y = ggml.to_numpy(gy)
  419. pt_model = load_pt_model()
  420. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  421. y_exp = y_exp.numpy()
  422. assert y.shape == y_exp.shape
  423. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
  424. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  425. src_lang = "eng"
  426. src_text = "We are all in a yellow submarine."
  427. tgt_lang = "fra"
  428. sample_file = Path(__file__).parent / "sample_input.npz"
  429. beam_size = 2
  430. if not sample_file.exists():
  431. translator = load_translator()
  432. device = translator.device
  433. token_encoder = translator.text_tokenizer.create_encoder(
  434. task="translation", lang=src_lang, mode="source", device=device
  435. )
  436. src = translator.collate(token_encoder(src_text))
  437. text_out, _ = translator.get_prediction(
  438. translator.model,
  439. translator.text_tokenizer,
  440. translator.unit_tokenizer,
  441. src,
  442. input_modality=Modality.TEXT,
  443. output_modality=Modality.TEXT,
  444. tgt_lang=tgt_lang,
  445. beam_size=beam_size,
  446. )
  447. tgt_text = str(text_out.sentences[0])
  448. assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  449. hypotheses = [
  450. {
  451. "seq": h.seq.tolist(),
  452. "score": h.score.item(),
  453. "step_scores": h.step_scores.numpy(),
  454. }
  455. for h in text_out.generator_output.results[0]
  456. ]
  457. np.savez(
  458. sample_file,
  459. encoder_output=text_out.encoder_output.numpy(),
  460. encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
  461. hypotheses=hypotheses,
  462. )
  463. # allow_pickle to load the hyp dicts
  464. text_out = np.load(sample_file, allow_pickle=True)
  465. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  466. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  467. prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
  468. max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
  469. job = ggml.SequenceGeneratorJob()
  470. job.opts.beam_size = beam_size
  471. job.opts.min_seq_len = 1
  472. job.opts.soft_max_seq_len_a = 1
  473. job.opts.soft_max_seq_len_b = 200
  474. job.opts.hard_max_seq_len = int(max_seq_len * 1.5)
  475. job.opts.len_penalty = 1.0
  476. job.opts.unk_penalty = 0.0
  477. job.opts.normalize_scores = True
  478. job.prefix_seq = ggml.from_numpy(ctx, prefix_seq)
  479. job.pad_idx = 0
  480. job.unk_idx = 1
  481. job.bos_idx = 2
  482. job.eos_idx = 3
  483. result_ptr = ggml.generate_sequence(
  484. g_model, job, encoder_out, encoder_padding_mask, ctx
  485. )
  486. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  487. assert len(results) == len(text_out["hypotheses"])
  488. for g_hyp, exp in zip(results, text_out["hypotheses"]):
  489. g_tokens = list(ggml.to_numpy(g_hyp.seq))
  490. g_step_scores = ggml.to_numpy(g_hyp.step_scores)
  491. assert g_tokens == exp["seq"]
  492. assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
  493. # The score error is big, this may negatively impact the beam search.
  494. assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
  495. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  496. src_audio_wav, _ = torchaudio.load("/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav")
  497. # translator = load_translator()
  498. # token_encoder = translator.text_tokenizer.create_encoder(
  499. # task="translation"
  500. # )
  501. # decoded_audio = {
  502. # "waveform": src_audio_wav.t(),
  503. # "sample_rate": 16000.,
  504. # "format": -1,
  505. # }
  506. # src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  507. # text_out, _ = translator.get_prediction(
  508. # translator.model,
  509. # translator.text_tokenizer,
  510. # translator.unit_tokenizer,
  511. # src,
  512. # input_modality=Modality.SPEECH,
  513. # output_modality=Modality.TEXT,
  514. # tgt_lang="cmn",
  515. # )
  516. # tgt_text = str(text_out.sentences[0])
  517. # assert tgt_text == "大家好 , 世界无主题。"
  518. # tgt_tokens = text_out.generator_output.results[0][0].seq
  519. # score = text_out.generator_output.results[0][0].score.item()
  520. tgt_tokens = [ 3, 256200, 16991, 249346, 249725, 146, 25220, 251069, 249211,
  521. 251148, 253935, 3] # "大家好 , 世界无主题。"
  522. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15) # Apply scale before sending into ggml!
  523. ggml.ggml_set_name(gx, b"x")
  524. gy = ggml.forward(
  525. "StandardConformerEncoder",
  526. g_model,
  527. "speech_encoder",
  528. gx,
  529. None, # TODO support padding mask
  530. )
  531. gf = ggml.ggml_build_forward(gy)
  532. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  533. encoder_out = gy
  534. job = ggml.SequenceGeneratorJob()
  535. job.opts.beam_size = 5
  536. job.opts.min_seq_len = 1
  537. job.opts.soft_max_seq_len_a = 1
  538. job.opts.soft_max_seq_len_b = 200
  539. job.opts.hard_max_seq_len = 1000
  540. job.opts.len_penalty = 1.0
  541. job.opts.unk_penalty = 0.0
  542. job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
  543. job.opts.normalize_scores = True
  544. job.pad_idx = 0
  545. job.unk_idx = 1
  546. job.bos_idx = 2
  547. job.eos_idx = 3
  548. result_ptr = ggml.generate_sequence(
  549. g_model, job, encoder_out, None, ctx
  550. )
  551. g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
  552. assert g_tokens == tgt_tokens