test_unity_cpp.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. import functools
  12. from typing import Tuple
  13. from pathlib import Path
  14. from ctypes_utils import Ptr
  15. from ctypes import c_void_p
  16. from typing import Any
  17. from pathlib import Path
  18. from typing import Iterator
  19. from ggml import NativeObj
  20. from ggml_convert import convert_model, read_layer_config
  21. from seamless_communication.models.inference.translator import Translator, Modality
  22. from fairseq2.data.audio import WaveformToFbankConverter
  23. import torchaudio
  24. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  25. Ctx = ggml.ggml_context_p
  26. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  27. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=None)
  28. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  29. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  30. @pytest.fixture(name="ctx")
  31. def _ctx() -> Iterator[Ctx]:
  32. """Allocate a new context with 1024 MB of memory"""
  33. try:
  34. ctx = ggml.ggml_init(params=CTX_PARAMS)
  35. with torch.inference_mode():
  36. yield ctx
  37. finally:
  38. ggml.ggml_free(ctx)
  39. @functools.lru_cache()
  40. def _load_g_model_once() -> NativeObj:
  41. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  42. if not model_file.exists():
  43. convert_model("seamlessM4T_medium", model_file)
  44. return ggml.load_fairseq2_ggml_file(model_file)
  45. @pytest.fixture()
  46. def g_model(ctx: Ctx) -> c_void_p:
  47. model = _load_g_model_once()
  48. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  49. return model.ptr
  50. @functools.lru_cache(maxsize=1)
  51. def load_translator() -> Translator:
  52. return Translator(
  53. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  54. )
  55. def load_pt_model() -> Any:
  56. return load_translator().model
  57. def test_convert_linear(tmp_path: Path) -> None:
  58. module = fairseq2.nn.Linear(16, 24, True)
  59. layer_config = read_layer_config(module)
  60. assert layer_config == {"input_dim": 16, "output_dim": 24, "skip_init": False}
  61. module_file = Path("module.ggml")
  62. convert_model(module, module_file)
  63. g_module = ggml.load_fairseq2_ggml_file(module_file)
  64. for k, v in layer_config.items():
  65. assert (
  66. ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  67. )
  68. def test_causal_attention_mask(ctx: Ctx):
  69. x = torch.zeros((1, 10, 32))
  70. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  71. mask_exp = generator(x).numpy()
  72. gx = ggml.from_numpy(ctx, x)
  73. gmask = ggml.causal_attention_mask(ctx, gx)
  74. mask = ggml.to_numpy(gmask)
  75. gf = ggml.ggml_build_forward(gmask)
  76. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  77. assert mask_exp.shape == (10, 10)
  78. assert mask.shape == (10, 10)
  79. assert np.all(mask == mask_exp)
  80. x = x[:, :8, :]
  81. mask_exp = generator(x).numpy()
  82. gx = ggml.from_numpy(ctx, x)
  83. gmask = ggml.causal_attention_mask(ctx, gx)
  84. mask = ggml.to_numpy(gmask)
  85. gf = ggml.ggml_build_forward(gmask)
  86. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  87. assert mask_exp.shape == (8, 8)
  88. assert mask.shape == (8, 8)
  89. assert np.all(mask == mask_exp)
  90. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  91. x = torch.empty((2, 21, 1024))
  92. torch.nn.init.uniform_(x, -1, 1)
  93. pt_model = load_pt_model()
  94. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  95. gx = ggml.from_numpy(ctx, x)
  96. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  97. ggml.build_and_compute(ctx, gy)
  98. y = ggml.to_numpy(gy)
  99. assert np.allclose(y_exp, y, atol=1e-5)
  100. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  101. x = torch.empty((2, 21, 1024))
  102. torch.nn.init.uniform_(x, -1, 1)
  103. pt_model = load_pt_model()
  104. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  105. gx = ggml.from_numpy(ctx, x)
  106. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  107. ggml.build_and_compute(ctx, gy)
  108. y = ggml.to_numpy(gy)
  109. assert np.allclose(y_exp, y, atol=1e-5)
  110. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  111. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  112. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  113. # Test FFN without LayerNorm
  114. pt_model = load_pt_model()
  115. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  116. gx = ggml.from_numpy(ctx, x)
  117. gy = ggml.forward(
  118. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  119. )
  120. ggml.build_and_compute(ctx, gy)
  121. y = ggml.to_numpy(gy)
  122. assert np.allclose(y_exp, y, atol=1e-5)
  123. @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
  124. def test_MultiheadAttention_forward(
  125. ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
  126. ) -> None:
  127. x = torch.empty((2, 21, 1024))
  128. torch.random.manual_seed(0)
  129. torch.nn.init.uniform_(x, -1, 1)
  130. # Note: we use different lengths for queries and keys,
  131. # this tests the implementation in decoding context too.
  132. # Note2: ggml_flash_attn requires that we have more keys than queries
  133. # qlen, klen = (11, 21) if flash_attn else (21, 13)
  134. qlen, klen = lengths
  135. xq = x[:, :qlen]
  136. xk = x[:, :klen]
  137. if qlen > klen and UNITY_FLASH_ATTN:
  138. pytest.skip(reason="flash_attn requires qlen > klen")
  139. gxq = ggml.from_numpy(ctx, xq.contiguous())
  140. gxk = ggml.from_numpy(ctx, xk.contiguous())
  141. ggml.ggml_set_name(gxk, b"xk")
  142. gy = ggml.forward(
  143. "MultiheadAttention",
  144. g_model,
  145. "text_encoder.layers.0.self_attn",
  146. gxq,
  147. gxk,
  148. gxk,
  149. None, # TODO: tests with causal attention masks
  150. )
  151. gf = ggml.ggml_build_forward(gy)
  152. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  153. pt_model = load_pt_model()
  154. self_attn = pt_model.text_encoder.layers[0].self_attn
  155. q_exp = self_attn.q_proj(xq).numpy()
  156. y = ggml.to_numpy(gy)
  157. nodes = ggml.nodes(gf)
  158. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  159. self_attn.register_attn_weight_hook(attn_weights_hook)
  160. y_exp = self_attn(xq, None, xk, xk).numpy()
  161. q = ggml.to_numpy(nodes[b"q"])
  162. assert q.shape == q_exp.shape
  163. assert np.allclose(q_exp, q, atol=1e-5)
  164. # with flash_attn we don't have attn_weights
  165. naive_attn = b"attn_weights" in nodes
  166. if naive_attn:
  167. attn_weights = ggml.to_numpy(nodes[b"attn_weights"])
  168. [attn_weights_exp] = attn_weights_hook._storage
  169. attn_weights_exp = attn_weights_exp.numpy()
  170. assert attn_weights_exp.shape == attn_weights.shape
  171. # GGML is very agressively reducing small softmax weights to 0,
  172. # so the error isn't that small
  173. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  174. # But the sums should be close to 1
  175. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, qlen)))
  176. # And the maximum index should match the original ones.
  177. assert np.allclose(
  178. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  179. )
  180. assert y.shape == y_exp.shape
  181. assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
  182. def test_MultiheadAttention_forward_self_attn_with_cache(
  183. ctx: Ctx, g_model: c_void_p
  184. ) -> None:
  185. pt_model = load_pt_model()
  186. attn = pt_model.text_decoder.layers[0].self_attn
  187. x = torch.empty((2, 21, 1024))
  188. torch.random.manual_seed(0)
  189. torch.nn.init.uniform_(x, -1, 1)
  190. state_bag = fairseq2.nn.IncrementalStateBag()
  191. ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
  192. # Incremental decoding
  193. for t in range(3):
  194. xq = x[:, t : t + 1]
  195. y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
  196. assert y_exp.shape == (2, 1, 1024)
  197. gxq = ggml.from_numpy(ctx, xq.contiguous())
  198. ggml.ggml_set_name(gxq, b"xq")
  199. gy = ggml.forward(
  200. "MultiheadAttention",
  201. g_model,
  202. "text_decoder.layers.0.self_attn",
  203. gxq,
  204. gxq,
  205. gxq,
  206. None, # type: ignore
  207. )
  208. gf = ggml.ggml_build_forward(gy)
  209. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  210. nodes = ggml.nodes(gf)
  211. state = state_bag.get_state(
  212. attn, fairseq2.nn.transformer.MultiheadAttentionState
  213. )
  214. assert state is not None
  215. assert np.allclose(
  216. state.prev_k.numpy(),
  217. ggml.to_numpy(nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]),
  218. atol=1e-3,
  219. )
  220. y = ggml.to_numpy(gy)
  221. assert np.allclose(y, y_exp, atol=1e-2)
  222. def test_MultiheadAttention_forward_cross_attn_with_cache(
  223. ctx: Ctx, g_model: c_void_p
  224. ) -> None:
  225. pt_model = load_pt_model()
  226. attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
  227. x = torch.empty((2, 21, 1024))
  228. torch.random.manual_seed(0)
  229. torch.nn.init.uniform_(x, -1, 1)
  230. state_bag = fairseq2.nn.IncrementalStateBag()
  231. ggml.fairseq2_kv_cache_alloc(g_model, 2, 21)
  232. # Incremental decoding, the keys come from the encoder, and don't change during decoding
  233. xk = x[:, :11]
  234. gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
  235. for t in range(3):
  236. xq = x[:, t : t + 1]
  237. gxq = ggml.from_numpy(ctx, xq.contiguous())
  238. ggml.ggml_set_name(gxq, b"xq")
  239. gy = ggml.forward(
  240. "MultiheadAttention",
  241. g_model,
  242. "text_decoder.layers.0.encoder_decoder_attn",
  243. gxq,
  244. gxk,
  245. gxk,
  246. None, # type: ignore
  247. )
  248. gf = ggml.ggml_build_forward(gy)
  249. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  250. y = ggml.to_numpy(gy)
  251. nodes = ggml.nodes(gf)
  252. leaves = ggml.leafs(gf)
  253. if t > 0:
  254. # the cache only appear in the graph during the second call
  255. state = state_bag.get_state(
  256. attn, fairseq2.nn.transformer.MultiheadAttentionState
  257. )
  258. assert state is not None
  259. assert np.allclose(
  260. state.prev_k.numpy(),
  261. ggml.to_numpy(nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]),
  262. atol=1e-3,
  263. )
  264. y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
  265. assert y_exp.shape == (2, 1, 1024)
  266. assert np.allclose(y, y_exp, atol=1e-2)
  267. def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  268. x = torch.empty((2, 21, 1024))
  269. padding_mask = torch.ones((2, 21))
  270. torch.random.manual_seed(0)
  271. torch.nn.init.uniform_(x, -1, 1)
  272. pt_model = load_pt_model()
  273. layer = pt_model.text_encoder.layers[0]
  274. gx = ggml.from_numpy(ctx, x)
  275. ggml.ggml_set_name(gx, b"x")
  276. gpad = ggml.from_numpy(ctx, padding_mask)
  277. ggml.ggml_set_name(gpad, b"padding_mask")
  278. gy = ggml.forward(
  279. "StandardTransformerEncoderLayer",
  280. g_model,
  281. "text_encoder.layers.0",
  282. gx,
  283. None, # TODO support padding mask
  284. )
  285. gf = ggml.ggml_build_forward(gy)
  286. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  287. y = ggml.to_numpy(gy)
  288. y_exp, _ = layer(x, padding_mask)
  289. y_exp = y_exp.numpy()
  290. assert y.shape == y_exp.shape
  291. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  292. def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  293. pt_model = load_pt_model()
  294. x = torch.load(
  295. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_conformer_block.pt"
  296. )
  297. padding_mask = torch.ones((1, x.shape[1]))
  298. layer = pt_model.speech_encoder.inner.layers[0]
  299. gx = ggml.from_numpy(ctx, x[0])
  300. ggml.ggml_set_name(gx, b"x")
  301. gpad = ggml.from_numpy(ctx, padding_mask[0])
  302. ggml.ggml_set_name(gpad, b"padding_mask")
  303. gy = ggml.forward(
  304. "StandardConformerEncoderLayer",
  305. g_model,
  306. "speech_encoder.inner.layers.0",
  307. gx,
  308. None, # TODO support padding mask
  309. )
  310. gf = ggml.ggml_build_forward(gy)
  311. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  312. y = ggml.to_numpy(gy)
  313. y_exp, _ = layer(x, padding_mask)
  314. y_exp = y_exp.numpy()
  315. assert y.shape == y_exp.shape
  316. assert np.allclose(y_exp, y, atol=2e-3)
  317. def test_StandardConformerEncoderAdaptorLayer_forward(
  318. ctx: Ctx, g_model: c_void_p
  319. ) -> None:
  320. pt_model = load_pt_model()
  321. x = torch.load(
  322. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev/seqs_before_adaptor.pt"
  323. )
  324. layer = pt_model.speech_encoder.adaptor_layers[0]
  325. gx = ggml.from_numpy(ctx, x[0])
  326. ggml.ggml_set_name(gx, b"x")
  327. gy = ggml.forward(
  328. "StandardConformerEncoderAdaptorLayer",
  329. g_model,
  330. "speech_encoder.adaptor_layers.0",
  331. gx,
  332. None, # TODO support padding mask
  333. )
  334. gf = ggml.ggml_build_forward(gy)
  335. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  336. y = ggml.to_numpy(gy)
  337. y_exp, _ = layer(x, None)
  338. y_exp = y_exp.numpy()
  339. assert y.shape == y_exp.shape
  340. assert np.allclose(y_exp, y, atol=2e-3)
  341. def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  342. x = torch.empty((2, 21, 1024))
  343. padding_mask = torch.ones((2, 21))
  344. torch.random.manual_seed(0)
  345. torch.nn.init.uniform_(x, -1, 1)
  346. gx = ggml.from_numpy(ctx, x)
  347. ggml.ggml_set_name(gx, b"x")
  348. gpad = ggml.from_numpy(ctx, padding_mask)
  349. ggml.ggml_set_name(gpad, b"padding_mask")
  350. gy = ggml.forward(
  351. "StandardTransformerEncoder",
  352. g_model,
  353. "text_encoder",
  354. gx,
  355. None, # TODO support padding mask
  356. )
  357. gf = ggml.ggml_build_forward(gy)
  358. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  359. y = ggml.to_numpy(gy)
  360. pt_model = load_pt_model()
  361. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  362. y_exp = y_exp.numpy()
  363. assert y.shape == y_exp.shape
  364. assert np.allclose(y_exp, y, atol=1e-4)
  365. def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  366. pt_model = load_pt_model()
  367. wav, _ = torchaudio.load(
  368. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
  369. )
  370. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  371. ggml.ggml_set_name(gx, b"x")
  372. gy = ggml.forward(
  373. "StandardConformerEncoder",
  374. g_model,
  375. "speech_encoder",
  376. gx,
  377. None, # TODO support padding mask
  378. )
  379. gf = ggml.ggml_build_forward(gy)
  380. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  381. converter = WaveformToFbankConverter(
  382. num_mel_bins=80,
  383. waveform_scale=2**15,
  384. channel_last=True,
  385. standardize=True,
  386. )
  387. converter_input = {
  388. "waveform": wav.transpose(0, 1),
  389. "sample_rate": 16000.0,
  390. "format": -1,
  391. }
  392. y = ggml.to_numpy(gy)
  393. speech_encoder_input = pt_model.speech_encoder_frontend(
  394. converter(converter_input)["fbank"].unsqueeze(0), None
  395. )[0]
  396. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  397. y_exp = y_exp.numpy() # remove batch dimension
  398. assert y.shape == y_exp.shape
  399. assert np.allclose(
  400. y_exp, y, atol=1e-2
  401. ) # There are 10 elements in a 137*1024 tensor with error >1e-2
  402. def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
  403. pt_model = load_pt_model()
  404. converter = WaveformToFbankConverter(
  405. num_mel_bins=80,
  406. waveform_scale=2**15,
  407. channel_last=True,
  408. standardize=True,
  409. )
  410. extractor = Wav2Vec2FbankFeatureExtractor(80, 2, 1)
  411. wav, _ = torchaudio.load(
  412. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
  413. )
  414. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  415. ggml.ggml_set_name(gx, b"x")
  416. gy = ggml.forward("WaveformToFbank", g_model, "", gx)
  417. gf = ggml.ggml_build_forward(gy)
  418. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  419. y = ggml.to_numpy(gy)
  420. converter_input = {
  421. "waveform": wav.transpose(0, 1),
  422. "sample_rate": 16000.0,
  423. "format": -1,
  424. }
  425. y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
  426. y_exp = y_exp.numpy()
  427. assert y.shape == y_exp.shape
  428. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  429. def test_causal_attention_mask(ctx: Ctx):
  430. x = torch.zeros((5, 10))
  431. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  432. mask_exp = generator(x)
  433. gx = ggml.from_numpy(ctx, x)
  434. gmask = ggml.causal_attention_mask(ctx, gx)
  435. mask = ggml.to_numpy(gmask)
  436. assert mask_exp.shape == (10, 10)
  437. assert mask.shape == (10, 10)
  438. assert np.allclose(mask, mask_exp)
  439. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  440. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  441. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  442. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  443. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  444. y_exp = pos_encoder(seq, None)[0].numpy()
  445. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  446. ggml.ggml_set_name(gseq, b"seq")
  447. gy = ggml.forward(
  448. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  449. )
  450. gf = ggml.ggml_build_forward(gy)
  451. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  452. y = ggml.to_numpy(gy)
  453. assert y.shape == y_exp.shape
  454. assert np.allclose(y_exp, y, atol=1e-6)
  455. def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
  456. seq = torch.arange(2 * 20).reshape(2, 20)
  457. seq[1, 15:] = 0 # padding for second sentence
  458. seq_len = torch.tensor([20, 15])
  459. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  460. ggml.ggml_set_name(gseq, b"seq")
  461. gy = ggml.forward(
  462. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  463. )
  464. ggml.build_and_compute(ctx, gy)
  465. y = ggml.to_numpy(gy)
  466. pt_model = load_pt_model()
  467. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  468. y_exp = y_exp.numpy()
  469. assert y.shape == y_exp.shape
  470. assert np.allclose(y_exp, y, atol=1e-6)
  471. def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  472. x = torch.empty((2, 13, 1024))
  473. encoder_out = torch.empty((2, 21, 1024))
  474. padding_mask = torch.ones((2, 13))
  475. torch.random.manual_seed(0)
  476. torch.nn.init.uniform_(x, -1, 1)
  477. torch.nn.init.uniform_(encoder_out, -1, 1)
  478. gx = ggml.from_numpy(ctx, x)
  479. ggml.ggml_set_name(gx, b"x")
  480. gpad = ggml.from_numpy(ctx, padding_mask)
  481. ggml.ggml_set_name(gpad, b"padding_mask")
  482. genc = ggml.from_numpy(ctx, encoder_out)
  483. gy = ggml.forward(
  484. "StandardTransformerDecoder",
  485. g_model,
  486. "text_decoder",
  487. gx,
  488. None, # TODO support padding mask,
  489. genc,
  490. None,
  491. )
  492. ggml.build_and_compute(ctx, gy)
  493. y = ggml.to_numpy(gy)
  494. pt_model = load_pt_model()
  495. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  496. y_exp = y_exp.numpy()
  497. assert y.shape == y_exp.shape
  498. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
  499. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  500. src_lang = "eng"
  501. src_text = "We are all in a yellow submarine."
  502. tgt_lang = "fra"
  503. sample_file = Path(__file__).parent / "sample_input.npz"
  504. beam_size = 2
  505. if not sample_file.exists():
  506. translator = load_translator()
  507. device = translator.device
  508. token_encoder = translator.text_tokenizer.create_encoder(
  509. task="translation", lang=src_lang, mode="source", device=device
  510. )
  511. src = translator.collate(token_encoder(src_text))
  512. text_out, _ = translator.get_prediction(
  513. translator.model,
  514. translator.text_tokenizer,
  515. translator.unit_tokenizer,
  516. src,
  517. input_modality=Modality.TEXT,
  518. output_modality=Modality.TEXT,
  519. tgt_lang=tgt_lang,
  520. beam_size=beam_size,
  521. )
  522. tgt_text = str(text_out.sentences[0])
  523. assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  524. hypotheses = [
  525. {
  526. "seq": h.seq.tolist(),
  527. "score": h.score.item(),
  528. "step_scores": h.step_scores.numpy(),
  529. }
  530. for h in text_out.generator_output.results[0]
  531. ]
  532. np.savez(
  533. sample_file,
  534. encoder_output=text_out.encoder_output.numpy(),
  535. encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
  536. hypotheses=hypotheses,
  537. )
  538. # allow_pickle to load the hyp dicts
  539. text_out = np.load(sample_file, allow_pickle=True)
  540. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  541. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  542. prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
  543. max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
  544. job = ggml.SequenceGeneratorJob()
  545. job.opts.beam_size = beam_size
  546. job.opts.min_seq_len = 1
  547. job.opts.soft_max_seq_len_a = 1
  548. job.opts.soft_max_seq_len_b = 200
  549. job.opts.hard_max_seq_len = int(max_seq_len * 1.5)
  550. job.opts.len_penalty = 1.0
  551. job.opts.unk_penalty = 0.0
  552. job.opts.normalize_scores = True
  553. job.prefix_seq = ggml.from_numpy(ctx, prefix_seq)
  554. job.pad_idx = 0
  555. job.unk_idx = 1
  556. job.bos_idx = 2
  557. job.eos_idx = 3
  558. result_ptr = ggml.generate_sequence(
  559. g_model, job, encoder_out, encoder_padding_mask, ctx
  560. )
  561. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  562. assert len(results) == len(text_out["hypotheses"])
  563. for g_hyp, exp in zip(results, text_out["hypotheses"]):
  564. g_tokens = list(ggml.to_numpy(g_hyp.seq))
  565. g_step_scores = ggml.to_numpy(g_hyp.step_scores)
  566. assert g_tokens == exp["seq"]
  567. assert g_hyp.score == pytest.approx(exp["score"], rel=1e-2)
  568. # The score error is big, this may negatively impact the beam search.
  569. assert np.allclose(g_step_scores, exp["step_scores"], atol=0.1)
  570. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  571. src_audio_wav, _ = torchaudio.load(
  572. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/test.wav"
  573. )
  574. # translator = load_translator()
  575. # token_encoder = translator.text_tokenizer.create_encoder(
  576. # task="translation"
  577. # )
  578. # decoded_audio = {
  579. # "waveform": src_audio_wav.t(),
  580. # "sample_rate": 16000.,
  581. # "format": -1,
  582. # }
  583. # src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  584. # text_out, _ = translator.get_prediction(
  585. # translator.model,
  586. # translator.text_tokenizer,
  587. # translator.unit_tokenizer,
  588. # src,
  589. # input_modality=Modality.SPEECH,
  590. # output_modality=Modality.TEXT,
  591. # tgt_lang="cmn",
  592. # )
  593. # tgt_text = str(text_out.sentences[0])
  594. # assert tgt_text == "大家好 , 世界无主题。"
  595. # tgt_tokens = text_out.generator_output.results[0][0].seq
  596. # score = text_out.generator_output.results[0][0].score.item()
  597. tgt_tokens = [
  598. 3,
  599. 256200,
  600. 16991,
  601. 249346,
  602. 249725,
  603. 146,
  604. 25220,
  605. 251069,
  606. 249211,
  607. 251148,
  608. 253935,
  609. 3,
  610. ] # "大家好 , 世界无主题。"
  611. gx = ggml.from_numpy(
  612. ctx, src_audio_wav * 2**15
  613. ) # Apply scale before sending into ggml!
  614. ggml.ggml_set_name(gx, b"x")
  615. gy = ggml.forward(
  616. "StandardConformerEncoder",
  617. g_model,
  618. "speech_encoder",
  619. gx,
  620. None, # TODO support padding mask
  621. )
  622. gf = ggml.ggml_build_forward(gy)
  623. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  624. encoder_out = gy
  625. job = ggml.SequenceGeneratorJob()
  626. job.opts.beam_size = 5
  627. job.opts.min_seq_len = 1
  628. job.opts.soft_max_seq_len_a = 1
  629. job.opts.soft_max_seq_len_b = 200
  630. job.opts.hard_max_seq_len = 1000
  631. job.opts.len_penalty = 1.0
  632. job.opts.unk_penalty = 0.0
  633. job.prefix_seq = ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32))
  634. job.opts.normalize_scores = True
  635. job.pad_idx = 0
  636. job.unk_idx = 1
  637. job.bos_idx = 2
  638. job.eos_idx = 3
  639. result_ptr = ggml.generate_sequence(g_model, job, encoder_out, None, ctx)
  640. g_tokens = list(ggml.to_numpy(result_ptr[0].seq))
  641. assert g_tokens == tgt_tokens