test_unity_cpp.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # MIT_LICENSE file in the root directory of this source tree.
  6. import ctypes
  7. import functools
  8. import shutil
  9. from ctypes import c_void_p
  10. from pathlib import Path
  11. from typing import Any, Iterator, Tuple
  12. import fairseq2.nn
  13. import fairseq2.nn.transformer
  14. import numpy as np
  15. import pytest
  16. import requests # type: ignore
  17. import torch
  18. import torchaudio # type: ignore
  19. from ctypes_utils import NULLPTR, Ptr
  20. from fairseq2.data.audio import WaveformToFbankConverter
  21. from fairseq2.models.wav2vec2.feature_extractor import Wav2Vec2FbankFeatureExtractor
  22. from ggml_convert import convert_model, read_layer_config
  23. import ggml
  24. from ggml import NativeObj
  25. from seamless_communication.inference.generator import SequenceGeneratorOptions
  26. from seamless_communication.inference.translator import Modality, Translator
  27. Ctx = ggml.ggml_context_p
  28. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  29. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  30. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  31. DATA = Path(__file__).parent / "test_data"
  32. LOCAL_AUDIO_SAMPLE_PATH = DATA / "LJ037-0171_sr16k.wav"
  33. TEST_AUDIO_SAMPLE_URL = (
  34. "https://dl.fbaipublicfiles.com/seamless/tests/LJ037-0171_sr16k.wav"
  35. )
  36. MB = 1024 * 1024
  37. @pytest.fixture(name="ctx")
  38. def _ctx() -> Iterator[Ctx]:
  39. """Allocate a new context with 1024 MB of memory"""
  40. try:
  41. mem_size = 16 * MB
  42. memory = torch.zeros(mem_size, dtype=torch.uint8)
  43. ctx = ggml.ggml_init(
  44. params=ggml.ggml_init_params(
  45. mem_size=mem_size,
  46. mem_buffer=ctypes.c_void_p(memory.data_ptr()),
  47. no_alloc=True,
  48. )
  49. )
  50. # Create 'dot' folder for temporary dump of ggml graphs
  51. (Path(__file__).parent / "dot").mkdir(exist_ok=True)
  52. with torch.inference_mode():
  53. yield ctx
  54. finally:
  55. ggml.ggml_free(ctx)
  56. @functools.lru_cache()
  57. def _load_g_model_once() -> NativeObj:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. if not model_file.exists():
  60. convert_model("seamlessM4T_medium", model_file)
  61. return ggml.load_fairseq2_ggml_file(model_file)
  62. @pytest.fixture()
  63. def g_model(ctx: Ctx) -> c_void_p:
  64. model = _load_g_model_once()
  65. ggml.lib.fairseq2_model_set_inference_ctx(model.ptr, ctx)
  66. return model.ptr
  67. @functools.lru_cache(maxsize=1)
  68. def load_translator() -> Translator:
  69. return Translator("seamlessM4T_medium", None, device=torch.device("cpu"))
  70. def load_pt_model() -> Any:
  71. return load_translator().model
  72. def download_sample_audio() -> Any:
  73. Path(DATA).mkdir(exist_ok=True)
  74. response = requests.get(TEST_AUDIO_SAMPLE_URL, stream=True)
  75. with open(DATA / "LJ037-0171_sr16k.wav", "wb") as file:
  76. for chunk in response.iter_content(chunk_size=1024):
  77. if chunk:
  78. file.write(chunk)
  79. def test_convert_linear(tmp_path: Path) -> None:
  80. module = fairseq2.nn.Linear(16, 24, True)
  81. layer_config = read_layer_config(module, "")
  82. assert layer_config == {"input_dim": 16, "output_dim": 24}
  83. module_file = tmp_path / "module.ggml"
  84. convert_model(module, module_file)
  85. g_module = ggml.load_fairseq2_ggml_file(module_file)
  86. for k, v in layer_config.items():
  87. assert (
  88. ggml.fairseq2_model_layer_config_int(g_module.ptr, bytes(k, "ascii")) == v
  89. )
  90. def test_convert_linear_fp16(tmp_path: Path, ctx: Ctx) -> None:
  91. pt_model = torch.nn.ModuleDict({"linear": fairseq2.nn.Linear(16, 24, True)})
  92. layer_config = read_layer_config(pt_model, "")
  93. assert layer_config == {"linear.input_dim": 16, "linear.output_dim": 24}
  94. ggml_file = tmp_path / "linear.ggml"
  95. convert_model(pt_model, ggml_file, fp16=True)
  96. assert ggml_file.stat().st_size < (16 * 24 + 24) * 2 * 1.5
  97. g_model = ggml.load_fairseq2_ggml_file(ggml_file)
  98. ggml.lib.fairseq2_model_set_inference_ctx(g_model.ptr, ctx)
  99. x = torch.empty((2, 5, 16))
  100. torch.nn.init.uniform_(x, -1, 1)
  101. y_exp = pt_model.linear(x).numpy()
  102. gx = ggml.from_numpy(ctx, x)
  103. gy = ggml.forward("Linear", g_model.ptr, "linear", gx)
  104. ggml.build_and_compute(ctx, gy)
  105. y = ggml.to_numpy(gy)
  106. assert np.allclose(y_exp, y, atol=1e-3)
  107. def test_causal_attention_mask(ctx: Ctx):
  108. x = torch.zeros((1, 10, 32))
  109. generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
  110. mask_exp = generator(x, x).materialize().numpy()
  111. gx = ggml.from_numpy(ctx, x)
  112. gmask = ggml.causal_attention_mask(ctx, gx)
  113. ggml.build_and_compute(ctx, gmask)
  114. mask = ggml.to_numpy(gmask)
  115. assert mask_exp.shape == (10, 10)
  116. assert mask.shape == (10, 10)
  117. assert np.all(mask == mask_exp)
  118. x = x[:, :8, :]
  119. mask_exp = generator(x, x).materialize().numpy()
  120. gx = ggml.from_numpy(ctx, x)
  121. gmask = ggml.causal_attention_mask(ctx, gx)
  122. ggml.build_and_compute(ctx, gmask)
  123. mask = ggml.to_numpy(gmask)
  124. assert mask_exp.shape == (8, 8)
  125. assert mask.shape == (8, 8)
  126. assert np.all(mask == mask_exp)
  127. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p) -> None:
  128. x = torch.empty((2, 21, 1024))
  129. torch.nn.init.uniform_(x, -1, 1)
  130. pt_model = load_pt_model()
  131. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  132. gx = ggml.from_numpy(ctx, x)
  133. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  134. ggml.build_and_compute(ctx, gy)
  135. y = ggml.to_numpy(gy)
  136. assert np.allclose(y_exp, y, atol=1e-5)
  137. def test_Linear_forward(ctx: Ctx, g_model: c_void_p) -> None:
  138. x = torch.empty((2, 21, 1024))
  139. torch.nn.init.uniform_(x, -1, 1)
  140. pt_model = load_pt_model()
  141. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  142. gx = ggml.from_numpy(ctx, x)
  143. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  144. ggml.build_and_compute(ctx, gy, dump="dot/test_Linear_forward.dot")
  145. y = ggml.to_numpy(gy)
  146. assert np.allclose(y_exp, y, atol=1e-5)
  147. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p) -> None:
  148. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  149. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  150. # Test FFN without LayerNorm
  151. pt_model = load_pt_model()
  152. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  153. gx = ggml.from_numpy(ctx, x)
  154. gy = ggml.forward(
  155. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  156. )
  157. ggml.build_and_compute(ctx, gy)
  158. y = ggml.to_numpy(gy)
  159. assert np.allclose(y_exp, y, atol=1e-5)
  160. @pytest.mark.parametrize("lengths", [(11, 21), (21, 13)])
  161. def test_MultiheadAttention_forward(
  162. ctx: Ctx, g_model: c_void_p, lengths: Tuple[int, int]
  163. ) -> None:
  164. x = torch.empty((2, 21, 1024))
  165. torch.random.manual_seed(0)
  166. torch.nn.init.uniform_(x, -1, 1)
  167. # Note: we use different lengths for queries and keys,
  168. # this tests the implementation in decoding context too.
  169. # Note2: ggml_flash_attn requires that we have more keys than queries
  170. # qlen, klen = (11, 21) if flash_attn else (21, 13)
  171. qlen, klen = lengths
  172. xq = x[:, :qlen]
  173. xk = x[:, :klen]
  174. if qlen > klen and UNITY_FLASH_ATTN:
  175. pytest.skip(reason="flash_attn requires qlen > klen")
  176. gxq = ggml.from_numpy(ctx, xq.contiguous())
  177. ggml.ggml_set_name(gxq, b"xq")
  178. gxk = ggml.from_numpy(ctx, xk.contiguous())
  179. ggml.ggml_set_name(gxk, b"xk")
  180. ggml.ggml_set_no_alloc(ctx, True)
  181. gy = ggml.forward(
  182. "MultiheadAttention",
  183. g_model,
  184. "text_encoder.layers.0.self_attn",
  185. gxq,
  186. gxk,
  187. gxk,
  188. NULLPTR, # TODO: tests with causal attention masks
  189. )
  190. gf = ggml.build_and_compute(ctx, gy, dump="dot/test_MultiheadAttention_forward")
  191. y = ggml.to_numpy(gy)
  192. nodes = ggml.nodes(gf)
  193. node_buffers = set(t.contents.data for t in nodes.values())
  194. pt_model = load_pt_model()
  195. self_attn = pt_model.text_encoder.layers[0].self_attn
  196. # If buffers are overlapping, reading node contents, can be misleading.
  197. overlap = len(node_buffers) < len(nodes)
  198. if not overlap:
  199. q_exp = self_attn._project_q(xq, None).numpy().reshape(2 * 16, qlen, 64)
  200. q = ggml.to_numpy(nodes[b"q"])
  201. assert q.shape == q_exp.shape
  202. assert np.allclose(q_exp, q, atol=1e-5)
  203. attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
  204. self_attn.register_attn_weight_hook(attn_weights_hook)
  205. y_exp = self_attn(xq, None, xk, None, xk).numpy()
  206. # with flash_attn we don't have attn_weights
  207. naive_attn = b"attn_weights" in nodes
  208. if naive_attn and not overlap:
  209. attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
  210. [(_, attn_weights_exp)] = attn_weights_hook._storage
  211. attn_weights_exp = attn_weights_exp.numpy()
  212. assert attn_weights_exp.shape == attn_weights.shape
  213. # GGML is very agressively reducing small softmax weights to 0,
  214. # so the error isn't that small
  215. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  216. # But the sums should be close to 1
  217. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
  218. # And the maximum index should match the original ones.
  219. assert np.allclose(
  220. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  221. )
  222. assert y.shape == y_exp.shape
  223. assert np.allclose(y_exp, y, atol=1e-2 if naive_attn else 1e-4)
  224. def test_MultiheadAttention_forward_self_attn_with_cache(
  225. ctx: Ctx, g_model: c_void_p
  226. ) -> None:
  227. pt_model = load_pt_model()
  228. attn = pt_model.text_decoder.layers[0].self_attn
  229. x = torch.empty((2, 21, 1024))
  230. torch.random.manual_seed(0)
  231. torch.nn.init.uniform_(x, -1, 1)
  232. state_bag = fairseq2.nn.IncrementalStateBag(100)
  233. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  234. # Incremental decoding
  235. for t in range(3):
  236. xq = x[:, t : t + 1]
  237. gxq = ggml.from_numpy(ctx, xq.contiguous())
  238. ggml.ggml_set_name(gxq, b"xq")
  239. gy = ggml.forward(
  240. "MultiheadAttention",
  241. g_model,
  242. "text_decoder.layers.0.self_attn",
  243. gxq,
  244. gxq,
  245. gxq,
  246. None, # type: ignore
  247. )
  248. gf = ggml.build_and_compute(
  249. ctx,
  250. gy,
  251. dump=f"dot/test_MultiheadAttention_forward_self_attn_with_cache_{t}.dot",
  252. )
  253. nodes = ggml.nodes(gf)
  254. gk_cache = ggml.to_numpy(
  255. nodes[b"text_decoder.layers.0.self_attn.k (step=%d)" % t]
  256. )
  257. assert gk_cache.shape == (2, t + 1, 1024)
  258. gk_cache = gk_cache.reshape(2, t + 1, 16, 64).transpose(0, 2, 1, 3)
  259. assert gk_cache.shape == (2, 16, t + 1, 64)
  260. y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
  261. assert y_exp.shape == (2, 1, 1024)
  262. state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
  263. state_bag.increment_step_nr()
  264. assert state is not None
  265. k_cache = state.get()[0].numpy()
  266. assert k_cache.shape == (2, 16, t + 1, 64)
  267. assert np.allclose(gk_cache, k_cache, atol=1e-3)
  268. y = ggml.to_numpy(gy)
  269. assert np.allclose(y, y_exp, atol=1e-2)
  270. def test_MultiheadAttention_forward_cross_attn_with_cache(
  271. ctx: Ctx, g_model: c_void_p
  272. ) -> None:
  273. pt_model = load_pt_model()
  274. attn = pt_model.text_decoder.layers[0].encoder_decoder_attn
  275. x = torch.empty((2, 21, 1024))
  276. torch.random.manual_seed(0)
  277. torch.nn.init.uniform_(x, -1, 1)
  278. state_bag = fairseq2.nn.IncrementalStateBag(100)
  279. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  280. # Incremental decoding, the keys come from the encoder, and don't change during decoding
  281. xk = x[:, :11]
  282. gxk = ggml.from_numpy(ctx, xk.contiguous(), name=b"xk")
  283. for t in range(3):
  284. xq = x[:, t : t + 1]
  285. gxq = ggml.from_numpy(ctx, xq.contiguous())
  286. ggml.ggml_set_name(gxq, b"xq")
  287. gy = ggml.forward(
  288. "MultiheadAttention",
  289. g_model,
  290. "text_decoder.layers.0.encoder_decoder_attn",
  291. gxq,
  292. gxk,
  293. gxk,
  294. None, # type: ignore
  295. )
  296. gf = ggml.build_and_compute(
  297. ctx,
  298. gy,
  299. dump=f"dot/test_MultiheadAttention_forward_cross_attn_with_cache_{t}.dot",
  300. )
  301. y = ggml.to_numpy(gy)
  302. nodes = ggml.nodes(gf)
  303. leaves = ggml.leafs(gf)
  304. if t > 0:
  305. # the cache only appear in the graph during the second call
  306. state = state_bag.get_state(
  307. attn, fairseq2.nn.transformer.AttentionState
  308. )
  309. assert state is not None
  310. assert np.allclose(
  311. state.get()[0].transpose(1, 2).numpy(),
  312. ggml.to_numpy(
  313. nodes[
  314. b"text_decoder.layers.0.encoder_decoder_attn.k_cache (view)"
  315. ]
  316. ),
  317. atol=1e-3,
  318. )
  319. state_bag.increment_step_nr()
  320. y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
  321. assert y_exp.shape == (2, 1, 1024)
  322. assert np.allclose(y, y_exp, atol=1e-2)
  323. def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  324. x = torch.empty((2, 21, 1024))
  325. torch.random.manual_seed(0)
  326. torch.nn.init.uniform_(x, -1, 1)
  327. pt_model = load_pt_model()
  328. layer = pt_model.text_encoder.layers[0]
  329. gx = ggml.from_numpy(ctx, x)
  330. ggml.ggml_set_name(gx, b"x")
  331. gy = ggml.forward(
  332. "StandardTransformerEncoderLayer",
  333. g_model,
  334. "text_encoder.layers.0",
  335. gx,
  336. None, # TODO support padding mask
  337. )
  338. gf = ggml.build_and_compute(ctx, gy)
  339. y = ggml.to_numpy(gy)
  340. y_exp, _ = layer(x, padding_mask=None)
  341. y_exp = y_exp.numpy()
  342. assert y.shape == y_exp.shape
  343. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  344. def test_StandardConformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  345. pt_model = load_pt_model()
  346. x = torch.rand(1, 137, 1024)
  347. layer = pt_model.speech_encoder.inner.layers[0]
  348. gx = ggml.from_numpy(ctx, x[0])
  349. ggml.ggml_set_name(gx, b"x")
  350. gy = ggml.forward(
  351. "StandardConformerEncoderLayer",
  352. g_model,
  353. "speech_encoder.inner.layers.0",
  354. gx,
  355. None, # TODO support padding mask
  356. )
  357. gf = ggml.build_and_compute(ctx, gy)
  358. y = ggml.to_numpy(gy)
  359. y_exp, _ = layer(x, padding_mask=None)
  360. y_exp = y_exp.squeeze(0).numpy()
  361. assert y.shape == y_exp.shape
  362. assert np.allclose(y_exp, y, atol=2e-3)
  363. def test_StandardConformerEncoderAdaptorLayer_forward(
  364. ctx: Ctx, g_model: c_void_p
  365. ) -> None:
  366. pt_model = load_pt_model()
  367. torch.random.manual_seed(0)
  368. x = torch.rand(1, 137, 1024)
  369. layer = pt_model.speech_encoder.adaptor_layers[0]
  370. gx = ggml.from_numpy(ctx, x[0])
  371. ggml.ggml_set_name(gx, b"x")
  372. gy = ggml.forward(
  373. "StandardConformerEncoderAdaptorLayer",
  374. g_model,
  375. "speech_encoder.adaptor_layers.0",
  376. gx,
  377. None, # TODO support padding mask
  378. )
  379. gf = ggml.build_and_compute(ctx, gy)
  380. y = ggml.to_numpy(gy)
  381. y_exp, _ = layer(x, None)
  382. y_exp = y_exp.numpy()
  383. assert y.shape == y_exp.shape
  384. assert np.allclose(y_exp, y, atol=2e-3)
  385. def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  386. x = torch.empty((2, 21, 1024))
  387. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
  388. torch.random.manual_seed(0)
  389. torch.nn.init.uniform_(x, -1, 1)
  390. gx = ggml.from_numpy(ctx, x)
  391. ggml.ggml_set_name(gx, b"x")
  392. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  393. ggml.ggml_set_name(gpad, b"padding_mask")
  394. gy = ggml.forward(
  395. "StandardTransformerEncoder",
  396. g_model,
  397. "text_encoder",
  398. gx,
  399. None, # TODO support padding mask
  400. )
  401. gf = ggml.build_and_compute(ctx, gy)
  402. y = ggml.to_numpy(gy)
  403. pt_model = load_pt_model()
  404. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  405. y_exp = y_exp.numpy()
  406. assert y.shape == y_exp.shape
  407. assert np.allclose(y_exp, y, atol=5e-3)
  408. def test_StandardConformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  409. pt_model = load_pt_model()
  410. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  411. download_sample_audio()
  412. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  413. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  414. ggml.ggml_set_name(gx, b"x")
  415. gy = ggml.forward(
  416. "StandardConformerEncoder",
  417. g_model,
  418. "speech_encoder",
  419. gx,
  420. None, # TODO support padding mask
  421. )
  422. gf = ggml.build_and_compute(ctx, gy)
  423. y = ggml.to_numpy(gy)
  424. cache = DATA / "test_StandardConformerEncoder_forward.npy"
  425. if not cache.exists():
  426. converter = WaveformToFbankConverter(
  427. num_mel_bins=80,
  428. waveform_scale=2**15,
  429. channel_last=True,
  430. standardize=True,
  431. )
  432. converter_input = {
  433. "waveform": wav.transpose(0, 1),
  434. "sample_rate": 16000.0,
  435. "format": -1,
  436. }
  437. pt_model = load_pt_model()
  438. speech_encoder_input = pt_model.speech_encoder_frontend(
  439. converter(converter_input)["fbank"].unsqueeze(0), None
  440. )[0]
  441. y_exp, _ = pt_model.speech_encoder(speech_encoder_input, None)
  442. y_exp = y_exp.numpy()
  443. np.save(cache, y_exp)
  444. else:
  445. y_exp = np.load(cache)
  446. assert y.shape == y_exp.shape
  447. assert np.allclose(y_exp, y, atol=1e-2)
  448. def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
  449. converter = WaveformToFbankConverter(
  450. num_mel_bins=80,
  451. waveform_scale=2**15,
  452. channel_last=True,
  453. standardize=True,
  454. )
  455. extractor = Wav2Vec2FbankFeatureExtractor(80, stride=2, sample_every_k=1)
  456. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  457. download_sample_audio()
  458. wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  459. gx = ggml.from_numpy(ctx, wav * 2**15) # Apply scale before sending into ggml!
  460. ggml.ggml_set_name(gx, b"x")
  461. gy = ggml.forward("WaveformToFbank", g_model, "", gx)
  462. gf = ggml.build_and_compute(ctx, gy)
  463. y = ggml.to_numpy(gy)
  464. converter_input = {
  465. "waveform": wav.transpose(0, 1),
  466. "sample_rate": 16000.0,
  467. "format": -1,
  468. }
  469. y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
  470. y_exp = y_exp.squeeze(0).numpy()
  471. assert y.shape == y_exp.shape
  472. assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
  473. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  474. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  475. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  476. y_exp = pos_encoder(seq, None)[0].numpy()
  477. gseq = ggml.from_numpy(ctx, seq[0].clone().numpy())
  478. ggml.ggml_set_name(gseq, b"seq")
  479. gy = ggml.forward(
  480. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  481. )
  482. gf = ggml.build_and_compute(ctx, gy, dump=True)
  483. y = ggml.to_numpy(gy)
  484. assert y.shape == y_exp.shape
  485. assert np.allclose(y_exp, y, atol=1e-6)
  486. def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) -> None:
  487. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  488. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=1)
  489. pos_encoder.eval()
  490. state_bag = fairseq2.nn.IncrementalStateBag(100)
  491. with ggml.fairseq2_kv_cache_alloc(g_model, 16 * MB, 2, 21):
  492. # Incremental decoding
  493. for t in range(20):
  494. gseq = ggml.from_numpy(ctx, seq[:, t : t + 1, :].numpy())
  495. ggml.ggml_set_name(gseq, b"seq")
  496. gy = ggml.forward(
  497. "PositionalEmbedding",
  498. g_model,
  499. "text_decoder_frontend.pos_encoder",
  500. gseq,
  501. )
  502. ggml.build_and_compute(ctx, gy, dump=t == 1)
  503. y = ggml.to_numpy(gy)
  504. y_exp = pos_encoder(seq[:, t : t + 1, :], None, state_bag=state_bag).numpy()
  505. state_bag.increment_step_nr()
  506. assert y.shape == y_exp.shape
  507. assert np.allclose(y_exp, y, atol=1e-6)
  508. def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> None:
  509. seq = torch.arange(2 * 20).reshape(2, 20)
  510. seq[1, 15:] = 0 # padding for second sentence
  511. seq_len = torch.tensor([20, 15])
  512. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  513. ggml.ggml_set_name(gseq, b"seq")
  514. gy = ggml.forward(
  515. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  516. )
  517. ggml.build_and_compute(ctx, gy)
  518. y = ggml.to_numpy(gy)
  519. pt_model = load_pt_model()
  520. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  521. y_exp = y_exp.numpy()
  522. assert y.shape == y_exp.shape
  523. assert np.allclose(y_exp, y, atol=1e-6)
  524. def test_StandardTransformerDecoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
  525. x = torch.empty((2, 13, 1024))
  526. encoder_out = torch.empty((2, 21, 1024))
  527. torch.random.manual_seed(0)
  528. torch.nn.init.uniform_(x, -1, 1)
  529. torch.nn.init.uniform_(encoder_out, -1, 1)
  530. self_attn_mask = fairseq2.nn.transformer.CausalAttentionMaskFactory()(x, x)
  531. gx = ggml.from_numpy(ctx, x)
  532. ggml.ggml_set_name(gx, b"x")
  533. gself_attn_mask = ggml.from_numpy(ctx, self_attn_mask.materialize().numpy())
  534. ggml.ggml_set_name(gself_attn_mask, b"self_attn_mask")
  535. genc = ggml.from_numpy(ctx, encoder_out)
  536. ggml.ggml_set_name(genc, b"encoder_out")
  537. gy = ggml.forward(
  538. "StandardTransformerDecoderLayer",
  539. g_model,
  540. "text_decoder.layers.0",
  541. gx,
  542. gself_attn_mask,
  543. genc,
  544. NULLPTR, # TODO support padding mask,
  545. )
  546. ggml.build_and_compute(ctx, gy, dump=True)
  547. y = ggml.to_numpy(gy)
  548. pt_model = load_pt_model()
  549. y_exp, _ = pt_model.text_decoder.layers[0](x, None, encoder_output=encoder_out, self_attn_mask=self_attn_mask)
  550. y_exp = y_exp.numpy()
  551. assert y.shape == y_exp.shape
  552. # We still have some numerical imprecision
  553. assert np.allclose(y_exp, y, atol=0.1)
  554. assert np.sum(np.abs(y_exp-y) > 1e-2) < 20
  555. def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
  556. x = torch.empty((2, 13, 1024))
  557. encoder_out = torch.empty((2, 21, 1024))
  558. padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
  559. torch.random.manual_seed(0)
  560. torch.nn.init.uniform_(x, -1, 1)
  561. torch.nn.init.uniform_(encoder_out, -1, 1)
  562. gx = ggml.from_numpy(ctx, x)
  563. ggml.ggml_set_name(gx, b"x")
  564. gpad = ggml.from_numpy(ctx, padding_mask.materialize())
  565. ggml.ggml_set_name(gpad, b"padding_mask")
  566. genc = ggml.from_numpy(ctx, encoder_out)
  567. gy = ggml.forward(
  568. "StandardTransformerDecoder",
  569. g_model,
  570. "text_decoder",
  571. gx,
  572. None, # TODO support padding mask,
  573. genc,
  574. None,
  575. )
  576. ggml.build_and_compute(ctx, gy)
  577. y = ggml.to_numpy(gy)
  578. pt_model = load_pt_model()
  579. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  580. y_exp = y_exp.numpy()
  581. assert y.shape == y_exp.shape
  582. assert np.allclose(y_exp, y, atol=1e-3) # TODO: those tests are failing now
  583. def test_s2tt(ctx: Ctx, g_model: c_void_p):
  584. if not LOCAL_AUDIO_SAMPLE_PATH.exists():
  585. download_sample_audio()
  586. src_audio_wav, _ = torchaudio.load(LOCAL_AUDIO_SAMPLE_PATH)
  587. sample_file = DATA / "LJ037-0171_sr16k.wav.trans"
  588. translator = load_translator()
  589. if not sample_file.exists():
  590. decoded_audio = {
  591. "waveform": src_audio_wav.t(),
  592. "sample_rate": 16000.0,
  593. "format": -1,
  594. }
  595. src = translator.collate(translator.convert_to_fbank(decoded_audio))["fbank"]
  596. text_out, _ = translator.get_prediction(
  597. translator.model,
  598. translator.text_tokenizer,
  599. translator.unit_tokenizer,
  600. src["seqs"],
  601. padding_mask=None,
  602. input_modality=Modality.SPEECH,
  603. output_modality=Modality.TEXT,
  604. tgt_lang="cmn",
  605. text_generation_opts=SequenceGeneratorOptions(),
  606. unit_generation_opts=None,
  607. )
  608. tgt_text = str(text_out[0])
  609. assert tgt_text == "专家的检查和证据使该委员会得出了结论,可能有五次枪击."
  610. with open(sample_file, "w") as f:
  611. f.write(tgt_text)
  612. with open(sample_file, "r") as exp:
  613. exp_tgt_text = exp.readlines()[0].strip()
  614. # Apply scale before sending into ggml!
  615. gx = ggml.from_numpy(ctx, src_audio_wav * 2**15)
  616. ggml.ggml_set_name(gx, b"x")
  617. encoder_out = ggml.forward(
  618. "StandardConformerEncoder",
  619. g_model,
  620. "speech_encoder",
  621. gx,
  622. NULLPTR, # TODO support padding mask
  623. )
  624. gf = ggml.build_and_compute(ctx, encoder_out)
  625. beam_size = 5
  626. opts = ggml.SequenceGeneratorOptions(
  627. beam_size=beam_size,
  628. soft_max_seq_len_a=1,
  629. soft_max_seq_len_b=200,
  630. hard_max_seq_len=500,
  631. )
  632. job = ggml.SequenceGeneratorJob(
  633. opts=opts,
  634. prefix_seq=ggml.from_numpy(ctx, np.array([3, 256200]).astype(np.int32)),
  635. pad_idx=0,
  636. unk_idx=1,
  637. bos_idx=2,
  638. eos_idx=3,
  639. )
  640. result_ptr = ggml.generate_sequence(g_model, Ptr(job), encoder_out, NULLPTR, ctx)
  641. results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
  642. tokens = [
  643. translator.text_tokenizer.model.index_to_token(id)
  644. for id in ggml.to_numpy(results[0].seq).tolist()
  645. ][2:-1]
  646. tokens = "".join(tokens).replace("▁", " ")[1:]
  647. assert tokens == exp_tgt_text