test_unity_cpp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  23. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  24. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  25. @pytest.fixture(name="ctx")
  26. def _ctx() -> Iterator[Ctx]:
  27. """Allocate a new context with 1024 MB of memory"""
  28. try:
  29. ctx = ggml.ggml_init(params=CTX_PARAMS)
  30. yield ctx
  31. finally:
  32. ggml.ggml_free(ctx)
  33. @pytest.fixture(scope="module")
  34. def g_model_once() -> Iterator[c_void_p]:
  35. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  36. if not model_file.exists():
  37. convert_model("seamlessM4T_medium", model_file)
  38. with ggml.load_unity_ggml_file(model_file) as model:
  39. yield model
  40. @pytest.fixture()
  41. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  42. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  43. return g_model_once
  44. @pytest.fixture(scope="module")
  45. def translator() -> Iterator[Any]:
  46. tr = Translator(
  47. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  48. )
  49. with torch.inference_mode():
  50. yield tr
  51. @pytest.fixture(scope="module")
  52. def pt_model(translator: Translator) -> Any:
  53. model = translator.model
  54. print(model)
  55. return model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_causal_attention_mask(ctx: Ctx):
  64. x = torch.zeros((1, 10, 32))
  65. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  66. mask_exp = generator(x).numpy()
  67. gx = ggml.from_numpy(ctx, x)
  68. gmask = ggml.causal_attention_mask(ctx, gx)
  69. mask = ggml.to_numpy(gmask)
  70. gf = ggml.ggml_build_forward(gmask)
  71. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  72. assert mask_exp.shape == (10, 10)
  73. assert mask.shape == (10, 10)
  74. assert np.all(mask == mask_exp)
  75. x = x[:, :8, :]
  76. mask_exp = generator(x).numpy()
  77. gx = ggml.from_numpy(ctx, x)
  78. gmask = ggml.causal_attention_mask(ctx, gx)
  79. mask = ggml.to_numpy(gmask)
  80. gf = ggml.ggml_build_forward(gmask)
  81. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  82. assert mask_exp.shape == (8, 8)
  83. assert mask.shape == (8, 8)
  84. assert np.all(mask == mask_exp)
  85. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  86. x = torch.empty((2, 21, 1024))
  87. torch.nn.init.uniform_(x, -1, 1)
  88. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  89. gx = ggml.from_numpy(ctx, x)
  90. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  91. ggml.build_and_compute(ctx, gy)
  92. y = ggml.to_numpy(gy)
  93. assert np.allclose(y_exp, y, atol=1e-5)
  94. def test_Linear_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  95. x = torch.empty((2, 21, 1024))
  96. torch.nn.init.uniform_(x, -1, 1)
  97. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  98. gx = ggml.from_numpy(ctx, x)
  99. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  100. ggml.build_and_compute(ctx, gy)
  101. y = ggml.to_numpy(gy)
  102. assert np.allclose(y_exp, y, atol=1e-5)
  103. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  104. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  105. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  106. # Test FFN without LayerNorm
  107. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  108. gx = ggml.from_numpy(ctx, x)
  109. gy = ggml.forward(
  110. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  111. )
  112. ggml.build_and_compute(ctx, gy)
  113. y = ggml.to_numpy(gy)
  114. assert np.allclose(y_exp, y, atol=1e-5)
  115. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  116. try:
  117. return tensor.contents.name # type: ignore[no-any-return]
  118. except ValueError:
  119. return b"???"
  120. def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  121. x = torch.empty((2, 21, 1024))
  122. torch.random.manual_seed(0)
  123. torch.nn.init.uniform_(x, -1, 1)
  124. self_attn = pt_model.text_encoder.layers[0].self_attn
  125. # Note: we use different lengths for queries and keys,
  126. # this tests the implementation in decoding context too.
  127. # Note2: ggml_flash_attn requires that we have more keys than queries
  128. gxq = ggml.from_numpy(ctx, x[:, :11, :].contiguous())
  129. gx = ggml.from_numpy(ctx, x)
  130. ggml.ggml_set_name(gx, b"x")
  131. gy = ggml.forward(
  132. "MultiheadAttention",
  133. g_model,
  134. "text_encoder.layers.0.self_attn",
  135. gxq,
  136. gx,
  137. gx,
  138. None, # TODO: tests with causal attention masks
  139. )
  140. gf = ggml.ggml_build_forward(gy)
  141. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  142. q_exp = self_attn.q_proj(x[:, :11, :]).numpy()
  143. y = ggml.to_numpy(gy)
  144. nodes = {}
  145. for i in range(gf.n_nodes):
  146. name = _name(gf.nodes[i])
  147. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  148. print(name, f"op({gf.nodes[i].contents.op})", children)
  149. nodes[name] = ggml.to_numpy(gf.nodes[i])
  150. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  151. self_attn.register_attn_weight_hook(attn_weights_hook)
  152. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  153. q = nodes[b"q"]
  154. assert q.shape == q_exp.shape
  155. assert np.allclose(q_exp, q, atol=1e-5)
  156. # with flash_attn we don't have attn_weights
  157. if not UNITY_FLASH_ATTN:
  158. attn_weights = nodes[b"attn_weights"]
  159. [attn_weights_exp] = attn_weights_hook._storage
  160. attn_weights_exp = attn_weights_exp.numpy()
  161. assert attn_weights_exp.shape == attn_weights.shape
  162. # GGML is very agressively reducing small softmax weights to 0,
  163. # so the error isn't that small
  164. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  165. # But the sums should be close to 1
  166. assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 11)))
  167. # And the maximum index should match the original ones.
  168. assert np.allclose(
  169. np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
  170. )
  171. assert y.shape == y_exp.shape
  172. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  173. def test_StandardTransformerEncoderLayer_forward(
  174. ctx: Ctx, g_model: c_void_p, pt_model: Any
  175. ) -> None:
  176. x = torch.empty((2, 21, 1024))
  177. padding_mask = torch.ones((2, 21))
  178. torch.random.manual_seed(0)
  179. torch.nn.init.uniform_(x, -1, 1)
  180. layer = pt_model.text_encoder.layers[0]
  181. gx = ggml.from_numpy(ctx, x)
  182. ggml.ggml_set_name(gx, b"x")
  183. gpad = ggml.from_numpy(ctx, padding_mask)
  184. ggml.ggml_set_name(gpad, b"padding_mask")
  185. gy = ggml.forward(
  186. "StandardTransformerEncoderLayer",
  187. g_model,
  188. "text_encoder.layers.0",
  189. gx,
  190. None, # TODO support padding mask
  191. )
  192. gf = ggml.ggml_build_forward(gy)
  193. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  194. y = ggml.to_numpy(gy)
  195. y_exp, _ = layer(x, padding_mask)
  196. y_exp = y_exp.numpy()
  197. assert y.shape == y_exp.shape
  198. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  199. def test_StandardTransformerEncoder_forward(
  200. ctx: Ctx, g_model: c_void_p, pt_model: Any
  201. ) -> None:
  202. x = torch.empty((2, 21, 1024))
  203. padding_mask = torch.ones((2, 21))
  204. torch.random.manual_seed(0)
  205. torch.nn.init.uniform_(x, -1, 1)
  206. gx = ggml.from_numpy(ctx, x)
  207. ggml.ggml_set_name(gx, b"x")
  208. gpad = ggml.from_numpy(ctx, padding_mask)
  209. ggml.ggml_set_name(gpad, b"padding_mask")
  210. gy = ggml.forward(
  211. "StandardTransformerEncoder",
  212. g_model,
  213. "text_encoder",
  214. gx,
  215. None, # TODO support padding mask
  216. )
  217. gf = ggml.ggml_build_forward(gy)
  218. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  219. y = ggml.to_numpy(gy)
  220. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  221. y_exp = y_exp.numpy()
  222. assert y.shape == y_exp.shape
  223. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  224. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  225. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  226. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  227. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  228. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  229. y_exp = pos_encoder(seq, None)[0].numpy()
  230. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  231. ggml.ggml_set_name(gseq, b"seq")
  232. gy = ggml.forward(
  233. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  234. )
  235. gf = ggml.ggml_build_forward(gy)
  236. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  237. y = ggml.to_numpy(gy)
  238. assert y.shape == y_exp.shape
  239. assert np.allclose(y_exp, y, atol=1e-6)
  240. def test_TransformerEmbeddingFrontend_forward(
  241. ctx: Ctx, g_model: c_void_p, pt_model: Any
  242. ) -> None:
  243. seq = torch.arange(2 * 20).reshape(2, 20)
  244. seq[1, 15:] = 0 # padding for second sentence
  245. seq_len = torch.tensor([20, 15])
  246. gseq = ggml.from_numpy(ctx, seq.numpy().astype(np.int32))
  247. ggml.ggml_set_name(gseq, b"seq")
  248. gy = ggml.forward(
  249. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  250. )
  251. ggml.build_and_compute(ctx, gy)
  252. y = ggml.to_numpy(gy)
  253. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  254. y_exp = y_exp.numpy()
  255. assert y.shape == y_exp.shape
  256. assert np.allclose(y_exp, y, atol=1e-6)
  257. def test_StandardTransformerDecoder_forward(
  258. ctx: Ctx, g_model: c_void_p, pt_model: Any
  259. ) -> None:
  260. x = torch.empty((2, 13, 1024))
  261. encoder_out = torch.empty((2, 21, 1024))
  262. padding_mask = torch.ones((2, 13))
  263. torch.random.manual_seed(0)
  264. torch.nn.init.uniform_(x, -1, 1)
  265. torch.nn.init.uniform_(encoder_out, -1, 1)
  266. gx = ggml.from_numpy(ctx, x)
  267. ggml.ggml_set_name(gx, b"x")
  268. gpad = ggml.from_numpy(ctx, padding_mask)
  269. ggml.ggml_set_name(gpad, b"padding_mask")
  270. genc = ggml.from_numpy(ctx, encoder_out)
  271. gy = ggml.forward(
  272. "StandardTransformerDecoder",
  273. g_model,
  274. "text_decoder",
  275. gx,
  276. None, # TODO support padding mask,
  277. genc,
  278. None,
  279. )
  280. ggml.build_and_compute(ctx, gy)
  281. y = ggml.to_numpy(gy)
  282. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  283. y_exp = y_exp.numpy()
  284. assert y.shape == y_exp.shape
  285. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-3)
  286. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  287. # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
  288. # device = translator.device
  289. src_lang = "eng"
  290. src_text = "We are all in a yellow submarine."
  291. tgt_lang = "fra"
  292. # token_encoder = translator.text_tokenizer.create_encoder(
  293. # task="translation", lang=src_lang, mode="source", device=device
  294. # )
  295. # src = translator.collate(token_encoder(src_text))
  296. # text_out, _ = translator.get_prediction(
  297. # translator.model,
  298. # translator.text_tokenizer,
  299. # translator.unit_tokenizer,
  300. # src,
  301. # input_modality=Modality.TEXT,
  302. # output_modality=Modality.TEXT,
  303. # tgt_lang=tgt_lang,
  304. # )
  305. # tgt_text = str(text_out.sentences[0])
  306. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  307. # tgt_tokens = text_out.generator_output.results[0][0].seq
  308. # score = text_out.generator_output.results[0][0].score.item()
  309. # np.savez(
  310. # Path(__file__).parent / "sample_input.npz",
  311. # score=score,
  312. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  313. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  314. # tgt_tokens=tgt_tokens.numpy(),
  315. # )
  316. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  317. score = text_out["score"].item()
  318. tgt_tokens = list(text_out["tgt_tokens"])
  319. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  320. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  321. job = ggml.SequenceGeneratorJob()
  322. job.opts.beam_size = 2
  323. job.opts.min_seq_len = 1
  324. job.opts.soft_max_seq_len_a = 1
  325. job.opts.soft_max_seq_len_b = 200
  326. job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
  327. job.opts.len_penalty = 1.0
  328. job.opts.unk_penalty = 0.0
  329. job.opts.normalize_scores = True
  330. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
  331. job.pad_idx = 0
  332. job.unk_idx = 1
  333. job.bos_idx = 2
  334. job.eos_idx = 3
  335. result = ggml.ggml_tensor()
  336. g_score = ggml.generate_sequence(
  337. g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
  338. )
  339. tokens = list(ggml.to_numpy(ctypes.pointer(result)))
  340. assert tokens == tgt_tokens
  341. assert g_score == pytest.approx(score, rel=1e-2)