test-mul-mat2.c 90 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585
  1. // quantized matrix multiplication
  2. #include "ggml.h"
  3. #include <float.h>
  4. #include <stdint.h>
  5. #include <stdio.h>
  6. #include <inttypes.h>
  7. #include <assert.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #include <math.h>
  11. #if defined(__ARM_NEON)
  12. #include "arm_neon.h"
  13. #elif defined(__AVX__) || defined(__AVX2__)
  14. #include "immintrin.h"
  15. #endif
  16. #ifndef MIN
  17. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  18. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  19. #endif
  20. #if defined(_MSC_VER)
  21. #pragma warning(disable: 4244 4267) // possible loss of data
  22. #include <intrin.h>
  23. #define __builtin_popcountll __popcnt64
  24. #endif
  25. const int M = 1280;
  26. const int N = 1536;
  27. const int K = 1280;
  28. //const int M = 64;
  29. //const int N = 64;
  30. //const int K = 64;
  31. #define QK 64
  32. #define QB 4
  33. //#define GGML_GQ_USE_FP16_SCALE
  34. #if defined(GGML_GQ_USE_FP16_SCALE)
  35. #define gq_scale_t ggml_fp16_t
  36. #define GGML_FP32_TO_GQ(x) ggml_fp32_to_fp16(x)
  37. #define GGML_GQ_TO_FP32(x) ggml_fp16_to_fp32(x)
  38. #else
  39. #define gq_scale_t float
  40. #define GGML_FP32_TO_GQ(x) (x)
  41. #define GGML_GQ_TO_FP32(x) (x)
  42. #endif
  43. #define gq_t_bits 64
  44. #define gq_quant_t uint64_t
  45. float frand(void) {
  46. return (float) rand() / (float) RAND_MAX;
  47. }
  48. #if defined(__AVX2__)
  49. // horizontally reduce 8 32-bit integers
  50. static inline uint32_t _mm256_hadd_epi32_gg(__m256i v) {
  51. __m128i v0 = _mm256_extractf128_si256(v, 0);
  52. __m128i v1 = _mm256_extractf128_si256(v, 1);
  53. v0 = _mm_add_epi32(v0, v1);
  54. v1 = _mm_shuffle_epi32(v0, 0x0e);
  55. v0 = _mm_add_epi32(v0, v1);
  56. v1 = _mm_shuffle_epi32(v0, 0x01);
  57. v0 = _mm_add_epi32(v0, v1);
  58. return _mm_cvtsi128_si32(v0);
  59. }
  60. //static inline float _mm256_hadd_epi32_gg(__m256i v) {
  61. // const __m256 v0 = _mm256_cvtepi32_ps(v);
  62. // const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v0), _mm256_extractf128_ps(v0, 1));
  63. // const __m128 t1 = _mm_hadd_ps(t0, t0);
  64. //
  65. // return _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
  66. //}
  67. // horizontally reduce 32 8-bit integers
  68. static inline int32_t _mm256_hadd_epi8_gg(__m256i v0) {
  69. __m256i v1 = _mm256_maddubs_epi16(v0, _mm256_set1_epi8(1));
  70. __m256i v2 = _mm256_madd_epi16 (v1, _mm256_set1_epi16(1));
  71. return _mm256_hadd_epi32_gg(v2);
  72. }
  73. static inline float _mm256_hadd_ps_gg(__m256 v) {
  74. const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
  75. const __m128 t1 = _mm_hadd_ps(t0, t0);
  76. return _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
  77. }
  78. #endif
  79. //
  80. // naive implementation
  81. //
  82. void mul_mat_f32_naive(
  83. const float * restrict src0, // M x K
  84. const float * restrict src1, // N x K (transposed)
  85. float * dst,
  86. int m, int n, int k) {
  87. for (int i = 0; i < m; i++) {
  88. for (int j = 0; j < n; j++) {
  89. float sum = 0;
  90. for (int l = 0; l < k; l++) {
  91. sum += src0[i*k + l] * src1[j*k + l];
  92. }
  93. dst[i*n + j] = sum;
  94. }
  95. }
  96. }
  97. //
  98. // method 1
  99. //
  100. static inline int quantize_1_blocks_per_row(int k) {
  101. return k/QK;
  102. }
  103. static inline int quantize_1_quants_per_block(void) {
  104. return QK/gq_t_bits;
  105. }
  106. static inline int quantize_1_row_size(int k) {
  107. const int nb = quantize_1_blocks_per_row(k);
  108. const int nq = quantize_1_quants_per_block();
  109. return nb*(2*sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
  110. }
  111. void quantize_1(const float * src, void * dst, int n, int k) {
  112. char * p0 = dst;
  113. gq_quant_t pp[QB];
  114. for (int j = 0; j < n; j++) {
  115. for (int i = 0; i < k/QK; i++) {
  116. float min = FLT_MAX;
  117. float max = -FLT_MAX;
  118. // find min/max
  119. #ifdef __ARM_NEON
  120. {
  121. float32x4_t minv = vdupq_n_f32(FLT_MAX);
  122. float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
  123. for (int l = 0; l < QK; l += 4) {
  124. float32x4_t v = vld1q_f32(src + j*k + i*QK + l);
  125. minv = vminq_f32(minv, v);
  126. maxv = vmaxq_f32(maxv, v);
  127. }
  128. float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
  129. float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
  130. min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
  131. max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
  132. //printf("SIMD min/max: %f %f\n", min, max);
  133. }
  134. #else
  135. {
  136. for (int l = 0; l < QK; l++) {
  137. const float v = src[j*k + i*QK + l];
  138. if (v < min) min = v;
  139. if (v > max) max = v;
  140. }
  141. //printf("NORM min/max: %f %f\n", min, max);
  142. }
  143. #endif
  144. const float d = (max - min) / ((1 << QB) - 1);
  145. const float id = d ? 1.0/d : 0.0;
  146. memcpy(p0, &min, sizeof(float)); p0 += sizeof(float);
  147. memcpy(p0, &d, sizeof(float)); p0 += sizeof(float);
  148. //printf("min/max/d/id: %f %f %f %f\n", min, max, d, id);
  149. for (int s = 0; s < QK/gq_t_bits; ++s) {
  150. memset(pp, 0, sizeof(pp));
  151. for (int l = 0; l < gq_t_bits; l++) {
  152. const float v = src[j*k + i*QK + s*gq_t_bits + l];
  153. const uint8_t q = (v - min)*id;
  154. for (int b = 0; b < QB; b++) {
  155. pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
  156. }
  157. }
  158. for (int b = 0; b < QB; b++) {
  159. memcpy(p0, &pp[b], sizeof(gq_quant_t)); p0 += sizeof(gq_quant_t);
  160. }
  161. }
  162. }
  163. }
  164. }
  165. void mul_mat_gq_1(
  166. const void * src0,
  167. const void * src1,
  168. float * dst,
  169. int m, int n, int k) {
  170. const int kp = k & ~(gq_t_bits - 1);
  171. const char * restrict p0 = src0;
  172. const char * restrict p1 = src1;
  173. float s0[QB + 1];
  174. float s1[QB + 1];
  175. gq_quant_t m0[QB + 1];
  176. gq_quant_t m1[QB + 1];
  177. for (int ir0 = 0; ir0 < m; ir0++) {
  178. for (int ir1 = 0; ir1 < n; ir1++) {
  179. float sumf = 0.0;
  180. const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_quant_t))*(k/QK));
  181. const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_quant_t))*(k/QK));
  182. for (int i = 0; i < kp/QK; i++) {
  183. float min0, d0;
  184. memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float);
  185. memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float);
  186. float min1, d1;
  187. memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float);
  188. memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float);
  189. //printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1);
  190. #if 1
  191. // >>> General case for any QB
  192. s0[0] = min0;
  193. s1[0] = min1;
  194. for (int b = 0; b < QB; b++) {
  195. s0[b + 1] = d0*(1 << b);
  196. s1[b + 1] = d1*(1 << b);
  197. }
  198. m0[0] = 0-1ULL;
  199. m1[0] = 0-1ULL;
  200. for (int s = 0; s < QK/gq_t_bits; ++s) {
  201. for (int b = 0; b < QB; b++) {
  202. memcpy(&m0[b + 1], pp0, sizeof(gq_quant_t)); pp0 += sizeof(gq_quant_t);
  203. memcpy(&m1[b + 1], pp1, sizeof(gq_quant_t)); pp1 += sizeof(gq_quant_t);
  204. }
  205. for (int q0 = 0; q0 < QB + 1; q0++) {
  206. for (int q1 = 0; q1 < QB + 1; q1++) {
  207. sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]);
  208. }
  209. }
  210. }
  211. #else
  212. #endif
  213. }
  214. dst[ir0*n + ir1] = sumf;
  215. }
  216. }
  217. }
  218. //
  219. // method 2
  220. // n-bit quantization (2nd attempt)
  221. //
  222. static inline int quantize_2_blocks_per_row(int k) {
  223. return k/QK;
  224. }
  225. static inline int quantize_2_quants_per_block(void) {
  226. return QK/gq_t_bits;
  227. }
  228. static inline int quantize_2_row_size(int k) {
  229. const int nb = quantize_2_blocks_per_row(k);
  230. const int nq = quantize_2_quants_per_block();
  231. return nb*(2*sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
  232. }
  233. void quantize_2_row(const float * restrict src, void * restrict dst, int k) {
  234. assert(k % QK == 0);
  235. const int nb = quantize_2_blocks_per_row(k);
  236. const int nq = quantize_2_quants_per_block();
  237. gq_scale_t * restrict pm = (gq_scale_t *) (dst);
  238. gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb);
  239. gq_quant_t * restrict pb = (gq_quant_t *) (pd + nb);
  240. gq_quant_t pp[QB];
  241. static const int32_t sh[32] = {
  242. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  243. 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
  244. };
  245. for (int i = 0; i < nb; i++) {
  246. float min = FLT_MAX;
  247. float max = -FLT_MAX;
  248. #ifdef __ARM_NEON
  249. {
  250. float32x4_t minv = vdupq_n_f32(FLT_MAX);
  251. float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
  252. for (int l = 0; l < QK; l += 4) {
  253. float32x4_t v = vld1q_f32(src + i*QK + l);
  254. minv = vminq_f32(minv, v);
  255. maxv = vmaxq_f32(maxv, v);
  256. }
  257. float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
  258. float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
  259. min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
  260. max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
  261. }
  262. #else
  263. {
  264. for (int l = 0; l < QK; l++) {
  265. const float v = src[i*QK + l];
  266. if (v < min) min = v;
  267. if (v > max) max = v;
  268. }
  269. }
  270. #endif
  271. const float d = (max - min) / ((1 << QB) - 1);
  272. const float id = d ? 1.0/d : 0.0;
  273. pm[i] = GGML_FP32_TO_GQ(min);
  274. pd[i] = GGML_FP32_TO_GQ(d);
  275. for (int s = 0; s < nq; ++s) {
  276. memset(pp, 0, sizeof(pp));
  277. #if 1
  278. for (int l = 0; l < gq_t_bits; l++) {
  279. const float v = src[i*QK + s*gq_t_bits + l];
  280. const uint8_t q = (v - min)*id + frand();
  281. for (int b = 0; b < QB; b++) {
  282. pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
  283. }
  284. }
  285. #elif defined(__ARM_NEON)
  286. #if 1
  287. {
  288. uint32_t ppt[2*4*QB];
  289. float32x4_t minv = vdupq_n_f32(min);
  290. float32x4_t idv = vdupq_n_f32(id);
  291. assert(gq_t_bits % 16 == 0);
  292. uint32x4_t p0[QB] = { vdupq_n_u32(0) };
  293. uint32x4_t p1[QB] = { vdupq_n_u32(0) };
  294. for (int l = 0; l < gq_t_bits; l += 16) {
  295. float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
  296. float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
  297. float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
  298. float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
  299. v0 = vsubq_f32(v0, minv);
  300. v1 = vsubq_f32(v1, minv);
  301. v2 = vsubq_f32(v2, minv);
  302. v3 = vsubq_f32(v3, minv);
  303. v0 = vmulq_f32(v0, idv);
  304. v1 = vmulq_f32(v1, idv);
  305. v2 = vmulq_f32(v2, idv);
  306. v3 = vmulq_f32(v3, idv);
  307. #if 1
  308. v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
  309. v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
  310. v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
  311. v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
  312. #endif
  313. uint32x4_t q0 = vcvtq_u32_f32(v0);
  314. uint32x4_t q1 = vcvtq_u32_f32(v1);
  315. uint32x4_t q2 = vcvtq_u32_f32(v2);
  316. uint32x4_t q3 = vcvtq_u32_f32(v3);
  317. for (int b = 0; b < QB; ++b) {
  318. uint32x4_t m = vdupq_n_u32(1 << b);
  319. uint32x4_t r = vdupq_n_u32(-b);
  320. if (l < 32) {
  321. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
  322. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
  323. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
  324. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
  325. } else {
  326. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
  327. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
  328. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
  329. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
  330. }
  331. }
  332. }
  333. #if QB == 4
  334. vst1q_u32((uint32_t *) ppt + 0, p0[0]);
  335. vst1q_u32((uint32_t *) ppt + 4, p1[0]);
  336. vst1q_u32((uint32_t *) ppt + 8, p0[1]);
  337. vst1q_u32((uint32_t *) ppt + 12, p1[1]);
  338. vst1q_u32((uint32_t *) ppt + 16, p0[2]);
  339. vst1q_u32((uint32_t *) ppt + 20, p1[2]);
  340. vst1q_u32((uint32_t *) ppt + 24, p0[3]);
  341. vst1q_u32((uint32_t *) ppt + 28, p1[3]);
  342. pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32;
  343. pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
  344. pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
  345. pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
  346. #else
  347. for (int b = 0; b < QB; ++b) {
  348. vst1q_u32((uint32_t *) ppt + 0, p0[b]);
  349. vst1q_u32((uint32_t *) ppt + 4, p1[b]);
  350. pp[b] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
  351. }
  352. #endif
  353. }
  354. #else
  355. // less optimal SIMD
  356. {
  357. float32x4_t minv = vdupq_n_f32(min);
  358. float32x4_t idv = vdupq_n_f32(id);
  359. assert(gq_t_bits == 64);
  360. uint8_t qq[gq_t_bits];
  361. for (int l = 0; l < gq_t_bits; l += 16) {
  362. float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
  363. float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
  364. float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
  365. float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
  366. v0 = vsubq_f32(v0, minv);
  367. v1 = vsubq_f32(v1, minv);
  368. v2 = vsubq_f32(v2, minv);
  369. v3 = vsubq_f32(v3, minv);
  370. v0 = vmulq_f32(v0, idv);
  371. v1 = vmulq_f32(v1, idv);
  372. v2 = vmulq_f32(v2, idv);
  373. v3 = vmulq_f32(v3, idv);
  374. #if 0
  375. v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
  376. v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
  377. v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
  378. v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
  379. #endif
  380. uint32x4_t q0 = vcvtq_u32_f32(v0);
  381. uint32x4_t q1 = vcvtq_u32_f32(v1);
  382. uint32x4_t q2 = vcvtq_u32_f32(v2);
  383. uint32x4_t q3 = vcvtq_u32_f32(v3);
  384. // store in qq as uint8_t
  385. vst1_u8(qq + l + 0, vmovn_u16(vcombine_u16(vmovn_u32(q0), vmovn_u32(q1))));
  386. vst1_u8(qq + l + 8, vmovn_u16(vcombine_u16(vmovn_u32(q2), vmovn_u32(q3))));
  387. }
  388. for (int l = 0; l < gq_t_bits; l++) {
  389. for (int b = 0; b < QB; b++) {
  390. const uint64_t ql = qq[l];
  391. /*pp[b] |= qq[l] & (1 << b) ? (1ULL << l) : 0;*/
  392. pp[b] |= ((ql & (1 << b)) >> b) << l;
  393. }
  394. }
  395. }
  396. #endif
  397. #endif
  398. memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
  399. }
  400. }
  401. }
  402. // reimplementation of quantize_2 using quantize_2_row
  403. void quantize_2(const float * restrict src, char * restrict dst, int n, int k) {
  404. assert(k % QK == 0);
  405. for (int j = 0; j < n; j++) {
  406. quantize_2_row(src + j*k, dst, k);
  407. dst = (char *) dst + quantize_2_row_size(k);
  408. }
  409. }
  410. void vec_dot_gq_2(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
  411. const int nb = quantize_2_blocks_per_row(n);
  412. const int nq = quantize_2_quants_per_block();
  413. const gq_scale_t * restrict pm0 = (const gq_scale_t *) x;
  414. const gq_scale_t * restrict pm1 = (const gq_scale_t *) y;
  415. const gq_scale_t * restrict pd0 = pm0 + nb;
  416. const gq_scale_t * restrict pd1 = pm1 + nb;
  417. const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
  418. const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
  419. float sumf = 0.0;
  420. #if 1
  421. for (int i = 0; i < nb; i++) {
  422. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  423. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  424. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  425. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  426. #if QB == 4
  427. int isum01 = 0;
  428. int isum10 = 0;
  429. int isum11 = 0;
  430. for (int s = 0; s < nq; ++s) {
  431. const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
  432. const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
  433. #define bpcnt(x) __builtin_popcountll(x)
  434. isum01 += (1 << 0)*(bpcnt(mm1[0]));
  435. isum01 += (1 << 1)*(bpcnt(mm1[1]));
  436. isum01 += (1 << 2)*(bpcnt(mm1[2]));
  437. isum01 += (1 << 3)*(bpcnt(mm1[3]));
  438. isum10 += (1 << 0)*(bpcnt(mm0[0]));
  439. isum10 += (1 << 1)*(bpcnt(mm0[1]));
  440. isum10 += (1 << 2)*(bpcnt(mm0[2]));
  441. isum10 += (1 << 3)*(bpcnt(mm0[3]));
  442. isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
  443. isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
  444. isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
  445. isum11 += (1 << 3)*(bpcnt(mm0[0] & mm1[3]) + bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]) + bpcnt(mm0[3] & mm1[0]));
  446. isum11 += (1 << 4)*(bpcnt(mm0[1] & mm1[3]) + bpcnt(mm0[2] & mm1[2]) + bpcnt(mm0[3] & mm1[1]));
  447. isum11 += (1 << 5)*(bpcnt(mm0[2] & mm1[3]) + bpcnt(mm0[3] & mm1[2]));
  448. isum11 += (1 << 6)*(bpcnt(mm0[3] & mm1[3]));
  449. #undef bpcnt
  450. }
  451. sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
  452. #elif QB == 3
  453. int isum01 = 0;
  454. int isum10 = 0;
  455. int isum11 = 0;
  456. for (int s = 0; s < nq; ++s) {
  457. const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
  458. const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
  459. #if gq_t_bits == 32
  460. #define bpcnt(x) __builtin_popcount(x)
  461. #else
  462. #define bpcnt(x) __builtin_popcountll(x)
  463. #endif
  464. isum01 += (1 << 0)*(bpcnt(mm1[0]));
  465. isum01 += (1 << 1)*(bpcnt(mm1[1]));
  466. isum01 += (1 << 2)*(bpcnt(mm1[2]));
  467. isum10 += (1 << 0)*(bpcnt(mm0[0]));
  468. isum10 += (1 << 1)*(bpcnt(mm0[1]));
  469. isum10 += (1 << 2)*(bpcnt(mm0[2]));
  470. isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
  471. isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
  472. isum11 += (1 << 2)*(bpcnt(mm0[0] & mm1[2]) + bpcnt(mm0[1] & mm1[1]) + bpcnt(mm0[2] & mm1[0]));
  473. isum11 += (1 << 3)*(bpcnt(mm0[1] & mm1[2]) + bpcnt(mm0[2] & mm1[1]));
  474. isum11 += (1 << 4)*(bpcnt(mm0[2] & mm1[2]));
  475. #undef bpcnt
  476. }
  477. sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
  478. #elif QB == 2
  479. int isum01 = 0;
  480. int isum10 = 0;
  481. int isum11 = 0;
  482. for (int s = 0; s < nq; ++s) {
  483. const gq_quant_t * restrict mm0 = pb0 + i*nq*QB + s*QB;
  484. const gq_quant_t * restrict mm1 = pb1 + i*nq*QB + s*QB;
  485. #if gq_t_bits == 32
  486. #define bpcnt(x) __builtin_popcount(x)
  487. #else
  488. #define bpcnt(x) __builtin_popcountll(x)
  489. #endif
  490. isum01 += (1 << 0)*(bpcnt(mm1[0]));
  491. isum01 += (1 << 1)*(bpcnt(mm1[1]));
  492. isum10 += (1 << 0)*(bpcnt(mm0[0]));
  493. isum10 += (1 << 1)*(bpcnt(mm0[1]));
  494. isum11 += (1 << 0)*(bpcnt(mm0[0] & mm1[0]));
  495. isum11 += (1 << 1)*(bpcnt(mm0[0] & mm1[1]) + bpcnt(mm0[1] & mm1[0]));
  496. isum11 += (1 << 2)*(bpcnt(mm0[1] & mm1[1]));
  497. #undef bpcnt
  498. }
  499. sumf += nq*gq_t_bits*(m0*m1) + isum01*(m0*d1) + isum10*(m1*d0) + isum11*(d0*d1);
  500. #else
  501. float s0[QB + 1];
  502. float s1[QB + 1];
  503. s0[0] = m0;
  504. s1[0] = m1;
  505. for (int b = 0; b < QB; b++) {
  506. s0[b + 1] = d0*(1 << b);
  507. s1[b + 1] = d1*(1 << b);
  508. }
  509. for (int s = 0; s < nq; ++s) {
  510. for (int q0 = 0; q0 < QB + 1; q0++) {
  511. const gq_quant_t mm0 = q0 ? pb0[i*nq*QB + s*QB + q0 - 1] : -1ULL;
  512. for (int q1 = 0; q1 < QB + 1; q1++) {
  513. const gq_quant_t mm1 = q1 ? pb1[i*nq*QB + s*QB + q1 - 1] : -1ULL;
  514. sumf += s0[q0]*s1[q1]*__builtin_popcountll(mm0 & mm1);
  515. }
  516. }
  517. }
  518. #endif
  519. }
  520. #else
  521. #error "not implemented"
  522. #endif
  523. *s = sumf;
  524. }
  525. // use vec_dot_gq_2 to compute the dot product of two rows
  526. void mul_mat_gq_2(
  527. const void * src0,
  528. const void * src1, // transposed
  529. float * dst,
  530. int m, int n, int k) {
  531. assert(k % QK == 0);
  532. for (int ir0 = 0; ir0 < m; ir0++) {
  533. for (int ir1 = 0; ir1 < n; ir1++) {
  534. vec_dot_gq_2(k, dst + ir1, src0, src1);
  535. src1 = (const char *) src1 + quantize_2_row_size(k);
  536. }
  537. src0 = (const char *) src0 + quantize_2_row_size(k);
  538. src1 = (const char *) src1 - n*quantize_2_row_size(k);
  539. dst = (float *) dst + n;
  540. }
  541. }
  542. //
  543. // method 3
  544. // (does not work)
  545. //
  546. static inline int quantize_3_blocks_per_row(int k) {
  547. return k/QK;
  548. }
  549. static inline int quantize_3_quants_per_block(void) {
  550. return QK/gq_t_bits;
  551. }
  552. static inline int quantize_3_row_size(int k) {
  553. const int nb = quantize_3_blocks_per_row(k);
  554. const int nq = quantize_3_quants_per_block();
  555. return nb*(sizeof(gq_scale_t) + nq*QB*sizeof(gq_quant_t));
  556. }
  557. void quantize_3_row(const float * restrict src, void * restrict dst, int k) {
  558. assert(k % QK == 0);
  559. const int nb = quantize_3_blocks_per_row(k);
  560. const int nq = quantize_3_quants_per_block();
  561. gq_scale_t * restrict pd = (gq_scale_t *) (dst);
  562. gq_quant_t * restrict pb = (gq_quant_t *) (pd + nb);
  563. gq_quant_t pp[QB];
  564. static const int32_t sh[32] = {
  565. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  566. 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
  567. };
  568. for (int i = 0; i < nb; i++) {
  569. float amax = 0.0f; // abs max
  570. #ifdef __ARM_NEON
  571. {
  572. // min / max
  573. //float32x4_t minv = vdupq_n_f32(FLT_MAX);
  574. //float32x4_t maxv = vdupq_n_f32(-FLT_MAX);
  575. //for (int l = 0; l < QK; l += 4) {
  576. // float32x4_t v = vld1q_f32(src + i*QK + l);
  577. // minv = vminq_f32(minv, v);
  578. // maxv = vmaxq_f32(maxv, v);
  579. //}
  580. //float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv));
  581. //float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv));
  582. //min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1));
  583. //max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1));
  584. // abs max
  585. float32x4_t amaxv = vdupq_n_f32(0.0f);
  586. for (int l = 0; l < QK; l += 4) {
  587. float32x4_t v = vld1q_f32(src + i*QK + l);
  588. amaxv = vmaxq_f32(amaxv, vabsq_f32(v));
  589. }
  590. float32x2_t amaxv32 = vpmax_f32(vget_low_f32(amaxv), vget_high_f32(amaxv));
  591. amax = MAX(vget_lane_f32(amaxv32, 0), vget_lane_f32(amaxv32, 1));
  592. }
  593. #else
  594. {
  595. for (int l = 0; l < QK; l++) {
  596. const float v = src[i*QK + l];
  597. amax = MAX(amax, fabsf(v));
  598. }
  599. }
  600. #endif
  601. const float d = amax / ((1 << (QB - 1)) - 1);
  602. const float id = d ? 1.0/d : 0.0;
  603. pd[i] = GGML_FP32_TO_GQ(d);
  604. for (int s = 0; s < nq; ++s) {
  605. memset(pp, 0, sizeof(pp));
  606. #if 0
  607. for (int l = 0; l < gq_t_bits; l++) {
  608. const float v = src[i*QK + s*gq_t_bits + l];
  609. const uint8_t q = v*id + frand();
  610. for (int b = 0; b < QB; b++) {
  611. pp[b] |= q & (1 << b) ? (1ULL << l) : 0;
  612. }
  613. }
  614. #elif defined(__ARM_NEON)
  615. {
  616. uint32_t ppt[2*4*QB];
  617. float32x4_t idv = vdupq_n_f32(id);
  618. assert(gq_t_bits == 64);
  619. uint32x4_t p0[QB] = { vdupq_n_u32(0) };
  620. uint32x4_t p1[QB] = { vdupq_n_u32(0) };
  621. for (int l = 0; l < gq_t_bits; l += 16) {
  622. float32x4_t v0 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 0);
  623. float32x4_t v1 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 4);
  624. float32x4_t v2 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 8);
  625. float32x4_t v3 = vld1q_f32(src + i*QK + s*gq_t_bits + l + 12);
  626. v0 = vmulq_f32(v0, idv);
  627. v1 = vmulq_f32(v1, idv);
  628. v2 = vmulq_f32(v2, idv);
  629. v3 = vmulq_f32(v3, idv);
  630. #if 1
  631. v0[0] += frand(); v0[1] += frand(); v0[2] += frand(); v0[3] += frand();
  632. v1[0] += frand(); v1[1] += frand(); v1[2] += frand(); v1[3] += frand();
  633. v2[0] += frand(); v2[1] += frand(); v2[2] += frand(); v2[3] += frand();
  634. v3[0] += frand(); v3[1] += frand(); v3[2] += frand(); v3[3] += frand();
  635. #endif
  636. uint32x4_t q0 = vcvtq_u32_f32(v0);
  637. uint32x4_t q1 = vcvtq_u32_f32(v1);
  638. uint32x4_t q2 = vcvtq_u32_f32(v2);
  639. uint32x4_t q3 = vcvtq_u32_f32(v3);
  640. for (int b = 0; b < QB; ++b) {
  641. uint32x4_t m = vdupq_n_u32(1 << b);
  642. int32x4_t r = vdupq_n_s32(-b);
  643. if (l < 32) {
  644. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l + 0)));
  645. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l + 4)));
  646. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l + 8)));
  647. p0[b] = vorrq_u32(p0[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l + 12)));
  648. } else {
  649. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q0, m), r), vld1q_s32(sh + l - 32)));
  650. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q1, m), r), vld1q_s32(sh + l - 28)));
  651. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q2, m), r), vld1q_s32(sh + l - 24)));
  652. p1[b] = vorrq_u32(p1[b], vshlq_u32(vshlq_u32(vandq_u32(q3, m), r), vld1q_s32(sh + l - 20)));
  653. }
  654. }
  655. }
  656. #if QB == 4
  657. vst1q_u32((uint32_t *) ppt + 0, p0[0]);
  658. vst1q_u32((uint32_t *) ppt + 4, p1[0]);
  659. vst1q_u32((uint32_t *) ppt + 8, p0[1]);
  660. vst1q_u32((uint32_t *) ppt + 12, p1[1]);
  661. vst1q_u32((uint32_t *) ppt + 16, p0[2]);
  662. vst1q_u32((uint32_t *) ppt + 20, p1[2]);
  663. vst1q_u32((uint32_t *) ppt + 24, p0[3]);
  664. vst1q_u32((uint32_t *) ppt + 28, p1[3]);
  665. pp[0] = (ppt[0] | ppt[1] | ppt[2] | ppt[3] ) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7]) ) << 32;
  666. pp[1] = (ppt[8] | ppt[9] | ppt[10] | ppt[11]) | ((uint64_t) (ppt[12] | ppt[13] | ppt[14] | ppt[15])) << 32;
  667. pp[2] = (ppt[16] | ppt[17] | ppt[18] | ppt[19]) | ((uint64_t) (ppt[20] | ppt[21] | ppt[22] | ppt[23])) << 32;
  668. pp[3] = (ppt[24] | ppt[25] | ppt[26] | ppt[27]) | ((uint64_t) (ppt[28] | ppt[29] | ppt[30] | ppt[31])) << 32;
  669. #else
  670. for (int q = 0; q < QB; ++q) {
  671. vst1q_u32((uint32_t *) ppt + 0, p0[q]);
  672. vst1q_u32((uint32_t *) ppt + 4, p1[q]);
  673. pp[q] = (ppt[0] | ppt[1] | ppt[2] | ppt[3]) | ((uint64_t) (ppt[4] | ppt[5] | ppt[6] | ppt[7])) << 32;
  674. }
  675. #endif
  676. }
  677. #endif
  678. memcpy(pb + i*nq*QB + s*QB, pp, sizeof(pp));
  679. }
  680. }
  681. }
  682. // reimplementation of quantize_3 using quantize_3_row
  683. void quantize_3(const float * restrict src, char * restrict dst, int n, int k) {
  684. assert(k % QK == 0);
  685. for (int j = 0; j < n; j++) {
  686. quantize_3_row(src + j*k, dst, k);
  687. dst = (char *) dst + quantize_3_row_size(k);
  688. }
  689. }
  690. void vec_dot_gq_3(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
  691. float sumf = 0.0f;
  692. const int nb = quantize_3_blocks_per_row(n);
  693. const int nq = quantize_3_quants_per_block();
  694. const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
  695. const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
  696. const gq_quant_t * restrict pb0 = (const gq_quant_t *) (pd0 + nb);
  697. const gq_quant_t * restrict pb1 = (const gq_quant_t *) (pd1 + nb);
  698. #if 1
  699. for (int i = 0; i < nb; i++) {
  700. int isum = 0;
  701. #if QB == 4
  702. for (int s = 0; s < nq; ++s) {
  703. const gq_quant_t * restrict m0 = pb0 + i*nq*QB + s*QB;
  704. const gq_quant_t * restrict m1 = pb1 + i*nq*QB + s*QB;
  705. isum += (1 << 0)*(__builtin_popcountll(m0[0] & m1[0]));
  706. isum += (1 << 1)*(__builtin_popcountll(m0[0] & m1[1]) + __builtin_popcountll(m0[1] & m1[0]));
  707. isum += (1 << 2)*(__builtin_popcountll(m0[0] & m1[2]) + __builtin_popcountll(m0[1] & m1[1]) + __builtin_popcountll(m0[2] & m1[0]));
  708. isum += (1 << 3)*(__builtin_popcountll(m0[0] & m1[3]) + __builtin_popcountll(m0[1] & m1[2]) + __builtin_popcountll(m0[2] & m1[1]) + __builtin_popcountll(m0[3] & m1[0]));
  709. isum += (1 << 4)*(__builtin_popcountll(m0[1] & m1[3]) + __builtin_popcountll(m0[2] & m1[2]) + __builtin_popcountll(m0[3] & m1[1]));
  710. isum += (1 << 5)*(__builtin_popcountll(m0[2] & m1[3]) + __builtin_popcountll(m0[3] & m1[2]));
  711. isum += (1 << 6)*(__builtin_popcountll(m0[3] & m1[3]));
  712. }
  713. #else
  714. for (int s = 0; s < nq; ++s) {
  715. for (int q0 = 0; q0 < QB; q0++) {
  716. const gq_quant_t mm0 = pb0[i*nq*QB + s*QB + q0];
  717. for (int q1 = 0; q1 < QB; q1++) {
  718. const gq_quant_t mm1 = pb1[i*nq*QB + s*QB + q1];
  719. isum += (1 << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
  720. }
  721. }
  722. }
  723. #endif
  724. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  725. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  726. sumf += d0*d1*isum;
  727. }
  728. #else
  729. #ifdef __ARM_NEON
  730. // gq_quant_t == uint64_t
  731. for (int i = 0; i < nb; i += 4) {
  732. int isum[4] = {0, 0, 0, 0};
  733. for (int k = 0; k < 4; ++k) {
  734. for (int s = 0; s < nq; ++s) {
  735. const gq_quant_t * restrict m0 = pb0 + (i+k)*nq*QB + s*QB;
  736. const gq_quant_t * restrict m1 = pb1 + (i+k)*nq*QB + s*QB;
  737. #if QB == 4
  738. #define bpcnt(x) __builtin_popcountll(x)
  739. //isum[k] += (1ULL << 0)*(bpcnt(m0[0] & m1[0])) +
  740. // (1ULL << 1)*(bpcnt(m0[0] & m1[1]) + bpcnt(m0[1] & m1[0])) +
  741. // (1ULL << 2)*(bpcnt(m0[0] & m1[2]) + bpcnt(m0[1] & m1[1]) + bpcnt(m0[2] & m1[0])) +
  742. // (1ULL << 3)*(bpcnt(m0[0] & m1[3]) + bpcnt(m0[1] & m1[2]) + bpcnt(m0[2] & m1[1]) + bpcnt(m0[3] & m1[0])) +
  743. // (1ULL << 4)*(bpcnt(m0[1] & m1[3]) + bpcnt(m0[2] & m1[2]) + bpcnt(m0[3] & m1[1])) +
  744. // (1ULL << 5)*(bpcnt(m0[2] & m1[3]) + bpcnt(m0[3] & m1[2])) +
  745. // (1ULL << 6)*(bpcnt(m0[3] & m1[3]));
  746. #undef bpcnt
  747. const uint8x8_t m00 = vld1_u8((const uint8_t *) (m0 + 0));
  748. const uint8x8_t m01 = vld1_u8((const uint8_t *) (m0 + 1));
  749. const uint8x8_t m02 = vld1_u8((const uint8_t *) (m0 + 2));
  750. const uint8x8_t m03 = vld1_u8((const uint8_t *) (m0 + 3));
  751. const uint8x8_t m10 = vld1_u8((const uint8_t *) (m1 + 0));
  752. const uint8x8_t m11 = vld1_u8((const uint8_t *) (m1 + 1));
  753. const uint8x8_t m12 = vld1_u8((const uint8_t *) (m1 + 2));
  754. const uint8x8_t m13 = vld1_u8((const uint8_t *) (m1 + 3));
  755. const uint8x8_t m00m10 = vand_u8(m00, m10);
  756. const uint8x8_t m00m11 = vand_u8(m00, m11);
  757. const uint8x8_t m01m10 = vand_u8(m01, m10);
  758. const uint8x8_t m00m12 = vand_u8(m00, m12);
  759. const uint8x8_t m01m11 = vand_u8(m01, m11);
  760. const uint8x8_t m02m10 = vand_u8(m02, m10);
  761. const uint8x8_t m00m13 = vand_u8(m00, m13);
  762. const uint8x8_t m01m12 = vand_u8(m01, m12);
  763. const uint8x8_t m02m11 = vand_u8(m02, m11);
  764. const uint8x8_t m03m10 = vand_u8(m03, m10);
  765. const uint8x8_t m01m13 = vand_u8(m01, m13);
  766. const uint8x8_t m02m12 = vand_u8(m02, m12);
  767. const uint8x8_t m03m11 = vand_u8(m03, m11);
  768. const uint8x8_t m02m13 = vand_u8(m02, m13);
  769. const uint8x8_t m03m12 = vand_u8(m03, m12);
  770. const uint8x8_t m03m13 = vand_u8(m03, m13);
  771. #define bpcnt(x) vaddv_u8(vcnt_u8(x))
  772. isum[k] += (1ULL << 0)*(bpcnt(m00m10)) +
  773. (1ULL << 1)*(bpcnt(m00m11) + bpcnt(m01m10)) +
  774. (1ULL << 2)*(bpcnt(m00m12) + bpcnt(m01m11) + bpcnt(m02m10)) +
  775. (1ULL << 3)*(bpcnt(m00m13) + bpcnt(m01m12) + bpcnt(m02m11) + bpcnt(m03m10)) +
  776. (1ULL << 4)*(bpcnt(m01m13) + bpcnt(m02m12) + bpcnt(m03m11)) +
  777. (1ULL << 5)*(bpcnt(m02m13) + bpcnt(m03m12)) +
  778. (1ULL << 6)*(bpcnt(m03m13));
  779. #undef bpcnt
  780. #else
  781. for (int q0 = 0; q0 < QB; q0++) {
  782. const gq_quant_t mm0 = m0[q0];
  783. for (int q1 = 0; q1 < QB; q1++) {
  784. const gq_quant_t mm1 = m1[q1];
  785. isum[k] += (1ULL << (q0 + q1))*(__builtin_popcountll(mm0 & mm1));
  786. }
  787. }
  788. #endif
  789. }
  790. }
  791. int32x4_t isumv = vld1q_s32(isum);
  792. float32x4_t d0v = vld1q_f32(pd0 + i);
  793. float32x4_t d1v = vld1q_f32(pd1 + i);
  794. float32x4_t sumfv = vmulq_f32(d0v, d1v);
  795. sumfv = vmulq_f32(sumfv, vcvtq_f32_s32(isumv));
  796. sumf += vaddvq_f32(sumfv);
  797. }
  798. #else
  799. #error "not implemented"
  800. #endif
  801. #endif
  802. *s = sumf;
  803. }
  804. // use vec_dot_gq_3 to compute the dot product of two rows
  805. void mul_mat_gq_3(
  806. const void * src0,
  807. const void * src1, // transposed
  808. float * dst,
  809. int m, int n, int k) {
  810. assert(k % QK == 0);
  811. const int nb = quantize_3_blocks_per_row(k);
  812. const int nq = quantize_3_quants_per_block();
  813. for (int ir0 = 0; ir0 < m; ir0++) {
  814. for (int ir1 = 0; ir1 < n; ir1++) {
  815. vec_dot_gq_3(k, dst + ir1, src0, src1);
  816. src1 = (const char *) src1 + quantize_3_row_size(k);
  817. }
  818. src0 = (const char *) src0 + quantize_3_row_size(k);
  819. src1 = (const char *) src1 - n*quantize_3_row_size(k);
  820. dst = (float *) dst + n;
  821. }
  822. }
  823. //
  824. // method 4
  825. // 4-bit quantization
  826. //
  827. static inline int quantize_4_blocks_per_row(int k) {
  828. return k/QK;
  829. }
  830. static inline int quantize_4_row_size(int k) {
  831. const int nb = quantize_4_blocks_per_row(k);
  832. return nb*(2*sizeof(gq_scale_t) + QK/2);
  833. }
  834. void quantize_4_row(const float * restrict src, void * restrict dst, int k) {
  835. assert(k % QK == 0);
  836. assert(QB == 4);
  837. const int nb = quantize_4_blocks_per_row(k);
  838. gq_scale_t * restrict pm = (gq_scale_t *) (dst);
  839. gq_scale_t * restrict pd = (gq_scale_t *) (pm + nb);
  840. uint8_t * restrict pb = (uint8_t *) (pd + nb);
  841. uint8_t pp[QK/2];
  842. for (int i = 0; i < nb; i++) {
  843. memset(pp, 0, sizeof(pp));
  844. float min = FLT_MAX;
  845. float max = -FLT_MAX;
  846. #if defined(__AVX2__)
  847. {
  848. assert(QK == 64);
  849. enum { QK8 = QK/8 };
  850. __m256 srcv[QK8];
  851. __m256 minv[QK8];
  852. __m256 maxv[QK8];
  853. for (int l = 0; l < QK8; l++) {
  854. srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l);
  855. }
  856. for (int l = 0; l < QK8/2; l++) {
  857. minv[2*l] = _mm256_min_ps(srcv[2*l], srcv[2*l+1]);
  858. maxv[2*l] = _mm256_max_ps(srcv[2*l], srcv[2*l+1]);
  859. }
  860. for (int l = 0; l < QK8/4; l++) {
  861. minv[4*l] = _mm256_min_ps(minv[4*l], minv[4*l+2]);
  862. maxv[4*l] = _mm256_max_ps(maxv[4*l], maxv[4*l+2]);
  863. }
  864. for (int l = 0; l < QK8/8; l++) {
  865. minv[8*l] = _mm256_min_ps(minv[8*l], minv[8*l+4]);
  866. maxv[8*l] = _mm256_max_ps(maxv[8*l], maxv[8*l+4]);
  867. }
  868. //min = MIN(minv[0][0], MIN(minv[0][1], MIN(minv[0][2], MIN(minv[0][3], MIN(minv[0][4], MIN(minv[0][5], MIN(minv[0][6], minv[0][7])))))));
  869. //max = MAX(maxv[0][0], MAX(maxv[0][1], MAX(maxv[0][2], MAX(maxv[0][3], MAX(maxv[0][4], MAX(maxv[0][5], MAX(maxv[0][6], maxv[0][7])))))));
  870. const __m256 minv0_0 = _mm256_permute2f128_ps(minv[0], minv[0], 3);
  871. const __m256 minv0_1 = _mm256_min_ps(minv[0], minv0_0);
  872. const __m256 minv0_2 = _mm256_permute_ps(minv0_1, 0x4e);
  873. const __m256 minv0_3 = _mm256_min_ps(minv0_1, minv0_2);
  874. const __m256 minv0_4 = _mm256_permute_ps(minv0_3, 0xb1);
  875. const __m256 minv0_5 = _mm256_min_ps(minv0_3, minv0_4);
  876. const __m256 maxv0_0 = _mm256_permute2f128_ps(maxv[0], maxv[0], 3);
  877. const __m256 maxv0_1 = _mm256_max_ps(maxv[0], maxv0_0);
  878. const __m256 maxv0_2 = _mm256_permute_ps(maxv0_1, 0x4e);
  879. const __m256 maxv0_3 = _mm256_max_ps(maxv0_1, maxv0_2);
  880. const __m256 maxv0_4 = _mm256_permute_ps(maxv0_3, 0xb1);
  881. const __m256 maxv0_5 = _mm256_max_ps(maxv0_3, maxv0_4);
  882. min = _mm256_cvtss_f32(minv0_5);
  883. max = _mm256_cvtss_f32(maxv0_5);
  884. const float d = (max - min) / ((1 << QB) - 2);
  885. const float id = d ? 1.0/d : 0.0;
  886. pm[i] = GGML_FP32_TO_GQ(min);
  887. pd[i] = GGML_FP32_TO_GQ(d);
  888. const __m256 idv = _mm256_set1_ps(id);
  889. for (int l = 0; l < QK/8; l++) {
  890. __m256 v = _mm256_mul_ps(_mm256_sub_ps(srcv[l], _mm256_set1_ps(min)), idv);
  891. #if 0
  892. v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
  893. v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
  894. #endif
  895. // convert to uint8
  896. __m256i vi = _mm256_cvtps_epi32(v);
  897. uint32_t vi_0 = _mm256_extract_epi32(vi, 0);
  898. uint32_t vi_1 = _mm256_extract_epi32(vi, 1);
  899. uint32_t vi_2 = _mm256_extract_epi32(vi, 2);
  900. uint32_t vi_3 = _mm256_extract_epi32(vi, 3);
  901. uint32_t vi_4 = _mm256_extract_epi32(vi, 4);
  902. uint32_t vi_5 = _mm256_extract_epi32(vi, 5);
  903. uint32_t vi_6 = _mm256_extract_epi32(vi, 6);
  904. uint32_t vi_7 = _mm256_extract_epi32(vi, 7);
  905. // convert to 4-bit, 2 consecutive packed into 1 byte
  906. pp[4*l + 0] = vi_0 | (vi_1 << 4);
  907. pp[4*l + 1] = vi_2 | (vi_3 << 4);
  908. pp[4*l + 2] = vi_4 | (vi_5 << 4);
  909. pp[4*l + 3] = vi_6 | (vi_7 << 4);
  910. //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
  911. //printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
  912. }
  913. memcpy(pb + i*QK/2, pp, sizeof(pp));
  914. }
  915. #elif defined(__ARM_NEON) && 0
  916. {
  917. // TODO
  918. }
  919. #else
  920. {
  921. for (int l = 0; l < QK; l++) {
  922. const float v = src[i*QK + l];
  923. if (v < min) min = v;
  924. if (v > max) max = v;
  925. }
  926. const float d = (max - min) / ((1 << QB) - 1);
  927. const float id = d ? 1.0/d : 0.0;
  928. pm[i] = GGML_FP32_TO_GQ(min);
  929. pd[i] = GGML_FP32_TO_GQ(d);
  930. for (int l = 0; l < QK; l++) {
  931. const float v = (src[i*QK + l] - min) * id;
  932. const uint8_t vi = (uint8_t) (v + frand());
  933. pp[l/2] |= (vi & 0xf) << (4*(l & 1));
  934. }
  935. memcpy(pb + i*QK/2, pp, sizeof(pp));
  936. }
  937. #endif
  938. //printf("min %f max %f\n", min, max);
  939. }
  940. }
  941. // reimplementation of quantize_4 using quantize_4_row
  942. void quantize_4(const float * restrict src, char * restrict dst, int n, int k) {
  943. assert(k % QK == 0);
  944. for (int j = 0; j < n; j++) {
  945. quantize_4_row(src + j*k, dst, k);
  946. dst = (char *) dst + quantize_4_row_size(k);
  947. }
  948. }
  949. void vec_dot_gq_4(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
  950. const int nb = quantize_4_blocks_per_row(n);
  951. const gq_scale_t * restrict pm0 = (const gq_scale_t *) x;
  952. const gq_scale_t * restrict pm1 = (const gq_scale_t *) y;
  953. const gq_scale_t * restrict pd0 = pm0 + nb;
  954. const gq_scale_t * restrict pd1 = pm1 + nb;
  955. const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
  956. const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
  957. float sumf = 0.0;
  958. #if 0
  959. // scalar
  960. for (int i = 0; i < nb; i++) {
  961. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  962. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  963. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  964. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  965. const uint8_t * restrict p0 = pb0 + i*QK/2;
  966. const uint8_t * restrict p1 = pb1 + i*QK/2;
  967. for (int j = 0; j < QK/2; j++) {
  968. const uint8_t v0 = p0[j];
  969. const uint8_t v1 = p1[j];
  970. const float f0 = d0*(v0 & 0xf) + m0;
  971. const float f1 = d0*(v0 >> 4) + m0;
  972. const float f2 = d1*(v1 & 0xf) + m1;
  973. const float f3 = d1*(v1 >> 4) + m1;
  974. sumf += f0*f2 + f1*f3;
  975. }
  976. }
  977. #else
  978. #if defined(__AVX2__)
  979. #if QK == 64 && 0
  980. __m256 sumv0 = _mm256_setzero_ps();
  981. __m256 sumv1 = _mm256_setzero_ps();
  982. for (int i = 0; i < nb; i++) {
  983. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  984. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  985. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  986. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  987. const uint8_t * restrict p0 = pb0 + i*QK/2;
  988. const uint8_t * restrict p1 = pb1 + i*QK/2;
  989. const __m256 m0v = _mm256_set1_ps(m0);
  990. const __m256 d0v = _mm256_set1_ps(d0);
  991. const __m256 m1v = _mm256_set1_ps(m1);
  992. const __m256 d1v = _mm256_set1_ps(d1);
  993. const __m256i m4b = _mm256_set1_epi8(0xf);
  994. __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
  995. //_mm_prefetch((const char *) (p0 + 32), _MM_HINT_T0);
  996. //_mm_prefetch((const char *) (p1 + 32), _MM_HINT_T0);
  997. //_mm_prefetch((const char *) (pm0 + i + 1), _MM_HINT_T0);
  998. //_mm_prefetch((const char *) (pm1 + i + 1), _MM_HINT_T0);
  999. //_mm_prefetch((const char *) (pd0 + i + 1), _MM_HINT_T0);
  1000. //_mm_prefetch((const char *) (pd1 + i + 1), _MM_HINT_T0);
  1001. __m256i v00 = _mm256_and_si256(v0, _mm256_set1_epi32(0x000000FF));
  1002. __m256i v01 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x0000FFFF)), 8);
  1003. __m256i v02 = _mm256_srli_epi32(_mm256_and_si256(v0, _mm256_set1_epi32(0x00FFFFFF)), 16);
  1004. __m256i v03 = _mm256_srli_epi32(v0, 24);
  1005. //////////////////////
  1006. //{
  1007. // uint32_t vi_0 = _mm256_extract_epi32(v00, 0);
  1008. // uint32_t vi_1 = _mm256_extract_epi32(v00, 1);
  1009. // uint32_t vi_2 = _mm256_extract_epi32(v00, 2);
  1010. // uint32_t vi_3 = _mm256_extract_epi32(v00, 3);
  1011. // uint32_t vi_4 = _mm256_extract_epi32(v00, 4);
  1012. // uint32_t vi_5 = _mm256_extract_epi32(v00, 5);
  1013. // uint32_t vi_6 = _mm256_extract_epi32(v00, 6);
  1014. // uint32_t vi_7 = _mm256_extract_epi32(v00, 7);
  1015. // printf("v0: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
  1016. // printf("p0: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[0], p0[4], p0[8], p0[12], p0[16], p0[20], p0[24], p0[28]);
  1017. // printf("p1: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[1], p0[5], p0[9], p0[13], p0[17], p0[21], p0[25], p0[29]);
  1018. // printf("p2: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[2], p0[6], p0[10], p0[14], p0[18], p0[22], p0[26], p0[30]);
  1019. // printf("p3: %7d %7d %7d %7d %7d %7d %7d %7d\n", p0[3], p0[7], p0[11], p0[15], p0[19], p0[23], p0[27], p0[31]);
  1020. //}
  1021. // compute 32 x 4-bit values (low and high)
  1022. __m256i v00l = _mm256_and_si256(v00, m4b);
  1023. __m256i v01l = _mm256_and_si256(v01, m4b);
  1024. __m256i v02l = _mm256_and_si256(v02, m4b);
  1025. __m256i v03l = _mm256_and_si256(v03, m4b);
  1026. __m256i v00h = _mm256_srli_epi32(v00, 4);
  1027. __m256i v01h = _mm256_srli_epi32(v01, 4);
  1028. __m256i v02h = _mm256_srli_epi32(v02, 4);
  1029. __m256i v03h = _mm256_srli_epi32(v03, 4);
  1030. //{
  1031. // uint32_t vi_0 = _mm256_extract_epi32(v00l, 0);
  1032. // uint32_t vi_1 = _mm256_extract_epi32(v00l, 1);
  1033. // uint32_t vi_2 = _mm256_extract_epi32(v00l, 2);
  1034. // uint32_t vi_3 = _mm256_extract_epi32(v00l, 3);
  1035. // uint32_t vi_4 = _mm256_extract_epi32(v00l, 4);
  1036. // uint32_t vi_5 = _mm256_extract_epi32(v00l, 5);
  1037. // uint32_t vi_6 = _mm256_extract_epi32(v00l, 6);
  1038. // uint32_t vi_7 = _mm256_extract_epi32(v00l, 7);
  1039. // printf("v0l: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
  1040. // vi_0 = _mm256_extract_epi32(v00h, 0);
  1041. // vi_1 = _mm256_extract_epi32(v00h, 1);
  1042. // vi_2 = _mm256_extract_epi32(v00h, 2);
  1043. // vi_3 = _mm256_extract_epi32(v00h, 3);
  1044. // vi_4 = _mm256_extract_epi32(v00h, 4);
  1045. // vi_5 = _mm256_extract_epi32(v00h, 5);
  1046. // vi_6 = _mm256_extract_epi32(v00h, 6);
  1047. // vi_7 = _mm256_extract_epi32(v00h, 7);
  1048. // printf("v0h: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
  1049. //}
  1050. // convert to float
  1051. __m256 vf00l = _mm256_cvtepi32_ps(v00l);
  1052. __m256 vf01l = _mm256_cvtepi32_ps(v01l);
  1053. __m256 vf02l = _mm256_cvtepi32_ps(v02l);
  1054. __m256 vf03l = _mm256_cvtepi32_ps(v03l);
  1055. __m256 vf00h = _mm256_cvtepi32_ps(v00h);
  1056. __m256 vf01h = _mm256_cvtepi32_ps(v01h);
  1057. __m256 vf02h = _mm256_cvtepi32_ps(v02h);
  1058. __m256 vf03h = _mm256_cvtepi32_ps(v03h);
  1059. //{
  1060. // printf("vf00l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf00l[0], vf00l[1], vf00l[2], vf00l[3], vf00l[4], vf00l[5], vf00l[6], vf00l[7]);
  1061. // printf("vf01l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf01l[0], vf01l[1], vf01l[2], vf01l[3], vf01l[4], vf01l[5], vf01l[6], vf01l[7]);
  1062. // printf("vf02l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf02l[0], vf02l[1], vf02l[2], vf02l[3], vf02l[4], vf02l[5], vf02l[6], vf02l[7]);
  1063. // printf("vf03l: %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", vf03l[0], vf03l[1], vf03l[2], vf03l[3], vf03l[4], vf03l[5], vf03l[6], vf03l[7]);
  1064. //}
  1065. // multiply by scale and add offset
  1066. vf00l = _mm256_fmadd_ps(vf00l, d0v, m0v);
  1067. vf01l = _mm256_fmadd_ps(vf01l, d0v, m0v);
  1068. vf02l = _mm256_fmadd_ps(vf02l, d0v, m0v);
  1069. vf03l = _mm256_fmadd_ps(vf03l, d0v, m0v);
  1070. vf00h = _mm256_fmadd_ps(vf00h, d0v, m0v);
  1071. vf01h = _mm256_fmadd_ps(vf01h, d0v, m0v);
  1072. vf02h = _mm256_fmadd_ps(vf02h, d0v, m0v);
  1073. vf03h = _mm256_fmadd_ps(vf03h, d0v, m0v);
  1074. __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
  1075. __m256i v10 = _mm256_and_si256(v1, _mm256_set1_epi32(0x000000FF));
  1076. __m256i v11 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x0000FFFF)), 8);
  1077. __m256i v12 = _mm256_srli_epi32(_mm256_and_si256(v1, _mm256_set1_epi32(0x00FFFFFF)), 16);
  1078. __m256i v13 = _mm256_srli_epi32(v1, 24);
  1079. __m256i v10l = _mm256_and_si256(v10, m4b);
  1080. __m256i v11l = _mm256_and_si256(v11, m4b);
  1081. __m256i v12l = _mm256_and_si256(v12, m4b);
  1082. __m256i v13l = _mm256_and_si256(v13, m4b);
  1083. __m256i v10h = _mm256_srli_epi32(v10, 4);
  1084. __m256i v11h = _mm256_srli_epi32(v11, 4);
  1085. __m256i v12h = _mm256_srli_epi32(v12, 4);
  1086. __m256i v13h = _mm256_srli_epi32(v13, 4);
  1087. __m256 vf10l = _mm256_cvtepi32_ps(v10l);
  1088. __m256 vf11l = _mm256_cvtepi32_ps(v11l);
  1089. __m256 vf12l = _mm256_cvtepi32_ps(v12l);
  1090. __m256 vf13l = _mm256_cvtepi32_ps(v13l);
  1091. __m256 vf10h = _mm256_cvtepi32_ps(v10h);
  1092. __m256 vf11h = _mm256_cvtepi32_ps(v11h);
  1093. __m256 vf12h = _mm256_cvtepi32_ps(v12h);
  1094. __m256 vf13h = _mm256_cvtepi32_ps(v13h);
  1095. vf10l = _mm256_fmadd_ps(vf10l, d1v, m1v);
  1096. vf11l = _mm256_fmadd_ps(vf11l, d1v, m1v);
  1097. vf12l = _mm256_fmadd_ps(vf12l, d1v, m1v);
  1098. vf13l = _mm256_fmadd_ps(vf13l, d1v, m1v);
  1099. vf10h = _mm256_fmadd_ps(vf10h, d1v, m1v);
  1100. vf11h = _mm256_fmadd_ps(vf11h, d1v, m1v);
  1101. vf12h = _mm256_fmadd_ps(vf12h, d1v, m1v);
  1102. vf13h = _mm256_fmadd_ps(vf13h, d1v, m1v);
  1103. // compute dot product
  1104. sumv0 = _mm256_fmadd_ps(vf00l, vf10l, sumv0);
  1105. sumv0 = _mm256_fmadd_ps(vf01l, vf11l, sumv0);
  1106. sumv0 = _mm256_fmadd_ps(vf02l, vf12l, sumv0);
  1107. sumv0 = _mm256_fmadd_ps(vf03l, vf13l, sumv0);
  1108. sumv1 = _mm256_fmadd_ps(vf00h, vf10h, sumv1);
  1109. sumv1 = _mm256_fmadd_ps(vf01h, vf11h, sumv1);
  1110. sumv1 = _mm256_fmadd_ps(vf02h, vf12h, sumv1);
  1111. sumv1 = _mm256_fmadd_ps(vf03h, vf13h, sumv1);
  1112. }
  1113. // accumulate (horizontal sum)
  1114. const __m256 vdot = _mm256_add_ps(sumv0, sumv1);
  1115. const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(vdot), _mm256_extractf128_ps(vdot, 1));
  1116. const __m128 t1 = _mm_hadd_ps(t0, t0);
  1117. sumf += _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
  1118. #elif QK == 64 && 0
  1119. float sum00 = 0.0f;
  1120. float sum01 = 0.0f;
  1121. float sum10 = 0.0f;
  1122. float sum11 = 0.0f;
  1123. const __m256i m4b = _mm256_set1_epi8(0xf);
  1124. for (int i = 0; i < nb; i++) {
  1125. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  1126. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1127. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  1128. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1129. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1130. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1131. // 64 x 4
  1132. const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
  1133. const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
  1134. // 32 x 8
  1135. const __m256i v0l = _mm256_and_si256(v0, m4b);
  1136. const __m256i v1l = _mm256_and_si256(v1, m4b);
  1137. const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
  1138. const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
  1139. const __m256i pl = _mm256_maddubs_epi16(v0l, v1l);
  1140. const __m256i ph = _mm256_maddubs_epi16(v0h, v1h);
  1141. const __m256i p16 = _mm256_add_epi16(ph, pl);
  1142. const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
  1143. sum00 += m0*m1;
  1144. sum01 += m1*d0*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v0l, v0h)));
  1145. sum10 += m0*d1*(_mm256_hadd_epi8_gg(_mm256_add_epi8(v1l, v1h)));
  1146. sum11 += d0*d1*(_mm256_hadd_epi32_gg(p));
  1147. }
  1148. sumf = 64.0*sum00 + sum01 + sum10 + sum11;
  1149. #elif QK == 64 && 1 // this is the best when using min + d
  1150. float sum00 = 0.0f;
  1151. __m256 sum01 = _mm256_setzero_ps();
  1152. __m256 sum10 = _mm256_setzero_ps();
  1153. __m256 sum11 = _mm256_setzero_ps();
  1154. for (int i = 0; i < nb; i++) {
  1155. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  1156. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1157. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  1158. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1159. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1160. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1161. const __m256 m0v = _mm256_set1_ps(m0);
  1162. const __m256 d0v = _mm256_set1_ps(d0);
  1163. const __m256 m1v = _mm256_set1_ps(m1);
  1164. const __m256 d1v = _mm256_set1_ps(d1);
  1165. const __m256 m1d0v = _mm256_mul_ps(m1v, d0v);
  1166. const __m256 m0d1v = _mm256_mul_ps(m0v, d1v);
  1167. const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
  1168. const __m256i m4b = _mm256_set1_epi8(0xf);
  1169. // 64 x 4
  1170. const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
  1171. const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
  1172. // 32 x 8
  1173. const __m256i v0l = _mm256_and_si256(v0, m4b);
  1174. const __m256i v1l = _mm256_and_si256(v1, m4b);
  1175. const __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
  1176. const __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
  1177. const __m256i v0a = _mm256_add_epi8(v0l, v0h);
  1178. const __m256i v1a = _mm256_add_epi8(v1l, v1h);
  1179. const __m128i v0al = _mm256_extracti128_si256(v0a, 0);
  1180. const __m128i v0ah = _mm256_extracti128_si256(v0a, 1);
  1181. const __m128i v1al = _mm256_extracti128_si256(v1a, 0);
  1182. const __m128i v1ah = _mm256_extracti128_si256(v1a, 1);
  1183. const __m128i v0as = _mm_add_epi8(v0al, v0ah);
  1184. const __m128i v1as = _mm_add_epi8(v1al, v1ah);
  1185. const __m256i v0as_0 = _mm256_cvtepu8_epi32(v0as);
  1186. const __m256i v0as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v0as, 8));
  1187. const __m256i v1as_0 = _mm256_cvtepu8_epi32(v1as);
  1188. const __m256i v1as_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(v1as, 8));
  1189. const __m256i v0ass = _mm256_add_epi32(v0as_0, v0as_1);
  1190. const __m256i v1ass = _mm256_add_epi32(v1as_0, v1as_1);
  1191. const __m256 v0f = _mm256_cvtepi32_ps(v0ass);
  1192. const __m256 v1f = _mm256_cvtepi32_ps(v1ass);
  1193. const __m256i pl = _mm256_maddubs_epi16(v0l, v1l);
  1194. const __m256i ph = _mm256_maddubs_epi16(v0h, v1h);
  1195. const __m256i p16 = _mm256_add_epi16(ph, pl);
  1196. const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
  1197. sum00 += m0*m1;
  1198. sum01 = _mm256_fmadd_ps(m1d0v, v0f, sum01);
  1199. sum10 = _mm256_fmadd_ps(m0d1v, v1f, sum10);
  1200. sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
  1201. }
  1202. sumf = 64.0*sum00 + _mm256_hadd_ps_gg(sum01) + _mm256_hadd_ps_gg(sum10) + _mm256_hadd_ps_gg(sum11);
  1203. #endif
  1204. #elif defined (__ARM_NEON)
  1205. float sum00 = 0.0f;
  1206. float sum01 = 0.0f;
  1207. float sum10 = 0.0f;
  1208. float sum11 = 0.0f;
  1209. for (int i = 0; i < nb; i++) {
  1210. const float m0 = GGML_GQ_TO_FP32(pm0[i]);
  1211. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1212. const float m1 = GGML_GQ_TO_FP32(pm1[i]);
  1213. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1214. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1215. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1216. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1217. const uint8x16_t v0_0 = vld1q_u8(p0);
  1218. const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
  1219. const uint8x16_t v1_0 = vld1q_u8(p1);
  1220. const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
  1221. // and with 0xf
  1222. const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
  1223. const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
  1224. const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
  1225. const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
  1226. const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
  1227. const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
  1228. const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
  1229. const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
  1230. // dot product into uint16x8_t
  1231. const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
  1232. const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
  1233. const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
  1234. const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
  1235. const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
  1236. const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
  1237. const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
  1238. const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
  1239. const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
  1240. const uint16x8_t pl1 = vaddq_u16(pl1l, pl1h);
  1241. const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
  1242. const uint16x8_t ph1 = vaddq_u16(ph1l, ph1h);
  1243. const uint16x8_t pl = vaddq_u16(pl0, pl1);
  1244. const uint16x8_t ph = vaddq_u16(ph0, ph1);
  1245. sum00 += m0*m1;
  1246. sum01 += m1*d0*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h) + vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
  1247. sum10 += m0*d1*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h) + vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
  1248. //sum11 += d0*d1*(
  1249. // vaddvq_u16(vaddq_u16(vaddq_u16(pl0l, pl0h), vaddq_u16(pl1l, pl1h))) +
  1250. // vaddvq_u16(vaddq_u16(vaddq_u16(ph0l, ph0h), vaddq_u16(ph1l, ph1h))));
  1251. sum11 += d0*d1*vaddvq_u16(vaddq_u16(pl, ph));
  1252. }
  1253. sumf = 64.0*sum00 + sum01 + sum10 + sum11;
  1254. #endif
  1255. #endif
  1256. *s = sumf;
  1257. }
  1258. // use vec_dot_gq_4 to compute the dot product of two rows
  1259. void mul_mat_gq_4(
  1260. const void * src0,
  1261. const void * src1, // transposed
  1262. float * dst,
  1263. int m, int n, int k) {
  1264. assert(k % QK == 0);
  1265. const int nb = quantize_4_blocks_per_row(k);
  1266. for (int ir0 = 0; ir0 < m; ir0++) {
  1267. for (int ir1 = 0; ir1 < n; ir1++) {
  1268. vec_dot_gq_4(k, dst + ir1, src0, src1);
  1269. src1 = (const char *) src1 + quantize_4_row_size(k);
  1270. }
  1271. src0 = (const char *) src0 + quantize_4_row_size(k);
  1272. src1 = (const char *) src1 - n*quantize_4_row_size(k);
  1273. dst = (float *) dst + n;
  1274. }
  1275. }
  1276. //
  1277. // method 5
  1278. // 4-bit quantization (without min, only delta)
  1279. //
  1280. static inline int quantize_5_blocks_per_row(int k) {
  1281. return k/QK;
  1282. }
  1283. static inline int quantize_5_row_size(int k) {
  1284. const int nb = quantize_5_blocks_per_row(k);
  1285. return nb*(sizeof(gq_scale_t) + QK/2);
  1286. }
  1287. void quantize_5_row(const float * restrict src, void * restrict dst, int k) {
  1288. assert(k % QK == 0);
  1289. assert(QB == 4);
  1290. const int nb = quantize_5_blocks_per_row(k);
  1291. gq_scale_t * restrict pd = (gq_scale_t *) (dst);
  1292. uint8_t * restrict pb = (uint8_t *) (pd + nb);
  1293. uint8_t pp[QK/2];
  1294. for (int i = 0; i < nb; i++) {
  1295. memset(pp, 0, sizeof(pp));
  1296. float amax = 0.0f; // absolute max
  1297. #if defined(__AVX2__)
  1298. {
  1299. assert(QK == 64);
  1300. enum { QK8 = QK/8 };
  1301. __m256 srcv [QK8];
  1302. __m256 asrcv[QK8];
  1303. __m256 amaxv[QK8];
  1304. for (int l = 0; l < QK8; l++) {
  1305. srcv[l] = _mm256_loadu_ps(src + i*QK + 8*l);
  1306. }
  1307. for (int l = 0; l < QK8; l++) {
  1308. asrcv[l] = _mm256_and_ps(srcv[l], _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)));
  1309. }
  1310. for (int l = 0; l < QK8/2; l++) {
  1311. amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
  1312. }
  1313. for (int l = 0; l < QK8/4; l++) {
  1314. amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
  1315. }
  1316. for (int l = 0; l < QK8/8; l++) {
  1317. amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]);
  1318. }
  1319. //amax = MAX(amaxv[0][0], MAX(amaxv[0][1], MAX(amaxv[0][2], MAX(amaxv[0][3], MAX(amaxv[0][4], MAX(amaxv[0][5], MAX(amaxv[0][6], amaxv[0][7])))))));
  1320. const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
  1321. const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
  1322. const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
  1323. const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
  1324. const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
  1325. const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
  1326. amax = _mm256_cvtss_f32(amaxv0_5);
  1327. //printf("amax = %f\n", amax);
  1328. const float d = amax / ((1 << (QB - 1)) - 1);
  1329. const float id = d ? 1.0/d : 0.0;
  1330. pd[i] = GGML_FP32_TO_GQ(d);
  1331. const __m256 idv = _mm256_set1_ps(id);
  1332. for (int l = 0; l < QK/8; l++) {
  1333. __m256 v = _mm256_mul_ps(srcv[l], idv);
  1334. #if 0
  1335. v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
  1336. v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
  1337. #endif
  1338. // convert to int8
  1339. __m256i vi = _mm256_cvtps_epi32(v);
  1340. vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
  1341. int32_t vi_0 = _mm256_extract_epi32(vi, 0);
  1342. int32_t vi_1 = _mm256_extract_epi32(vi, 1);
  1343. int32_t vi_2 = _mm256_extract_epi32(vi, 2);
  1344. int32_t vi_3 = _mm256_extract_epi32(vi, 3);
  1345. int32_t vi_4 = _mm256_extract_epi32(vi, 4);
  1346. int32_t vi_5 = _mm256_extract_epi32(vi, 5);
  1347. int32_t vi_6 = _mm256_extract_epi32(vi, 6);
  1348. int32_t vi_7 = _mm256_extract_epi32(vi, 7);
  1349. // convert to 4-bit, 2 consecutive packed into 1 byte
  1350. pp[4*l + 0] = vi_0 | (vi_1 << 4);
  1351. pp[4*l + 1] = vi_2 | (vi_3 << 4);
  1352. pp[4*l + 2] = vi_4 | (vi_5 << 4);
  1353. pp[4*l + 3] = vi_6 | (vi_7 << 4);
  1354. //printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
  1355. ////printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
  1356. assert(vi_0 >= 0 && vi_0 < 16);
  1357. assert(vi_1 >= 0 && vi_1 < 16);
  1358. assert(vi_2 >= 0 && vi_2 < 16);
  1359. assert(vi_3 >= 0 && vi_3 < 16);
  1360. assert(vi_4 >= 0 && vi_4 < 16);
  1361. assert(vi_5 >= 0 && vi_5 < 16);
  1362. assert(vi_6 >= 0 && vi_6 < 16);
  1363. assert(vi_7 >= 0 && vi_7 < 16);
  1364. }
  1365. memcpy(pb + i*QK/2, pp, sizeof(pp));
  1366. }
  1367. #elif defined(__ARM_NEON) && 0
  1368. {
  1369. // TODO
  1370. }
  1371. #else
  1372. {
  1373. for (int l = 0; l < QK; l++) {
  1374. const float v = src[i*QK + l];
  1375. amax = MAX(amax, fabsf(v));
  1376. }
  1377. const float d = amax / ((1 << (QB - 1)) - 1);
  1378. const float id = d ? 1.0/d : 0.0;
  1379. pd[i] = GGML_FP32_TO_GQ(d);
  1380. for (int l = 0; l < QK; l++) {
  1381. const float v = src[i*QK + l]*id;
  1382. const int8_t vi = ((int8_t) (round(v))) + 8;
  1383. assert(vi >= 0 && vi < 16);
  1384. pp[l/2] |= (vi & 0xf) << (4*(l & 1));
  1385. }
  1386. memcpy(pb + i*QK/2, pp, sizeof(pp));
  1387. }
  1388. #endif
  1389. //printf("min %f max %f\n", min, max);
  1390. }
  1391. }
  1392. // reimplementation of quantize_5 using quantize_5_row
  1393. void quantize_5(const float * restrict src, char * restrict dst, int n, int k) {
  1394. assert(k % QK == 0);
  1395. for (int j = 0; j < n; j++) {
  1396. quantize_5_row(src + j*k, dst, k);
  1397. dst = (char *) dst + quantize_5_row_size(k);
  1398. }
  1399. }
  1400. void vec_dot_gq_5(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
  1401. const int nb = quantize_5_blocks_per_row(n);
  1402. const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
  1403. const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
  1404. const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
  1405. const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
  1406. float sumf = 0.0;
  1407. #if 0
  1408. // scalar
  1409. for (int i = 0; i < nb; i++) {
  1410. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1411. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1412. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1413. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1414. for (int j = 0; j < QK/2; j++) {
  1415. const uint8_t v0 = p0[j];
  1416. const uint8_t v1 = p1[j];
  1417. const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
  1418. const float f1 = d0*((int8_t) (v0 >> 4) - 8);
  1419. const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
  1420. const float f3 = d1*((int8_t) (v1 >> 4) - 8);
  1421. sumf += f0*f2 + f1*f3;
  1422. }
  1423. }
  1424. #else
  1425. #if defined(__AVX2__)
  1426. #if QK == 64 && 1
  1427. __m256 sum11 = _mm256_setzero_ps();
  1428. for (int i = 0; i < nb; i++) {
  1429. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1430. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1431. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1432. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1433. const __m256 d0v = _mm256_set1_ps(d0);
  1434. const __m256 d1v = _mm256_set1_ps(d1);
  1435. const __m256 d0d1v = _mm256_mul_ps(d0v, d1v);
  1436. const __m256i m4b = _mm256_set1_epi8(0xf);
  1437. // 64 x 4
  1438. const __m256i v0 = _mm256_loadu_si256((__m256i *) p0);
  1439. const __m256i v1 = _mm256_loadu_si256((__m256i *) p1);
  1440. // 32 x 8
  1441. __m256i v0l = _mm256_and_si256(v0, m4b);
  1442. __m256i v1l = _mm256_and_si256(v1, m4b);
  1443. __m256i v0h = _mm256_and_si256(_mm256_srli_epi16(v0, 4), m4b);
  1444. __m256i v1h = _mm256_and_si256(_mm256_srli_epi16(v1, 4), m4b);
  1445. // sub 8
  1446. v0l = _mm256_sub_epi8(v0l, _mm256_set1_epi8(8));
  1447. v0h = _mm256_sub_epi8(v0h, _mm256_set1_epi8(8));
  1448. v1l = _mm256_sub_epi8(v1l, _mm256_set1_epi8(8));
  1449. v1h = _mm256_sub_epi8(v1h, _mm256_set1_epi8(8));
  1450. // abs
  1451. const __m256i v0la = _mm256_sign_epi8(v0l, v0l);
  1452. const __m256i v0ha = _mm256_sign_epi8(v0h, v0h);
  1453. // sign
  1454. const __m256i v1ls = _mm256_sign_epi8(v1l, v0l);
  1455. const __m256i v1hs = _mm256_sign_epi8(v1h, v0h);
  1456. const __m256i pl = _mm256_maddubs_epi16(v0la, v1ls);
  1457. const __m256i ph = _mm256_maddubs_epi16(v0ha, v1hs);
  1458. const __m256i p16 = _mm256_add_epi16(ph, pl);
  1459. const __m256i p = _mm256_madd_epi16(_mm256_set1_epi16(1), p16);
  1460. sum11 = _mm256_fmadd_ps(d0d1v, _mm256_cvtepi32_ps(p), sum11);
  1461. }
  1462. sumf = _mm256_hadd_ps_gg(sum11);
  1463. #endif
  1464. #elif defined (__ARM_NEON)
  1465. float sum11 = 0.0f;
  1466. //float32x4_t sum_0 = vdupq_n_f32(0.0f);
  1467. //float32x4_t sum_1 = vdupq_n_f32(0.0f);
  1468. //float16x8_t sum_0 = vdupq_n_f16(0.0f);
  1469. //float16x8_t sum_1 = vdupq_n_f16(0.0f);
  1470. for (int i = 0; i < nb; i++) {
  1471. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1472. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1473. //float32x4_t d0d1v = vdupq_n_f32(d0*d1);
  1474. //float16x8_t d0d1v = vdupq_n_f16(d0*d1);
  1475. const uint8_t * restrict p0 = pb0 + i*QK/2;
  1476. const uint8_t * restrict p1 = pb1 + i*QK/2;
  1477. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1478. const int8x16_t s8b = vdupq_n_s8(0x8);
  1479. const uint8x16_t v0_0 = vld1q_u8(p0);
  1480. const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
  1481. const uint8x16_t v1_0 = vld1q_u8(p1);
  1482. const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
  1483. // 4-bit -> 8-bit
  1484. const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
  1485. const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
  1486. const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
  1487. const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
  1488. const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
  1489. const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
  1490. const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
  1491. const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
  1492. // sub 8
  1493. const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
  1494. const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
  1495. const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
  1496. const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
  1497. const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
  1498. const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
  1499. const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
  1500. const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
  1501. // dot product into int16x8_t
  1502. const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
  1503. const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
  1504. const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
  1505. const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
  1506. const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
  1507. const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
  1508. const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
  1509. const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
  1510. const int16x8_t pl0 = vaddq_s16(pl0l, pl0h);
  1511. const int16x8_t pl1 = vaddq_s16(pl1l, pl1h);
  1512. const int16x8_t ph0 = vaddq_s16(ph0l, ph0h);
  1513. const int16x8_t ph1 = vaddq_s16(ph1l, ph1h);
  1514. const int16x8_t pl = vaddq_s16(pl0, pl1);
  1515. const int16x8_t ph = vaddq_s16(ph0, ph1);
  1516. //const int8x16_t pl0 = vmulq_s8(v0_0ls, v1_0ls);
  1517. //const int8x16_t pl1 = vmulq_s8(v0_1ls, v1_1ls);
  1518. //const int8x16_t ph0 = vmulq_s8(v0_0hs, v1_0hs);
  1519. //const int8x16_t ph1 = vmulq_s8(v0_1hs, v1_1hs);
  1520. //const int16x8_t pll = vaddl_s8(vget_low_s8(pl0), vget_low_s8(pl1));
  1521. //const int16x8_t plh = vaddl_s8(vget_high_s8(pl0), vget_high_s8(pl1));
  1522. //const int16x8_t phl = vaddl_s8(vget_low_s8(ph0), vget_low_s8(ph1));
  1523. //const int16x8_t phh = vaddl_s8(vget_high_s8(ph0), vget_high_s8(ph1));
  1524. //const int16x8_t pl = vaddq_s16(pll, plh);
  1525. //const int16x8_t ph = vaddq_s16(phl, phh);
  1526. const int16x8_t p = vaddq_s16(pl, ph);
  1527. // convert to float
  1528. //const float32x4_t pf0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (p)));
  1529. //const float32x4_t pf1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(p)));
  1530. // scalar
  1531. sum11 += d0*d1*vaddvq_s16(p);
  1532. //sum11 += d0*d1*(vaddvq_s16(pl) + vaddvq_s16(ph));
  1533. //sum11 += d0*d1*vaddvq_s16(vaddq_s16(pl, ph));
  1534. //sum11 += d0*d1*(vaddvq_s8(pl0) + vaddvq_s8(pl1) + vaddvq_s8(ph0) + vaddvq_s8(ph1));
  1535. //sum11 += d0*d1*(vaddvq_s16(pll) + vaddvq_s16(plh) + vaddvq_s16(phl) + vaddvq_s16(phh));
  1536. //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(p));
  1537. //sum_0 = vfmaq_f16(sum_0, d0d1v, vcvtq_f16_s16(pl));
  1538. //sum_1 = vfmaq_f16(sum_1, d0d1v, vcvtq_f16_s16(ph));
  1539. // vectorize
  1540. //sum_0 = vmlaq_f32(sum_0, d0d1v, pf0);
  1541. //sum_1 = vmlaq_f32(sum_1, d0d1v, pf1);
  1542. }
  1543. sumf = sum11;
  1544. //sumf = vaddvq_f32(sum_0) + vaddvq_f32(sum_1);
  1545. //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
  1546. //sum_0 = vaddq_f16(sum_0, sum_1);
  1547. //sumf = sum_0[0] + sum_0[1] + sum_0[2] + sum_0[3] + sum_0[4] + sum_0[5] + sum_0[6] + sum_0[7];
  1548. #endif
  1549. #endif
  1550. *s = sumf;
  1551. }
  1552. // use vec_dot_gq_5 to compute the dot product of two rows
  1553. void mul_mat_gq_5(
  1554. const void * src0,
  1555. const void * src1, // transposed
  1556. float * dst,
  1557. int m, int n, int k) {
  1558. assert(k % QK == 0);
  1559. const int nb = quantize_5_blocks_per_row(k);
  1560. for (int ir0 = 0; ir0 < m; ir0++) {
  1561. for (int ir1 = 0; ir1 < n; ir1++) {
  1562. vec_dot_gq_5(k, dst + ir1, src0, src1);
  1563. src1 = (const char *) src1 + quantize_5_row_size(k);
  1564. }
  1565. src0 = (const char *) src0 + quantize_5_row_size(k);
  1566. src1 = (const char *) src1 - n*quantize_5_row_size(k);
  1567. dst = (float *) dst + n;
  1568. }
  1569. }
  1570. //
  1571. // method 6
  1572. // same as 5 but with 32 element blocks
  1573. //
  1574. static inline int quantize_6_blocks_per_row(int k) {
  1575. return k/32;
  1576. }
  1577. static inline int quantize_6_row_size(int k) {
  1578. const int nb = quantize_6_blocks_per_row(k);
  1579. return nb*(sizeof(gq_scale_t) + 16);
  1580. }
  1581. void quantize_6_row(const float * restrict src, void * restrict dst, int k) {
  1582. assert(k % 32 == 0);
  1583. assert(QB == 4);
  1584. const int nb = quantize_6_blocks_per_row(k);
  1585. gq_scale_t * restrict pd = (gq_scale_t *) (dst);
  1586. uint8_t * restrict pb = (uint8_t *) (pd + nb);
  1587. uint8_t pp[16];
  1588. for (int i = 0; i < nb; i++) {
  1589. memset(pp, 0, sizeof(pp));
  1590. float amax = 0.0f; // absolute max
  1591. #if defined(__AVX2__)
  1592. {
  1593. enum { QK8 = 4 };
  1594. __m256 srcv [QK8];
  1595. __m256 asrcv[QK8];
  1596. __m256 amaxv[QK8];
  1597. for (int l = 0; l < QK8; l++) {
  1598. srcv[l] = _mm256_loadu_ps(src + i*32 + 8*l);
  1599. }
  1600. for (int l = 0; l < QK8; l++) {
  1601. asrcv[l] = _mm256_and_ps(srcv[l], _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)));
  1602. }
  1603. for (int l = 0; l < QK8/2; l++) {
  1604. amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
  1605. }
  1606. for (int l = 0; l < QK8/4; l++) {
  1607. amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
  1608. }
  1609. const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
  1610. const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
  1611. const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
  1612. const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
  1613. const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
  1614. const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
  1615. amax = _mm256_cvtss_f32(amaxv0_5);
  1616. const float d = amax / ((1 << (QB - 1)) - 1);
  1617. const float id = d ? 1.0/d : 0.0;
  1618. pd[i] = GGML_FP32_TO_GQ(d);
  1619. const __m256 idv = _mm256_set1_ps(id);
  1620. for (int l = 0; l < 4; l++) {
  1621. __m256 v = _mm256_mul_ps(srcv[l], idv);
  1622. // convert to int8
  1623. __m256i vi = _mm256_cvtps_epi32(v);
  1624. vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
  1625. int32_t vi_0 = _mm256_extract_epi32(vi, 0);
  1626. int32_t vi_1 = _mm256_extract_epi32(vi, 1);
  1627. int32_t vi_2 = _mm256_extract_epi32(vi, 2);
  1628. int32_t vi_3 = _mm256_extract_epi32(vi, 3);
  1629. int32_t vi_4 = _mm256_extract_epi32(vi, 4);
  1630. int32_t vi_5 = _mm256_extract_epi32(vi, 5);
  1631. int32_t vi_6 = _mm256_extract_epi32(vi, 6);
  1632. int32_t vi_7 = _mm256_extract_epi32(vi, 7);
  1633. // convert to 4-bit, 2 consecutive packed into 1 byte
  1634. pp[4*l + 0] = vi_0 | (vi_1 << 4);
  1635. pp[4*l + 1] = vi_2 | (vi_3 << 4);
  1636. pp[4*l + 2] = vi_4 | (vi_5 << 4);
  1637. pp[4*l + 3] = vi_6 | (vi_7 << 4);
  1638. assert(vi_0 >= 0 && vi_0 < 16);
  1639. assert(vi_1 >= 0 && vi_1 < 16);
  1640. assert(vi_2 >= 0 && vi_2 < 16);
  1641. assert(vi_3 >= 0 && vi_3 < 16);
  1642. assert(vi_4 >= 0 && vi_4 < 16);
  1643. assert(vi_5 >= 0 && vi_5 < 16);
  1644. assert(vi_6 >= 0 && vi_6 < 16);
  1645. assert(vi_7 >= 0 && vi_7 < 16);
  1646. }
  1647. memcpy(pb + i*16, pp, sizeof(pp));
  1648. }
  1649. #elif defined(__ARM_NEON)
  1650. {
  1651. float32x4_t srcv [8];
  1652. float32x4_t asrcv[8];
  1653. float32x4_t amaxv[8];
  1654. for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(src + i*32 + 4*l);
  1655. for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
  1656. for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
  1657. for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
  1658. for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
  1659. amax = MAX(
  1660. MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
  1661. MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
  1662. const float d = amax / ((1 << 3) - 1);
  1663. const float id = d ? 1.0/d : 0.0;
  1664. pd[i] = GGML_FP32_TO_GQ(d);
  1665. for (int l = 0; l < 8; l++) {
  1666. const float32x4_t v = vmulq_n_f32(srcv[l], id);
  1667. const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
  1668. const int32x4_t vi = vcvtq_s32_f32(vf);
  1669. pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
  1670. pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
  1671. }
  1672. memcpy(pb + i*16, pp, sizeof(pp));
  1673. }
  1674. #else
  1675. {
  1676. for (int l = 0; l < 32; l++) {
  1677. const float v = src[i*32 + l];
  1678. amax = MAX(amax, fabsf(v));
  1679. }
  1680. const float d = amax / ((1 << (QB - 1)) - 1);
  1681. const float id = d ? 1.0/d : 0.0;
  1682. pd[i] = GGML_FP32_TO_GQ(d);
  1683. for (int l = 0; l < 32; l++) {
  1684. const float v = src[i*32 + l]*id;
  1685. const int8_t vi = ((int8_t) (round(v))) + 8;
  1686. assert(vi >= 0 && vi < 16);
  1687. pp[l/2] |= (vi & 0xf) << (4*(l & 1));
  1688. }
  1689. memcpy(pb + i*16, pp, sizeof(pp));
  1690. }
  1691. #endif
  1692. //printf("amax = %f\n", amax);
  1693. }
  1694. }
  1695. // reimplementation of quantize__6using quantize_6_row
  1696. void quantize_6(const float * restrict src, char * restrict dst, int n, int k) {
  1697. assert(k % 32 == 0);
  1698. for (int j = 0; j < n; j++) {
  1699. quantize_6_row(src + j*k, dst, k);
  1700. dst = (char *) dst + quantize_6_row_size(k);
  1701. }
  1702. }
  1703. void vec_dot_gq_6(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
  1704. const int nb = quantize_6_blocks_per_row(n);
  1705. const gq_scale_t * restrict pd0 = (const gq_scale_t *) x;
  1706. const gq_scale_t * restrict pd1 = (const gq_scale_t *) y;
  1707. const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
  1708. const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
  1709. float sumf = 0.0;
  1710. #if 0
  1711. // scalar
  1712. for (int i = 0; i < nb; i++) {
  1713. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1714. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1715. const uint8_t * restrict p0 = pb0 + i*16;
  1716. const uint8_t * restrict p1 = pb1 + i*16;
  1717. for (int j = 0; j < 16; j++) {
  1718. const uint8_t v0 = p0[j];
  1719. const uint8_t v1 = p1[j];
  1720. const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
  1721. const float f1 = d0*((int8_t) (v0 >> 4) - 8);
  1722. const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
  1723. const float f3 = d1*((int8_t) (v1 >> 4) - 8);
  1724. sumf += f0*f2 + f1*f3;
  1725. }
  1726. }
  1727. #else
  1728. #if defined(__AVX2__)
  1729. // TODO
  1730. #elif defined (__ARM_NEON)
  1731. #if 0
  1732. float sum0 = 0.0f;
  1733. for (int i = 0; i < nb; i++) {
  1734. const float d0 = GGML_GQ_TO_FP32(pd0[i]);
  1735. const float d1 = GGML_GQ_TO_FP32(pd1[i]);
  1736. //float32x4_t d0d1v = vdupq_n_f32(d0*d1);
  1737. //float16x8_t d0d1v = vdupq_n_f16(d0*d1);
  1738. const uint8_t * restrict p0 = pb0 + i*16;
  1739. const uint8_t * restrict p1 = pb1 + i*16;
  1740. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1741. const int8x16_t s8b = vdupq_n_s8(0x8);
  1742. const uint8x16_t v0_0 = vld1q_u8(p0);
  1743. const uint8x16_t v1_0 = vld1q_u8(p1);
  1744. // 4-bit -> 8-bit
  1745. const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
  1746. const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
  1747. const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
  1748. const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
  1749. // sub 8
  1750. const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
  1751. const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
  1752. const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
  1753. const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
  1754. // dot product into int16x8_t
  1755. const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
  1756. const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
  1757. const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
  1758. const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
  1759. const int16x8_t pl = vaddq_s16(pl0l, pl0h);
  1760. const int16x8_t ph = vaddq_s16(ph0l, ph0h);
  1761. const int16x8_t p = vaddq_s16(pl, ph);
  1762. // scalar
  1763. sum0 += d0*d1*vaddvq_s16(p);
  1764. }
  1765. sumf = sum0;
  1766. #elif 1 // this is a bit faster than the above
  1767. float sum0 = 0.0f;
  1768. float sum1 = 0.0f;
  1769. for (int i = 0; i < nb; i += 2) {
  1770. const float d0_0 = GGML_GQ_TO_FP32(pd0[i + 0]);
  1771. const float d1_0 = GGML_GQ_TO_FP32(pd1[i + 0]);
  1772. const float d0_1 = GGML_GQ_TO_FP32(pd0[i + 1]);
  1773. const float d1_1 = GGML_GQ_TO_FP32(pd1[i + 1]);
  1774. const uint8_t * restrict p0 = pb0 + i*16;
  1775. const uint8_t * restrict p1 = pb1 + i*16;
  1776. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1777. const int8x16_t s8b = vdupq_n_s8(0x8);
  1778. const uint8x16_t v0_0 = vld1q_u8(p0);
  1779. const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
  1780. const uint8x16_t v1_0 = vld1q_u8(p1);
  1781. const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
  1782. // 4-bit -> 8-bit
  1783. const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
  1784. const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
  1785. const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
  1786. const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
  1787. const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
  1788. const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
  1789. const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
  1790. const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
  1791. // sub 8
  1792. const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
  1793. const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
  1794. const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
  1795. const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
  1796. const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
  1797. const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
  1798. const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
  1799. const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
  1800. // dot product into int16x8_t
  1801. const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
  1802. const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
  1803. const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
  1804. const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
  1805. const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
  1806. const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
  1807. const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
  1808. const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
  1809. const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
  1810. const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
  1811. const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
  1812. const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
  1813. const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
  1814. const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
  1815. // scalar
  1816. sum0 += d0_0*d1_0*vaddvq_s16(p_0);
  1817. sum1 += d0_1*d1_1*vaddvq_s16(p_1);
  1818. }
  1819. sumf = sum0 + sum1;
  1820. #endif
  1821. #endif
  1822. #endif
  1823. *s = sumf;
  1824. }
  1825. // use vec_dot_gq_6 to compute the dot product of two rows
  1826. void mul_mat_gq_6(
  1827. const void * src0,
  1828. const void * src1, // transposed
  1829. float * dst,
  1830. int m, int n, int k) {
  1831. assert(k % 32 == 0);
  1832. for (int ir0 = 0; ir0 < m; ir0++) {
  1833. for (int ir1 = 0; ir1 < n; ir1++) {
  1834. vec_dot_gq_6(k, dst + ir1, src0, src1);
  1835. src1 = (const char *) src1 + quantize_6_row_size(k);
  1836. }
  1837. src0 = (const char *) src0 + quantize_6_row_size(k);
  1838. src1 = (const char *) src1 - n*quantize_6_row_size(k);
  1839. dst = (float *) dst + n;
  1840. }
  1841. }
  1842. int main(int argc, const char ** argv) {
  1843. assert(sizeof(gq_quant_t)*8 == gq_t_bits);
  1844. ggml_time_init();
  1845. // needed to initialize f16 tables
  1846. {
  1847. struct ggml_init_params params = { 0, NULL, false };
  1848. struct ggml_context * ctx = ggml_init(params);
  1849. ggml_free(ctx);
  1850. }
  1851. int method = 0;
  1852. if (argc > 1) {
  1853. method = atoi(argv[1]);
  1854. }
  1855. float * src0 = malloc(sizeof(float)*M*K);
  1856. float * src1 = malloc(sizeof(float)*N*K);
  1857. float * dst = malloc(sizeof(float)*M*N);
  1858. // allocate aligned memory
  1859. //float * src0 = (float *)aligned_alloc(32, sizeof(float)*M*K);
  1860. //float * src1 = (float *)aligned_alloc(32, sizeof(float)*N*K);
  1861. //float * dst = (float *)aligned_alloc(32, sizeof(float)*M*N);
  1862. for (int i = 0; i < M*K; i++) {
  1863. src0[i] = 0.8 - rand() / (float)RAND_MAX;
  1864. /*src0[i] = rand() / (float)RAND_MAX;*/
  1865. /*src0[i] = i % 2;*/
  1866. }
  1867. for (int i = 0; i < N*K; i++) {
  1868. src1[i] = 0.8 - rand() / (float)RAND_MAX;
  1869. /*src1[i] = rand() / (float)RAND_MAX;*/
  1870. /*src1[i] = i % 3;*/
  1871. }
  1872. void * src0_gq = NULL;
  1873. void * src1_gq = NULL;
  1874. size_t sizegq = 0;
  1875. {
  1876. if (method == 1) {
  1877. src0_gq = calloc(1, quantize_1_row_size(K)*M);
  1878. src1_gq = calloc(1, quantize_1_row_size(K)*N);
  1879. sizegq = quantize_1_row_size(K)*M + quantize_1_row_size(K)*N;
  1880. }
  1881. if (method == 2) {
  1882. src0_gq = calloc(1, quantize_2_row_size(K)*M);
  1883. src1_gq = calloc(1, quantize_2_row_size(K)*N);
  1884. sizegq = quantize_2_row_size(K)*M + quantize_2_row_size(K)*N;
  1885. }
  1886. if (method == 3) {
  1887. src0_gq = calloc(1, quantize_3_row_size(K)*M);
  1888. src1_gq = calloc(1, quantize_3_row_size(K)*N);
  1889. sizegq = quantize_3_row_size(K)*M + quantize_3_row_size(K)*N;
  1890. }
  1891. if (method == 4) {
  1892. src0_gq = calloc(1, quantize_4_row_size(K)*M);
  1893. src1_gq = calloc(1, quantize_4_row_size(K)*N);
  1894. sizegq = quantize_4_row_size(K)*M + quantize_4_row_size(K)*N;
  1895. }
  1896. if (method == 5) {
  1897. src0_gq = calloc(1, quantize_5_row_size(K)*M);
  1898. src1_gq = calloc(1, quantize_5_row_size(K)*N);
  1899. sizegq = quantize_5_row_size(K)*M + quantize_5_row_size(K)*N;
  1900. }
  1901. if (method == 6) {
  1902. src0_gq = calloc(1, quantize_6_row_size(K)*M);
  1903. src1_gq = calloc(1, quantize_6_row_size(K)*N);
  1904. sizegq = quantize_6_row_size(K)*M + quantize_6_row_size(K)*N;
  1905. }
  1906. }
  1907. const size_t sizef16 = sizeof(ggml_fp16_t)*M*K + sizeof(ggml_fp16_t)*N*K;
  1908. printf("compression: %f\n", (float)sizegq/sizef16);
  1909. // convert fp32 -> gq
  1910. {
  1911. const int64_t t_start = ggml_time_us();
  1912. if (method == 1) {
  1913. quantize_1(src0, src0_gq, M, K);
  1914. quantize_1(src1, src1_gq, N, K);
  1915. }
  1916. if (method == 2) {
  1917. quantize_2(src0, src0_gq, M, K);
  1918. quantize_2(src1, src1_gq, N, K);
  1919. }
  1920. if (method == 3) {
  1921. quantize_3(src0, src0_gq, M, K);
  1922. quantize_3(src1, src1_gq, N, K);
  1923. }
  1924. if (method == 4) {
  1925. quantize_4(src0, src0_gq, M, K);
  1926. quantize_4(src1, src1_gq, N, K);
  1927. }
  1928. if (method == 5) {
  1929. quantize_5(src0, src0_gq, M, K);
  1930. quantize_5(src1, src1_gq, N, K);
  1931. }
  1932. if (method == 6) {
  1933. quantize_6(src0, src0_gq, M, K);
  1934. quantize_6(src1, src1_gq, N, K);
  1935. }
  1936. const int64_t t_end = ggml_time_us();
  1937. printf("convert time: %f ms / method = %d\n", (t_end - t_start) / 1000.0, method);
  1938. }
  1939. for (int i = 0; i < 16; ++i) {
  1940. printf("%f %f\n", src0[i], src1[i]);
  1941. }
  1942. const int nIter = 1;
  1943. const int64_t start = ggml_cycles();
  1944. const int64_t start_us = ggml_time_us();
  1945. double iM = 1.0/M;
  1946. double sum = 0.0f;
  1947. for (int i = 0; i < nIter; i++) {
  1948. if (method == 0) {
  1949. mul_mat_f32_naive(src0, src1, dst, M, N, K);
  1950. }
  1951. if (method == 1) {
  1952. mul_mat_gq_1(src0_gq, src1_gq, dst, M, N, K);
  1953. }
  1954. if (method == 2) {
  1955. mul_mat_gq_2(src0_gq, src1_gq, dst, M, N, K);
  1956. }
  1957. if (method == 3) {
  1958. mul_mat_gq_3(src0_gq, src1_gq, dst, M, N, K);
  1959. }
  1960. if (method == 4) {
  1961. mul_mat_gq_4(src0_gq, src1_gq, dst, M, N, K);
  1962. }
  1963. if (method == 5) {
  1964. mul_mat_gq_5(src0_gq, src1_gq, dst, M, N, K);
  1965. }
  1966. if (method == 6) {
  1967. mul_mat_gq_6(src0_gq, src1_gq, dst, M, N, K);
  1968. }
  1969. }
  1970. for (int i = 0; i < N; i++) {
  1971. sum += dst[i]*iM;
  1972. }
  1973. {
  1974. const int64_t end = ggml_cycles();
  1975. const int64_t end_us = ggml_time_us();
  1976. printf("%s: elapsed ticks: %" PRIu64 "\n", __func__, end - start);
  1977. printf("%s: elapsed us: %d / %f ms\n", __func__, (int)(end_us - start_us), (end_us - start_us) / 1000.0 / nIter);
  1978. }
  1979. #if 0
  1980. // print src0
  1981. printf("src0:\n");
  1982. for (int i = 0; i < M; i++) {
  1983. for (int j = 0; j < K; j++) {
  1984. printf("%4.1f ", src0[i*K+j]);
  1985. }
  1986. printf("\n");
  1987. }
  1988. // print src1
  1989. printf("src1:\n");
  1990. for (int i = 0; i < N; i++) {
  1991. for (int j = 0; j < K; j++) {
  1992. printf("%4.1f ", src1[i*K+j]);
  1993. }
  1994. printf("\n");
  1995. }
  1996. printf("dst:\n");
  1997. for (int i = 0; i < M; i++) {
  1998. for (int j = 0; j < N; j++) {
  1999. printf("%4.1f ", dst[i*N+j]);
  2000. }
  2001. printf("\n");
  2002. }
  2003. #endif
  2004. printf("%f\n", sum);
  2005. free(src0);
  2006. free(src1);
  2007. free(dst);
  2008. if (src0_gq) free(src0_gq);
  2009. if (src1_gq) free(src1_gq);
  2010. return 0;
  2011. }