test_unity_cpp.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import ctypes
  7. import functools
  8. from ctypes import c_void_p
  9. from pathlib import Path
  10. from typing import Any, Iterator, List, Tuple
  11. import ggml
  12. import fairseq2.nn
  13. import fairseq2.nn.transformer
  14. import numpy as np
  15. import pytest
  16. import torch
  17. import torchaudio
  18. from fairseq2.data.audio import WaveformToFbankConverter
  19. from seamless_communication.inference.generator import SequenceGeneratorOptions
  20. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  21. from seamless_communication.inference.translator import Modality, Translator
  22. from ctypes_utils import NULLPTR, Ptr
  23. from ggml import NativeObj
  24. from ggml_convert import convert_model, read_layer_config
  25. import requests
  26. Ctx = ggml.ggml_context_p
  27. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  28. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  29. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  30. DATA = Path(__file__).parent / "test_data"
  31. LOCAL_AUDIO_SAMPLE_PATH = DATA / "LJ037-0171_sr16k.wav"
  32. TEST_AUDIO_SAMPLE_URL = (
  33. "https://dl.fbaipublicfiles.com/seamless/tests/LJ037-0171_sr16k.wav"
  34. )
  35. MB = 1024 * 1024
  36. @pytest.fixture(name="ctx")
  37. def _ctx() -> Iterator[Ctx]:
  38. """Allocate a new context with 1024 MB of memory"""
  39. try:
  40. mem_size = 16 * MB
  41. memory = torch.zeros(mem_size, dtype=torch.uint8)
  42. ctx = ggml.ggml_init(
  43. params=ggml.ggml_init_params(
  44. mem_size=mem_size,
  45. mem_buffer=ctypes.c_void_p(memory.data_ptr()),
  46. no_alloc=True,
  47. )
  48. )
  49. with torch.inference_mode():
  50. yield ctx
  51. finally:
  52. ggml.ggml_free(ctx)
  53. @functools.lru_cache()
  54. def _load_g_model_once() -> NativeObj:
  55. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  56. if not model_file.exists():
  57. convert_model("seamlessM4T_medium", model_file)
  58. return ggml.load_fairseq2_ggml_file(model_file)
  59. @pytest.fixture()
  60. def g_model(ctx: Ctx) -> c_void_p:
  61. model = _load_g_model_once()
  62. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  63. return model.ptr
  64. @functools.lru_cache(maxsize=1)
  65. def load_translator() -> Translator:
  66. return Translator("seamlessM4T_medium", None, device=torch.device("cpu"))
  67. def load_pt_model() -> Any:
  68. return load_translator().model
  69. def download_sample_audio() -> Any:
  70. response = requests.get(TEST_AUDIO_SAMPLE_URL, stream=True)
  71. with open(DATA / "LJ037-0171_sr16k.wav", "wb") as file:
  72. for chunk in response.iter_content(chunk_size=1024):
  73. if chunk:
  74. file.write(chunk)
  75. def test_convert_linear(tmp_path: Path) -> None:
  76. module = fairseq2.nn.Linear(16, 24, True)
  77. layer_config = read_layer_config(module)
  78. assert layer_config == {"input_dim": 16, "output_dim": 24}
  79. module_file = Path("module.ggml")
  80. convert_model(module, module_file)
  81. g_module = ggml.load_fairseq2_ggml_file(module_file)
  82. for k, v in layer_config.items():
  83. assert (
  84. ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  85. )
  86. def test_causal_attention_mask(ctx: Ctx):
  87. x = torch.zeros((1, 10, 32))
  88. generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
  89. mask_exp = generator(x, x).materialize().numpy()
  90. gx = ggml.from_numpy(ctx, x)
  91. gmask = ggml.causal_attention_mask(ctx, gx)
  92. ggml.build_and_compute(ctx, gmask)
  93. mask = ggml.to_numpy(gmask)
  94. assert mask_exp.shape == (10, 10)
  95. assert mask.shape == (10, 10)
  96. assert np.all(mask == mask_exp)
  97. x = x[:, :8, :]
  98. mask_exp = generator(x, x).materialize().numpy()
  99. gx = ggml.from_numpy(ctx, x)
  100. gmask = ggml.causal_attention_mask(ctx, gx)
  101. ggml.build_and_compute(ctx, gmask)
  102. mask = ggml.to_numpy(gmask)
  103. assert mask_exp.shape == (8, 8)
  104. assert mask.shape == (8, 8)
  105. assert np.all(mask == mask_exp)
  106. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  107. x = torch.empty((2, 21, 1024))
  108. torch.nn.init.uniform_(x, -1, 1)
  109. pt_model = load_pt_model()
  110. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  111. gx = ggml.from_numpy(ctx, x)
  112. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  113. ggml.build_and_compute(ctx, gy)
  114. y = ggml.to_numpy(gy)
  115. assert np.allclose(y_exp, y, atol=1e-5)
  116. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  117. x = torch.empty((2, 21, 1024))
  118. torch.nn.init.uniform_(x, -1, 1)
  119. pt_model = load_pt_model()
  120. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  121. gx = ggml.from_numpy(ctx, x)
  122. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  123. gf = ggml.build_and_compute(ctx, gy, dump="dot/test_Linear_forward.dot")
  124. y = ggml.to_numpy(gy)
  125. assert np.allclose(y_exp, y, atol=1e-5)
  126. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  127. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  128. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  129. # Test FFN without LayerNorm
  130. pt_model = load_pt_model()
  131. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  132. gx = ggml.from_numpy(ctx, x)
  133. gy = ggml.forward(
  134. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  135. )
  136. ggml.build_and_compute(ctx, gy)
  137. y = ggml.to_numpy(gy)
  138. assert np.allclose(y_exp, y, atol=1e-5)
  139. @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
  140. def test_MultiheadAttention_forward(
  141. ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
  142. ) -> None:
  143. x = torch.empty((2, 21, 1024))
  144. torch.random.manual_seed(0)
  145. torch.nn.init.uniform_(x, -1, 1)
  146. # Note: we use different lengths for queries and keys,
  147. # this tests the implementation in decoding context too.
  148. # Note2: ggml_flash_attn requires that we have more keys than queries
  149. # qlen, klen = (11, 21) if flash_attn else (21, 13)
  150. qlen, klen = lengths
  151. xq = x[:, :qlen]
  152. xk = x[:, :klen]
  153. if qlen > klen and UNITY_FLASH_ATTN:
  154. pytest.skip(reason="flash_attn requires qlen > klen")
  155. gxq = ggml.from_numpy(ctx, xq.contiguous())
  156. ggml.ggml_set_name(gxq, b"xq")
  157. gxk = ggml.from_numpy(ctx, xk.contiguous())
  158. ggml.ggml_set_name(gxk, b"xk")
  159. ggml.ggml_set_no_alloc(ctx, True)
  160. gy = ggml.forward(
  161. "MultiheadAttention",
  162. g_model,
  163. "text_encoder.layers.0.self_attn",
  164. gxq,
  165. gxk,
  166. gxk,
  167. NULLPTR, # TODO: tests with causal attention masks
  168. )
  169. gf = ggml.build_and_compute(ctx, gy, dump="dot/test_MultiheadAttention_forward")
  170. y = ggml.to_numpy(gy)
  171. nodes = ggml.nodes(gf)
  172. node_buffers = set(t.contents.data for t in nodes.values())
  173. pt_model = load_pt_model()
  174. self_attn = pt_model.text_encoder.layers[0].self_attn
  175. # If buffers are overlapping, reading node contents, can be misleading.
  176. overlap = len(node_buffers) < len(nodes)
  177. if not overlap:
  178. q_exp = self_attn._project_q(xq, None).numpy().reshape(2 * 16, qlen, 64)
  179. q = ggml.to_numpy(nodes[b"q"])
  180. assert q.shape == q_exp.shape
  181. assert np.allclose(q_exp, q, atol=1e-5)
  182. attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
  183. self_attn.register_attn_weight_hook(attn_weights_hook)
  184. y_exp = self_attn(xq, None, xk, None, xk).numpy()
  185. # with flash_attn we don't have attn_weights
  186. naive_attn = b"attn_weights" in nodes
  187. if naive_attn and not overlap:
  188. attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
  189. [(_, attn_weights_exp)] = attn_weights_hook._storage
  190. attn_weights_exp = attn_weights_exp.numpy()
  191. assert attn_weights_exp.shape == attn_weights.shape
  192. # GGML is very agressively reducing small softmax weights to 0,
  193. # so the error isn't that small
  194. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  195. # But the sums should be close to 1
  196. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
  197. # And the maximum index should match the original ones.
  198. assert np.allclose(
  199. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  200. )
  201. assert y.shape == y_exp.shape
  202. assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
  203. def test_MultiheadAttention_forward_self_attn_with_cache(
  204. ctx: Ctx, g_model: c_void_p
  205. ) -> None:
  206. pt_model = load_pt_model()
  207. attn = pt_model.text_decoder.layers[0].self_attn
  208. x = torch.empty((2, 21, 1024))
  209. torch.random.manual_seed(0)
  210. torch.nn.init.uniform_(x, -1, 1)
  211. state_bag = fairseq2.nn.IncrementalStateBag(100)
  212. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  213. # Incremental decoding
  214. for t in range(3):
  215. xq = x[:, t : t + 1]
  216. gxq = ggml.from_numpy(ctx, xq.contiguous())
  217. ggml.ggml_set_name(gxq, b"xq")
  218. gy = ggml.forward(
  219. "MultiheadAttention",
  220. g_model,
  221. "text_decoder.layers.0.self_attn",
  222. gxq,
  223. gxq,
  224. gxq,
  225. None, # type: ignore
  226. )
  227. gf = ggml.build_and_compute(
  228. ctx,
  229. gy,
  230. dump=f"dot/test_MultiheadAttention_forward_self_attn_with_cache_{t}.dot",
  231. )
  232. nodes = ggml.nodes(gf)
  233. gk_cache = ggml.to_numpy(
  234. nodes[b"text_decoder.layers.0.self_attn.k (step=%d)" % t]
  235. )
  236. assert gk_cache.shape == (2, t + 1, 1024)
  237. gk_cache = gk_cache.reshape(2, t + 1, 16, 64).transpose(0, 2, 1, 3)
  238. assert gk_cache.shape == (2, 16, t + 1, 64)
  239. y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
  240. assert y_exp.shape == (2, 1, 1024)
  241. state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
  242. state_bag.increment_step_nr()
  243. assert state is not None
  244. k_cache = state.get()[0].numpy()
  245. assert k_cache.shape == (2, 16, t + 1, 64)
  246. assert np.allclose(gk_cache, k_cache, atol=1e-3)
  247. y = ggml.to_numpy(gy)
  248. assert np.allclose(y, y_exp, atol=1e-2)
  249. def test_MultiheadAttention_forward_cross_attn_with_cache(
  250. ctx: Ctx, g_model: c_void_p
  251. ) -> None:
  252. pt_model = load_pt_model()
  253. attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
  254. x = torch.empty((2, 21, 1024))
  255. torch.random.manual_seed(0)
  256. torch.nn.init.uniform_(x, -1, 1)
  257. state_bag = fairseq2.nn.IncrementalStateBag(100)
  258. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  259. # Incremental decoding, the keys come from the encoder, and don't change during decoding
  260. xk = x[:, :11]
  261. gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
  262. for t in range(3):
  263. xq = x[:, t : t + 1]
  264. gxq = ggml.from_numpy(ctx, xq.contiguous())
  265. ggml.ggml_set_name(gxq, b"xq")
  266. gy = ggml.forward(
  267. "MultiheadAttention",
  268. g_model,
  269. "text_decoder.layers.0.encoder_decoder_attn",
  270. gxq,
  271. gxk,
  272. gxk,
  273. None, # type: ignore
  274. )
  275. gf = ggml.build_and_compute(
  276. ctx,
  277. gy,
  278. dump=f"dot/test_MultiheadAttention_forward_cross_attn_with_cache_{t}.dot",
  279. )
  280. y = ggml.to_numpy(gy)
  281. nodes = ggml.nodes(gf)
  282. leaves = ggml.leafs(gf)
  283. if t > 0:
  284. # the cache only appear in the graph during the second call
  285. state = state_bag.get_state(
  286. attn, fairseq2.nn.transformer.AttentionState
  287. )
  288. assert state is not None
  289. assert np.allclose(
  290. state.get()[0].transpose(1, 2).numpy(),
  291. ggml.to_numpy(
  292. nodes[
  293. b"text_decoder.layers.0.encoder_decoder_attn.k_cache (view)"
  294. ]
  295. ),
  296. atol=1e-3,
  297. )
  298. state_bag.increment_step_nr()
  299. y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
  300. assert y_exp.shape == (2, 1, 1024)
  301. assert np.allclose(y, y_exp, atol=1e-2)
  302. def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  303. x = torch.empty((2, 21, 1024))
  304. torch.random.manual_seed(0)
  305. torch.nn.init.uniform_(x, -1, 1)
  306. pt_model = load_pt_model()
  307. layer = pt_model.text_encoder.layers[0]
  308. gx = ggml.from_numpy(ctx, x)
  309. ggml.ggml_set_name(gx, b"x")
  310. gy = ggml.forward(
  311. "StandardTransformerEncoderLayer",
  312. g_model,
  313. "text_encoder.layers.0",
  314. gx,
  315. None, # TODO support padding mask
  316. )
  317. gf = ggml.build_and_compute(ctx, gy)
  318. y = ggml.to_numpy(gy)
  319. y_exp, _ = layer(x, padding_mask=None)
  320. y_exp = y_exp.numpy()
  321. assert y.shape == y_exp.shape
  322. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  323. def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  324. pt_model = load_pt_model()
  325. x = torch.rand(1, 137, 1024)
  326. layer = pt_model.speech_encoder.inner.layers[0]
  327. gx = ggml.from_numpy(ctx, x[0])
  328. ggml.ggml_set_name(gx, b"x")
  329. gy = ggml.forward(
  330. "StandardConformerEncoderLayer",
  331. g_model,
  332. "speech_encoder.inner.layers.0",
  333. gx,
  334. None, # TODO support padding mask
  335. )
  336. gf = ggml.build_and_compute(ctx, gy)
  337. y = ggml.to_numpy(gy)
  338. y_exp, _ = layer(x, padding_mask=None)
  339. y_exp = y_exp.squeeze(0).numpy()
  340. assert y.shape == y_exp.shape
  341. assert np.allclose(y_exp, y, atol=2e-3)
  342. def test_StandardConformerEncoderAdaptorLayer_forward(
  343. ctx: Ctx, g_model: c_void_p
  344. ) -> None:
  345. pt_model = load_pt_model()
  346. torch.random.manual_seed(0)
  347. x = torch.rand(1, 137, 1024)
  348. layer = pt_model.speech_encoder.adaptor_layers[0]
  349. gx = ggml.from_numpy(ctx, x[0])
  350. ggml.ggml_set_name(gx, b"x")
  351. gy = ggml.forward(
  352. "StandardConformerEncoderAdaptorLayer",
  353. g_model,
  354. "speech_encoder.adaptor_layers.0",
  355. gx,
  356. None, # TODO support padding mask
  357. )
  358. gf = ggml.build_and_compute(ctx, gy)
  359. y = ggml.to_numpy(gy)
  360. y_exp, _ = layer(x, None)
  361. y_exp = y_exp.numpy()
  362. assert y.shape == y_exp.shape
  363. assert np.allclose(y_exp, y, atol=2e-3)
  364. def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  365. x = torch.empty((2, 21, 1024))
  366. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
  367. torch.random.manual_seed(0)
  368. torch.nn.init.uniform_(x, -1, 1)
  369. gx = ggml.from_numpy(ctx, x)
  370. ggml.ggml_set_name(gx, b"x")
  371. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  372. ggml.ggml_set_name(gpad, b"padding_mask")
  373. gy = ggml.forward(
  374. "StandardTransformerEncoder",
  375. g_model,
  376. "text_encoder",
  377. gx,
  378. None, # TODO support padding mask
  379. )
  380. gf = ggml.build_and_compute(ctx, gy)
  381. y = ggml.to_numpy(gy)
  382. pt_model = load_pt_model()
  383. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  384. y_exp = y_exp.numpy()
  385. assert y.shape == y_exp.shape
  386. assert np.allclose(y_exp, y, atol=5e-3)
  387. def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  388. pt_model = load_pt_model()
  389. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  390. download_sample_audio()
  391. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  392. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  393. ggml.ggml_set_name(gx, b"x")
  394. gy = ggml.forward(
  395. "StandardConformerEncoder",
  396. g_model,
  397. "speech_encoder",
  398. gx,
  399. None, # TODO support padding mask
  400. )
  401. gf = ggml.build_and_compute(ctx, gy)
  402. y = ggml.to_numpy(gy)
  403. cache = DATA / "test_StandardConformerEncoder_forward.npy"
  404. if not cache.exists():
  405. converter = WaveformToFbankConverter(
  406. num_mel_bins=80,
  407. waveform_scale=2**15,
  408. channel_last=True,
  409. standardize=True,
  410. )
  411. converter_input = {
  412. "waveform": wav.transpose(0, 1),
  413. "sample_rate": 16000.0,
  414. "format": -1,
  415. }
  416. pt_model = load_pt_model()
  417. speech_encoder_input = pt_model.speech_encoder_frontend(
  418. converter(converter_input)["fbank"].unsqueeze(0), None
  419. )[0]
  420. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  421. y_exp = y_exp.numpy()
  422. np.save(cache, y_exp)
  423. else:
  424. y_exp = np.load(cache)
  425. assert y.shape == y_exp.shape
  426. assert np.allclose(y_exp, y, atol=1e-2)
  427. def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
  428. converter = WaveformToFbankConverter(
  429. num_mel_bins=80,
  430. waveform_scale=2**15,
  431. channel_last=True,
  432. standardize=True,
  433. )
  434. extractor = Wav2Vec2FbankFeatureExtractor(80, stride=2, sample_every_k=1)
  435. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  436. download_sample_audio()
  437. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  438. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  439. ggml.ggml_set_name(gx, b"x")
  440. gy = ggml.forward("WaveformToFbank", g_model, "", gx)
  441. gf = ggml.build_and_compute(ctx, gy)
  442. y = ggml.to_numpy(gy)
  443. converter_input = {
  444. "waveform": wav.transpose(0, 1),
  445. "sample_rate": 16000.0,
  446. "format": -1,
  447. }
  448. y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
  449. y_exp = y_exp.squeeze(0).numpy()
  450. assert y.shape == y_exp.shape
  451. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  452. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  453. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  454. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  455. y_exp = pos_encoder(seq, None)[0].numpy()
  456. gseq = ggml.from_numpy(ctx, seq[0].clone().numpy())
  457. ggml.ggml_set_name(gseq, b"seq")
  458. gy = ggml.forward(
  459. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  460. )
  461. gf = ggml.build_and_compute(ctx, gy, dump=True)
  462. y = ggml.to_numpy(gy)
  463. assert y.shape == y_exp.shape
  464. assert np.allclose(y_exp, y, atol=1e-6)
  465. def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
  466. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  467. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  468. pos_encoder.eval()
  469. state_bag = fairseq2.nn.IncrementalStateBag(100)
  470. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  471. # Incremental decoding
  472. for t in range(20):
  473. gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
  474. ggml.ggml_set_name(gseq, b"seq")
  475. gy = ggml.forward(
  476. "PositionalEmbedding",
  477. g_model,
  478. "text_decoder_frontend.pos_encoder",
  479. gseq,
  480. )
  481. gf = ggml.build_and_compute(ctx, gy, dump=t == 1)
  482. y = ggml.to_numpy(gy)
  483. y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
  484. state_bag.increment_step_nr()
  485. assert y.shape == y_exp.shape
  486. assert np.allclose(y_exp, y, atol=1e-6)
  487. def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
  488. seq = torch.arange(2 * 20).reshape(2, 20)
  489. seq[1, 15:] = 0 # padding for second sentence
  490. seq_len = torch.tensor([20, 15])
  491. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  492. ggml.ggml_set_name(gseq, b"seq")
  493. gy = ggml.forward(
  494. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  495. )
  496. ggml.build_and_compute(ctx, gy)
  497. y = ggml.to_numpy(gy)
  498. pt_model = load_pt_model()
  499. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  500. y_exp = y_exp.numpy()
  501. assert y.shape == y_exp.shape
  502. assert np.allclose(y_exp, y, atol=1e-6)
  503. def test_StandardTransformerDecoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  504. x = torch.empty((2, 13, 1024))
  505. encoder_out = torch.empty((2, 21, 1024))
  506. torch.random.manual_seed(0)
  507. torch.nn.init.uniform_(x, -1, 1)
  508. torch.nn.init.uniform_(encoder_out, -1, 1)
  509. self_attn_mask = fairseq2.nn.transformer.CausalAttentionMaskFactory()(x, x)
  510. gx = ggml.from_numpy(ctx, x)
  511. ggml.ggml_set_name(gx, b"x")
  512. gself_attn_mask = ggml.from_numpy(ctx, self_attn_mask.materialize().numpy())
  513. ggml.ggml_set_name(gself_attn_mask, b"self_attn_mask")
  514. genc = ggml.from_numpy(ctx, encoder_out)
  515. ggml.ggml_set_name(genc, b"encoder_out")
  516. gy = ggml.forward(
  517. "StandardTransformerDecoderLayer",
  518. g_model,
  519. "text_decoder.layers.0",
  520. gx,
  521. gself_attn_mask,
  522. genc,
  523. NULLPTR, # TODO support padding mask,
  524. )
  525. ggml.build_and_compute(ctx, gy, dump=True)
  526. y = ggml.to_numpy(gy)
  527. pt_model = load_pt_model()
  528. y_exp, _ = pt_model.text_decoder.layers[0](x, None, encoder_output=encoder_out, self_attn_mask=self_attn_mask)
  529. y_exp = y_exp.numpy()
  530. assert y.shape == y_exp.shape
  531. # We still have some numerical imprecision
  532. assert np.allclose(y_exp, y, atol=0.1)
  533. assert np.sum(np.abs(y_exp-y) > 1e-2) < 20
  534. def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  535. x = torch.empty((2, 13, 1024))
  536. encoder_out = torch.empty((2, 21, 1024))
  537. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
  538. torch.random.manual_seed(0)
  539. torch.nn.init.uniform_(x, -1, 1)
  540. torch.nn.init.uniform_(encoder_out, -1, 1)
  541. gx = ggml.from_numpy(ctx, x)
  542. ggml.ggml_set_name(gx, b"x")
  543. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  544. ggml.ggml_set_name(gpad, b"padding_mask")
  545. genc = ggml.from_numpy(ctx, encoder_out)
  546. gy = ggml.forward(
  547. "StandardTransformerDecoder",
  548. g_model,
  549. "text_decoder",
  550. gx,
  551. None, # TODO support padding mask,
  552. genc,
  553. None,
  554. )
  555. ggml.build_and_compute(ctx, gy)
  556. y = ggml.to_numpy(gy)
  557. pt_model = load_pt_model()
  558. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  559. y_exp = y_exp.numpy()
  560. assert y.shape == y_exp.shape
  561. assert np.allclose(y_exp, y, atol=1e-3) # TODO: those tests are failing now
  562. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  563. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  564. download_sample_audio()
  565. src_audio_wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  566. sample_file = DATA / "LJ037-0171_sr16k.wav.trans"
  567. translator = load_translator()
  568. if not sample_file.exists():
  569. decoded_audio = {
  570. "waveform": src_audio_wav.t(),
  571. "sample_rate": 16000.0,
  572. "format": -1,
  573. }
  574. src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  575. text_out, _ = translator.get_prediction(
  576. translator.model,
  577. translator.text_tokenizer,
  578. translator.unit_tokenizer,
  579. src["seqs"],
  580. padding_mask=None,
  581. input_modality=Modality.SPEECH,
  582. output_modality=Modality.TEXT,
  583. tgt_lang="cmn",
  584. text_generation_opts=SequenceGeneratorOptions(),
  585. unit_generation_opts=None,
  586. )
  587. tgt_text = str(text_out[0])
  588. assert tgt_text == "专家的检查和证据使该委员会得出了结论,可能有五次枪击."
  589. with open(sample_file, "w") as f:
  590. f.write(tgt_text)
  591. with open(sample_file, "r") as exp:
  592. exp_tgt_text = exp.readlines()[0].strip()
  593. # Apply scale before sending into ggml!
  594. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
  595. ggml.ggml_set_name(gx, b"x")
  596. encoder_out = ggml.forward(
  597. "StandardConformerEncoder",
  598. g_model,
  599. "speech_encoder",
  600. gx,
  601. NULLPTR, # TODO support padding mask
  602. )
  603. gf = ggml.build_and_compute(ctx, encoder_out)
  604. beam_size = 5
  605. opts = ggml.SequenceGeneratorOptions(
  606. beam_size=beam_size,
  607. soft_max_seq_len_a=1,
  608. soft_max_seq_len_b=200,
  609. hard_max_seq_len=500,
  610. )
  611. job = ggml.SequenceGeneratorJob(
  612. opts=opts,
  613. prefix_seq=ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32)),
  614. pad_idx=0,
  615. unk_idx=1,
  616. bos_idx=2,
  617. eos_idx=3,
  618. )
  619. result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
  620. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  621. tokens = [
  622. translator.text_tokenizer.model.index_to_token(id)
  623. for id in ggml.to_numpy(results[0].seq).tolist()
  624. ][2:-1]
  625. tokens = "".join(tokens).replace("▁", " ")[1:]
  626. assert tokens == exp_tgt_text