test-xpos.c 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #include "ggml/ggml.h"
  2. #include <math.h>
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. bool is_close(float a, float b, float epsilon) {
  6. return fabs(a - b) < epsilon;
  7. }
  8. int main(int argc, char ** argv) {
  9. const int n_threads = 1;
  10. const int n_embd_head = 4; // aka head_dim
  11. const int n_head = 1;
  12. const int N = 8;
  13. struct ggml_init_params params = {
  14. .mem_size = 16*1024*1024,
  15. .mem_buffer = NULL,
  16. };
  17. // memory allocation happens here
  18. struct ggml_context * ctx = ggml_init(params);
  19. struct ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
  20. struct ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
  21. for (int i = 0; i < ggml_nelements(Q); i++) {
  22. ((float*) Q->data)[i] = 2.0f;
  23. ((float*) K->data)[i] = 2.0f;
  24. }
  25. struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false);
  26. struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true);
  27. struct ggml_cgraph gf = ggml_build_forward(Qx);
  28. ggml_build_forward_expand(&gf, Kx);
  29. ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
  30. // expected output for Qx:
  31. // -0.6009 2.7568 1.9782 2.0182
  32. // -2.6379 0.9815 1.9562 2.0361
  33. // -2.2457 -1.6853 1.9341 2.0538
  34. // 0.2043 -2.7934 1.9118 2.0712
  35. // 2.4550 -1.3341 1.8894 2.0884
  36. // 2.4430 1.3417 1.8668 2.1054
  37. // 0.1905 2.7739 1.8440 2.1221
  38. // -2.2257 1.6550 1.8212 2.1386
  39. for (int i = 0; i < ggml_nelements(Q); i++) {
  40. if (((float*) Qx->data)[i] > 0) printf(" ");
  41. printf("%.4f ", ((float*) Qx->data)[i]);
  42. if ((i+1) % n_embd_head == 0) printf("\n");
  43. }
  44. printf("\n");
  45. GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 0], -2.2257f, 0.0001f));
  46. GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 1], 1.6550f, 0.0001f));
  47. GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 2], 1.8212f, 0.0001f));
  48. GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3], 2.1386f, 0.0001f));
  49. // expected output for Kx:
  50. // -0.6038 2.7703 1.9816 2.0216
  51. // -2.6639 0.9911 1.9630 2.0431
  52. // -2.2789 -1.7103 1.9441 2.0644
  53. // 0.2083 -2.8486 1.9251 2.0856
  54. // 2.5158 -1.3671 1.9057 2.1065
  55. // 2.5158 1.3816 1.8862 2.1273
  56. // 0.1972 2.8705 1.8665 2.1479
  57. // -2.3146 1.7211 1.8465 2.1684
  58. for (int i = 0; i < ggml_nelements(K); i++) {
  59. if (((float*) Kx->data)[i] > 0) printf(" ");
  60. printf("%.4f ", ((float*) Kx->data)[i]);
  61. if ((i+1) % n_embd_head == 0) printf("\n");
  62. }
  63. printf("\n");
  64. GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 0], -2.3146f, 0.0001f));
  65. GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 1], 1.7211f, 0.0001f));
  66. GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 2], 1.8465f, 0.0001f));
  67. GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 3], 2.1684f, 0.0001f));
  68. ggml_free(ctx);
  69. return 0;
  70. }