test_unity_cpp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  23. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  24. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  25. @pytest.fixture(name="ctx")
  26. def _ctx() -> Iterator[Ctx]:
  27. """Allocate a new context with 1024 MB of memory"""
  28. try:
  29. ctx = ggml.ggml_init(params=CTX_PARAMS)
  30. yield ctx
  31. finally:
  32. ggml.ggml_free(ctx)
  33. @pytest.fixture(scope="module")
  34. def g_model_once() -> Iterator[c_void_p]:
  35. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  36. if not model_file.exists():
  37. convert_model("seamlessM4T_medium", model_file)
  38. with ggml.load_unity_ggml_file(model_file) as model:
  39. yield model
  40. @pytest.fixture()
  41. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  42. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  43. return g_model_once
  44. @pytest.fixture(scope="module")
  45. def translator() -> Iterator[Any]:
  46. tr = Translator(
  47. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  48. )
  49. with torch.inference_mode():
  50. yield tr
  51. @pytest.fixture(scope="module")
  52. def pt_model(translator: Translator) -> Any:
  53. model = translator.model
  54. print(model)
  55. return model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_causal_attention_mask(ctx: Ctx):
  64. x = torch.zeros((1, 10, 32))
  65. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  66. mask_exp = generator(x).numpy()
  67. gx = ggml.from_numpy(ctx, x)
  68. gmask = ggml.causal_attention_mask(ctx, gx)
  69. mask = ggml.to_numpy(gmask)
  70. gf = ggml.ggml_build_forward(gmask)
  71. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  72. assert mask_exp.shape == (10, 10)
  73. assert mask.shape == (10, 10)
  74. assert np.all(mask == mask_exp)
  75. x = x[:, :8, :]
  76. mask_exp = generator(x).numpy()
  77. gx = ggml.from_numpy(ctx, x)
  78. gmask = ggml.causal_attention_mask(ctx, gx)
  79. mask = ggml.to_numpy(gmask)
  80. gf = ggml.ggml_build_forward(gmask)
  81. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  82. assert mask_exp.shape == (8, 8)
  83. assert mask.shape == (8, 8)
  84. assert np.all(mask == mask_exp)
  85. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  86. x = torch.empty((2, 21, 1024))
  87. torch.nn.init.uniform_(x, -1, 1)
  88. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  89. gx = ggml.from_numpy(ctx, x)
  90. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  91. ggml.build_and_compute(ctx, gy)
  92. y = ggml.to_numpy(gy)
  93. assert np.allclose(y_exp, y, atol=1e-5)
  94. def test_Linear_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  95. x = torch.empty((2, 21, 1024))
  96. torch.nn.init.uniform_(x, -1, 1)
  97. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  98. gx = ggml.from_numpy(ctx, x)
  99. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  100. ggml.build_and_compute(ctx, gy)
  101. y = ggml.to_numpy(gy)
  102. assert np.allclose(y_exp, y, atol=1e-5)
  103. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  104. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  105. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  106. # Test FFN without LayerNorm
  107. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  108. gx = ggml.from_numpy(ctx, x)
  109. gy = ggml.forward(
  110. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  111. )
  112. ggml.build_and_compute(ctx, gy)
  113. y = ggml.to_numpy(gy)
  114. assert np.allclose(y_exp, y, atol=1e-5)
  115. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  116. try:
  117. return tensor.contents.name # type: ignore[no-any-return]
  118. except ValueError:
  119. return b"???"
  120. def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  121. x = torch.empty((2, 21, 1024))
  122. torch.random.manual_seed(0)
  123. torch.nn.init.uniform_(x, -1, 1)
  124. self_attn = pt_model.text_encoder.layers[0].self_attn
  125. # Note: we use different lengths for queries and keys,
  126. # this tests the implementation in decoding context too.
  127. # Note2: ggml_flash_attn requires that we have more keys than queries
  128. gxq = ggml.from_numpy(ctx, x[:, :11, :])
  129. gx = ggml.from_numpy(ctx, x)
  130. ggml.ggml_set_name(gx, b"x")
  131. gy = ggml.forward(
  132. "MultiheadAttention",
  133. g_model,
  134. "text_encoder.layers.0.self_attn",
  135. gxq,
  136. gx,
  137. gx,
  138. None, # TODO: tests with causal attention masks
  139. )
  140. gf = ggml.ggml_build_forward(gy)
  141. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  142. q_exp = self_attn._project_q(x[:, :11, :], None, None).numpy()
  143. y = ggml.to_numpy(gy)
  144. nodes = {}
  145. for i in range(gf.n_nodes):
  146. name = _name(gf.nodes[i])
  147. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  148. print(name, f"op({gf.nodes[i].contents.op})", children)
  149. nodes[name] = ggml.to_numpy(gf.nodes[i])
  150. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  151. self_attn.register_attn_weight_hook(attn_weights_hook)
  152. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  153. q = nodes[b"q"]
  154. assert q.shape == q_exp.shape
  155. assert np.allclose(q_exp, q, atol=1e-5)
  156. # with flash_attn we don't have attn_weights
  157. if not UNITY_FLASH_ATTN:
  158. attn_weights = nodes[b"attn_weights"]
  159. [attn_weights_exp] = attn_weights_hook._storage
  160. # Fix the shape of attn_weights_exp
  161. attn_weights_exp = attn_weights_exp.unflatten(0, (2, 16)).numpy()
  162. assert attn_weights_exp.shape == attn_weights.shape
  163. # GGML is very agressively reducing small softmax weights to 0.
  164. # assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  165. # But the sums should be close to 1
  166. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, 11)))
  167. # And the maximum index should match the original ones.
  168. assert np.allclose(
  169. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  170. )
  171. assert y.shape == y_exp.shape
  172. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  173. def test_StandardTransformerEncoderLayer_forward(
  174. ctx: Ctx, g_model: c_void_p, pt_model: Any
  175. ) -> None:
  176. x = torch.empty((1, 21, 1024))
  177. padding_mask = torch.ones((1, 21))
  178. torch.random.manual_seed(0)
  179. torch.nn.init.uniform_(x, -1, 1)
  180. layer = pt_model.text_encoder.layers[0]
  181. gx = ggml.from_numpy(ctx, x[0])
  182. ggml.ggml_set_name(gx, b"x")
  183. gpad = ggml.from_numpy(ctx, padding_mask[0])
  184. ggml.ggml_set_name(gpad, b"padding_mask")
  185. gy = ggml.forward(
  186. "StandardTransformerEncoderLayer",
  187. g_model,
  188. "text_encoder.layers.0",
  189. gx,
  190. None, # TODO support padding mask
  191. )
  192. gf = ggml.ggml_build_forward(gy)
  193. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  194. y = ggml.to_numpy(gy)
  195. y_exp, _ = layer(x, padding_mask)
  196. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  197. assert y.shape == y_exp.shape
  198. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  199. def test_StandardTransformerEncoder_forward(
  200. ctx: Ctx, g_model: c_void_p, pt_model: Any
  201. ) -> None:
  202. x = torch.empty((1, 21, 1024))
  203. padding_mask = torch.ones((1, 21))
  204. torch.random.manual_seed(0)
  205. torch.nn.init.uniform_(x, -1, 1)
  206. gx = ggml.from_numpy(ctx, x[0])
  207. ggml.ggml_set_name(gx, b"x")
  208. gpad = ggml.from_numpy(ctx, padding_mask[0])
  209. ggml.ggml_set_name(gpad, b"padding_mask")
  210. gy = ggml.forward(
  211. "StandardTransformerEncoder",
  212. g_model,
  213. "text_encoder",
  214. gx,
  215. None, # TODO support padding mask
  216. )
  217. gf = ggml.ggml_build_forward(gy)
  218. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  219. y = ggml.to_numpy(gy)
  220. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  221. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  222. assert y.shape == y_exp.shape
  223. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  224. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  225. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  226. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  227. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  228. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  229. y_exp = pos_encoder(seq, None)[0].numpy()
  230. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  231. ggml.ggml_set_name(gseq, b"seq")
  232. gy = ggml.forward(
  233. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  234. )
  235. gf = ggml.ggml_build_forward(gy)
  236. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  237. y = ggml.to_numpy(gy)
  238. assert y.shape == y_exp.shape
  239. assert np.allclose(y_exp, y, atol=1e-6)
  240. def test_TransformerEmbeddingFrontend_forward(
  241. ctx: Ctx, g_model: c_void_p, pt_model: Any
  242. ) -> None:
  243. seq = torch.arange(20).reshape(1, 20)
  244. seq_len = torch.tensor([20])
  245. gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
  246. ggml.ggml_set_name(gseq, b"seq")
  247. gy = ggml.forward(
  248. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  249. )
  250. gf = ggml.ggml_build_forward(gy)
  251. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  252. y = ggml.to_numpy(gy)
  253. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  254. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  255. assert y.shape == y_exp.shape
  256. assert np.allclose(y_exp, y, atol=1e-6)
  257. def test_StandardTransformerDecoder_forward(
  258. ctx: Ctx, g_model: c_void_p, pt_model: Any
  259. ) -> None:
  260. pytest.skip("foo")
  261. x = torch.empty((1, 13, 1024))
  262. encoder_out = torch.empty((1, 21, 1024))
  263. padding_mask = torch.ones((1, 13))
  264. torch.random.manual_seed(0)
  265. torch.nn.init.uniform_(x, -1, 1)
  266. torch.nn.init.uniform_(encoder_out, -1, 1)
  267. gx = ggml.from_numpy(ctx, x[0])
  268. ggml.ggml_set_name(gx, b"x")
  269. gpad = ggml.from_numpy(ctx, padding_mask[0])
  270. ggml.ggml_set_name(gpad, b"padding_mask")
  271. genc = ggml.from_numpy(ctx, encoder_out[0])
  272. gy = ggml.forward(
  273. "StandardTransformerDecoder",
  274. g_model,
  275. "text_decoder",
  276. gx,
  277. None, # TODO support padding mask,
  278. genc,
  279. None,
  280. )
  281. gf = ggml.ggml_build_forward(gy)
  282. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  283. y = ggml.to_numpy(gy)
  284. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  285. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  286. assert y.shape == y_exp.shape
  287. assert np.allclose(y_exp, y, atol=1e-4)
  288. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  289. # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
  290. # device = translator.device
  291. src_lang = "eng"
  292. src_text = "We are all in a yellow submarine."
  293. tgt_lang = "fra"
  294. # token_encoder = translator.text_tokenizer.create_encoder(
  295. # task="translation", lang=src_lang, mode="source", device=device
  296. # )
  297. # src = translator.collate(token_encoder(src_text))
  298. # text_out, _ = translator.get_prediction(
  299. # translator.model,
  300. # translator.text_tokenizer,
  301. # translator.unit_tokenizer,
  302. # src,
  303. # input_modality=Modality.TEXT,
  304. # output_modality=Modality.TEXT,
  305. # tgt_lang=tgt_lang,
  306. # )
  307. # tgt_text = str(text_out.sentences[0])
  308. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  309. # tgt_tokens = text_out.generator_output.results[0][0].seq
  310. # score = text_out.generator_output.results[0][0].score.item()
  311. # np.savez(
  312. # Path(__file__).parent / "sample_input.npz",
  313. # score=score,
  314. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  315. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  316. # tgt_tokens=tgt_tokens.numpy(),
  317. # )
  318. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  319. score = text_out["score"].item()
  320. tgt_tokens = list(text_out["tgt_tokens"])
  321. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  322. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  323. job = ggml.SequenceGeneratorJob()
  324. job.opts.beam_size = 1
  325. job.opts.min_seq_len = 1
  326. job.opts.soft_max_seq_len_a = 1
  327. job.opts.soft_max_seq_len_b = 200
  328. job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
  329. job.opts.len_penalty = 1.0
  330. job.opts.unk_penalty = 0.0
  331. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
  332. job.pad_idx = 0
  333. job.unk_idx = 1
  334. job.bos_idx = 2
  335. job.eos_idx = 3
  336. result = ggml.ggml_tensor()
  337. g_score = ggml.generate_sequence(
  338. g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
  339. )
  340. tokens = list(ggml.to_numpy(ctypes.pointer(result)))
  341. assert tokens == tgt_tokens
  342. assert g_score == pytest.approx(score)