test_unity_cpp.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. import ctypes
  2. import functools
  3. import logging
  4. import sys
  5. from ctypes import c_void_p
  6. from pathlib import Path
  7. from typing import Any, Iterator, List, Tuple
  8. import fairseq2.nn
  9. import fairseq2.nn.transformer
  10. import numpy as np
  11. import pytest
  12. import torch
  13. import torchaudio
  14. from fairseq2.data.audio import WaveformToFbankConverter
  15. from fairseq2.generation import SequenceGeneratorOptions
  16. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  17. from seamless_communication.inference.translator import Modality, Translator
  18. import ggml
  19. from ctypes_utils import NULLPTR, Ptr
  20. from ggml import NativeObj
  21. from ggml_convert import convert_model, read_layer_config
  22. Ctx = ggml.ggml_context_p
  23. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  24. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024 * 5, mem_buffer=None)
  25. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  26. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  27. DATA = Path(__file__).parent
  28. DATA_DEV = DATA / "dev"
  29. if not DATA_DEV.exists():
  30. DATA_DEV = Path(
  31. "/private/home/dnn/internal_sc/seamless_communication/ggml/examples/unity/dev"
  32. )
  33. @pytest.fixture(name="ctx")
  34. def _ctx() -> Iterator[Ctx]:
  35. """Allocate a new context with 1024 MB of memory"""
  36. try:
  37. ctx = ggml.ggml_init(params=CTX_PARAMS)
  38. with torch.inference_mode():
  39. yield ctx
  40. finally:
  41. ggml.ggml_free(ctx)
  42. @functools.lru_cache()
  43. def _load_g_model_once() -> NativeObj:
  44. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  45. if not model_file.exists():
  46. convert_model("seamlessM4T_medium", model_file)
  47. return ggml.load_fairseq2_ggml_file(model_file)
  48. @pytest.fixture()
  49. def g_model(ctx: Ctx) -> c_void_p:
  50. model = _load_g_model_once()
  51. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  52. return model.ptr
  53. @functools.lru_cache(maxsize=1)
  54. def load_translator() -> Translator:
  55. return Translator(
  56. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  57. )
  58. def load_pt_model() -> Any:
  59. return load_translator().model
  60. def test_convert_linear(tmp_path: Path) -> None:
  61. module = fairseq2.nn.Linear(16, 24, True)
  62. layer_config = read_layer_config(module)
  63. assert layer_config == {"input_dim": 16, "output_dim": 24}
  64. module_file = Path("module.ggml")
  65. convert_model(module, module_file)
  66. g_module = ggml.load_fairseq2_ggml_file(module_file)
  67. for k, v in layer_config.items():
  68. assert (
  69. ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  70. )
  71. def test_causal_attention_mask(ctx: Ctx):
  72. x = torch.zeros((1, 10, 32))
  73. generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
  74. mask_exp = generator(x, x).materialize().numpy()
  75. gx = ggml.from_numpy(ctx, x)
  76. gmask = ggml.causal_attention_mask(ctx, gx)
  77. mask = ggml.to_numpy(gmask)
  78. gf = ggml.ggml_build_forward(gmask)
  79. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  80. assert mask_exp.shape == (10, 10)
  81. assert mask.shape == (10, 10)
  82. assert np.all(mask == mask_exp)
  83. x = x[:, :8, :]
  84. mask_exp = generator(x, x).materialize().numpy()
  85. gx = ggml.from_numpy(ctx, x)
  86. gmask = ggml.causal_attention_mask(ctx, gx)
  87. mask = ggml.to_numpy(gmask)
  88. gf = ggml.ggml_build_forward(gmask)
  89. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  90. assert mask_exp.shape == (8, 8)
  91. assert mask.shape == (8, 8)
  92. assert np.all(mask == mask_exp)
  93. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  94. x = torch.empty((2, 21, 1024))
  95. torch.nn.init.uniform_(x, -1, 1)
  96. pt_model = load_pt_model()
  97. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  98. gx = ggml.from_numpy(ctx, x)
  99. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  100. ggml.build_and_compute(ctx, gy)
  101. y = ggml.to_numpy(gy)
  102. assert np.allclose(y_exp, y, atol=1e-5)
  103. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  104. x = torch.empty((2, 21, 1024))
  105. torch.nn.init.uniform_(x, -1, 1)
  106. pt_model = load_pt_model()
  107. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  108. gx = ggml.from_numpy(ctx, x)
  109. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  110. ggml.build_and_compute(ctx, gy)
  111. y = ggml.to_numpy(gy)
  112. assert np.allclose(y_exp, y, atol=1e-5)
  113. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  114. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  115. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  116. # Test FFN without LayerNorm
  117. pt_model = load_pt_model()
  118. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  119. gx = ggml.from_numpy(ctx, x)
  120. gy = ggml.forward(
  121. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  122. )
  123. ggml.build_and_compute(ctx, gy)
  124. y = ggml.to_numpy(gy)
  125. assert np.allclose(y_exp, y, atol=1e-5)
  126. @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
  127. def test_MultiheadAttention_forward(
  128. ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
  129. ) -> None:
  130. x = torch.empty((2, 21, 1024))
  131. torch.random.manual_seed(0)
  132. torch.nn.init.uniform_(x, -1, 1)
  133. # Note: we use different lengths for queries and keys,
  134. # this tests the implementation in decoding context too.
  135. # Note2: ggml_flash_attn requires that we have more keys than queries
  136. # qlen, klen = (11, 21) if flash_attn else (21, 13)
  137. qlen, klen = lengths
  138. xq = x[:, :qlen]
  139. xk = x[:, :klen]
  140. if qlen > klen and UNITY_FLASH_ATTN:
  141. pytest.skip(reason="flash_attn requires qlen > klen")
  142. gxq = ggml.from_numpy(ctx, xq.contiguous())
  143. gxk = ggml.from_numpy(ctx, xk.contiguous())
  144. ggml.ggml_set_name(gxk, b"xk")
  145. gy = ggml.forward(
  146. "MultiheadAttention",
  147. g_model,
  148. "text_encoder.layers.0.self_attn",
  149. gxq,
  150. gxk,
  151. gxk,
  152. NULLPTR, # TODO: tests with causal attention masks
  153. )
  154. gf = ggml.ggml_build_forward(gy)
  155. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  156. pt_model = load_pt_model()
  157. self_attn = pt_model.text_encoder.layers[0].self_attn
  158. q_exp = self_attn.q_proj(xq).numpy()
  159. y = ggml.to_numpy(gy)
  160. nodes = ggml.nodes(gf)
  161. attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
  162. self_attn.register_attn_weight_hook(attn_weights_hook)
  163. y_exp = self_attn(xq, None, xk, None, xk).numpy()
  164. q = ggml.to_numpy(nodes[b"q"])
  165. assert q.shape == q_exp.shape
  166. assert np.allclose(q_exp, q, atol=1e-5)
  167. # with flash_attn we don't have attn_weights
  168. naive_attn = b"attn_weights" in nodes
  169. if naive_attn:
  170. attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
  171. [(_, attn_weights_exp)] = attn_weights_hook._storage
  172. attn_weights_exp = attn_weights_exp.numpy()
  173. assert attn_weights_exp.shape == attn_weights.shape
  174. # GGML is very agressively reducing small softmax weights to 0,
  175. # so the error isn't that small
  176. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  177. # But the sums should be close to 1
  178. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
  179. # And the maximum index should match the original ones.
  180. assert np.allclose(
  181. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  182. )
  183. assert y.shape == y_exp.shape
  184. assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
  185. def test_MultiheadAttention_forward_self_attn_with_cache(
  186. ctx: Ctx, g_model: c_void_p
  187. ) -> None:
  188. pt_model = load_pt_model()
  189. attn = pt_model.text_decoder.layers[0].self_attn
  190. x = torch.empty((2, 21, 1024))
  191. torch.random.manual_seed(0)
  192. torch.nn.init.uniform_(x, -1, 1)
  193. state_bag = fairseq2.nn.IncrementalStateBag(100)
  194. with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
  195. # Incremental decoding
  196. for t in range(3):
  197. xq = x[:, t : t + 1]
  198. y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
  199. assert y_exp.shape == (2, 1, 1024)
  200. gxq = ggml.from_numpy(ctx, xq.contiguous())
  201. ggml.ggml_set_name(gxq, b"xq")
  202. gy = ggml.forward(
  203. "MultiheadAttention",
  204. g_model,
  205. "text_decoder.layers.0.self_attn",
  206. gxq,
  207. gxq,
  208. gxq,
  209. None, # type: ignore
  210. )
  211. gf = ggml.ggml_build_forward(gy)
  212. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  213. nodes = ggml.nodes(gf)
  214. state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
  215. state_bag.increment_step()
  216. assert state is not None
  217. assert np.allclose(
  218. state.get()[0].transpose(1, 2).reshape(2, t + 1, -1).numpy(),
  219. ggml.to_numpy(
  220. nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
  221. ),
  222. atol=1e-3,
  223. )
  224. y = ggml.to_numpy(gy)
  225. assert np.allclose(y, y_exp, atol=1e-2)
  226. def test_MultiheadAttention_forward_cross_attn_with_cache(
  227. ctx: Ctx, g_model: c_void_p
  228. ) -> None:
  229. pt_model = load_pt_model()
  230. attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
  231. x = torch.empty((2, 21, 1024))
  232. torch.random.manual_seed(0)
  233. torch.nn.init.uniform_(x, -1, 1)
  234. state_bag = fairseq2.nn.IncrementalStateBag(100)
  235. with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
  236. # Incremental decoding, the keys come from the encoder, and don't change during decoding
  237. xk = x[:, :11]
  238. gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
  239. for t in range(3):
  240. xq = x[:, t : t + 1]
  241. gxq = ggml.from_numpy(ctx, xq.contiguous())
  242. ggml.ggml_set_name(gxq, b"xq")
  243. gy = ggml.forward(
  244. "MultiheadAttention",
  245. g_model,
  246. "text_decoder.layers.0.encoder_decoder_attn",
  247. gxq,
  248. gxk,
  249. gxk,
  250. None, # type: ignore
  251. )
  252. gf = ggml.ggml_build_forward(gy)
  253. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  254. y = ggml.to_numpy(gy)
  255. nodes = ggml.nodes(gf)
  256. leaves = ggml.leafs(gf)
  257. if t > 0:
  258. # the cache only appear in the graph during the second call
  259. state = state_bag.get_state(
  260. attn, fairseq2.nn.transformer.AttentionState
  261. )
  262. assert state is not None
  263. assert np.allclose(
  264. state.get()[0].transpose(1, 2).reshape(2, 11, -1).numpy(),
  265. ggml.to_numpy(
  266. nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
  267. ),
  268. atol=1e-3,
  269. )
  270. state_bag.increment_step()
  271. y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
  272. assert y_exp.shape == (2, 1, 1024)
  273. assert np.allclose(y, y_exp, atol=1e-2)
  274. def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  275. x = torch.empty((2, 21, 1024))
  276. torch.random.manual_seed(0)
  277. torch.nn.init.uniform_(x, -1, 1)
  278. pt_model = load_pt_model()
  279. layer = pt_model.text_encoder.layers[0]
  280. gx = ggml.from_numpy(ctx, x)
  281. ggml.ggml_set_name(gx, b"x")
  282. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
  283. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  284. ggml.ggml_set_name(gpad, b"padding_mask")
  285. gy = ggml.forward(
  286. "StandardTransformerEncoderLayer",
  287. g_model,
  288. "text_encoder.layers.0",
  289. gx,
  290. None, # TODO support padding mask
  291. )
  292. gf = ggml.ggml_build_forward(gy)
  293. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  294. y = ggml.to_numpy(gy)
  295. y_exp, _ = layer(x, padding_mask=None)
  296. y_exp = y_exp.numpy()
  297. assert y.shape == y_exp.shape
  298. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  299. def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  300. pt_model = load_pt_model()
  301. if not DATA_DEV.exists():
  302. pytest.skip(reason=f"Folder {DATA_DEV} not found !")
  303. x = torch.load(DATA_DEV / "seqs_before_conformer_block.pt")
  304. padding_mask = torch.ones((1, x.shape[1]))
  305. layer = pt_model.speech_encoder.inner.layers[0]
  306. gx = ggml.from_numpy(ctx, x[0])
  307. ggml.ggml_set_name(gx, b"x")
  308. gpad = ggml.from_numpy(ctx, padding_mask[0])
  309. ggml.ggml_set_name(gpad, b"padding_mask")
  310. gy = ggml.forward(
  311. "StandardConformerEncoderLayer",
  312. g_model,
  313. "speech_encoder.inner.layers.0",
  314. gx,
  315. None, # TODO support padding mask
  316. )
  317. gf = ggml.ggml_build_forward(gy)
  318. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  319. y = ggml.to_numpy(gy)
  320. y_exp, _ = layer(x, padding_mask)
  321. y_exp = y_exp.numpy()
  322. assert y.shape == y_exp.shape
  323. assert np.allclose(y_exp, y, atol=2e-3)
  324. def test_StandardConformerEncoderAdaptorLayer_forward(
  325. ctx: Ctx, g_model: c_void_p
  326. ) -> None:
  327. pt_model = load_pt_model()
  328. if not DATA_DEV.exists():
  329. pytest.skip(reason=f"Folder {DATA_DEV} not found !")
  330. x = torch.load(DATA_DEV / "seqs_before_adaptor.pt")
  331. layer = pt_model.speech_encoder.adaptor_layers[0]
  332. gx = ggml.from_numpy(ctx, x[0])
  333. ggml.ggml_set_name(gx, b"x")
  334. gy = ggml.forward(
  335. "StandardConformerEncoderAdaptorLayer",
  336. g_model,
  337. "speech_encoder.adaptor_layers.0",
  338. gx,
  339. None, # TODO support padding mask
  340. )
  341. gf = ggml.ggml_build_forward(gy)
  342. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  343. y = ggml.to_numpy(gy)
  344. y_exp, _ = layer(x, None)
  345. y_exp = y_exp.numpy()
  346. assert y.shape == y_exp.shape
  347. assert np.allclose(y_exp, y, atol=2e-3)
  348. def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  349. x = torch.empty((2, 21, 1024))
  350. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
  351. torch.random.manual_seed(0)
  352. torch.nn.init.uniform_(x, -1, 1)
  353. gx = ggml.from_numpy(ctx, x)
  354. ggml.ggml_set_name(gx, b"x")
  355. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  356. ggml.ggml_set_name(gpad, b"padding_mask")
  357. gy = ggml.forward(
  358. "StandardTransformerEncoder",
  359. g_model,
  360. "text_encoder",
  361. gx,
  362. None, # TODO support padding mask
  363. )
  364. gf = ggml.ggml_build_forward(gy)
  365. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  366. y = ggml.to_numpy(gy)
  367. pt_model = load_pt_model()
  368. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  369. y_exp = y_exp.numpy()
  370. assert y.shape == y_exp.shape
  371. assert np.allclose(y_exp, y, atol=5e-3)
  372. def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  373. pt_model = load_pt_model()
  374. wav, _ = torchaudio.load(DATA / "test.wav")
  375. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  376. ggml.ggml_set_name(gx, b"x")
  377. gy = ggml.forward(
  378. "StandardConformerEncoder",
  379. g_model,
  380. "speech_encoder",
  381. gx,
  382. None, # TODO support padding mask
  383. )
  384. gf = ggml.ggml_build_forward(gy)
  385. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  386. converter = WaveformToFbankConverter(
  387. num_mel_bins=80,
  388. waveform_scale=2**15,
  389. channel_last=True,
  390. standardize=True,
  391. )
  392. converter_input = {
  393. "waveform": wav.transpose(0, 1),
  394. "sample_rate": 16000.0,
  395. "format": -1,
  396. }
  397. y = ggml.to_numpy(gy)
  398. speech_encoder_input = pt_model.speech_encoder_frontend(
  399. converter(converter_input)["fbank"].unsqueeze(0), None
  400. )[0]
  401. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  402. y_exp = y_exp.numpy() # remove batch dimension
  403. assert y.shape == y_exp.shape
  404. assert np.allclose(
  405. y_exp, y, atol=1e-2
  406. ) # There are 10 elements in a 137*1024 tensor with error >1e-2
  407. def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
  408. pt_model = load_pt_model()
  409. converter = WaveformToFbankConverter(
  410. num_mel_bins=80,
  411. waveform_scale=2**15,
  412. channel_last=True,
  413. standardize=True,
  414. )
  415. extractor = Wav2Vec2FbankFeatureExtractor(80, stride=2, sample_every_k=1)
  416. wav, _ = torchaudio.load(DATA / "test.wav")
  417. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  418. ggml.ggml_set_name(gx, b"x")
  419. gy = ggml.forward("WaveformToFbank", g_model, "", gx)
  420. gf = ggml.ggml_build_forward(gy)
  421. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  422. y = ggml.to_numpy(gy)
  423. converter_input = {
  424. "waveform": wav.transpose(0, 1),
  425. "sample_rate": 16000.0,
  426. "format": -1,
  427. }
  428. y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
  429. y_exp = y_exp.squeeze(0).numpy()
  430. assert y.shape == y_exp.shape
  431. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  432. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  433. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  434. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  435. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  436. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  437. y_exp = pos_encoder(seq, None)[0].numpy()
  438. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  439. ggml.ggml_set_name(gseq, b"seq")
  440. gy = ggml.forward(
  441. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  442. )
  443. gf = ggml.ggml_build_forward(gy)
  444. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  445. y = ggml.to_numpy(gy)
  446. assert y.shape == y_exp.shape
  447. assert np.allclose(y_exp, y, atol=1e-6)
  448. def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
  449. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  450. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  451. pos_encoder.eval()
  452. state_bag = fairseq2.nn.IncrementalStateBag(100)
  453. with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
  454. # Incremental decoding
  455. for t in range(20):
  456. gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
  457. ggml.ggml_set_name(gseq, b"seq")
  458. gy = ggml.forward(
  459. "PositionalEmbedding",
  460. g_model,
  461. "text_decoder_frontend.pos_encoder",
  462. gseq,
  463. )
  464. gf = ggml.ggml_build_forward(gy)
  465. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  466. y = ggml.to_numpy(gy)
  467. y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
  468. state_bag.increment_step()
  469. assert y.shape == y_exp.shape
  470. assert np.allclose(y_exp, y, atol=1e-6)
  471. def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
  472. seq = torch.arange(2 * 20).reshape(2, 20)
  473. seq[1, 15:] = 0 # padding for second sentence
  474. seq_len = torch.tensor([20, 15])
  475. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  476. ggml.ggml_set_name(gseq, b"seq")
  477. gy = ggml.forward(
  478. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  479. )
  480. ggml.build_and_compute(ctx, gy)
  481. y = ggml.to_numpy(gy)
  482. pt_model = load_pt_model()
  483. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  484. y_exp = y_exp.numpy()
  485. assert y.shape == y_exp.shape
  486. assert np.allclose(y_exp, y, atol=1e-6)
  487. def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  488. x = torch.empty((2, 13, 1024))
  489. encoder_out = torch.empty((2, 21, 1024))
  490. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
  491. torch.random.manual_seed(0)
  492. torch.nn.init.uniform_(x, -1, 1)
  493. torch.nn.init.uniform_(encoder_out, -1, 1)
  494. gx = ggml.from_numpy(ctx, x)
  495. ggml.ggml_set_name(gx, b"x")
  496. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  497. ggml.ggml_set_name(gpad, b"padding_mask")
  498. genc = ggml.from_numpy(ctx, encoder_out)
  499. gy = ggml.forward(
  500. "StandardTransformerDecoder",
  501. g_model,
  502. "text_decoder",
  503. gx,
  504. None, # TODO support padding mask,
  505. genc,
  506. None,
  507. )
  508. ggml.build_and_compute(ctx, gy)
  509. y = ggml.to_numpy(gy)
  510. pt_model = load_pt_model()
  511. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  512. y_exp = y_exp.numpy()
  513. assert y.shape == y_exp.shape
  514. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
  515. def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
  516. src_lang = "eng"
  517. src_text = "We are all in a yellow submarine."
  518. tgt_lang = "fra"
  519. sample_file = DATA / "sample_input.npz"
  520. beam_size = 2
  521. if not sample_file.exists():
  522. translator = load_translator()
  523. device = translator.device
  524. token_encoder = translator.text_tokenizer.create_encoder(
  525. task="translation", lang=src_lang, mode="source", device=device
  526. )
  527. src = translator.collate(token_encoder(src_text))
  528. text_out, _ = translator.get_prediction(
  529. translator.model,
  530. translator.text_tokenizer,
  531. translator.unit_tokenizer,
  532. src["seqs"],
  533. None,
  534. input_modality=Modality.TEXT,
  535. output_modality=Modality.TEXT,
  536. tgt_lang=tgt_lang,
  537. text_generation_opts=SequenceGeneratorOptions(beam_size=beam_size),
  538. unit_generation_opts=None,
  539. )
  540. tgt_text = str(text_out.sentences[0])
  541. assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  542. hypotheses = [
  543. {
  544. "seq": h.seq.tolist(),
  545. "score": h.score.item(),
  546. "step_scores": h.step_scores.numpy(),
  547. }
  548. for h in text_out.generator_output.results[0]
  549. ]
  550. np.savez(
  551. sample_file,
  552. encoder_output=text_out.encoder_output.numpy(),
  553. hypotheses=hypotheses,
  554. )
  555. # allow_pickle to load the hyp dicts
  556. text_out = np.load(sample_file, allow_pickle=True)
  557. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  558. prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
  559. max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
  560. opts = ggml.SequenceGeneratorOptions(
  561. beam_size=beam_size,
  562. min_seq_len=1,
  563. soft_max_seq_len_a=1,
  564. soft_max_seq_len_b=200,
  565. hard_max_seq_len=int(max_seq_len * 1.5),
  566. len_penalty=1.0,
  567. unk_penalty=0.0,
  568. normalize_scores=True,
  569. )
  570. job = ggml.SequenceGeneratorJob(
  571. opts=opts,
  572. prefix_seq=ggml.from_numpy(ctx, prefix_seq),
  573. pad_idx=0,
  574. unk_idx=1,
  575. bos_idx=2,
  576. eos_idx=3,
  577. )
  578. result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
  579. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  580. # The step score error is big, this may negatively impact the beam search.
  581. assert_hypotheses(
  582. text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
  583. )
  584. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  585. src_audio_wav, _ = torchaudio.load(DATA / "test.wav")
  586. sample_file = DATA / "test.wav.npz"
  587. if not sample_file.exists():
  588. translator = load_translator()
  589. token_encoder = translator.text_tokenizer.create_encoder(task="translation")
  590. decoded_audio = {
  591. "waveform": src_audio_wav.t(),
  592. "sample_rate": 16000.0,
  593. "format": -1,
  594. }
  595. src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  596. text_out, _ = translator.get_prediction(
  597. translator.model,
  598. translator.text_tokenizer,
  599. translator.unit_tokenizer,
  600. src,
  601. input_modality=Modality.SPEECH,
  602. output_modality=Modality.TEXT,
  603. tgt_lang="cmn",
  604. )
  605. tgt_text = str(text_out.sentences[0])
  606. assert tgt_text == "大家好 , 世界无主题。"
  607. hypotheses = [
  608. {
  609. "seq": h.seq.tolist(),
  610. "score": h.score.item(),
  611. "step_scores": h.step_scores.numpy(),
  612. }
  613. for h in text_out.generator_output.results[0]
  614. ]
  615. np.savez(
  616. sample_file,
  617. encoder_output=text_out.encoder_output.numpy(),
  618. hypotheses=hypotheses,
  619. )
  620. text_out = np.load(sample_file, allow_pickle=True)
  621. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  622. tgt_tokens = text_out["hypotheses"][0]["seq"]
  623. max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
  624. max_seq_len = int(max_seq_len * 1.5)
  625. # Apply scale before sending into ggml!
  626. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
  627. ggml.ggml_set_name(gx, b"x")
  628. encoder_out = ggml.forward(
  629. "StandardConformerEncoder",
  630. g_model,
  631. "speech_encoder",
  632. gx,
  633. None, # TODO support padding mask
  634. )
  635. gf = ggml.ggml_build_forward(encoder_out)
  636. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  637. beam_size = 5
  638. opts = ggml.SequenceGeneratorOptions(
  639. beam_size=beam_size,
  640. soft_max_seq_len_a=1,
  641. soft_max_seq_len_b=200,
  642. hard_max_seq_len=max_seq_len,
  643. )
  644. job = ggml.SequenceGeneratorJob(
  645. opts=opts,
  646. prefix_seq=ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32)),
  647. pad_idx=0,
  648. unk_idx=1,
  649. bos_idx=2,
  650. eos_idx=3,
  651. )
  652. result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
  653. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  654. assert_hypotheses(
  655. text_out["hypotheses"], results, score_rtol=1e-2, step_scores_rtol=0.1
  656. )
  657. def assert_hypotheses(
  658. expected: List[Any],
  659. results: List[Any],
  660. *,
  661. score_rtol: float,
  662. step_scores_rtol: float,
  663. ) -> None:
  664. assert len(results) == len(expected)
  665. for g_hyp, exp in zip(results, expected):
  666. g_tokens = list(ggml.to_numpy(g_hyp.seq))
  667. g_step_scores = ggml.to_numpy(g_hyp.step_scores)
  668. assert g_tokens == exp["seq"]
  669. assert g_hyp.score == pytest.approx(exp["score"], rel=score_rtol)
  670. assert np.allclose(g_step_scores, exp["step_scores"], rtol=step_scores_rtol)