test-pool.c 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #include "ggml/ggml.h"
  2. #include <string.h>
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. struct ggml_context* make_ctx(void) {
  6. struct ggml_init_params params = {
  7. .mem_size = 2 * 1024 * 1024,
  8. };
  9. return ggml_init(params);
  10. }
  11. int main(int argc, const char** argv) {
  12. float buf_f32[1024];
  13. for (int i = 0; i < 1024; ++i) {
  14. buf_f32[i] = (float)(i + 1);
  15. }
  16. // avg pool 1d
  17. {
  18. struct ggml_context * ctx = make_ctx();
  19. struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
  20. memcpy(t->data, buf_f32, ggml_nbytes(t));
  21. struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_AVG, 3, 3, 0);
  22. GGML_ASSERT(t_pooled->ne[0] == 3);
  23. GGML_ASSERT(t_pooled->ne[1] == 2);
  24. GGML_ASSERT(t_pooled->ne[2] == 1);
  25. struct ggml_cgraph graph = ggml_build_forward(t_pooled);
  26. ggml_graph_compute_with_ctx(ctx, &graph, 4);
  27. const float * output = ggml_get_data_f32(t_pooled);
  28. GGML_ASSERT(output[0] == 2);
  29. GGML_ASSERT(output[1] == 5);
  30. GGML_ASSERT(output[2] == 8);
  31. GGML_ASSERT(output[3] == 12);
  32. GGML_ASSERT(output[4] == 15);
  33. GGML_ASSERT(output[5] == 18);
  34. ggml_free(ctx);
  35. }
  36. // max pool 1d
  37. {
  38. struct ggml_context * ctx = make_ctx();
  39. struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
  40. memcpy(t->data, buf_f32, ggml_nbytes(t));
  41. struct ggml_tensor * t_pooled = ggml_pool_1d(ctx, t, GGML_OP_POOL_MAX, 3, 3, 0);
  42. GGML_ASSERT(t_pooled->ne[0] == 3);
  43. GGML_ASSERT(t_pooled->ne[1] == 2);
  44. GGML_ASSERT(t_pooled->ne[2] == 1);
  45. struct ggml_cgraph graph = ggml_build_forward(t_pooled);
  46. ggml_graph_compute_with_ctx(ctx, &graph, 4);
  47. const float * output = ggml_get_data_f32(t_pooled);
  48. GGML_ASSERT(output[0] == 3);
  49. GGML_ASSERT(output[1] == 6);
  50. GGML_ASSERT(output[2] == 9);
  51. GGML_ASSERT(output[3] == 13);
  52. GGML_ASSERT(output[4] == 16);
  53. GGML_ASSERT(output[5] == 19);
  54. ggml_free(ctx);
  55. }
  56. // avg pool 2d
  57. {
  58. struct ggml_context * ctx = make_ctx();
  59. struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 10, 10, 2);
  60. memcpy(t->data, buf_f32, ggml_nbytes(t));
  61. struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_AVG, 3, 4, 3, 4, 0, 0);
  62. GGML_ASSERT(t_pooled->ne[0] == 3);
  63. GGML_ASSERT(t_pooled->ne[1] == 2);
  64. GGML_ASSERT(t_pooled->ne[2] == 2);
  65. GGML_ASSERT(t_pooled->ne[3] == 1);
  66. struct ggml_cgraph graph = ggml_build_forward(t_pooled);
  67. ggml_graph_compute_with_ctx(ctx, &graph, 4);
  68. const float * output = ggml_get_data_f32(t_pooled);
  69. GGML_ASSERT(output[0] == 17);
  70. GGML_ASSERT(output[1] == 20);
  71. GGML_ASSERT(output[2] == 23);
  72. GGML_ASSERT(output[3] == 57);
  73. GGML_ASSERT(output[4] == 60);
  74. GGML_ASSERT(output[5] == 63);
  75. GGML_ASSERT(output[6] == 117);
  76. GGML_ASSERT(output[7] == 120);
  77. GGML_ASSERT(output[8] == 123);
  78. GGML_ASSERT(output[9] == 157);
  79. GGML_ASSERT(output[10] == 160);
  80. GGML_ASSERT(output[11] == 163);
  81. ggml_free(ctx);
  82. }
  83. // max pool 2d
  84. {
  85. struct ggml_context * ctx = make_ctx();
  86. struct ggml_tensor * t = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 10, 10, 2);
  87. memcpy(t->data, buf_f32, ggml_nbytes(t));
  88. struct ggml_tensor * t_pooled = ggml_pool_2d(ctx, t, GGML_OP_POOL_MAX, 3, 4, 3, 4, 0, 0);
  89. GGML_ASSERT(t_pooled->ne[0] == 3);
  90. GGML_ASSERT(t_pooled->ne[1] == 2);
  91. GGML_ASSERT(t_pooled->ne[2] == 2);
  92. GGML_ASSERT(t_pooled->ne[3] == 1);
  93. struct ggml_cgraph graph = ggml_build_forward(t_pooled);
  94. ggml_graph_compute_with_ctx(ctx, &graph, 4);
  95. const float * output = ggml_get_data_f32(t_pooled);
  96. GGML_ASSERT(output[0] == 33);
  97. GGML_ASSERT(output[1] == 36);
  98. GGML_ASSERT(output[2] == 39);
  99. GGML_ASSERT(output[3] == 73);
  100. GGML_ASSERT(output[4] == 76);
  101. GGML_ASSERT(output[5] == 79);
  102. GGML_ASSERT(output[6] == 133);
  103. GGML_ASSERT(output[7] == 136);
  104. GGML_ASSERT(output[8] == 139);
  105. GGML_ASSERT(output[9] == 173);
  106. GGML_ASSERT(output[10] == 176);
  107. GGML_ASSERT(output[11] == 179);
  108. ggml_free(ctx);
  109. }
  110. return 0;
  111. }