test_unity_cpp.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. PARAMS_256MB = ggml.ggml_init_params(mem_size=256 * 1024 * 1024, mem_buffer=None)
  23. @pytest.fixture(name="ctx")
  24. def _ctx() -> Iterator[Ctx]:
  25. """Allocate a new context with 256 MB of memory"""
  26. try:
  27. ctx = ggml.ggml_init(params=PARAMS_256MB)
  28. yield ctx
  29. finally:
  30. ggml.ggml_free(ctx)
  31. def test_ggml_bindings_work(ctx: Ctx) -> None:
  32. # Instantiate tensors
  33. x = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  34. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  35. b = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  36. # Use ggml operations to build a computational graph
  37. x2 = ggml.ggml_mul(ctx, x, x)
  38. f = ggml.ggml_add(ctx, ggml.ggml_mul(ctx, a, x2), b)
  39. gf = ggml.ggml_build_forward(f)
  40. # Set the input values
  41. ggml.ggml_set_f32(x, 2.0)
  42. ggml.ggml_set_f32(a, 3.0)
  43. ggml.ggml_set_f32(b, 4.0)
  44. # Compute the graph
  45. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  46. # Get the output value
  47. output = ggml.ggml_get_f32_1d(f, 0)
  48. assert output == 16.0
  49. def test_ggml_matmul(ctx: Ctx) -> None:
  50. # Instantiate tensors
  51. a = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 2)
  52. x = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 3)
  53. # Use ggml operations to build a computational graph
  54. y = ggml.ggml_mul_mat(ctx, a, x)
  55. assert ggml.shape(y) == (3, 2)
  56. gf = ggml.ggml_build_forward(y)
  57. # Set the input values
  58. ggml.ggml_set_f32(x, 0.0)
  59. for i in range(4 * 3):
  60. ggml.ggml_set_f32_1d(x, i, i)
  61. ggml.ggml_set_f32(a, 0.0)
  62. ggml.ggml_set_f32_1d(a, 1, 1.0)
  63. ggml.ggml_set_f32_1d(a, 7, 1.0)
  64. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  65. output = [[ggml.ggml_get_f32_1d(y, j * 2 + i) for j in range(3)] for i in range(2)]
  66. assert output == [[1, 5, 9], [3, 7, 11]]
  67. def test_shape_works(ctx: Ctx) -> None:
  68. """GGML shape order convention is the reverse from numpy"""
  69. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  70. assert ggml.shape(a) == (10,)
  71. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  72. assert ggml.shape(b) == (21, 11)
  73. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  74. assert ggml.shape(c) == (32, 22, 12)
  75. def test_nb_works(ctx: Ctx) -> None:
  76. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  77. assert ggml.nb(a) == (4, 40, 40, 40)
  78. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  79. assert ggml.nb(b) == (2, 22, 462, 462)
  80. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  81. assert ggml.nb(c) == (4, 48, 1056, 33792)
  82. @pytest.mark.xfail(reason="TODO: fix strides")
  83. def test_strides_works(ctx: Ctx) -> None:
  84. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  85. assert ggml.strides(a) == np.ones((10,), dtype=np.float32).strides
  86. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  87. assert ggml.strides(b) == np.ones((11, 21), dtype=np.float32).strides
  88. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  89. assert ggml.strides(c) == np.ones((12, 22, 32), dtype=np.float32).strides
  90. def test_to_numpy_works_with_f32(ctx: Ctx) -> None:
  91. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  92. na = ggml.to_numpy(a)
  93. for i in range(10):
  94. ggml.ggml_set_f32_1d(a, i, i)
  95. assert na[5] == 5
  96. assert np.allclose(na, np.array(range(10), dtype=np.float32))
  97. ggml.ggml_set_f32_1d(a, 5, -1.5)
  98. assert na[5] == -1.5
  99. # Note: GGML order of dims is reversed wrt numpy shapes
  100. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  101. for i in range(11 * 21):
  102. ggml.ggml_set_f32_1d(b, i, i)
  103. nb = ggml.to_numpy(b)
  104. # assert nb.shape == (21, 11)
  105. assert nb[0, 5] == 5
  106. assert nb[3, 5] == 11 * 3 + 5
  107. assert np.allclose(
  108. nb, np.array(range(11 * 21), dtype=np.float32).reshape(ggml.shape(b))
  109. )
  110. ggml.ggml_set_f32_1d(b, 11 * 3 + 5, -1.5)
  111. assert nb[3, 5] == -1.5
  112. sum_rows = ggml.ggml_sum_rows(ctx, b)
  113. gf = ggml.ggml_build_forward(sum_rows)
  114. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  115. np_sum_rows = np.sum(nb, axis=-1, keepdims=True)
  116. assert np_sum_rows.shape == ggml.shape(sum_rows)
  117. for i in range(11):
  118. assert np_sum_rows[i] == ggml.ggml_get_f32_1d(sum_rows, i)
  119. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  120. for i in range(12 * 22 * 32):
  121. ggml.ggml_set_f32_1d(c, i, i)
  122. nc = ggml.to_numpy(c)
  123. assert ggml.shape(c) == (32, 22, 12)
  124. assert nc[3, 5, 11] == 22 * 12 * 3 + 12 * 5 + 11
  125. assert np.allclose(
  126. nc, np.array(range(12 * 22 * 32), dtype=np.float32).reshape(ggml.shape(c))
  127. )
  128. ggml.ggml_set_f32_1d(c, 22 * 12 * 3 + 12 * 5 + 11, -1.5)
  129. assert nc[3, 5, 11] == -1.5
  130. def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
  131. a = np.random.normal(size=(10,)).astype(dtype=np.float32)
  132. ga = ggml.from_numpy(ctx, a)
  133. assert ggml.shape(ga) == (10,)
  134. assert ggml.nb(ga) == ggml.nb(ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10))
  135. assert np.allclose(a, ggml.to_numpy(ga))
  136. a = np.random.normal(size=(11, 21)).astype(dtype=np.float32)
  137. ga = ggml.from_numpy(ctx, a)
  138. assert ggml.shape(ga) == (11, 21)
  139. assert ggml.nb(ga) == ggml.nb(
  140. ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
  141. )
  142. assert np.allclose(a, ggml.to_numpy(ga))
  143. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float32)
  144. ga = ggml.from_numpy(ctx, a)
  145. assert ggml.shape(ga) == (12, 22, 32)
  146. assert ggml.nb(ga) == ggml.nb(
  147. ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
  148. )
  149. assert np.allclose(a, ggml.to_numpy(ga))
  150. def test_to_numpy_works_with_f16(ctx: Ctx) -> None:
  151. # We explicitly fill the tensor otherwise they might have non-zero values in them.
  152. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F16, 10)
  153. na = ggml.to_numpy(a)
  154. ggml.ggml_set_f32(a, 2.14)
  155. assert np.allclose(na, np.ones((10,), dtype=np.float16) * 2.14)
  156. ggml.ggml_set_f32(a, 4.28)
  157. assert np.allclose(na, np.ones((10,), dtype=np.float16) * 4.28)
  158. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  159. nb = ggml.to_numpy(b)
  160. ggml.ggml_set_f32(b, 4.18)
  161. assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 4.18)
  162. ggml.ggml_set_f32(b, 5.12)
  163. assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 5.12)
  164. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F16, 12, 22, 32)
  165. nc = ggml.to_numpy(c)
  166. ggml.ggml_set_f32(c, 3.16)
  167. assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 3.16)
  168. ggml.ggml_set_f32(c, 5.08)
  169. assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 5.08)
  170. def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
  171. a = np.random.normal(size=(10,)).astype(dtype=np.float16)
  172. ga = ggml.from_numpy(ctx, a)
  173. assert np.allclose(a, ggml.to_numpy(ga))
  174. a = np.random.normal(size=(11, 21)).astype(dtype=np.float16)
  175. ga = ggml.from_numpy(ctx, a)
  176. assert np.allclose(a, ggml.to_numpy(ga))
  177. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float16)
  178. ga = ggml.from_numpy(ctx, a)
  179. assert np.allclose(a, ggml.to_numpy(ga))
  180. def test_to_numpy_works_with_transposed(ctx: Ctx) -> None:
  181. ga = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 10, 5)
  182. a = ggml.to_numpy(ga)
  183. a[...] = np.arange(50).reshape(5, 10).astype(dtype=np.float32)
  184. gat = ggml.ggml_transpose(ctx, ga)
  185. gf = ggml.ggml_build_forward(ga)
  186. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  187. at = ggml.to_numpy(gat)
  188. assert np.allclose(a.T, at)
  189. def test_ning_model_load(ctx: Ctx) -> None:
  190. pytest.skip("borken")
  191. model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
  192. print(model, vocab)
  193. example = ggml.from_file(
  194. ctx, UNITY_MODELS / "unity-large/seqs_before_conformer_block.bin", (1024, 137)
  195. )
  196. with ggml.MeasureArena() as arena:
  197. graph = ggml.unity_audio_encoder_graph(model, example)
  198. # TODO: why the extra memory ?
  199. mem_size = ggml.ggml_allocr_alloc_graph(arena, graph) + ggml.GGML_MEM_ALIGN
  200. with ggml.FixedSizeArena(mem_size) as allocr:
  201. print(
  202. f"unity_audio_encoder_graph: compute buffer size: {mem_size/1024/1024} MB"
  203. )
  204. eval_res_ptr = ggml.unity_eval(allocr, model, example, 1)
  205. eval_res = eval_res_ptr.contents
  206. inpL = ggml.to_numpy(eval_res.nodes[eval_res.n_nodes - 1])
  207. expected_raw = "-0.1308,0.0346,-0.2656,0.2873,-0.0104,0.0574,0.4033,-0.1125,-0.0460,-0.0496"
  208. expected = map(float, expected_raw.split(","))
  209. assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
  210. @pytest.fixture(scope="module")
  211. def g_model_once() -> Iterator[c_void_p]:
  212. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  213. if not model_file.exists():
  214. convert_model("seamlessM4T_medium", model_file)
  215. with ggml.load_unity_ggml_file(model_file) as model:
  216. yield model
  217. @pytest.fixture()
  218. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  219. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  220. return g_model_once
  221. @pytest.fixture(scope="module")
  222. def translator() -> Iterator[Any]:
  223. tr = Translator(
  224. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  225. )
  226. with torch.inference_mode():
  227. yield tr
  228. @pytest.fixture(scope="module")
  229. def pt_model(translator: Translator) -> Any:
  230. model = translator.model
  231. print(model)
  232. return model
  233. @pytest.mark.xfail(reason="TODO")
  234. def test_hparams_code_is_up_to_date() -> None:
  235. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  236. hparams_header_file = model_file.with_suffix(".hparams.h")
  237. hparams_struct = hparams_header_file.read_text().strip()
  238. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  239. assert hparams_struct in actual_code
  240. def test_numpy_mul_mat(ctx: Ctx) -> None:
  241. slen, d_in, d_out = (5, 4, 2)
  242. # torch.nn and fairseq2.nn assumes (seq_len, dim) to represent inputs,
  243. x = np.zeros((slen, d_in), dtype=np.float32) # (seq_len, dim_in)
  244. x[0, :] = [1, 1 / 3, 0, 0]
  245. weight = np.eye(d_out, d_in, dtype=np.float32)
  246. weight[1, 1] = 1
  247. # assert weight.shape == (d_out, d_in) # (dim_out, dim_in)
  248. y_exp = x @ weight.T # (seq_len, dim_out)
  249. gx = ggml.from_numpy(ctx, x) # (dim_in, seq_len)
  250. gw = ggml.from_numpy(ctx, weight) # (dim_in, dim_out)
  251. # gb = ggml.from_numpy(ctx, linear.bias.numpy()) # (dim_out)
  252. # GGML linear impl
  253. assert ggml.ggml_can_mul_mat(gw, gx)
  254. # gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb) # (dim_out, seq_len)
  255. gy = ggml.ggml_mul_mat(ctx, gw, gx) # (dim_out, seq_len)
  256. gf = ggml.ggml_build_forward(gy)
  257. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  258. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  259. assert np.allclose(y_exp, y)
  260. @torch.no_grad()
  261. def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
  262. slen, d_in, num_heads = (5, 4, 2)
  263. torch.random.manual_seed(0)
  264. q = torch.zeros((num_heads, slen, d_in))
  265. torch.nn.init.uniform_(q, -1, 1)
  266. k = torch.zeros((num_heads, slen, d_in))
  267. torch.nn.init.uniform_(k, -1, 1)
  268. v = torch.zeros((num_heads, slen, d_in))
  269. torch.nn.init.uniform_(v, -1, 1)
  270. # Note: we are using x for both keys and queries, so every position
  271. # attends mostly to itself, hence y_exp looks a bit like arange(slen)
  272. y_exp = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
  273. y_exp = y_exp.numpy()
  274. gq = ggml.from_numpy(ctx, q.numpy())
  275. gk = ggml.from_numpy(ctx, k.numpy())
  276. # ggml flash attention expect a different order of axis for v:
  277. # (H, slen, H_dim) -> (H, H_dim, slen)
  278. gv = ggml.from_numpy(ctx, v.transpose(1, 2).contiguous().numpy())
  279. assert ggml.shape(gv) == (num_heads, d_in, slen)
  280. gy = ggml.ggml_flash_attn(ctx, gq, gk, gv, True)
  281. gf = ggml.ggml_build_forward(gy)
  282. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  283. y = ggml.to_numpy(gy)
  284. assert np.allclose(y_exp, y)
  285. def test_ggml_softmax_vs_torch(ctx: Ctx) -> None:
  286. x = torch.empty((5, 8, 4))
  287. torch.nn.init.uniform_(x, -1, 1)
  288. y_exp = torch.softmax(x, dim=-1).numpy()
  289. gx = ggml.from_numpy(ctx, x.numpy())
  290. gy = ggml.ggml_soft_max(ctx, gx)
  291. y = ggml.to_numpy(gy)
  292. gf = ggml.ggml_build_forward(gy)
  293. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  294. assert np.allclose(y_exp, y, rtol=1e-3)
  295. def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  296. x = torch.empty((21, 1024)) # (seq_len, model_dim)
  297. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  298. # Test FFN without LayerNorm
  299. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  300. gx = ggml.from_numpy(ctx, x)
  301. gy = ggml.forward(
  302. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  303. )
  304. gf = ggml.ggml_build_forward(gy)
  305. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  306. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  307. assert np.allclose(y_exp, y, atol=1e-6)
  308. def test_forward_layer_norm(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  309. x = torch.empty((21, 1024))
  310. torch.nn.init.uniform_(x, -1, 1)
  311. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  312. gx = ggml.from_numpy(ctx, x)
  313. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  314. gf = ggml.ggml_build_forward(gy)
  315. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  316. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  317. assert np.allclose(y_exp, y, rtol=1e-3, atol=1e-4)
  318. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  319. try:
  320. return tensor.contents.name # type: ignore[no-any-return]
  321. except ValueError:
  322. return b"???"
  323. def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  324. x = torch.empty((1, 21, 1024))
  325. torch.random.manual_seed(0)
  326. torch.nn.init.uniform_(x, -1, 1)
  327. self_attn = pt_model.text_encoder.layers[0].self_attn
  328. # Note: we use different lengths for queries and keys,
  329. # this tests the implementation in decoding context too.
  330. # Note2: ggml_flash_attn requires that we have more keys than queries
  331. gxq = ggml.from_numpy(ctx, x[0, :11, :])
  332. gx = ggml.from_numpy(ctx, x[0])
  333. ggml.ggml_set_name(gx, b"x")
  334. gy = ggml.forward(
  335. "MultiheadAttention",
  336. g_model,
  337. "text_encoder.layers.0.self_attn",
  338. gxq,
  339. gx,
  340. gx,
  341. None, # TODO: tests with causal attention masks
  342. )
  343. gf = ggml.ggml_build_forward(gy)
  344. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  345. # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
  346. y = ggml.to_numpy(gy)
  347. nodes = {}
  348. for i in range(gf.n_nodes):
  349. name = _name(gf.nodes[i])
  350. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  351. print(name, f"op({gf.nodes[i].contents.op})", children)
  352. nodes[name] = ggml.to_numpy(gf.nodes[i])
  353. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  354. self_attn.register_attn_weight_hook(attn_weights_hook)
  355. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  356. y_exp = y_exp.squeeze(0) # remove batch dimension
  357. # q = nodes[b"q"]
  358. # assert q.shape == q_exp.shape
  359. # assert np.allclose(q_exp, q, atol=1e-5)
  360. attn_exp, attn_weights_exp = map(
  361. lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0]
  362. )
  363. # with flash_attn we don't have attn_weights
  364. flash_attn = b"attn_weights" not in nodes
  365. if not flash_attn:
  366. attn_weights = nodes[b"attn_weights"]
  367. assert attn_weights_exp.shape == attn_weights.shape
  368. # GGML is very agressively reducing small softmax weights to 0.
  369. # Not sure to what this is due.
  370. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  371. attn_exp = attn_exp.transpose(0, 2, 1)
  372. attn = nodes[b"attn"]
  373. assert attn_exp.shape == attn.shape
  374. # Because of rounding errors in softmax, it's even worse here.
  375. # flash attention have a better numerical precision though.
  376. assert np.allclose(attn_exp, attn, atol=1e-4 if flash_attn else 1e-2)
  377. assert y.shape == y_exp.shape
  378. assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
  379. def test_StandardTransformerEncoderLayer_forward(
  380. ctx: Ctx, g_model: c_void_p, pt_model: Any
  381. ) -> None:
  382. x = torch.empty((1, 21, 1024))
  383. padding_mask = torch.ones((1, 21))
  384. torch.random.manual_seed(0)
  385. torch.nn.init.uniform_(x, -1, 1)
  386. layer = pt_model.text_encoder.layers[0]
  387. gx = ggml.from_numpy(ctx, x[0])
  388. ggml.ggml_set_name(gx, b"x")
  389. gpad = ggml.from_numpy(ctx, padding_mask[0])
  390. ggml.ggml_set_name(gpad, b"padding_mask")
  391. gy = ggml.forward(
  392. "StandardTransformerEncoderLayer",
  393. g_model,
  394. "text_encoder.layers.0",
  395. gx,
  396. None, # TODO support padding mask
  397. )
  398. gf = ggml.ggml_build_forward(gy)
  399. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  400. y = ggml.to_numpy(gy)
  401. y_exp, _ = layer(x, padding_mask)
  402. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  403. assert y.shape == y_exp.shape
  404. assert np.allclose(y_exp, y, atol=1e-4)
  405. def test_StandardTransformerEncoder_forward(
  406. ctx: Ctx, g_model: c_void_p, pt_model: Any
  407. ) -> None:
  408. x = torch.empty((1, 21, 1024))
  409. padding_mask = torch.ones((1, 21))
  410. torch.random.manual_seed(0)
  411. torch.nn.init.uniform_(x, -1, 1)
  412. gx = ggml.from_numpy(ctx, x[0])
  413. ggml.ggml_set_name(gx, b"x")
  414. gpad = ggml.from_numpy(ctx, padding_mask[0])
  415. ggml.ggml_set_name(gpad, b"padding_mask")
  416. gy = ggml.forward(
  417. "StandardTransformerEncoder",
  418. g_model,
  419. "text_encoder",
  420. gx,
  421. None, # TODO support padding mask
  422. )
  423. gf = ggml.ggml_build_forward(gy)
  424. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  425. y = ggml.to_numpy(gy)
  426. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  427. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  428. assert y.shape == y_exp.shape
  429. assert np.allclose(y_exp, y, atol=1e-4)
  430. def test_causal_attention_mask(ctx: Ctx):
  431. x = torch.zeros((5, 10))
  432. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  433. mask_exp = generator(x)
  434. gx = ggml.from_numpy(ctx, x)
  435. gmask = ggml.causal_attention_mask(ctx, gx)
  436. mask = ggml.to_numpy(gmask)
  437. assert mask_exp.shape == (10, 10)
  438. assert mask.shape == (10, 10)
  439. assert np.allclose(mask, mask_exp)
  440. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  441. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  442. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  443. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  444. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  445. y_exp = pos_encoder(seq, None)[0].numpy()
  446. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  447. ggml.ggml_set_name(gseq, b"seq")
  448. gy = ggml.forward(
  449. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  450. )
  451. gf = ggml.ggml_build_forward(gy)
  452. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  453. y = ggml.to_numpy(gy)
  454. assert y.shape == y_exp.shape
  455. assert np.allclose(y_exp, y, atol=1e-6)
  456. def test_TransformerEmbeddingFrontend_forward(
  457. ctx: Ctx, g_model: c_void_p, pt_model: Any
  458. ) -> None:
  459. seq = torch.arange(20).reshape(1, 20)
  460. seq_len = torch.tensor([20])
  461. gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
  462. ggml.ggml_set_name(gseq, b"seq")
  463. gy = ggml.forward(
  464. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  465. )
  466. gf = ggml.ggml_build_forward(gy)
  467. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  468. y = ggml.to_numpy(gy)
  469. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  470. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  471. assert y.shape == y_exp.shape
  472. assert np.allclose(y_exp, y, atol=1e-6)
  473. def test_StandardTransformerDecoder_forward(
  474. ctx: Ctx, g_model: c_void_p, pt_model: Any
  475. ) -> None:
  476. x = torch.empty((1, 13, 1024))
  477. encoder_out = torch.empty((1, 21, 1024))
  478. padding_mask = torch.ones((1, 13))
  479. torch.random.manual_seed(0)
  480. torch.nn.init.uniform_(x, -1, 1)
  481. torch.nn.init.uniform_(encoder_out, -1, 1)
  482. gx = ggml.from_numpy(ctx, x[0])
  483. ggml.ggml_set_name(gx, b"x")
  484. gpad = ggml.from_numpy(ctx, padding_mask[0])
  485. ggml.ggml_set_name(gpad, b"padding_mask")
  486. genc = ggml.from_numpy(ctx, encoder_out[0])
  487. gy = ggml.forward(
  488. "StandardTransformerDecoder",
  489. g_model,
  490. "text_decoder",
  491. gx,
  492. None, # TODO support padding mask,
  493. genc,
  494. None,
  495. )
  496. gf = ggml.ggml_build_forward(gy)
  497. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  498. y = ggml.to_numpy(gy)
  499. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  500. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  501. assert y.shape == y_exp.shape
  502. assert np.allclose(y_exp, y, atol=1e-4)
  503. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  504. # device = translator.device
  505. src_lang = "eng"
  506. src_text = "We are all in a yellow submarine."
  507. tgt_lang = "fra"
  508. # token_encoder = translator.text_tokenizer.create_encoder(
  509. # task="translation", lang=src_lang, mode="source", device=device
  510. # )
  511. # src = translator.collate(token_encoder(src_text))
  512. # text_out, _ = translator.get_prediction(
  513. # translator.model,
  514. # translator.text_tokenizer,
  515. # translator.unit_tokenizer,
  516. # src,
  517. # input_modality=Modality.TEXT,
  518. # output_modality=Modality.TEXT,
  519. # tgt_lang=tgt_lang,
  520. # )
  521. # tgt_text = str(text_out.sentences[0])
  522. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  523. # tgt_tokens = text_out.generator_output.results[0][0].seq
  524. # score = text_out.generator_output.results[0][0].score.item()
  525. # np.savez(
  526. # Path(__file__).parent / "sample_input.npz",
  527. # score=score,
  528. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  529. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  530. # tgt_tokens=tgt_tokens.numpy(),
  531. # )
  532. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  533. score = text_out["score"].item()
  534. tgt_tokens = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32))
  535. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  536. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  537. job = ggml.SequenceGeneratorJob()
  538. job.opts.beam_size = 1
  539. job.opts.min_seq_len = 1
  540. job.opts.soft_max_seq_len_a = 1
  541. job.opts.soft_max_seq_len_b = 200
  542. job.opts.hard_max_seq_len = 1024
  543. job.opts.len_penalty = 1.0
  544. job.opts.unk_penalty = 0.0
  545. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:1])
  546. job.eos_idx = 3
  547. result = ctypes.byref(ggml.ggml_tensor())
  548. g_score = ggml.generate_sequence(
  549. g_model, job, encoder_out, encoder_padding_mask, result
  550. )
  551. breakpoint()
  552. assert g_score == pytest.approx(score)
  553. def test_in_loop(ctx: Ctx, g_model: c_void_p, pt_model: Any):
  554. resources = locals()
  555. import importlib
  556. import time
  557. testcase = test_TransformerEmbeddingFrontend_forward.__name__
  558. name, script = __name__, __file__
  559. root = Path(__file__).parent
  560. watched_files = [Path(__file__), root / "ggml.py", root / "build/src/libggml.so"]
  561. last_try = 0.0
  562. while True:
  563. last_save = max(f.stat().st_mtime for f in watched_files)
  564. if last_save <= last_try:
  565. time.sleep(0.1)
  566. continue
  567. last_try = last_save
  568. spec = importlib.util.spec_from_file_location(name, script)
  569. module = importlib.util.module_from_spec(spec)
  570. spec.loader.exec_module(module)
  571. sys.modules[name] = module
  572. f = getattr(module, testcase)
  573. f_args = [k for k in f.__annotations__ if k != "return"]
  574. try:
  575. f(**{k: resources[k] for k in f_args})
  576. print(f"Testcase {testcase} success")
  577. except AssertionError as e:
  578. print(f"Testcase {testcase} failed: {e}")
  579. except Exception as e:
  580. import pdb
  581. logging.exception(f"Testcase {testcase} crashed !")
  582. pdb.post_mortem()