test_unity_cpp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. import fairseq2.nn
  8. import fairseq2.nn.transformer
  9. import logging
  10. import sys
  11. from pathlib import Path
  12. from ctypes_utils import Ptr
  13. from ctypes import c_void_p
  14. from typing import Any
  15. from pathlib import Path
  16. from typing import Iterator
  17. from ggml import NativeObj
  18. from ggml_convert import convert_model
  19. from seamless_communication.models.inference.translator import Translator, Modality
  20. Ctx = ggml.ggml_context_p
  21. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  22. CTX_PARAMS = ggml.ggml_init_params(mem_size=1024 * 1024 * 1024, mem_buffer=None)
  23. FAIRSEQ2_CPP = Path(__file__).parent / "examples/unity/fairseq2.cpp"
  24. UNITY_FLASH_ATTN = "\n# define UNITY_FLASH_ATTN 0\n" not in FAIRSEQ2_CPP.read_text()
  25. @pytest.fixture(name="ctx")
  26. def _ctx() -> Iterator[Ctx]:
  27. """Allocate a new context with 1024 MB of memory"""
  28. try:
  29. ctx = ggml.ggml_init(params=CTX_PARAMS)
  30. yield ctx
  31. finally:
  32. ggml.ggml_free(ctx)
  33. @pytest.fixture(scope="module")
  34. def g_model_once() -> Iterator[c_void_p]:
  35. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  36. if not model_file.exists():
  37. convert_model("seamlessM4T_medium", model_file)
  38. with ggml.load_unity_ggml_file(model_file) as model:
  39. yield model
  40. @pytest.fixture()
  41. def g_model(ctx: Ctx, g_model_once: c_void_p) -> c_void_p:
  42. ggml.lib.fairseq2_model_set_inference_ctx(g_model_once, ctx)
  43. return g_model_once
  44. @pytest.fixture(scope="module")
  45. def translator() -> Iterator[Any]:
  46. tr = Translator(
  47. "seamlessM4T_medium", "vocoder_36langs", torch.device("cpu"), torch.float32
  48. )
  49. with torch.inference_mode():
  50. yield tr
  51. @pytest.fixture(scope="module")
  52. def pt_model(translator: Translator) -> Any:
  53. model = translator.model
  54. print(model)
  55. return model
  56. @pytest.mark.xfail(reason="TODO")
  57. def test_hparams_code_is_up_to_date() -> None:
  58. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  59. hparams_header_file = model_file.with_suffix(".hparams.h")
  60. hparams_struct = hparams_header_file.read_text().strip()
  61. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  62. assert hparams_struct in actual_code
  63. def test_causal_attention_mask(ctx: Ctx):
  64. x = torch.zeros((1, 10, 32))
  65. generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
  66. mask_exp = generator(x).numpy()
  67. gx = ggml.from_numpy(ctx, x)
  68. gmask = ggml.causal_attention_mask(ctx, gx)
  69. mask = ggml.to_numpy(gmask)
  70. gf = ggml.ggml_build_forward(gmask)
  71. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  72. assert mask_exp.shape == (10, 10)
  73. assert mask.shape == (10, 10)
  74. assert np.all(mask == mask_exp)
  75. x = x[:, :8, :]
  76. mask_exp = generator(x).numpy()
  77. gx = ggml.from_numpy(ctx, x)
  78. gmask = ggml.causal_attention_mask(ctx, gx)
  79. mask = ggml.to_numpy(gmask)
  80. gf = ggml.ggml_build_forward(gmask)
  81. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  82. assert mask_exp.shape == (8, 8)
  83. assert mask.shape == (8, 8)
  84. assert np.all(mask == mask_exp)
  85. def test_LayerNorm_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  86. x = torch.empty((2, 21, 1024))
  87. torch.nn.init.uniform_(x, -1, 1)
  88. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  89. gx = ggml.from_numpy(ctx, x)
  90. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  91. ggml.build_and_compute(ctx, gy)
  92. y = ggml.to_numpy(gy)
  93. assert np.allclose(y_exp, y, atol=1e-5)
  94. def test_Linear_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  95. x = torch.empty((2, 21, 1024))
  96. torch.nn.init.uniform_(x, -1, 1)
  97. y_exp = pt_model.text_encoder.layers[0].ffn.inner_proj(x).numpy()
  98. gx = ggml.from_numpy(ctx, x)
  99. gy = ggml.forward("Linear", g_model, "text_encoder.layers.0.ffn.inner_proj", gx)
  100. ggml.build_and_compute(ctx, gy)
  101. y = ggml.to_numpy(gy)
  102. assert np.allclose(y_exp, y, atol=1e-5)
  103. def test_FeedForwardNetwork_forward(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  104. x = torch.empty((2, 21, 1024)) # (bs, seq_len, model_dim)
  105. torch.nn.init.uniform_(x, -1 / 32, 1 / 32)
  106. # Test FFN without LayerNorm
  107. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  108. gx = ggml.from_numpy(ctx, x)
  109. gy = ggml.forward(
  110. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  111. )
  112. ggml.build_and_compute(ctx, gy)
  113. y = ggml.to_numpy(gy)
  114. assert np.allclose(y_exp, y, atol=1e-5)
  115. def _name(tensor: ggml.ggml_tensor_p) -> bytes:
  116. try:
  117. return tensor.contents.name # type: ignore[no-any-return]
  118. except ValueError:
  119. return b"???"
  120. def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
  121. x = torch.empty((1, 21, 1024))
  122. torch.random.manual_seed(0)
  123. torch.nn.init.uniform_(x, -1, 1)
  124. self_attn = pt_model.text_encoder.layers[0].self_attn
  125. # Note: we use different lengths for queries and keys,
  126. # this tests the implementation in decoding context too.
  127. # Note2: ggml_flash_attn requires that we have more keys than queries
  128. gxq = ggml.from_numpy(ctx, x[0, :11, :])
  129. gx = ggml.from_numpy(ctx, x[0])
  130. ggml.ggml_set_name(gx, b"x")
  131. gy = ggml.forward(
  132. "MultiheadAttention",
  133. g_model,
  134. "text_encoder.layers.0.self_attn",
  135. gxq,
  136. gx,
  137. gx,
  138. None, # TODO: tests with causal attention masks
  139. )
  140. gf = ggml.ggml_build_forward(gy)
  141. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  142. # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
  143. y = ggml.to_numpy(gy)
  144. nodes = {}
  145. for i in range(gf.n_nodes):
  146. name = _name(gf.nodes[i])
  147. children = [_name(gf.nodes[i].contents.src[j]) for j in range(2)]
  148. print(name, f"op({gf.nodes[i].contents.op})", children)
  149. nodes[name] = ggml.to_numpy(gf.nodes[i])
  150. attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
  151. self_attn.register_attn_weight_hook(attn_weights_hook)
  152. y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
  153. y_exp = y_exp.squeeze(0) # remove batch dimension
  154. # q = nodes[b"q"]
  155. # assert q.shape == q_exp.shape
  156. # assert np.allclose(q_exp, q, atol=1e-5)
  157. # with flash_attn we don't have attn_weights
  158. if not UNITY_FLASH_ATTN:
  159. attn_weights = nodes[b"attn_weights"]
  160. [attn_weights_exp] = attn_weights_hook._storage
  161. attn_weights_exp = attn_weights_exp.squeeze(0).numpy()
  162. assert attn_weights_exp.shape == attn_weights.shape
  163. # GGML is very agressively reducing small softmax weights to 0.
  164. # Not sure to what this is due.
  165. assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
  166. assert y.shape == y_exp.shape
  167. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  168. def test_StandardTransformerEncoderLayer_forward(
  169. ctx: Ctx, g_model: c_void_p, pt_model: Any
  170. ) -> None:
  171. x = torch.empty((1, 21, 1024))
  172. padding_mask = torch.ones((1, 21))
  173. torch.random.manual_seed(0)
  174. torch.nn.init.uniform_(x, -1, 1)
  175. layer = pt_model.text_encoder.layers[0]
  176. gx = ggml.from_numpy(ctx, x[0])
  177. ggml.ggml_set_name(gx, b"x")
  178. gpad = ggml.from_numpy(ctx, padding_mask[0])
  179. ggml.ggml_set_name(gpad, b"padding_mask")
  180. gy = ggml.forward(
  181. "StandardTransformerEncoderLayer",
  182. g_model,
  183. "text_encoder.layers.0",
  184. gx,
  185. None, # TODO support padding mask
  186. )
  187. gf = ggml.ggml_build_forward(gy)
  188. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  189. y = ggml.to_numpy(gy)
  190. y_exp, _ = layer(x, padding_mask)
  191. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  192. assert y.shape == y_exp.shape
  193. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  194. def test_StandardTransformerEncoder_forward(
  195. ctx: Ctx, g_model: c_void_p, pt_model: Any
  196. ) -> None:
  197. x = torch.empty((1, 21, 1024))
  198. padding_mask = torch.ones((1, 21))
  199. torch.random.manual_seed(0)
  200. torch.nn.init.uniform_(x, -1, 1)
  201. gx = ggml.from_numpy(ctx, x[0])
  202. ggml.ggml_set_name(gx, b"x")
  203. gpad = ggml.from_numpy(ctx, padding_mask[0])
  204. ggml.ggml_set_name(gpad, b"padding_mask")
  205. gy = ggml.forward(
  206. "StandardTransformerEncoder",
  207. g_model,
  208. "text_encoder",
  209. gx,
  210. None, # TODO support padding mask
  211. )
  212. gf = ggml.ggml_build_forward(gy)
  213. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  214. y = ggml.to_numpy(gy)
  215. y_exp, _ = pt_model.text_encoder(x, padding_mask)
  216. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  217. assert y.shape == y_exp.shape
  218. assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
  219. def test_PositionalEmbedding_forward(ctx: Ctx, g_model: c_void_p) -> None:
  220. seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
  221. # this _legacy_pad_idx is suspicious. Shouldn't the model use 1 ? But
  222. # this is consistent with pt_model.text_decoder_frontend.pos_encoder._sin_offset
  223. pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
  224. y_exp = pos_encoder(seq, None)[0].numpy()
  225. gseq = ggml.from_numpy(ctx, seq[0].numpy())
  226. ggml.ggml_set_name(gseq, b"seq")
  227. gy = ggml.forward(
  228. "PositionalEmbedding", g_model, "text_decoder_frontend.pos_encoder", gseq
  229. )
  230. gf = ggml.ggml_build_forward(gy)
  231. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  232. y = ggml.to_numpy(gy)
  233. assert y.shape == y_exp.shape
  234. assert np.allclose(y_exp, y, atol=1e-6)
  235. def test_TransformerEmbeddingFrontend_forward(
  236. ctx: Ctx, g_model: c_void_p, pt_model: Any
  237. ) -> None:
  238. seq = torch.arange(20).reshape(1, 20)
  239. seq_len = torch.tensor([20])
  240. gseq = ggml.from_numpy(ctx, seq[0].numpy().astype(np.int32))
  241. ggml.ggml_set_name(gseq, b"seq")
  242. gy = ggml.forward(
  243. "TransformerEmbeddingFrontend", g_model, "text_decoder_frontend", gseq
  244. )
  245. gf = ggml.ggml_build_forward(gy)
  246. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  247. y = ggml.to_numpy(gy)
  248. y_exp, _ = pt_model.text_decoder_frontend(seq, seq_len)
  249. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  250. assert y.shape == y_exp.shape
  251. assert np.allclose(y_exp, y, atol=1e-6)
  252. def test_StandardTransformerDecoder_forward(
  253. ctx: Ctx, g_model: c_void_p, pt_model: Any
  254. ) -> None:
  255. pytest.skip("foo")
  256. x = torch.empty((1, 13, 1024))
  257. encoder_out = torch.empty((1, 21, 1024))
  258. padding_mask = torch.ones((1, 13))
  259. torch.random.manual_seed(0)
  260. torch.nn.init.uniform_(x, -1, 1)
  261. torch.nn.init.uniform_(encoder_out, -1, 1)
  262. gx = ggml.from_numpy(ctx, x[0])
  263. ggml.ggml_set_name(gx, b"x")
  264. gpad = ggml.from_numpy(ctx, padding_mask[0])
  265. ggml.ggml_set_name(gpad, b"padding_mask")
  266. genc = ggml.from_numpy(ctx, encoder_out[0])
  267. gy = ggml.forward(
  268. "StandardTransformerDecoder",
  269. g_model,
  270. "text_decoder",
  271. gx,
  272. None, # TODO support padding mask,
  273. genc,
  274. None,
  275. )
  276. gf = ggml.ggml_build_forward(gy)
  277. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  278. y = ggml.to_numpy(gy)
  279. y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
  280. y_exp = y_exp.squeeze(0).numpy() # remove batch dimension
  281. assert y.shape == y_exp.shape
  282. assert np.allclose(y_exp, y, atol=1e-4)
  283. def test_t2tt(ctx: Ctx, g_model: c_void_p):
  284. # def test_t2tt(ctx: Ctx, g_model: c_void_p, translator):
  285. # device = translator.device
  286. src_lang = "eng"
  287. src_text = "We are all in a yellow submarine."
  288. tgt_lang = "fra"
  289. # token_encoder = translator.text_tokenizer.create_encoder(
  290. # task="translation", lang=src_lang, mode="source", device=device
  291. # )
  292. # src = translator.collate(token_encoder(src_text))
  293. # text_out, _ = translator.get_prediction(
  294. # translator.model,
  295. # translator.text_tokenizer,
  296. # translator.unit_tokenizer,
  297. # src,
  298. # input_modality=Modality.TEXT,
  299. # output_modality=Modality.TEXT,
  300. # tgt_lang=tgt_lang,
  301. # )
  302. # tgt_text = str(text_out.sentences[0])
  303. # assert tgt_text == "Nous sommes tous dans un sous-marin jaune."
  304. # tgt_tokens = text_out.generator_output.results[0][0].seq
  305. # score = text_out.generator_output.results[0][0].score.item()
  306. # np.savez(
  307. # Path(__file__).parent / "sample_input.npz",
  308. # score=score,
  309. # encoder_output=text_out.encoder_output.squeeze(0).numpy(),
  310. # encoder_padding_mask=text_out.encoder_padding_mask.squeeze(0).numpy(),
  311. # tgt_tokens=tgt_tokens.numpy(),
  312. # )
  313. text_out = np.load(Path(__file__).parent / "sample_input.npz")
  314. score = text_out["score"].item()
  315. tgt_tokens = list(text_out["tgt_tokens"])
  316. encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
  317. encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
  318. job = ggml.SequenceGeneratorJob()
  319. job.opts.beam_size = 1
  320. job.opts.min_seq_len = 1
  321. job.opts.soft_max_seq_len_a = 1
  322. job.opts.soft_max_seq_len_b = 200
  323. job.opts.hard_max_seq_len = int(len(tgt_tokens) * 1.5)
  324. job.opts.len_penalty = 1.0
  325. job.opts.unk_penalty = 0.0
  326. job.prefix_seq = ggml.from_numpy(ctx, text_out["tgt_tokens"].astype(np.int32)[:2])
  327. job.pad_idx = 0
  328. job.unk_idx = 1
  329. job.bos_idx = 2
  330. job.eos_idx = 3
  331. result = ggml.ggml_tensor()
  332. g_score = ggml.generate_sequence(
  333. g_model, job, encoder_out, encoder_padding_mask, ctypes.byref(result)
  334. )
  335. tokens = list(ggml.to_numpy(ctypes.pointer(result)))
  336. assert tokens == tgt_tokens
  337. assert g_score == pytest.approx(score)