test_unity_cpp.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import ggml
  2. import ctypes
  3. import torch
  4. import pytest
  5. import numpy as np
  6. import torch
  7. from typing import Any
  8. from pathlib import Path
  9. from typing import Iterator
  10. from ggml import NativeObj
  11. from ggml_convert import convert_model
  12. from seamless_communication.models.unity import load_unity_model
  13. Ctx = ggml.ggml_context_p
  14. UNITY_MODELS = Path(__file__).parent / "examples/unity/models"
  15. PARAMS_16MB = ggml.ggml_init_params(mem_size=16 * 1024 * 1024, mem_buffer=None)
  16. @pytest.fixture(name="ctx")
  17. def _ctx() -> Iterator[Ctx]:
  18. """Allocate a new context with 16 MB of memory"""
  19. try:
  20. ctx = ggml.ggml_init(params=PARAMS_16MB)
  21. yield ctx
  22. finally:
  23. ggml.ggml_free(ctx)
  24. def test_ggml_bindings_work(ctx: Ctx) -> None:
  25. # Instantiate tensors
  26. x = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  27. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  28. b = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
  29. # Use ggml operations to build a computational graph
  30. x2 = ggml.ggml_mul(ctx, x, x)
  31. f = ggml.ggml_add(ctx, ggml.ggml_mul(ctx, a, x2), b)
  32. gf = ggml.ggml_build_forward(f)
  33. # Set the input values
  34. ggml.ggml_set_f32(x, 2.0)
  35. ggml.ggml_set_f32(a, 3.0)
  36. ggml.ggml_set_f32(b, 4.0)
  37. # Compute the graph
  38. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  39. # Get the output value
  40. output = ggml.ggml_get_f32_1d(f, 0)
  41. assert output == 16.0
  42. def test_shape_works(ctx: Ctx) -> None:
  43. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  44. assert ggml.shape(a) == (10,)
  45. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  46. assert ggml.shape(b) == (11, 21)
  47. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  48. assert ggml.shape(c) == (12, 22, 32)
  49. def test_nb_works(ctx: Ctx) -> None:
  50. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  51. assert ggml.nb(a) == (4, 40, 40, 40)
  52. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  53. assert ggml.nb(b) == (2, 22, 462, 462)
  54. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  55. assert ggml.nb(c) == (4, 48, 1056, 33792)
  56. @pytest.mark.xfail(reason="TODO: fix strides")
  57. def test_strides_works(ctx: Ctx) -> None:
  58. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  59. assert ggml.strides(a) == np.ones((10,), dtype=np.float32).strides
  60. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  61. assert ggml.strides(b) == np.ones((11, 21), dtype=np.float32).strides
  62. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  63. assert ggml.strides(c) == np.ones((12, 22, 32), dtype=np.float32).strides
  64. def test_to_numpy_works_with_f32(ctx: Ctx) -> None:
  65. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
  66. a = ggml.ggml_set_f32(a, 2.14)
  67. assert np.allclose(ggml.to_numpy(a), np.ones((10,)) * 2.14)
  68. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  69. b = ggml.ggml_set_f32(b, 2.14)
  70. assert np.allclose(ggml.to_numpy(b), np.ones((11, 21)) * 2.14)
  71. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  72. c = ggml.ggml_set_f32(c, 2.14)
  73. assert np.allclose(ggml.to_numpy(c), np.ones((12, 22, 32)) * 2.14)
  74. def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
  75. a = np.random.normal(size=(10,)).astype(dtype=np.float32)
  76. ga = ggml.from_numpy(ctx, a)
  77. assert ggml.shape(ga) == (10,)
  78. assert ggml.nb(ga) == ggml.nb(ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10))
  79. assert np.allclose(a, ggml.to_numpy(ga))
  80. a = np.random.normal(size=(11, 21)).astype(dtype=np.float32)
  81. ga = ggml.from_numpy(ctx, a)
  82. assert ggml.shape(ga) == (11, 21)
  83. assert ggml.nb(ga) == ggml.nb(
  84. ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
  85. )
  86. assert np.allclose(a, ggml.to_numpy(ga))
  87. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float32)
  88. ga = ggml.from_numpy(ctx, a)
  89. assert ggml.shape(ga) == (12, 22, 32)
  90. assert ggml.nb(ga) == ggml.nb(
  91. ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
  92. )
  93. assert np.allclose(a, ggml.to_numpy(ga))
  94. def test_to_numpy_works_with_f16(ctx: Ctx) -> None:
  95. # We explicitly fill the tensor otherwise they might have non-zero values in them.
  96. a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F16, 10)
  97. a = ggml.ggml_set_f32(a, 2.14)
  98. assert np.allclose(ggml.to_numpy(a), np.ones((10,), dtype=np.float16) * 2.14)
  99. b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
  100. b = ggml.ggml_set_f32(b, 4.18)
  101. assert np.allclose(ggml.to_numpy(b), np.ones((11, 21), dtype=np.float16) * 4.18)
  102. c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F16, 12, 22, 32)
  103. c = ggml.ggml_set_f32(c, 3.16)
  104. assert np.allclose(ggml.to_numpy(c), np.ones((12, 22, 32), dtype=np.float16) * 3.16)
  105. def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
  106. a = np.random.normal(size=(10,)).astype(dtype=np.float16)
  107. ga = ggml.from_numpy(ctx, a)
  108. assert np.allclose(a, ggml.to_numpy(ga))
  109. a = np.random.normal(size=(11, 21)).astype(dtype=np.float16)
  110. ga = ggml.from_numpy(ctx, a)
  111. assert np.allclose(a, ggml.to_numpy(ga))
  112. a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float16)
  113. ga = ggml.from_numpy(ctx, a)
  114. assert np.allclose(a, ggml.to_numpy(ga))
  115. def test_ning_model_load(ctx: Ctx) -> None:
  116. model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
  117. print(model, vocab)
  118. example = ggml.from_file(
  119. ctx, UNITY_MODELS / "unity-large/seqs_before_conformer_block.bin", (1024, 137)
  120. )
  121. with ggml.MeasureArena() as arena:
  122. graph = ggml.unity_audio_encoder_graph(model, example)
  123. # TODO: why the extra memory ?
  124. mem_size = ggml.ggml_allocr_alloc_graph(arena, graph) + ggml.GGML_MEM_ALIGN
  125. with ggml.FixedSizeArena(mem_size) as allocr:
  126. print(
  127. f"unity_audio_encoder_graph: compute buffer size: {mem_size/1024/1024} MB"
  128. )
  129. eval_res_ptr = ggml.unity_eval(allocr, model, example, 1)
  130. eval_res = eval_res_ptr.contents
  131. inpL = ggml.to_numpy(eval_res.nodes[eval_res.n_nodes - 1])
  132. expected_raw = "-0.1308,0.0346,-0.2656,0.2873,-0.0104,0.0574,0.4033,-0.1125,-0.0460,-0.0496"
  133. expected = map(float, expected_raw.split(","))
  134. assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)
  135. @pytest.fixture(scope="module")
  136. def g_model() -> NativeObj:
  137. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  138. if not model_file.exists():
  139. convert_model("seamlessM4T_medium", model_file)
  140. return ggml.load_unity_ggml_file(model_file)
  141. @pytest.fixture(scope="module")
  142. def pt_model() -> Iterator[Any]:
  143. model = load_unity_model("seamlessM4T_medium")
  144. print(model)
  145. model.eval()
  146. with torch.inference_mode():
  147. yield model
  148. @pytest.mark.xfail(reason="TODO")
  149. def test_hparams_code_is_up_to_date() -> None:
  150. model_file = Path(__file__).parent / "seamlessM4T_medium.ggml"
  151. hparams_header_file = model_file.with_suffix(".hparams.h")
  152. hparams_struct = hparams_header_file.read_text().strip()
  153. actual_code = (UNITY_MODELS.parent / "unity_model_loader.h").read_text()
  154. assert hparams_struct in actual_code
  155. def test_forward_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
  156. x = torch.empty((1024))
  157. torch.nn.init.uniform_(x, -1, 1)
  158. # Test FFN without LayerNorm
  159. y_exp = pt_model.text_encoder.layers[0].ffn(x).numpy()
  160. gx = ggml.from_numpy(ctx, x)
  161. gy = ggml.forward(
  162. "StandardFeedForwardNetwork", g_model, "text_encoder.layers.0.ffn", gx
  163. )
  164. gf = ggml.ggml_build_forward(gy)
  165. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  166. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
  167. abs_diff = np.max(np.abs(y - y_exp))
  168. assert abs_diff < 1e-2
  169. assert np.allclose(y_exp, y, rtol=1e-3)
  170. def test_forward_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
  171. x = torch.empty((1024,))
  172. torch.nn.init.uniform_(x, -1, 1)
  173. y_exp = pt_model.text_encoder.layers[0].ffn_layer_norm(x).numpy()
  174. gx = ggml.from_numpy(ctx, x)
  175. gy = ggml.forward("LayerNorm", g_model, "text_encoder.layers.0.ffn_layer_norm", gx)
  176. gf = ggml.ggml_build_forward(gy)
  177. ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
  178. y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
  179. abs_diff = np.max(np.abs(y - y_exp))
  180. assert np.allclose(y_exp, y)