test_unity_cpp.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. from ctypes import c_void_p
  10. from typing import Any
  11. from pathlib import Path
  12. from typing import Iterator
  13. from ggml import NativeObj
  14. from ggml_convert import convert_model
  15. from seamless_communication.models.unity import load_unity_model
  16. Ctx = ggml.ggml_context_p
  17. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  18. PARAMS_256MB = ggml.ggml_init_params(mem_size=256 * 1024 * 1024, mem_buffer=None)
  19. @pytest.fixture(name="ctx")
  20. def _ctx() -> Iterator[Ctx]:
  21. """Allocate a new context with 256 MB of memory"""
  22. try:
  23. ctx = ggml.ggml_init(params=PARAMS_256MB)
  24. yield ctx
  25. finally:
  26. ggml.ggml_free(ctx)
  27. def test_ggml_bindings_work(ctx: Ctx) -> None:
  28. # Instantiate tensors
  29. x = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  30. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  31. b = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  32. # Use ggml operations to build a computational graph
  33. x2 = ggml.ggml_mul(ctx, x, x)
  34. f = ggml.ggml_add(ctx, ggml.ggml_mul(ctx, a, x2), b)
  35. gf = ggml.ggml_build_forward(f)
  36. # Set the input values
  37. ggml.ggml_set_f32(x, 2.0)
  38. ggml.ggml_set_f32(a, 3.0)
  39. ggml.ggml_set_f32(b, 4.0)
  40. # Compute the graph
  41. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  42. # Get the output value
  43. output = ggml.ggml_get_f32_1d(f, 0)
  44. assert output == 16.0
  45. def test_ggml_matmul(ctx: Ctx) -> None:
  46. # Instantiate tensors
  47. a = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 2)
  48. x = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 3)
  49. # Use ggml operations to build a computational graph
  50. y = ggml.ggml_mul_mat(ctx, a, x)
  51. assert ggml.shape(y) == (3, 2)
  52. gf = ggml.ggml_build_forward(y)
  53. # Set the input values
  54. ggml.ggml_set_f32(x, 0.0)
  55. for i in range(4 * 3):
  56. ggml.ggml_set_f32_1d(x, i, i)
  57. ggml.ggml_set_f32(a, 0.0)
  58. ggml.ggml_set_f32_1d(a, 1, 1.0)
  59. ggml.ggml_set_f32_1d(a, 7, 1.0)
  60. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  61. output = [[ggml.ggml_get_f32_1d(y, j * 2 + i) for j in range(3)] for i in range(2)]
  62. assert output == [[1, 5, 9], [3, 7, 11]]
  63. def test_shape_works(ctx: Ctx) -> None:
  64. """GGML shape order convention is the reverse from numpy"""
  65. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  66. assert ggml.shape(a) == (10,)
  67. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  68. assert ggml.shape(b) == (21, 11)
  69. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  70. assert ggml.shape(c) == (32, 22, 12)
  71. def test_nb_works(ctx: Ctx) -> None:
  72. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  73. assert ggml.nb(a) == (4, 40, 40, 40)
  74. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  75. assert ggml.nb(b) == (2, 22, 462, 462)
  76. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  77. assert ggml.nb(c) == (4, 48, 1056, 33792)
  78. @pytest.mark.xfail(reason="TODO: fix strides")
  79. def test_strides_works(ctx: Ctx) -> None:
  80. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  81. assert ggml.strides(a) == np.ones((10,), dtype=np.float32).strides
  82. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  83. assert ggml.strides(b) == np.ones((11, 21), dtype=np.float32).strides
  84. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  85. assert ggml.strides(c) == np.ones((12, 22, 32), dtype=np.float32).strides
  86. def test_to_numpy_works_with_f32(ctx: Ctx) -> None:
  87. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  88. na = ggml.to_numpy(a)
  89. for i in range(10):
  90. ggml.ggml_set_f32_1d(a, i, i)
  91. assert na[5] == 5
  92. assert np.allclose(na, np.array(range(10), dtype=np.float32))
  93. ggml.ggml_set_f32_1d(a, 5, -1.5)
  94. assert na[5] == -1.5
  95. # Note: GGML order of dims is reversed wrt numpy shapes
  96. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  97. for i in range(11 * 21):
  98. ggml.ggml_set_f32_1d(b, i, i)
  99. nb = ggml.to_numpy(b)
  100. # assert nb.shape == (21, 11)
  101. assert nb[0, 5] == 5
  102. assert nb[3, 5] == 11 * 3 + 5
  103. assert np.allclose(
  104. nb, np.array(range(11 * 21), dtype=np.float32).reshape(ggml.shape(b))
  105. )
  106. ggml.ggml_set_f32_1d(b, 11 * 3 + 5, -1.5)
  107. assert nb[3, 5] == -1.5
  108. sum_rows = ggml.ggml_sum_rows(ctx, b)
  109. gf = ggml.ggml_build_forward(sum_rows)
  110. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  111. np_sum_rows = np.sum(nb, axis=-1, keepdims=True)
  112. assert np_sum_rows.shape == ggml.shape(sum_rows)
  113. for i in range(11):
  114. assert np_sum_rows[i] == ggml.ggml_get_f32_1d(sum_rows, i)
  115. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  116. for i in range(12 * 22 * 32):
  117. ggml.ggml_set_f32_1d(c, i, i)
  118. nc = ggml.to_numpy(c)
  119. assert ggml.shape(c) == (32, 22, 12)
  120. assert nc[3, 5, 11] == 22 * 12 * 3 + 12 * 5 + 11
  121. assert np.allclose(
  122. nc, np.array(range(12 * 22 * 32), dtype=np.float32).reshape(ggml.shape(c))
  123. )
  124. ggml.ggml_set_f32_1d(c, 22 * 12 * 3 + 12 * 5 + 11, -1.5)
  125. assert nc[3, 5, 11] == -1.5
  126. def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
  127. a = np.random.normal(size=(10,)).astype(dtype=np.float32)
  128. ga = ggml.from_numpy(ctx, a)
  129. assert ggml.shape(ga) == (10,)
  130. assert ggml.nb(ga) == ggml.nb(ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10))
  131. assert np.allclose(a, ggml.to_numpy(ga))
  132. a = np.random.normal(size=(11, 21)).astype(dtype=np.float32)
  133. ga = ggml.from_numpy(ctx, a)
  134. assert ggml.shape(ga) == (11, 21)
  135. assert ggml.nb(ga) == ggml.nb(
  136. ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
  137. )
  138. assert np.allclose(a, ggml.to_numpy(ga))
  139. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float32)
  140. ga = ggml.from_numpy(ctx, a)
  141. assert ggml.shape(ga) == (12, 22, 32)
  142. assert ggml.nb(ga) == ggml.nb(
  143. ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
  144. )
  145. assert np.allclose(a, ggml.to_numpy(ga))
  146. def test_to_numpy_works_with_f16(ctx: Ctx) -> None:
  147. # We explicitly fill the tensor otherwise they might have non-zero values in them.
  148. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F16, 10)
  149. na = ggml.to_numpy(a)
  150. ggml.ggml_set_f32(a, 2.14)
  151. assert np.allclose(na, np.ones((10,), dtype=np.float16) * 2.14)
  152. ggml.ggml_set_f32(a, 4.28)
  153. assert np.allclose(na, np.ones((10,), dtype=np.float16) * 4.28)
  154. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  155. nb = ggml.to_numpy(b)
  156. ggml.ggml_set_f32(b, 4.18)
  157. assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 4.18)
  158. ggml.ggml_set_f32(b, 5.12)
  159. assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 5.12)
  160. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F16, 12, 22, 32)
  161. nc = ggml.to_numpy(c)
  162. ggml.ggml_set_f32(c, 3.16)
  163. assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 3.16)
  164. ggml.ggml_set_f32(c, 5.08)
  165. assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 5.08)
  166. def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
  167. a = np.random.normal(size=(10,)).astype(dtype=np.float16)
  168. ga = ggml.from_numpy(ctx, a)
  169. assert np.allclose(a, ggml.to_numpy(ga))
  170. a = np.random.normal(size=(11, 21)).astype(dtype=np.float16)
  171. ga = ggml.from_numpy(ctx, a)
  172. assert np.allclose(a, ggml.to_numpy(ga))
  173. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float16)
  174. ga = ggml.from_numpy(ctx, a)
  175. assert np.allclose(a, ggml.to_numpy(ga))
  176. def test_to_numpy_works_with_transposed(ctx: Ctx) -> None:
  177. ga = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 10, 5)
  178. a = ggml.to_numpy(ga)
  179. a[...] = np.arange(50).reshape(5, 10).astype(dtype=np.float32)
  180. gat = ggml.ggml_transpose(ctx, ga)
  181. gf = ggml.ggml_build_forward(ga)
  182. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  183. at = ggml.to_numpy(gat)
  184. assert np.allclose(a.T, at)
  185. def test_ning_model_load(ctx: Ctx) -> None:
  186. pytest.skip("borken")
  187. model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
  188. print(model, vocab)
  189. example = ggml.from_file(
  190. ctx, UNITY_MODELS / "unity-large/seqs_before_conformer_block.bin", (1024, 137)
  191. )
  192. with ggml.MeasureArena() as arena:
  193. graph = ggml.unity_audio_encoder_graph(model, example)
  194. # TODO: why the extra memory ?
  195. mem_size = ggml.ggml_allocr_alloc_graph(arena, graph) + ggml.GGML_MEM_ALIGN
  196. with ggml.FixedSizeArena(mem_size) as allocr:
  197. print(
  198. f"unity_audio_encoder_graph: compute buffer size: {mem_size/1024/1024} MB"
  199. )
  200. eval_res_ptr = ggml.unity_eval(allocr, model, example, 1)
  201. eval_res = eval_res_ptr.contents
  202. inpL = ggml.to_numpy(eval_res.nodes[eval_res.n_nodes - 1])
  203. expected_raw = "-0.1308,0.0346,-0.2656,0.2873,-0.0104,0.0574,0.4033,-0.1125,-0.0460,-0.0496"
  204. expected = map(float, expected_raw.split(","))
  205. assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
  206. @pytest.fixture(scope="module")
  207. def g_model_once() -> Iterator[c_void_p]:
  208. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  209. if not model_file.exists():
  210. convert_model("seamlessM4T_medium", model_file)
  211. with ggml.load_unity_ggml_file(model_file) as model:
  212. yield model
  213. @pytest.fixture()
  214. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  215. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  216. return g_model_once
  217. @pytest.fixture(scope="module")
  218. def pt_model() -> Iterator[Any]:
  219. model = load_unity_model("seamlessM4T_medium")
  220. print(model)
  221. model.eval()
  222. with torch.inference_mode():
  223. yield model
  224. @pytest.mark.xfail(reason="TODO")
  225. def test_hparams_code_is_up_to_date() -> None:
  226. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  227. hparams_header_file = model_file.with_suffix(".hparams.h")
  228. hparams_struct = hparams_header_file.read_text().strip()
  229. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  230. assert hparams_struct in actual_code
  231. def test_numpy_mul_mat(ctx: Ctx) -> None:
  232. slen, d_in, d_out = (5, 4, 2)
  233. # torch.nn and fairseq2.nn assumes (seq_len, dim) to represent inputs,
  234. x = np.zeros((slen, d_in), dtype=np.float32) # (seq_len, dim_in)
  235. x[0, :] = [1, 1 / 3, 0, 0]
  236. weight = np.eye(d_out, d_in, dtype=np.float32)
  237. weight[1, 1] = 1
  238. # assert weight.shape == (d_out, d_in) # (dim_out, dim_in)
  239. y_exp = x @ weight.T # (seq_len, dim_out)
  240. gx = ggml.from_numpy(ctx, x) # (dim_in, seq_len)
  241. gw = ggml.from_numpy(ctx, weight) # (dim_in, dim_out)
  242. # gb = ggml.from_numpy(ctx, linear.bias.numpy()) # (dim_out)
  243. # GGML linear impl
  244. assert ggml.ggml_can_mul_mat(gw, gx)
  245. # gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb) # (dim_out, seq_len)
  246. gy = ggml.ggml_mul_mat(ctx, gw, gx) # (dim_out, seq_len)
  247. gf = ggml.ggml_build_forward(gy)
  248. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  249. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  250. assert np.allclose(y_exp, y)
  251. @torch.no_grad()
  252. def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
  253. slen, d_in, num_heads = (5, 4, 2)
  254. torch.random.manual_seed(0)
  255. q = torch.zeros((num_heads, slen, d_in))
  256. torch.nn.init.uniform_(q, -1, 1)
  257. k = torch.zeros((num_heads, slen, d_in))
  258. torch.nn.init.uniform_(k, -1, 1)
  259. v = torch.zeros((num_heads, slen, d_in))
  260. torch.nn.init.uniform_(v, -1, 1)
  261. # Note: we are using x for both keys and queries, so every position
  262. # attends mostly to itself, hence y_exp looks a bit like arange(slen)
  263. y_exp = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
  264. y_exp = y_exp.numpy()
  265. gq = ggml.from_numpy(ctx, q.numpy())
  266. gk = ggml.from_numpy(ctx, k.numpy())
  267. # ggml flash attention expect a different order of axis for v:
  268. # (H, slen, H_dim) -> (H, H_dim, slen)
  269. gv = ggml.from_numpy(ctx, v.transpose(1, 2).contiguous().numpy())
  270. assert ggml.shape(gv) == (num_heads, d_in, slen)
  271. gy = ggml.ggml_flash_attn(ctx, gq, gk, gv, True)
  272. gf = ggml.ggml_build_forward(gy)
  273. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  274. y = ggml.to_numpy(gy)
  275. assert np.allclose(y_exp, y)
  276. def test_ggml_softmax_vs_torch(ctx: Ctx) -> None:
  277. x = torch.empty((5, 8, 4))
  278. torch.nn.init.uniform_(x, -1, 1)
  279. y_exp = torch.softmax(x, dim=-1).numpy()
  280. gx = ggml.from_numpy(ctx, x.numpy())
  281. gy = ggml.ggml_soft_max(ctx, gx)
  282. y = ggml.to_numpy(gy)
  283. gf = ggml.ggml_build_forward(gy)
  284. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  285. assert np.allclose(y_exp, y, rtol=1e-3)
  286. def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  287. x = torch.empty((21, 1024)) # (seq_len, model_dim)
  288. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  289. # Test FFN without LayerNorm
  290. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  291. gx = ggml.from_numpy(ctx, x)
  292. gy = ggml.forward(
  293. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  294. )
  295. gf = ggml.ggml_build_forward(gy)
  296. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  297. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  298. assert np.allclose(y_exp, y, atol=1e-6)
  299. def test_forward_layer_norm(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  300. x = torch.empty((21, 1024))
  301. torch.nn.init.uniform_(x, -1, 1)
  302. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  303. gx = ggml.from_numpy(ctx, x)
  304. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  305. gf = ggml.ggml_build_forward(gy)
  306. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  307. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  308. assert np.allclose(y_exp, y, rtol=1e-3, atol=1e-4)
  309. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  310. try:
  311. return tensor.contents.name # type: ignore[no-any-return]
  312. except ValueError:
  313. return b"???"
  314. def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  315. x = torch.empty((1, 21, 1024))
  316. torch.random.manual_seed(0)
  317. torch.nn.init.uniform_(x, -1, 1)
  318. self_attn = pt_model.text_encoder.layers[0].self_attn
  319. # Note: we use different lengths for queries and keys,
  320. # this tests the implementation in decoding context too.
  321. # Note2: ggml_flash_attn requires that we have more keys than queries
  322. gxq = ggml.from_numpy(ctx, x[0, :11, :])
  323. gx = ggml.from_numpy(ctx, x[0])
  324. ggml.ggml_set_name(gx, b"x")
  325. gy = ggml.forward(
  326. "MultiheadAttention",
  327. g_model,
  328. "text_encoder.layers.0.self_attn",
  329. gxq,
  330. gx,
  331. gx,
  332. None, # TODO: tests with causal attention masks
  333. )
  334. gf = ggml.ggml_build_forward(gy)
  335. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  336. # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
  337. y = ggml.to_numpy(gy)
  338. nodes = {}
  339. for i in range(gf.n_nodes):
  340. name = _name(gf.nodes[i])
  341. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  342. print(name, f"op({gf.nodes[i].contents.op})", children)
  343. nodes[name] = ggml.to_numpy(gf.nodes[i])
  344. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  345. self_attn.register_attn_weight_hook(attn_weights_hook)
  346. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  347. y_exp = y_exp.squeeze(0) # remove batch dimension
  348. # q = nodes[b"q"]
  349. # assert q.shape == q_exp.shape
  350. # assert np.allclose(q_exp, q, atol=1e-5)
  351. attn_exp, attn_weights_exp = map(
  352. lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0]
  353. )
  354. # with flash_attn we don't have attn_weights
  355. flash_attn = b"attn_weights" not in nodes
  356. if not flash_attn:
  357. attn_weights = nodes[b"attn_weights"]
  358. assert attn_weights_exp.shape == attn_weights.shape
  359. # GGML is very agressively reducing small softmax weights to 0.
  360. # Not sure to what this is due.
  361. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  362. attn_exp = attn_exp.transpose(0, 2, 1)
  363. attn = nodes[b"attn"]
  364. assert attn_exp.shape == attn.shape
  365. # Because of rounding errors in softmax, it's even worse here.
  366. # flash attention have a better numerical precision though.
  367. assert np.allclose(attn_exp, attn, atol=1e-4 if flash_attn else 1e-2)
  368. assert y.shape == y_exp.shape
  369. assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
  370. def test_StandardTransformerEncoderLayer_forward(
  371. ctx: Ctx, g_model: c_void_p, pt_model: Any
  372. ) -> None:
  373. x = torch.empty((1, 21, 1024))
  374. padding_mask = torch.ones((1, 21))
  375. torch.random.manual_seed(0)
  376. torch.nn.init.uniform_(x, -1, 1)
  377. layer = pt_model.text_encoder.layers[0]
  378. gx = ggml.from_numpy(ctx, x[0])
  379. ggml.ggml_set_name(gx, b"x")
  380. gpad = ggml.from_numpy(ctx, padding_mask[0])
  381. ggml.ggml_set_name(gpad, b"padding_mask")
  382. gy = ggml.forward(
  383. "StandardTransformerEncoderLayer",
  384. g_model,
  385. "text_encoder.layers.0",
  386. gx,
  387. None, # TODO support padding mask
  388. )
  389. gf = ggml.ggml_build_forward(gy)
  390. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  391. y = ggml.to_numpy(gy)
  392. y_exp, _ = layer(x, padding_mask)
  393. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  394. assert y.shape == y_exp.shape
  395. assert np.allclose(y_exp, y, atol=1e-4)
  396. def test_StandardTransformerEncoder_forward(
  397. ctx: Ctx, g_model: c_void_p, pt_model: Any
  398. ) -> None:
  399. x = torch.empty((1, 21, 1024))
  400. padding_mask = torch.ones((1, 21))
  401. torch.random.manual_seed(0)
  402. torch.nn.init.uniform_(x, -1, 1)
  403. gx = ggml.from_numpy(ctx, x[0])
  404. ggml.ggml_set_name(gx, b"x")
  405. gpad = ggml.from_numpy(ctx, padding_mask[0])
  406. ggml.ggml_set_name(gpad, b"padding_mask")
  407. gy = ggml.forward(
  408. "StandardTransformerEncoder",
  409. g_model,
  410. "text_encoder",
  411. gx,
  412. None, # TODO support padding mask
  413. )
  414. gf = ggml.ggml_build_forward(gy)
  415. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  416. y = ggml.to_numpy(gy)
  417. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  418. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  419. assert y.shape == y_exp.shape
  420. assert np.allclose(y_exp, y, atol=1e-4)
  421. def test_causal_attention_mask(ctx: Ctx):
  422. x = torch.zeros((5, 10))
  423. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  424. mask_exp = generator(x)
  425. gx = ggml.from_numpy(ctx, x)
  426. gmask = ggml.causal_attention_mask(ctx, gx)
  427. mask = ggml.to_numpy(gmask)
  428. assert mask_exp.shape == (10, 10)
  429. assert mask.shape == (10, 10)
  430. assert np.allclose(mask, mask_exp)
  431. def test_StandardTransformerDecoder_forward(
  432. ctx: Ctx, g_model: c_void_p, pt_model: Any
  433. ) -> None:
  434. x = torch.empty((1, 13, 1024))
  435. encoder_out = torch.empty((1, 21, 1024))
  436. padding_mask = torch.ones((1, 13))
  437. torch.random.manual_seed(0)
  438. torch.nn.init.uniform_(x, -1, 1)
  439. torch.nn.init.uniform_(encoder_out, -1, 1)
  440. gx = ggml.from_numpy(ctx, x[0])
  441. ggml.ggml_set_name(gx, b"x")
  442. gpad = ggml.from_numpy(ctx, padding_mask[0])
  443. ggml.ggml_set_name(gpad, b"padding_mask")
  444. genc = ggml.from_numpy(ctx, encoder_out[0])
  445. gy = ggml.forward(
  446. "StandardTransformerDecoder",
  447. g_model,
  448. "text_decoder",
  449. gx,
  450. None, # TODO support padding mask,
  451. genc,
  452. None,
  453. )
  454. gf = ggml.ggml_build_forward(gy)
  455. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  456. y = ggml.to_numpy(gy)
  457. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  458. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  459. assert y.shape == y_exp.shape
  460. assert np.allclose(y_exp, y, atol=1e-4)