test-mul-mat1.c 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #include <stdint.h>
  2. #include <stdio.h>
  3. #include <assert.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <time.h>
  7. #include <math.h>
  8. #include <sys/time.h>
  9. #include <arm_neon.h>
  10. #include <Accelerate/Accelerate.h>
  11. const int M = 1280;
  12. const int N = 1536;
  13. const int K = 1280;
  14. uint64_t get_time_us() {
  15. struct timeval tv;
  16. gettimeofday(&tv, NULL);
  17. return tv.tv_sec * 1000000 + tv.tv_usec;
  18. }
  19. //
  20. // naive implementation
  21. //
  22. void mul_mat_f32_0(
  23. const float * restrict src0, // M x K
  24. const float * restrict src1, // N x K (transposed)
  25. float * dst,
  26. int m, int n, int k) {
  27. for (int i = 0; i < m; i++) {
  28. for (int j = 0; j < n; j++) {
  29. float sum = 0;
  30. for (int l = 0; l < k; l++) {
  31. sum += src0[i*k + l] * src1[j*k + l];
  32. }
  33. dst[i*n + j] = sum;
  34. }
  35. }
  36. }
  37. void mul_mat_f16_0(
  38. const __fp16 * src0,
  39. const __fp16 * src1,
  40. float * dst,
  41. int m, int n, int k) {
  42. const int k32 = k & ~31;
  43. for (int i = 0; i < m; i++) {
  44. for (int j = 0; j < n; j++) {
  45. float sumf = 0.0;
  46. float16x8_t sum0 = vdupq_n_f16(0.0f);
  47. float16x8_t sum1 = vdupq_n_f16(0.0f);
  48. float16x8_t sum2 = vdupq_n_f16(0.0f);
  49. float16x8_t sum3 = vdupq_n_f16(0.0f);
  50. float16x8_t x0, x1, x2, x3;
  51. float16x8_t y0, y1, y2, y3;
  52. const __fp16 * restrict p0 = src0 + i*k;
  53. const __fp16 * restrict p1 = src1 + j*k;
  54. for (int l = 0; l < k32; l += 32) {
  55. x0 = vld1q_f16(p0 + l + 0 );
  56. x1 = vld1q_f16(p0 + l + 8 );
  57. x2 = vld1q_f16(p0 + l + 16);
  58. x3 = vld1q_f16(p0 + l + 24);
  59. y0 = vld1q_f16(p1 + l + 0 );
  60. y1 = vld1q_f16(p1 + l + 8 );
  61. y2 = vld1q_f16(p1 + l + 16);
  62. y3 = vld1q_f16(p1 + l + 24);
  63. sum0 = vfmaq_f16(sum0, x0, y0);
  64. sum1 = vfmaq_f16(sum1, x1, y1);
  65. sum2 = vfmaq_f16(sum2, x2, y2);
  66. sum3 = vfmaq_f16(sum3, x3, y3);
  67. }
  68. // reduce sum0..sum3 to sum0
  69. sum0 = vaddq_f16(sum0, sum1);
  70. sum2 = vaddq_f16(sum2, sum3);
  71. sum0 = vaddq_f16(sum0, sum2);
  72. // load sum0 into 2 float32x4_t
  73. float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0));
  74. float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0));
  75. // reduce sum0f32 and sum1f32 to sumf
  76. sum0f32 = vaddq_f32(sum0f32, sum1f32);
  77. float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
  78. sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
  79. //sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7];
  80. for (int l = k32; l < k32; l++) {
  81. sumf += p0[l]*p1[l];
  82. }
  83. dst[i*n + j] = sumf;
  84. }
  85. }
  86. }
  87. // blocking with block size 32
  88. void mul_mat_f16_1(
  89. const __fp16 * src0,
  90. const __fp16 * src1,
  91. float * dst,
  92. int m, int n, int k) {
  93. const int k32 = k & ~31;
  94. const int bs = 32;
  95. memset(dst, 0, m*n*sizeof(float));
  96. for (int i = 0; i < m; i += bs) {
  97. for (int j = 0; j < n; j += bs) {
  98. for (int l = 0; l < k; l += bs) {
  99. for (int ii = i; ii < i + bs; ii++) {
  100. const __fp16 * restrict p0 = src0 + ii*k;
  101. float16x8_t x0, x1, x2, x3;
  102. x0 = vld1q_f16(p0 + l + 0 );
  103. x1 = vld1q_f16(p0 + l + 8 );
  104. x2 = vld1q_f16(p0 + l + 16);
  105. x3 = vld1q_f16(p0 + l + 24);
  106. for (int jj = j; jj < j + bs; jj++) {
  107. float sumf = 0.0;
  108. float16x8_t sum0 = vdupq_n_f16(0.0f);
  109. float16x8_t sum1 = vdupq_n_f16(0.0f);
  110. float16x8_t sum2 = vdupq_n_f16(0.0f);
  111. float16x8_t sum3 = vdupq_n_f16(0.0f);
  112. float16x8_t y0, y1, y2, y3;
  113. const __fp16 * restrict p1 = src1 + jj*k;
  114. y0 = vld1q_f16(p1 + l + 0 );
  115. y1 = vld1q_f16(p1 + l + 8 );
  116. y2 = vld1q_f16(p1 + l + 16);
  117. y3 = vld1q_f16(p1 + l + 24);
  118. sum0 = vfmaq_f16(sum0, x0, y0);
  119. sum1 = vfmaq_f16(sum1, x1, y1);
  120. sum2 = vfmaq_f16(sum2, x2, y2);
  121. sum3 = vfmaq_f16(sum3, x3, y3);
  122. // reduce sum0..sum3 to sum0
  123. sum0 = vaddq_f16(sum0, sum1);
  124. sum2 = vaddq_f16(sum2, sum3);
  125. sum0 = vaddq_f16(sum0, sum2);
  126. // load sum0 into 2 float32x4_t
  127. float32x4_t sum0f32 = vcvt_f32_f16(vget_low_f16(sum0));
  128. float32x4_t sum1f32 = vcvt_f32_f16(vget_high_f16(sum0));
  129. // reduce sum0f32 and sum1f32 to sumf
  130. sum0f32 = vaddq_f32(sum0f32, sum1f32);
  131. float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
  132. sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
  133. //sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7];
  134. dst[ii*n + jj] += sumf;
  135. }
  136. }
  137. }
  138. }
  139. }
  140. }
  141. void mul_mat_f8_0(
  142. const uint8_t * src0,
  143. const uint8_t * src1,
  144. float * dst,
  145. int m, int n, int k) {
  146. const int k32 = k & ~31;
  147. for (int i = 0; i < m; i++) {
  148. for (int j = 0; j < n; j++) {
  149. float sumf = 0.0;
  150. const uint8_t * restrict p0 = src0 + i*k;
  151. const uint8_t * restrict p1 = src1 + j*k;
  152. for (int l = 0; l < k32; l += 32) {
  153. uint8x16_t x0 = vld1q_u8(p0 + l + 0 );
  154. uint8x16_t x1 = vld1q_u8(p0 + l + 16);
  155. uint8x16_t y0 = vld1q_u8(p1 + l + 0 );
  156. uint8x16_t y1 = vld1q_u8(p1 + l + 16);
  157. x0 = vmulq_u8(x0, y0);
  158. x1 = vmulq_u8(x1, y1);
  159. sumf += vaddvq_u8(x0) + vaddvq_u8(x1);
  160. }
  161. dst[i*n + j] = sumf;
  162. }
  163. }
  164. }
  165. int main(int argc, const char ** argv) {
  166. float * src0 = malloc(sizeof(float)*M*K);
  167. float * src1 = malloc(sizeof(float)*N*K);
  168. float * dst = malloc(sizeof(float)*M*N);
  169. for (int i = 0; i < M*K; i++) {
  170. src0[i] = rand() / (float)RAND_MAX;
  171. }
  172. for (int i = 0; i < N*K; i++) {
  173. src1[i] = rand() / (float)RAND_MAX;
  174. }
  175. // convert src0 and src1 to __fp16
  176. __fp16 * src0_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*M*K));
  177. __fp16 * src1_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*N*K));
  178. uint8_t * src0_fp8 = (uint8_t *)(malloc(sizeof(__fp16)*M*K));
  179. uint8_t * src1_fp8 = (uint8_t *)(malloc(sizeof(__fp16)*N*K));
  180. {
  181. const uint64_t t_start = get_time_us();
  182. for (int i = 0; i < M*K; i++) {
  183. src0_fp16[i] = src0[i];
  184. //printf("%f %f\n", src0[i], src0_fp16[i]);
  185. //assert(!isnan(src0_fp16[i]));
  186. }
  187. for (int i = 0; i < N*K; i++) {
  188. src1_fp16[i] = src1[i];
  189. }
  190. const uint64_t t_end = get_time_us();
  191. printf("convert time: %f ms\n", (t_end - t_start) / 1000.0);
  192. }
  193. for (int i = 0; i < 16; ++i) {
  194. printf("%f %f\n", src0[i], src0_fp16[i]);
  195. }
  196. int method = 0;
  197. if (argc > 1) {
  198. method = atoi(argv[1]);
  199. }
  200. const int nIter = 1;
  201. const clock_t start = clock();
  202. const uint64_t start_us = get_time_us();
  203. double iM = 1.0/M;
  204. double sum = 0.0f;
  205. for (int i = 0; i < nIter; i++) {
  206. if (method == 0) {
  207. mul_mat_f32_0(src0, src1, dst, M, N, K);
  208. }
  209. if (method == 1) {
  210. mul_mat_f16_0(src0_fp16, src1_fp16, dst, M, N, K);
  211. }
  212. if (method == 2) {
  213. mul_mat_f16_1(src0_fp16, src1_fp16, dst, M, N, K);
  214. }
  215. if (method == 3) {
  216. mul_mat_f8_0(src0_fp8, src1_fp8, dst, M, N, K);
  217. }
  218. if (method == 4) {
  219. // Use BLAS sgemm from Accelerate framework
  220. cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, K, 1.0f, src0, K, src1, K, 0.0f, dst, N);
  221. }
  222. }
  223. for (int i = 0; i < N; i++) {
  224. sum += dst[i]*iM;
  225. }
  226. {
  227. const clock_t end = clock();
  228. const uint64_t end_us = get_time_us();
  229. printf("%s: elapsed ticks: %ld\n", __func__, end - start);
  230. printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter);
  231. }
  232. printf("%f\n", sum);
  233. free(src0);
  234. free(src1);
  235. free(dst);
  236. free(src0_fp16);
  237. free(src1_fp16);
  238. return 0;
  239. }