test_unity_cpp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  23. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  24. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  25. @pytest.fixture(name="ctx")
  26. def _ctx() -> Iterator[Ctx]:
  27. """Allocate a new context with 1024 MB of memory"""
  28. try:
  29. ctx = ggml.ggml_init(params=CTX_PARAMS)
  30. yield ctx
  31. finally:
  32. ggml.ggml_free(ctx)
  33. @pytest.fixture(scope="module")
  34. def g_model_once() -> Iterator[c_void_p]:
  35. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  36. if not model_file.exists():
  37. convert_model("seamlessM4T_medium", model_file)
  38. with ggml.load_unity_ggml_file(model_file) as model:
  39. yield model
  40. @pytest.fixture()
  41. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  42. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  43. return g_model_once
  44. @pytest.fixture(scope="module")
  45. def translator() -> Iterator[Any]:
  46. tr = Translator(
  47. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  48. )
  49. with torch.inference_mode():
  50. yield tr
  51. @pytest.fixture(scope="module")
  52. def pt_model(translator: Translator) -> Any:
  53. model = translator.model
  54. print(model)
  55. return model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  64. x = torch.empty((21, 1024)) # (seq_len, model_dim)
  65. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  66. # Test FFN without LayerNorm
  67. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  68. gx = ggml.from_numpy(ctx, x)
  69. gy = ggml.forward(
  70. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  71. )
  72. gf = ggml.ggml_build_forward(gy)
  73. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  74. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  75. assert np.allclose(y_exp, y, atol=1e-5)
  76. def test_forward_layer_norm(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  77. x = torch.empty((21, 1024))
  78. torch.nn.init.uniform_(x, -1, 1)
  79. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  80. gx = ggml.from_numpy(ctx, x)
  81. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  82. gf = ggml.ggml_build_forward(gy)
  83. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  84. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  85. assert np.allclose(y_exp, y, rtol=1e-3, atol=1e-4)
  86. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  87. try:
  88. return tensor.contents.name # type: ignore[no-any-return]
  89. except ValueError:
  90. return b"???"
  91. def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  92. x = torch.empty((1, 21, 1024))
  93. torch.random.manual_seed(0)
  94. torch.nn.init.uniform_(x, -1, 1)
  95. self_attn = pt_model.text_encoder.layers[0].self_attn
  96. # Note: we use different lengths for queries and keys,
  97. # this tests the implementation in decoding context too.
  98. # Note2: ggml_flash_attn requires that we have more keys than queries
  99. gxq = ggml.from_numpy(ctx, x[0, :11, :])
  100. gx = ggml.from_numpy(ctx, x[0])
  101. ggml.ggml_set_name(gx, b"x")
  102. gy = ggml.forward(
  103. "MultiheadAttention",
  104. g_model,
  105. "text_encoder.layers.0.self_attn",
  106. gxq,
  107. gx,
  108. gx,
  109. None, # TODO: tests with causal attention masks
  110. )
  111. gf = ggml.ggml_build_forward(gy)
  112. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  113. # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
  114. y = ggml.to_numpy(gy)
  115. nodes = {}
  116. for i in range(gf.n_nodes):
  117. name = _name(gf.nodes[i])
  118. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  119. print(name, f"op({gf.nodes[i].contents.op})", children)
  120. nodes[name] = ggml.to_numpy(gf.nodes[i])
  121. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  122. self_attn.register_attn_weight_hook(attn_weights_hook)
  123. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  124. y_exp = y_exp.squeeze(0) # remove batch dimension
  125. # q = nodes[b"q"]
  126. # assert q.shape == q_exp.shape
  127. # assert np.allclose(q_exp, q, atol=1e-5)
  128. # with flash_attn we don't have attn_weights
  129. if not UNITY_FLASH_ATTN:
  130. attn_weights = nodes[b"attn_weights"]
  131. [attn_weights_exp] = attn_weights_hook._storage
  132. attn_weights_exp = attn_weights_exp.squeeze(0).numpy()
  133. assert attn_weights_exp.shape == attn_weights.shape
  134. # GGML is very agressively reducing small softmax weights to 0.
  135. # Not sure to what this is due.
  136. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  137. assert y.shape == y_exp.shape
  138. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  139. def test_StandardTransformerEncoderLayer_forward(
  140. ctx: Ctx, g_model: c_void_p, pt_model: Any
  141. ) -> None:
  142. x = torch.empty((1, 21, 1024))
  143. padding_mask = torch.ones((1, 21))
  144. torch.random.manual_seed(0)
  145. torch.nn.init.uniform_(x, -1, 1)
  146. layer = pt_model.text_encoder.layers[0]
  147. gx = ggml.from_numpy(ctx, x[0])
  148. ggml.ggml_set_name(gx, b"x")
  149. gpad = ggml.from_numpy(ctx, padding_mask[0])
  150. ggml.ggml_set_name(gpad, b"padding_mask")
  151. gy = ggml.forward(
  152. "StandardTransformerEncoderLayer",
  153. g_model,
  154. "text_encoder.layers.0",
  155. gx,
  156. None, # TODO support padding mask
  157. )
  158. gf = ggml.ggml_build_forward(gy)
  159. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  160. y = ggml.to_numpy(gy)
  161. y_exp, _ = layer(x, padding_mask)
  162. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  163. assert y.shape == y_exp.shape
  164. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  165. def test_StandardTransformerEncoder_forward(
  166. ctx: Ctx, g_model: c_void_p, pt_model: Any
  167. ) -> None:
  168. x = torch.empty((1, 21, 1024))
  169. padding_mask = torch.ones((1, 21))
  170. torch.random.manual_seed(0)
  171. torch.nn.init.uniform_(x, -1, 1)
  172. gx = ggml.from_numpy(ctx, x[0])
  173. ggml.ggml_set_name(gx, b"x")
  174. gpad = ggml.from_numpy(ctx, padding_mask[0])
  175. ggml.ggml_set_name(gpad, b"padding_mask")
  176. gy = ggml.forward(
  177. "StandardTransformerEncoder",
  178. g_model,
  179. "text_encoder",
  180. gx,
  181. None, # TODO support padding mask
  182. )
  183. gf = ggml.ggml_build_forward(gy)
  184. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  185. y = ggml.to_numpy(gy)
  186. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  187. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  188. assert y.shape == y_exp.shape
  189. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  190. def test_causal_attention_mask(ctx: Ctx):
  191. x = torch.zeros((1, 10, 32))
  192. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  193. mask_exp = generator(x).numpy()
  194. gx = ggml.from_numpy(ctx, x)
  195. gmask = ggml.causal_attention_mask(ctx, gx)
  196. mask = ggml.to_numpy(gmask)
  197. assert mask_exp.shape == (10, 10)
  198. assert mask.shape == (10, 10)
  199. assert np.all(mask == mask_exp)
  200. x = x[:, :8, :]
  201. mask_exp = generator(x).numpy()
  202. gx = ggml.from_numpy(ctx, x)
  203. gmask = ggml.causal_attention_mask(ctx, gx)
  204. mask = ggml.to_numpy(gmask)
  205. assert mask_exp.shape == (8, 8)
  206. assert mask.shape == (8, 8)
  207. assert np.all(mask == mask_exp)
  208. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  209. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  210. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  211. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  212. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  213. y_exp = pos_encoder(seq, None)[0].numpy()
  214. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  215. ggml.ggml_set_name(gseq, b"seq")
  216. gy = ggml.forward(
  217. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  218. )
  219. gf = ggml.ggml_build_forward(gy)
  220. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  221. y = ggml.to_numpy(gy)
  222. assert y.shape == y_exp.shape
  223. assert np.allclose(y_exp, y, atol=1e-6)
  224. def test_TransformerEmbeddingFrontend_forward(
  225. ctx: Ctx, g_model: c_void_p, pt_model: Any
  226. ) -> None:
  227. seq = torch.arange(20).reshape(1, 20)
  228. seq_len = torch.tensor([20])
  229. gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
  230. ggml.ggml_set_name(gseq, b"seq")
  231. gy = ggml.forward(
  232. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  233. )
  234. gf = ggml.ggml_build_forward(gy)
  235. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  236. y = ggml.to_numpy(gy)
  237. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  238. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  239. assert y.shape == y_exp.shape
  240. assert np.allclose(y_exp, y, atol=1e-6)
  241. def test_StandardTransformerDecoder_forward(
  242. ctx: Ctx, g_model: c_void_p, pt_model: Any
  243. ) -> None:
  244. pytest.skip("foo")
  245. x = torch.empty((1, 13, 1024))
  246. encoder_out = torch.empty((1, 21, 1024))
  247. padding_mask = torch.ones((1, 13))
  248. torch.random.manual_seed(0)
  249. torch.nn.init.uniform_(x, -1, 1)
  250. torch.nn.init.uniform_(encoder_out, -1, 1)
  251. gx = ggml.from_numpy(ctx, x[0])
  252. ggml.ggml_set_name(gx, b"x")
  253. gpad = ggml.from_numpy(ctx, padding_mask[0])
  254. ggml.ggml_set_name(gpad, b"padding_mask")
  255. genc = ggml.from_numpy(ctx, encoder_out[0])
  256. gy = ggml.forward(
  257. "StandardTransformerDecoder",
  258. g_model,
  259. "text_decoder",
  260. gx,
  261. None, # TODO support padding mask,
  262. genc,
  263. None,
  264. )
  265. gf = ggml.ggml_build_forward(gy)
  266. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  267. y = ggml.to_numpy(gy)
  268. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  269. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  270. assert y.shape == y_exp.shape
  271. assert np.allclose(y_exp, y, atol=1e-4)
  272. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  273. # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
  274. # device = translator.device
  275. src_lang = "eng"
  276. src_text = "We are all in a yellow submarine."
  277. tgt_lang = "fra"
  278. # token_encoder = translator.text_tokenizer.create_encoder(
  279. # task="translation", lang=src_lang, mode="source", device=device
  280. # )
  281. # src = translator.collate(token_encoder(src_text))
  282. # text_out, _ = translator.get_prediction(
  283. # translator.model,
  284. # translator.text_tokenizer,
  285. # translator.unit_tokenizer,
  286. # src,
  287. # input_modality=Modality.TEXT,
  288. # output_modality=Modality.TEXT,
  289. # tgt_lang=tgt_lang,
  290. # )
  291. # tgt_text = str(text_out.sentences[0])
  292. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  293. # tgt_tokens = text_out.generator_output.results[0][0].seq
  294. # score = text_out.generator_output.results[0][0].score.item()
  295. # np.savez(
  296. # Path(__file__).parent / "sample_input.npz",
  297. # score=score,
  298. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  299. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  300. # tgt_tokens=tgt_tokens.numpy(),
  301. # )
  302. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  303. score = text_out["score"].item()
  304. tgt_tokens = list(text_out["tgt_tokens"])
  305. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  306. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  307. job = ggml.SequenceGeneratorJob()
  308. job.opts.beam_size = 1
  309. job.opts.min_seq_len = 1
  310. job.opts.soft_max_seq_len_a = 1
  311. job.opts.soft_max_seq_len_b = 200
  312. job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
  313. job.opts.len_penalty = 1.0
  314. job.opts.unk_penalty = 0.0
  315. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
  316. job.pad_idx = 0
  317. job.unk_idx = 1
  318. job.bos_idx = 2
  319. job.eos_idx = 3
  320. result = ggml.ggml_tensor()
  321. g_score = ggml.generate_sequence(
  322. g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
  323. )
  324. tokens = list(ggml.to_numpy(result))
  325. assert tokens == tgt_tokens
  326. assert g_score == pytest.approx(score)
  327. def test_in_loop(ctx: Ctx, g_model: c_void_p, pt_model: Any):
  328. resources = locals()
  329. import importlib
  330. import time
  331. testcase = test_TransformerEmbeddingFrontend_forward.__name__
  332. name, script = __name__, __file__
  333. root = Path(__file__).parent
  334. watched_files = [Path(__file__), root / "ggml.py", root / "build/src/libggml.so"]
  335. last_try = 0.0
  336. while True:
  337. last_save = max(f.stat().st_mtime for f in watched_files)
  338. if last_save <= last_try:
  339. time.sleep(0.1)
  340. continue
  341. last_try = last_save
  342. spec = importlib.util.spec_from_file_location(name, script)
  343. module = importlib.util.module_from_spec(spec)
  344. spec.loader.exec_module(module)
  345. sys.modules[name] = module
  346. f = getattr(module, testcase)
  347. f_args = [k for k in f.__annotations__ if k != "return"]
  348. try:
  349. f(**{k: resources[k] for k in f_args})
  350. print(f"Testcase {testcase} success")
  351. except AssertionError as e:
  352. print(f"Testcase {testcase} failed: {e}")
  353. except Exception as e:
  354. import pdb
  355. logging.exception(f"Testcase {testcase} crashed !")
  356. pdb.post_mortem()