test_unity_cpp.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  23. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  24. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  25. @pytest.fixture(name="ctx")
  26. def _ctx() -> Iterator[Ctx]:
  27. """Allocate a new context with 1024 MB of memory"""
  28. try:
  29. ctx = ggml.ggml_init(params=CTX_PARAMS)
  30. yield ctx
  31. finally:
  32. ggml.ggml_free(ctx)
  33. @pytest.fixture(scope="module")
  34. def g_model_once() -> Iterator[c_void_p]:
  35. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  36. if not model_file.exists():
  37. convert_model("seamlessM4T_medium", model_file)
  38. with ggml.load_unity_ggml_file(model_file) as model:
  39. yield model
  40. @pytest.fixture()
  41. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  42. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  43. return g_model_once
  44. @pytest.fixture(scope="module")
  45. def translator() -> Iterator[Any]:
  46. tr = Translator(
  47. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  48. )
  49. with torch.inference_mode():
  50. yield tr
  51. @pytest.fixture(scope="module")
  52. def pt_model(translator: Translator) -> Any:
  53. model = translator.model
  54. print(model)
  55. return model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_causal_attention_mask(ctx: Ctx):
  64. x = torch.zeros((1, 10, 32))
  65. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  66. mask_exp = generator(x).numpy()
  67. gx = ggml.from_numpy(ctx, x)
  68. gmask = ggml.causal_attention_mask(ctx, gx)
  69. mask = ggml.to_numpy(gmask)
  70. gf = ggml.ggml_build_forward(gmask)
  71. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  72. assert mask_exp.shape == (10, 10)
  73. assert mask.shape == (10, 10)
  74. assert np.all(mask == mask_exp)
  75. x = x[:, :8, :]
  76. mask_exp = generator(x).numpy()
  77. gx = ggml.from_numpy(ctx, x)
  78. gmask = ggml.causal_attention_mask(ctx, gx)
  79. mask = ggml.to_numpy(gmask)
  80. gf = ggml.ggml_build_forward(gmask)
  81. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  82. assert mask_exp.shape == (8, 8)
  83. assert mask.shape == (8, 8)
  84. assert np.all(mask == mask_exp)
  85. def test_forward_ffn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  86. x = torch.empty((21, 1024)) # (seq_len, model_dim)
  87. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  88. # Test FFN without LayerNorm
  89. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  90. gx = ggml.from_numpy(ctx, x)
  91. gy = ggml.forward(
  92. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  93. )
  94. gf = ggml.ggml_build_forward(gy)
  95. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  96. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  97. assert np.allclose(y_exp, y, atol=1e-5)
  98. def test_forward_layer_norm(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  99. x = torch.empty((21, 1024))
  100. torch.nn.init.uniform_(x, -1, 1)
  101. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  102. gx = ggml.from_numpy(ctx, x)
  103. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  104. gf = ggml.ggml_build_forward(gy)
  105. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  106. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
  107. assert np.allclose(y_exp, y, rtol=1e-3, atol=1e-4)
  108. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  109. try:
  110. return tensor.contents.name # type: ignore[no-any-return]
  111. except ValueError:
  112. return b"???"
  113. def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  114. x = torch.empty((1, 21, 1024))
  115. torch.random.manual_seed(0)
  116. torch.nn.init.uniform_(x, -1, 1)
  117. self_attn = pt_model.text_encoder.layers[0].self_attn
  118. # Note: we use different lengths for queries and keys,
  119. # this tests the implementation in decoding context too.
  120. # Note2: ggml_flash_attn requires that we have more keys than queries
  121. gxq = ggml.from_numpy(ctx, x[0, :11, :])
  122. gx = ggml.from_numpy(ctx, x[0])
  123. ggml.ggml_set_name(gx, b"x")
  124. gy = ggml.forward(
  125. "MultiheadAttention",
  126. g_model,
  127. "text_encoder.layers.0.self_attn",
  128. gxq,
  129. gx,
  130. gx,
  131. None, # TODO: tests with causal attention masks
  132. )
  133. gf = ggml.ggml_build_forward(gy)
  134. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  135. # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
  136. y = ggml.to_numpy(gy)
  137. nodes = {}
  138. for i in range(gf.n_nodes):
  139. name = _name(gf.nodes[i])
  140. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  141. print(name, f"op({gf.nodes[i].contents.op})", children)
  142. nodes[name] = ggml.to_numpy(gf.nodes[i])
  143. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  144. self_attn.register_attn_weight_hook(attn_weights_hook)
  145. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  146. y_exp = y_exp.squeeze(0) # remove batch dimension
  147. # q = nodes[b"q"]
  148. # assert q.shape == q_exp.shape
  149. # assert np.allclose(q_exp, q, atol=1e-5)
  150. # with flash_attn we don't have attn_weights
  151. if not UNITY_FLASH_ATTN:
  152. attn_weights = nodes[b"attn_weights"]
  153. [attn_weights_exp] = attn_weights_hook._storage
  154. attn_weights_exp = attn_weights_exp.squeeze(0).numpy()
  155. assert attn_weights_exp.shape == attn_weights.shape
  156. # GGML is very agressively reducing small softmax weights to 0.
  157. # Not sure to what this is due.
  158. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  159. assert y.shape == y_exp.shape
  160. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  161. def test_StandardTransformerEncoderLayer_forward(
  162. ctx: Ctx, g_model: c_void_p, pt_model: Any
  163. ) -> None:
  164. x = torch.empty((1, 21, 1024))
  165. padding_mask = torch.ones((1, 21))
  166. torch.random.manual_seed(0)
  167. torch.nn.init.uniform_(x, -1, 1)
  168. layer = pt_model.text_encoder.layers[0]
  169. gx = ggml.from_numpy(ctx, x[0])
  170. ggml.ggml_set_name(gx, b"x")
  171. gpad = ggml.from_numpy(ctx, padding_mask[0])
  172. ggml.ggml_set_name(gpad, b"padding_mask")
  173. gy = ggml.forward(
  174. "StandardTransformerEncoderLayer",
  175. g_model,
  176. "text_encoder.layers.0",
  177. gx,
  178. None, # TODO support padding mask
  179. )
  180. gf = ggml.ggml_build_forward(gy)
  181. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  182. y = ggml.to_numpy(gy)
  183. y_exp, _ = layer(x, padding_mask)
  184. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  185. assert y.shape == y_exp.shape
  186. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  187. def test_StandardTransformerEncoder_forward(
  188. ctx: Ctx, g_model: c_void_p, pt_model: Any
  189. ) -> None:
  190. x = torch.empty((1, 21, 1024))
  191. padding_mask = torch.ones((1, 21))
  192. torch.random.manual_seed(0)
  193. torch.nn.init.uniform_(x, -1, 1)
  194. gx = ggml.from_numpy(ctx, x[0])
  195. ggml.ggml_set_name(gx, b"x")
  196. gpad = ggml.from_numpy(ctx, padding_mask[0])
  197. ggml.ggml_set_name(gpad, b"padding_mask")
  198. gy = ggml.forward(
  199. "StandardTransformerEncoder",
  200. g_model,
  201. "text_encoder",
  202. gx,
  203. None, # TODO support padding mask
  204. )
  205. gf = ggml.ggml_build_forward(gy)
  206. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  207. y = ggml.to_numpy(gy)
  208. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  209. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  210. assert y.shape == y_exp.shape
  211. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  212. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  213. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  214. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  215. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  216. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  217. y_exp = pos_encoder(seq, None)[0].numpy()
  218. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  219. ggml.ggml_set_name(gseq, b"seq")
  220. gy = ggml.forward(
  221. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  222. )
  223. gf = ggml.ggml_build_forward(gy)
  224. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  225. y = ggml.to_numpy(gy)
  226. assert y.shape == y_exp.shape
  227. assert np.allclose(y_exp, y, atol=1e-6)
  228. def test_TransformerEmbeddingFrontend_forward(
  229. ctx: Ctx, g_model: c_void_p, pt_model: Any
  230. ) -> None:
  231. seq = torch.arange(20).reshape(1, 20)
  232. seq_len = torch.tensor([20])
  233. gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
  234. ggml.ggml_set_name(gseq, b"seq")
  235. gy = ggml.forward(
  236. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  237. )
  238. gf = ggml.ggml_build_forward(gy)
  239. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  240. y = ggml.to_numpy(gy)
  241. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  242. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  243. assert y.shape == y_exp.shape
  244. assert np.allclose(y_exp, y, atol=1e-6)
  245. def test_StandardTransformerDecoder_forward(
  246. ctx: Ctx, g_model: c_void_p, pt_model: Any
  247. ) -> None:
  248. pytest.skip("foo")
  249. x = torch.empty((1, 13, 1024))
  250. encoder_out = torch.empty((1, 21, 1024))
  251. padding_mask = torch.ones((1, 13))
  252. torch.random.manual_seed(0)
  253. torch.nn.init.uniform_(x, -1, 1)
  254. torch.nn.init.uniform_(encoder_out, -1, 1)
  255. gx = ggml.from_numpy(ctx, x[0])
  256. ggml.ggml_set_name(gx, b"x")
  257. gpad = ggml.from_numpy(ctx, padding_mask[0])
  258. ggml.ggml_set_name(gpad, b"padding_mask")
  259. genc = ggml.from_numpy(ctx, encoder_out[0])
  260. gy = ggml.forward(
  261. "StandardTransformerDecoder",
  262. g_model,
  263. "text_decoder",
  264. gx,
  265. None, # TODO support padding mask,
  266. genc,
  267. None,
  268. )
  269. gf = ggml.ggml_build_forward(gy)
  270. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  271. y = ggml.to_numpy(gy)
  272. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  273. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  274. assert y.shape == y_exp.shape
  275. assert np.allclose(y_exp, y, atol=1e-4)
  276. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  277. # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
  278. # device = translator.device
  279. src_lang = "eng"
  280. src_text = "We are all in a yellow submarine."
  281. tgt_lang = "fra"
  282. # token_encoder = translator.text_tokenizer.create_encoder(
  283. # task="translation", lang=src_lang, mode="source", device=device
  284. # )
  285. # src = translator.collate(token_encoder(src_text))
  286. # text_out, _ = translator.get_prediction(
  287. # translator.model,
  288. # translator.text_tokenizer,
  289. # translator.unit_tokenizer,
  290. # src,
  291. # input_modality=Modality.TEXT,
  292. # output_modality=Modality.TEXT,
  293. # tgt_lang=tgt_lang,
  294. # )
  295. # tgt_text = str(text_out.sentences[0])
  296. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  297. # tgt_tokens = text_out.generator_output.results[0][0].seq
  298. # score = text_out.generator_output.results[0][0].score.item()
  299. # np.savez(
  300. # Path(__file__).parent / "sample_input.npz",
  301. # score=score,
  302. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  303. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  304. # tgt_tokens=tgt_tokens.numpy(),
  305. # )
  306. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  307. score = text_out["score"].item()
  308. tgt_tokens = list(text_out["tgt_tokens"])
  309. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  310. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  311. job = ggml.SequenceGeneratorJob()
  312. job.opts.beam_size = 1
  313. job.opts.min_seq_len = 1
  314. job.opts.soft_max_seq_len_a = 1
  315. job.opts.soft_max_seq_len_b = 200
  316. job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
  317. job.opts.len_penalty = 1.0
  318. job.opts.unk_penalty = 0.0
  319. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
  320. job.pad_idx = 0
  321. job.unk_idx = 1
  322. job.bos_idx = 2
  323. job.eos_idx = 3
  324. result = ggml.ggml_tensor()
  325. g_score = ggml.generate_sequence(
  326. g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
  327. )
  328. tokens = list(ggml.to_numpy(ctypes.pointer(result)))
  329. assert tokens == tgt_tokens
  330. assert g_score == pytest.approx(score)