test_unity_cpp.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import ctypes
  7. import functools
  8. import shutil
  9. from ctypes import c_void_p
  10. from pathlib import Path
  11. from typing import Any, Iterator, Tuple
  12. import fairseq2.nn
  13. import fairseq2.nn.transformer
  14. import numpy as np
  15. import pytest
  16. import requests # type: ignore
  17. import torch
  18. import torchaudio # type: ignore
  19. from ctypes_utils import NULLPTR, Ptr
  20. from fairseq2.data.audio import WaveformToFbankConverter
  21. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  22. from ggml_convert import convert_model, read_layer_config
  23. import ggml
  24. from ggml import NativeObj
  25. from seamless_communication.inference.generator import SequenceGeneratorOptions
  26. from seamless_communication.inference.translator import Modality, Translator
  27. Ctx = ggml.ggml_context_p
  28. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  29. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  30. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  31. DATA = Path(__file__).parent / "test_data"
  32. LOCAL_AUDIO_SAMPLE_PATH = DATA / "LJ037-0171_sr16k.wav"
  33. TEST_AUDIO_SAMPLE_URL = (
  34. "https://dl.fbaipublicfiles.com/seamless/tests/LJ037-0171_sr16k.wav"
  35. )
  36. MB = 1024 * 1024
  37. @pytest.fixture(name="ctx")
  38. def _ctx() -> Iterator[Ctx]:
  39. """Allocate a new context with 1024 MB of memory"""
  40. try:
  41. mem_size = 16 * MB
  42. memory = torch.zeros(mem_size, dtype=torch.uint8)
  43. ctx = ggml.ggml_init(
  44. params=ggml.ggml_init_params(
  45. mem_size=mem_size,
  46. mem_buffer=ctypes.c_void_p(memory.data_ptr()),
  47. no_alloc=True,
  48. )
  49. )
  50. # Create 'dot' folder for temporary dump of ggml graphs
  51. (Path(__file__).parent / "dot").mkdir(exist_ok=True)
  52. with torch.inference_mode():
  53. yield ctx
  54. finally:
  55. ggml.ggml_free(ctx)
  56. @functools.lru_cache()
  57. def _load_g_model_once() -> NativeObj:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. if not model_file.exists():
  60. convert_model("seamlessM4T_medium", model_file)
  61. return ggml.load_fairseq2_ggml_file(model_file)
  62. @pytest.fixture()
  63. def g_model(ctx: Ctx) -> c_void_p:
  64. model = _load_g_model_once()
  65. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  66. return model.ptr
  67. @functools.lru_cache(maxsize=1)
  68. def load_translator() -> Translator:
  69. return Translator("seamlessM4T_medium", None, device=torch.device("cpu"))
  70. def load_pt_model() -> Any:
  71. return load_translator().model
  72. def download_sample_audio() -> Any:
  73. Path(DATA).mkdir(exist_ok=True)
  74. response = requests.get(TEST_AUDIO_SAMPLE_URL, stream=True)
  75. with open(DATA / "LJ037-0171_sr16k.wav", "wb") as file:
  76. for chunk in response.iter_content(chunk_size=1024):
  77. if chunk:
  78. file.write(chunk)
  79. def test_convert_linear(tmp_path: Path) -> None:
  80. module = fairseq2.nn.Linear(16, 24, True)
  81. layer_config = read_layer_config(module)
  82. assert layer_config == {"input_dim": 16, "output_dim": 24}
  83. module_file = Path("module.ggml")
  84. convert_model(module, module_file)
  85. g_module = ggml.load_fairseq2_ggml_file(module_file)
  86. for k, v in layer_config.items():
  87. assert (
  88. ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  89. )
  90. def test_causal_attention_mask(ctx: Ctx):
  91. x = torch.zeros((1, 10, 32))
  92. generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
  93. mask_exp = generator(x, x).materialize().numpy()
  94. gx = ggml.from_numpy(ctx, x)
  95. gmask = ggml.causal_attention_mask(ctx, gx)
  96. ggml.build_and_compute(ctx, gmask)
  97. mask = ggml.to_numpy(gmask)
  98. assert mask_exp.shape == (10, 10)
  99. assert mask.shape == (10, 10)
  100. assert np.all(mask == mask_exp)
  101. x = x[:, :8, :]
  102. mask_exp = generator(x, x).materialize().numpy()
  103. gx = ggml.from_numpy(ctx, x)
  104. gmask = ggml.causal_attention_mask(ctx, gx)
  105. ggml.build_and_compute(ctx, gmask)
  106. mask = ggml.to_numpy(gmask)
  107. assert mask_exp.shape == (8, 8)
  108. assert mask.shape == (8, 8)
  109. assert np.all(mask == mask_exp)
  110. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  111. x = torch.empty((2, 21, 1024))
  112. torch.nn.init.uniform_(x, -1, 1)
  113. pt_model = load_pt_model()
  114. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  115. gx = ggml.from_numpy(ctx, x)
  116. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  117. ggml.build_and_compute(ctx, gy)
  118. y = ggml.to_numpy(gy)
  119. assert np.allclose(y_exp, y, atol=1e-5)
  120. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  121. x = torch.empty((2, 21, 1024))
  122. torch.nn.init.uniform_(x, -1, 1)
  123. pt_model = load_pt_model()
  124. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  125. gx = ggml.from_numpy(ctx, x)
  126. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  127. ggml.build_and_compute(ctx, gy, dump="dot/test_Linear_forward.dot")
  128. y = ggml.to_numpy(gy)
  129. assert np.allclose(y_exp, y, atol=1e-5)
  130. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  131. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  132. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  133. # Test FFN without LayerNorm
  134. pt_model = load_pt_model()
  135. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  136. gx = ggml.from_numpy(ctx, x)
  137. gy = ggml.forward(
  138. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  139. )
  140. ggml.build_and_compute(ctx, gy)
  141. y = ggml.to_numpy(gy)
  142. assert np.allclose(y_exp, y, atol=1e-5)
  143. @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
  144. def test_MultiheadAttention_forward(
  145. ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
  146. ) -> None:
  147. x = torch.empty((2, 21, 1024))
  148. torch.random.manual_seed(0)
  149. torch.nn.init.uniform_(x, -1, 1)
  150. # Note: we use different lengths for queries and keys,
  151. # this tests the implementation in decoding context too.
  152. # Note2: ggml_flash_attn requires that we have more keys than queries
  153. # qlen, klen = (11, 21) if flash_attn else (21, 13)
  154. qlen, klen = lengths
  155. xq = x[:, :qlen]
  156. xk = x[:, :klen]
  157. if qlen > klen and UNITY_FLASH_ATTN:
  158. pytest.skip(reason="flash_attn requires qlen > klen")
  159. gxq = ggml.from_numpy(ctx, xq.contiguous())
  160. ggml.ggml_set_name(gxq, b"xq")
  161. gxk = ggml.from_numpy(ctx, xk.contiguous())
  162. ggml.ggml_set_name(gxk, b"xk")
  163. ggml.ggml_set_no_alloc(ctx, True)
  164. gy = ggml.forward(
  165. "MultiheadAttention",
  166. g_model,
  167. "text_encoder.layers.0.self_attn",
  168. gxq,
  169. gxk,
  170. gxk,
  171. NULLPTR, # TODO: tests with causal attention masks
  172. )
  173. gf = ggml.build_and_compute(ctx, gy, dump="dot/test_MultiheadAttention_forward")
  174. y = ggml.to_numpy(gy)
  175. nodes = ggml.nodes(gf)
  176. node_buffers = set(t.contents.data for t in nodes.values())
  177. pt_model = load_pt_model()
  178. self_attn = pt_model.text_encoder.layers[0].self_attn
  179. # If buffers are overlapping, reading node contents, can be misleading.
  180. overlap = len(node_buffers) < len(nodes)
  181. if not overlap:
  182. q_exp = self_attn._project_q(xq, None).numpy().reshape(2 * 16, qlen, 64)
  183. q = ggml.to_numpy(nodes[b"q"])
  184. assert q.shape == q_exp.shape
  185. assert np.allclose(q_exp, q, atol=1e-5)
  186. attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
  187. self_attn.register_attn_weight_hook(attn_weights_hook)
  188. y_exp = self_attn(xq, None, xk, None, xk).numpy()
  189. # with flash_attn we don't have attn_weights
  190. naive_attn = b"attn_weights" in nodes
  191. if naive_attn and not overlap:
  192. attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
  193. [(_, attn_weights_exp)] = attn_weights_hook._storage
  194. attn_weights_exp = attn_weights_exp.numpy()
  195. assert attn_weights_exp.shape == attn_weights.shape
  196. # GGML is very agressively reducing small softmax weights to 0,
  197. # so the error isn't that small
  198. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  199. # But the sums should be close to 1
  200. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
  201. # And the maximum index should match the original ones.
  202. assert np.allclose(
  203. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  204. )
  205. assert y.shape == y_exp.shape
  206. assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
  207. def test_MultiheadAttention_forward_self_attn_with_cache(
  208. ctx: Ctx, g_model: c_void_p
  209. ) -> None:
  210. pt_model = load_pt_model()
  211. attn = pt_model.text_decoder.layers[0].self_attn
  212. x = torch.empty((2, 21, 1024))
  213. torch.random.manual_seed(0)
  214. torch.nn.init.uniform_(x, -1, 1)
  215. state_bag = fairseq2.nn.IncrementalStateBag(100)
  216. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  217. # Incremental decoding
  218. for t in range(3):
  219. xq = x[:, t : t + 1]
  220. gxq = ggml.from_numpy(ctx, xq.contiguous())
  221. ggml.ggml_set_name(gxq, b"xq")
  222. gy = ggml.forward(
  223. "MultiheadAttention",
  224. g_model,
  225. "text_decoder.layers.0.self_attn",
  226. gxq,
  227. gxq,
  228. gxq,
  229. None, # type: ignore
  230. )
  231. gf = ggml.build_and_compute(
  232. ctx,
  233. gy,
  234. dump=f"dot/test_MultiheadAttention_forward_self_attn_with_cache_{t}.dot",
  235. )
  236. nodes = ggml.nodes(gf)
  237. gk_cache = ggml.to_numpy(
  238. nodes[b"text_decoder.layers.0.self_attn.k (step=%d)" % t]
  239. )
  240. assert gk_cache.shape == (2, t + 1, 1024)
  241. gk_cache = gk_cache.reshape(2, t + 1, 16, 64).transpose(0, 2, 1, 3)
  242. assert gk_cache.shape == (2, 16, t + 1, 64)
  243. y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
  244. assert y_exp.shape == (2, 1, 1024)
  245. state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
  246. state_bag.increment_step_nr()
  247. assert state is not None
  248. k_cache = state.get()[0].numpy()
  249. assert k_cache.shape == (2, 16, t + 1, 64)
  250. assert np.allclose(gk_cache, k_cache, atol=1e-3)
  251. y = ggml.to_numpy(gy)
  252. assert np.allclose(y, y_exp, atol=1e-2)
  253. def test_MultiheadAttention_forward_cross_attn_with_cache(
  254. ctx: Ctx, g_model: c_void_p
  255. ) -> None:
  256. pt_model = load_pt_model()
  257. attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
  258. x = torch.empty((2, 21, 1024))
  259. torch.random.manual_seed(0)
  260. torch.nn.init.uniform_(x, -1, 1)
  261. state_bag = fairseq2.nn.IncrementalStateBag(100)
  262. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  263. # Incremental decoding, the keys come from the encoder, and don't change during decoding
  264. xk = x[:, :11]
  265. gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
  266. for t in range(3):
  267. xq = x[:, t : t + 1]
  268. gxq = ggml.from_numpy(ctx, xq.contiguous())
  269. ggml.ggml_set_name(gxq, b"xq")
  270. gy = ggml.forward(
  271. "MultiheadAttention",
  272. g_model,
  273. "text_decoder.layers.0.encoder_decoder_attn",
  274. gxq,
  275. gxk,
  276. gxk,
  277. None, # type: ignore
  278. )
  279. gf = ggml.build_and_compute(
  280. ctx,
  281. gy,
  282. dump=f"dot/test_MultiheadAttention_forward_cross_attn_with_cache_{t}.dot",
  283. )
  284. y = ggml.to_numpy(gy)
  285. nodes = ggml.nodes(gf)
  286. leaves = ggml.leafs(gf)
  287. if t > 0:
  288. # the cache only appear in the graph during the second call
  289. state = state_bag.get_state(
  290. attn, fairseq2.nn.transformer.AttentionState
  291. )
  292. assert state is not None
  293. assert np.allclose(
  294. state.get()[0].transpose(1, 2).numpy(),
  295. ggml.to_numpy(
  296. nodes[
  297. b"text_decoder.layers.0.encoder_decoder_attn.k_cache (view)"
  298. ]
  299. ),
  300. atol=1e-3,
  301. )
  302. state_bag.increment_step_nr()
  303. y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
  304. assert y_exp.shape == (2, 1, 1024)
  305. assert np.allclose(y, y_exp, atol=1e-2)
  306. def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  307. x = torch.empty((2, 21, 1024))
  308. torch.random.manual_seed(0)
  309. torch.nn.init.uniform_(x, -1, 1)
  310. pt_model = load_pt_model()
  311. layer = pt_model.text_encoder.layers[0]
  312. gx = ggml.from_numpy(ctx, x)
  313. ggml.ggml_set_name(gx, b"x")
  314. gy = ggml.forward(
  315. "StandardTransformerEncoderLayer",
  316. g_model,
  317. "text_encoder.layers.0",
  318. gx,
  319. None, # TODO support padding mask
  320. )
  321. gf = ggml.build_and_compute(ctx, gy)
  322. y = ggml.to_numpy(gy)
  323. y_exp, _ = layer(x, padding_mask=None)
  324. y_exp = y_exp.numpy()
  325. assert y.shape == y_exp.shape
  326. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  327. def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  328. pt_model = load_pt_model()
  329. x = torch.rand(1, 137, 1024)
  330. layer = pt_model.speech_encoder.inner.layers[0]
  331. gx = ggml.from_numpy(ctx, x[0])
  332. ggml.ggml_set_name(gx, b"x")
  333. gy = ggml.forward(
  334. "StandardConformerEncoderLayer",
  335. g_model,
  336. "speech_encoder.inner.layers.0",
  337. gx,
  338. None, # TODO support padding mask
  339. )
  340. gf = ggml.build_and_compute(ctx, gy)
  341. y = ggml.to_numpy(gy)
  342. y_exp, _ = layer(x, padding_mask=None)
  343. y_exp = y_exp.squeeze(0).numpy()
  344. assert y.shape == y_exp.shape
  345. assert np.allclose(y_exp, y, atol=2e-3)
  346. def test_StandardConformerEncoderAdaptorLayer_forward(
  347. ctx: Ctx, g_model: c_void_p
  348. ) -> None:
  349. pt_model = load_pt_model()
  350. torch.random.manual_seed(0)
  351. x = torch.rand(1, 137, 1024)
  352. layer = pt_model.speech_encoder.adaptor_layers[0]
  353. gx = ggml.from_numpy(ctx, x[0])
  354. ggml.ggml_set_name(gx, b"x")
  355. gy = ggml.forward(
  356. "StandardConformerEncoderAdaptorLayer",
  357. g_model,
  358. "speech_encoder.adaptor_layers.0",
  359. gx,
  360. None, # TODO support padding mask
  361. )
  362. gf = ggml.build_and_compute(ctx, gy)
  363. y = ggml.to_numpy(gy)
  364. y_exp, _ = layer(x, None)
  365. y_exp = y_exp.numpy()
  366. assert y.shape == y_exp.shape
  367. assert np.allclose(y_exp, y, atol=2e-3)
  368. def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  369. x = torch.empty((2, 21, 1024))
  370. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
  371. torch.random.manual_seed(0)
  372. torch.nn.init.uniform_(x, -1, 1)
  373. gx = ggml.from_numpy(ctx, x)
  374. ggml.ggml_set_name(gx, b"x")
  375. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  376. ggml.ggml_set_name(gpad, b"padding_mask")
  377. gy = ggml.forward(
  378. "StandardTransformerEncoder",
  379. g_model,
  380. "text_encoder",
  381. gx,
  382. None, # TODO support padding mask
  383. )
  384. gf = ggml.build_and_compute(ctx, gy)
  385. y = ggml.to_numpy(gy)
  386. pt_model = load_pt_model()
  387. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  388. y_exp = y_exp.numpy()
  389. assert y.shape == y_exp.shape
  390. assert np.allclose(y_exp, y, atol=5e-3)
  391. def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  392. pt_model = load_pt_model()
  393. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  394. download_sample_audio()
  395. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  396. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  397. ggml.ggml_set_name(gx, b"x")
  398. gy = ggml.forward(
  399. "StandardConformerEncoder",
  400. g_model,
  401. "speech_encoder",
  402. gx,
  403. None, # TODO support padding mask
  404. )
  405. gf = ggml.build_and_compute(ctx, gy)
  406. y = ggml.to_numpy(gy)
  407. cache = DATA / "test_StandardConformerEncoder_forward.npy"
  408. if not cache.exists():
  409. converter = WaveformToFbankConverter(
  410. num_mel_bins=80,
  411. waveform_scale=2**15,
  412. channel_last=True,
  413. standardize=True,
  414. )
  415. converter_input = {
  416. "waveform": wav.transpose(0, 1),
  417. "sample_rate": 16000.0,
  418. "format": -1,
  419. }
  420. pt_model = load_pt_model()
  421. speech_encoder_input = pt_model.speech_encoder_frontend(
  422. converter(converter_input)["fbank"].unsqueeze(0), None
  423. )[0]
  424. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  425. y_exp = y_exp.numpy()
  426. np.save(cache, y_exp)
  427. else:
  428. y_exp = np.load(cache)
  429. assert y.shape == y_exp.shape
  430. assert np.allclose(y_exp, y, atol=1e-2)
  431. def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
  432. converter = WaveformToFbankConverter(
  433. num_mel_bins=80,
  434. waveform_scale=2**15,
  435. channel_last=True,
  436. standardize=True,
  437. )
  438. extractor = Wav2Vec2FbankFeatureExtractor(80, stride=2, sample_every_k=1)
  439. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  440. download_sample_audio()
  441. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  442. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  443. ggml.ggml_set_name(gx, b"x")
  444. gy = ggml.forward("WaveformToFbank", g_model, "", gx)
  445. gf = ggml.build_and_compute(ctx, gy)
  446. y = ggml.to_numpy(gy)
  447. converter_input = {
  448. "waveform": wav.transpose(0, 1),
  449. "sample_rate": 16000.0,
  450. "format": -1,
  451. }
  452. y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
  453. y_exp = y_exp.squeeze(0).numpy()
  454. assert y.shape == y_exp.shape
  455. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  456. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  457. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  458. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  459. y_exp = pos_encoder(seq, None)[0].numpy()
  460. gseq = ggml.from_numpy(ctx, seq[0].clone().numpy())
  461. ggml.ggml_set_name(gseq, b"seq")
  462. gy = ggml.forward(
  463. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  464. )
  465. gf = ggml.build_and_compute(ctx, gy, dump=True)
  466. y = ggml.to_numpy(gy)
  467. assert y.shape == y_exp.shape
  468. assert np.allclose(y_exp, y, atol=1e-6)
  469. def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
  470. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  471. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  472. pos_encoder.eval()
  473. state_bag = fairseq2.nn.IncrementalStateBag(100)
  474. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  475. # Incremental decoding
  476. for t in range(20):
  477. gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
  478. ggml.ggml_set_name(gseq, b"seq")
  479. gy = ggml.forward(
  480. "PositionalEmbedding",
  481. g_model,
  482. "text_decoder_frontend.pos_encoder",
  483. gseq,
  484. )
  485. ggml.build_and_compute(ctx, gy, dump=t == 1)
  486. y = ggml.to_numpy(gy)
  487. y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
  488. state_bag.increment_step_nr()
  489. assert y.shape == y_exp.shape
  490. assert np.allclose(y_exp, y, atol=1e-6)
  491. def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
  492. seq = torch.arange(2 * 20).reshape(2, 20)
  493. seq[1, 15:] = 0 # padding for second sentence
  494. seq_len = torch.tensor([20, 15])
  495. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  496. ggml.ggml_set_name(gseq, b"seq")
  497. gy = ggml.forward(
  498. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  499. )
  500. ggml.build_and_compute(ctx, gy)
  501. y = ggml.to_numpy(gy)
  502. pt_model = load_pt_model()
  503. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  504. y_exp = y_exp.numpy()
  505. assert y.shape == y_exp.shape
  506. assert np.allclose(y_exp, y, atol=1e-6)
  507. def test_StandardTransformerDecoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  508. x = torch.empty((2, 13, 1024))
  509. encoder_out = torch.empty((2, 21, 1024))
  510. torch.random.manual_seed(0)
  511. torch.nn.init.uniform_(x, -1, 1)
  512. torch.nn.init.uniform_(encoder_out, -1, 1)
  513. self_attn_mask = fairseq2.nn.transformer.CausalAttentionMaskFactory()(x, x)
  514. gx = ggml.from_numpy(ctx, x)
  515. ggml.ggml_set_name(gx, b"x")
  516. gself_attn_mask = ggml.from_numpy(ctx, self_attn_mask.materialize().numpy())
  517. ggml.ggml_set_name(gself_attn_mask, b"self_attn_mask")
  518. genc = ggml.from_numpy(ctx, encoder_out)
  519. ggml.ggml_set_name(genc, b"encoder_out")
  520. gy = ggml.forward(
  521. "StandardTransformerDecoderLayer",
  522. g_model,
  523. "text_decoder.layers.0",
  524. gx,
  525. gself_attn_mask,
  526. genc,
  527. NULLPTR, # TODO support padding mask,
  528. )
  529. ggml.build_and_compute(ctx, gy, dump=True)
  530. y = ggml.to_numpy(gy)
  531. pt_model = load_pt_model()
  532. y_exp, _ = pt_model.text_decoder.layers[0](x, None, encoder_output=encoder_out, self_attn_mask=self_attn_mask)
  533. y_exp = y_exp.numpy()
  534. assert y.shape == y_exp.shape
  535. # We still have some numerical imprecision
  536. assert np.allclose(y_exp, y, atol=0.1)
  537. assert np.sum(np.abs(y_exp-y) > 1e-2) < 20
  538. def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  539. x = torch.empty((2, 13, 1024))
  540. encoder_out = torch.empty((2, 21, 1024))
  541. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
  542. torch.random.manual_seed(0)
  543. torch.nn.init.uniform_(x, -1, 1)
  544. torch.nn.init.uniform_(encoder_out, -1, 1)
  545. gx = ggml.from_numpy(ctx, x)
  546. ggml.ggml_set_name(gx, b"x")
  547. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  548. ggml.ggml_set_name(gpad, b"padding_mask")
  549. genc = ggml.from_numpy(ctx, encoder_out)
  550. gy = ggml.forward(
  551. "StandardTransformerDecoder",
  552. g_model,
  553. "text_decoder",
  554. gx,
  555. None, # TODO support padding mask,
  556. genc,
  557. None,
  558. )
  559. ggml.build_and_compute(ctx, gy)
  560. y = ggml.to_numpy(gy)
  561. pt_model = load_pt_model()
  562. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  563. y_exp = y_exp.numpy()
  564. assert y.shape == y_exp.shape
  565. assert np.allclose(y_exp, y, atol=1e-3) # TODO: those tests are failing now
  566. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  567. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  568. download_sample_audio()
  569. src_audio_wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  570. sample_file = DATA / "LJ037-0171_sr16k.wav.trans"
  571. translator = load_translator()
  572. if not sample_file.exists():
  573. decoded_audio = {
  574. "waveform": src_audio_wav.t(),
  575. "sample_rate": 16000.0,
  576. "format": -1,
  577. }
  578. src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  579. text_out, _ = translator.get_prediction(
  580. translator.model,
  581. translator.text_tokenizer,
  582. translator.unit_tokenizer,
  583. src["seqs"],
  584. padding_mask=None,
  585. input_modality=Modality.SPEECH,
  586. output_modality=Modality.TEXT,
  587. tgt_lang="cmn",
  588. text_generation_opts=SequenceGeneratorOptions(),
  589. unit_generation_opts=None,
  590. )
  591. tgt_text = str(text_out[0])
  592. assert tgt_text == "专家的检查和证据使该委员会得出了结论,可能有五次枪击."
  593. with open(sample_file, "w") as f:
  594. f.write(tgt_text)
  595. with open(sample_file, "r") as exp:
  596. exp_tgt_text = exp.readlines()[0].strip()
  597. # Apply scale before sending into ggml!
  598. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
  599. ggml.ggml_set_name(gx, b"x")
  600. encoder_out = ggml.forward(
  601. "StandardConformerEncoder",
  602. g_model,
  603. "speech_encoder",
  604. gx,
  605. NULLPTR, # TODO support padding mask
  606. )
  607. gf = ggml.build_and_compute(ctx, encoder_out)
  608. beam_size = 5
  609. opts = ggml.SequenceGeneratorOptions(
  610. beam_size=beam_size,
  611. soft_max_seq_len_a=1,
  612. soft_max_seq_len_b=200,
  613. hard_max_seq_len=500,
  614. )
  615. job = ggml.SequenceGeneratorJob(
  616. opts=opts,
  617. prefix_seq=ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32)),
  618. pad_idx=0,
  619. unk_idx=1,
  620. bos_idx=2,
  621. eos_idx=3,
  622. )
  623. result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
  624. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  625. tokens = [
  626. translator.text_tokenizer.model.index_to_token(id)
  627. for id in ggml.to_numpy(results[0].seq).tolist()
  628. ][2:-1]
  629. tokens = "".join(tokens).replace("▁", " ")[1:]
  630. assert tokens == exp_tgt_text