main-mtl.m 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. #import "main-mtl.h"
  2. #import "ggml/ggml.h"
  3. #import <Foundation/Foundation.h>
  4. #import <Metal/Metal.h>
  5. #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
  6. // TODO: couldn't get this to work
  7. //#define GGML_MTL_HEAP
  8. struct ggml_mtl_context {
  9. struct ggml_context * ctx_data;
  10. struct ggml_context * ctx_eval;
  11. struct ggml_context * ctx_work;
  12. id<MTLDevice> device;
  13. id<MTLCommandQueue> queue;
  14. id<MTLLibrary> library;
  15. #ifdef GGML_MTL_HEAP
  16. id<MTLHeap> heap_data;
  17. id<MTLHeap> heap_eval;
  18. #else
  19. id<MTLBuffer> buffer_data;
  20. id<MTLBuffer> buffer_eval;
  21. #endif
  22. id<MTLBuffer> out;
  23. // custom kernels
  24. id<MTLFunction> function_add;
  25. id<MTLComputePipelineState> pipeline_add;
  26. id<MTLFunction> function_relu;
  27. id<MTLComputePipelineState> pipeline_relu;
  28. id<MTLFunction> function_soft_max;
  29. id<MTLComputePipelineState> pipeline_soft_max;
  30. };
  31. // MSL code
  32. NSString * const msl_library_mnist = @"\
  33. #include <metal_stdlib> \n\
  34. using namespace metal; \n\
  35. \n\
  36. #define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\
  37. \n\
  38. constant int k_digits [[function_constant(0)]]; \n\
  39. \n\
  40. kernel void kernel_add( \n\
  41. device const float * src0, \n\
  42. device const float * src1, \n\
  43. device float * dst, \n\
  44. uint gid[[thread_position_in_grid]]) { \n\
  45. dst[gid] = src0[gid] + src1[gid]; \n\
  46. } \n\
  47. \n\
  48. kernel void kernel_relu( \n\
  49. device const float * src, \n\
  50. device float * dst, \n\
  51. uint gid[[thread_position_in_grid]]) { \n\
  52. dst[gid] = max(0.0f, src[gid]); \n\
  53. } \n\
  54. \n\
  55. kernel void kernel_soft_max( \n\
  56. device const float * src, \n\
  57. device float * dst, \n\
  58. uint gid[[thread_position_in_grid]]) { \n\
  59. float max = 0.0f; \n\
  60. for (int i = 0; i < k_digits; i++) { \n\
  61. max = MAX(max, src[i]); \n\
  62. } \n\
  63. float sum = 0.0f; \n\
  64. for (int i = 0; i < k_digits; i++) { \n\
  65. dst[i] = exp(src[i] - max); \n\
  66. sum += dst[i]; \n\
  67. } \n\
  68. for (int i = 0; i < k_digits; i++) { \n\
  69. dst[i] /= sum; \n\
  70. } \n\
  71. } \n\
  72. ";
  73. struct ggml_mtl_context * mnist_mtl_init(
  74. struct ggml_context * ctx_data,
  75. struct ggml_context * ctx_eval,
  76. struct ggml_context * ctx_work,
  77. struct ggml_cgraph * gf) {
  78. fprintf(stderr, "%s: allocating\n", __func__);
  79. struct ggml_mtl_context * ctx = malloc(sizeof(struct ggml_mtl_context));
  80. ctx->ctx_data = ctx_data;
  81. ctx->ctx_eval = ctx_eval;
  82. ctx->ctx_work = ctx_work;
  83. ctx->device = MTLCreateSystemDefaultDevice();
  84. ctx->queue = [ctx->device newCommandQueue];
  85. // determine if we can use MPS
  86. if (MPSSupportsMTLDevice(ctx->device)) {
  87. fprintf(stderr, "%s: using MPS\n", __func__);
  88. } else {
  89. fprintf(stderr, "%s: not using MPS\n", __func__);
  90. GGML_ASSERT(false && "MPS not supported");
  91. }
  92. // compile from source string and show compile log
  93. {
  94. NSError * error = nil;
  95. ctx->library = [ctx->device newLibraryWithSource:msl_library_mnist options:nil error:&error];
  96. if (error) {
  97. fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
  98. exit(1);
  99. }
  100. }
  101. // load kernels
  102. {
  103. const int k_digits = ggml_graph_get_tensor(gf, "probs")->ne[0];
  104. MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new];
  105. [constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"];
  106. ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"];
  107. ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
  108. fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, ctx->pipeline_add);
  109. ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
  110. ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
  111. fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, ctx->pipeline_relu);
  112. ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil];
  113. ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil];
  114. fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, ctx->pipeline_soft_max);
  115. }
  116. #ifdef GGML_MTL_HEAP
  117. // MTLHeap approach
  118. // pin ctx_data memory to GPU
  119. // use MTLStorageModeShared to allow us to initialize the weights from the CPU
  120. // TODO: how to use MTLStorageModeManaged?
  121. // TODO: see if we can avoid this copy somehow
  122. {
  123. const void * mem_buffer = ggml_get_mem_buffer(ctx_data);
  124. const size_t mem_size = ggml_get_mem_size(ctx_data);
  125. MTLHeapDescriptor * heap_desc = [MTLHeapDescriptor new];
  126. heap_desc.storageMode = MTLStorageModeShared;
  127. heap_desc.size = mem_size;
  128. printf("heap_desc.size = %zu\n", mem_size);
  129. ctx->heap_data = [ctx->device newHeapWithDescriptor:heap_desc];
  130. [ctx->heap_data setPurgeableState:MTLPurgeableStateNonVolatile]; // TODO: is this needed?
  131. ctx->heap_data.label = @"heap_data";
  132. printf("ctx->heap_data.size = %zu\n", [ctx->heap_data size]);
  133. id<MTLBuffer> buffer = [ctx->heap_data newBufferWithLength:mem_size options:MTLResourceStorageModeShared];
  134. if (!buffer) {
  135. fprintf(stderr, "%s: error: failed to allocate buffer\n", __func__);
  136. exit(1);
  137. }
  138. // copy data from CPU to GPU
  139. memcpy([buffer contents], mem_buffer, mem_size);
  140. fprintf(stderr, "%s: allocated data heap, size = %zu\n", __func__, mem_size);
  141. }
  142. // pin ctx_eval memory to GPU
  143. // this heap will be used for the intermediate results of the evaluation
  144. {
  145. const size_t mem_size = ggml_get_mem_size(ctx_eval);
  146. MTLHeapDescriptor * heap_desc = [MTLHeapDescriptor new];
  147. heap_desc.storageMode = MTLStorageModePrivate; // GPU only
  148. heap_desc.size = mem_size;
  149. ctx->heap_eval = [ctx->device newHeapWithDescriptor:heap_desc];
  150. [ctx->heap_eval setPurgeableState:MTLPurgeableStateNonVolatile]; // TODO: is this needed?
  151. fprintf(stderr, "%s: allocated eval heap, size = %zu\n", __func__, mem_size);
  152. }
  153. #else
  154. // MTLBuffer approach
  155. // pin ctx_data memory to GPU
  156. // use MTLStorageModeShared to allow us to initialize the weights from the CPU
  157. // TODO: how to use MTLStorageModeManaged?
  158. // TODO: see if we can avoid this copy somehow
  159. {
  160. const void * mem_buffer = ggml_get_mem_buffer(ctx_data);
  161. const size_t mem_size = ggml_get_mem_size(ctx_data);
  162. ctx->buffer_data = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared];
  163. fprintf(stderr, "%s: allocated data buffer, size = %zu\n", __func__, mem_size);
  164. }
  165. // pin ctx_eval memory to GPU
  166. // this buffer will be used for the intermediate results of the evaluation
  167. {
  168. const size_t mem_size = ggml_get_mem_size(ctx_eval);
  169. ctx->buffer_eval = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModePrivate];
  170. fprintf(stderr, "%s: allocated eval buffer, size = %zu\n", __func__, mem_size);
  171. }
  172. #endif
  173. // allocate buffer for result extraction
  174. {
  175. const size_t mem_size = ggml_nbytes(gf->nodes[gf->n_nodes - 1]);
  176. ctx->out = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModeShared];
  177. fprintf(stderr, "%s: allocated out buffer, size = %zu\n", __func__, mem_size);
  178. }
  179. return ctx;
  180. }
  181. void mnist_mtl_free(struct ggml_mtl_context * ctx) {
  182. fprintf(stderr, "%s: deallocating\n", __func__);
  183. free(ctx);
  184. }
  185. #ifdef GGML_MTL_HEAP
  186. // make a view of the respective MTL heap
  187. id<MTLBuffer> mnist_mtl_get_buffer_on_heap(struct ggml_mtl_context * ctx, struct ggml_tensor * t) {
  188. const int64_t offs_data = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_data);
  189. const int64_t offs_eval = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_eval);
  190. const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
  191. const size_t t_size = ggml_nbytes(t);
  192. const size_t t_offs = is_data ? offs_data : offs_eval;
  193. id<MTLBuffer> result;
  194. if (is_data) {
  195. fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
  196. result = [ctx->heap_data newBufferWithLength:t_size options:MTLResourceStorageModeShared offset:t_offs];
  197. } else {
  198. fprintf(stderr, "%s: eval tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
  199. result = [ctx->heap_eval newBufferWithLength:t_size options:MTLResourceStorageModePrivate offset:t_offs];
  200. }
  201. if (result == nil) {
  202. fprintf(stderr, "%s: error: buffer is nil\n", __func__);
  203. GGML_ASSERT(false);
  204. }
  205. return result;
  206. }
  207. #else
  208. // get data / eval buffer + offset
  209. id<MTLBuffer> mnist_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_tensor * t, size_t * offs) {
  210. const int64_t offs_data = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_data);
  211. const int64_t offs_eval = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_eval);
  212. const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
  213. const size_t t_size = ggml_nbytes(t);
  214. const size_t t_offs = is_data ? offs_data : offs_eval;
  215. id<MTLBuffer> result;
  216. if (is_data) {
  217. fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
  218. result = ctx->buffer_data;
  219. } else {
  220. fprintf(stderr, "%s: eval tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
  221. result = ctx->buffer_eval;
  222. }
  223. if (result == nil) {
  224. fprintf(stderr, "%s: error: buffer is nil\n", __func__);
  225. GGML_ASSERT(false);
  226. }
  227. if (offs != nil) {
  228. *offs = t_offs;
  229. }
  230. return result;
  231. }
  232. #endif
  233. int mnist_mtl_eval(
  234. struct ggml_mtl_context * ctx,
  235. struct ggml_cgraph * gf) {
  236. fprintf(stderr, "%s: evaluating\n", __func__);
  237. id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
  238. id<MTLComputeCommandEncoder> encoder = nil;
  239. size_t offs_src0;
  240. size_t offs_src1;
  241. size_t offs_dst;
  242. // copy the input data to the GPU
  243. {
  244. struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "input");
  245. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, inp, &offs_src0);
  246. memcpy(id_dst.contents + offs_src0, inp->data, ggml_nbytes(inp));
  247. }
  248. for (int i = 0; i < gf->n_nodes; ++i) {
  249. fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
  250. switch (gf->nodes[i]->op) {
  251. case GGML_OP_ADD:
  252. {
  253. if (encoder == nil) {
  254. encoder = [command_buffer computeCommandEncoder];
  255. }
  256. id<MTLBuffer> id_src0 = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
  257. id<MTLBuffer> id_src1 = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[1], &offs_src1);
  258. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
  259. [encoder setComputePipelineState:ctx->pipeline_add];
  260. [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
  261. [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
  262. [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
  263. const int64_t n = ggml_nelements(gf->nodes[i]);
  264. [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
  265. } break;
  266. case GGML_OP_UNARY:
  267. switch (ggml_get_unary_op(gf->nodes[i])) {
  268. case GGML_UNARY_OP_RELU:
  269. {
  270. if (encoder == nil) {
  271. encoder = [command_buffer computeCommandEncoder];
  272. }
  273. id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
  274. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
  275. [encoder setComputePipelineState:ctx->pipeline_relu];
  276. [encoder setBuffer:id_src offset:offs_src0 atIndex:0];
  277. [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
  278. const int64_t n = ggml_nelements(gf->nodes[i]);
  279. [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
  280. } break;
  281. default:
  282. {
  283. fprintf(stderr, "%s: node %3d, op = %8s, unary op %d not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), (int) ggml_get_unary_op(gf->nodes[i]));
  284. GGML_ASSERT(false);
  285. return -1;
  286. }
  287. break;
  288. } break;
  289. case GGML_OP_SOFT_MAX:
  290. {
  291. #if 0
  292. // NOTE: MPSMatrixSoftMax is not working properly, probably there is a bug
  293. if (encoder != nil) {
  294. [encoder endEncoding];
  295. encoder = nil;
  296. }
  297. // use MPSMatrixSoftMax
  298. id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
  299. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
  300. MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
  301. matrixDescriptorWithRows:1 columns:gf->nodes[i]->ne[0] rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
  302. MPSMatrix * mat_src = [[MPSMatrix alloc] initWithBuffer:id_src offset:offs_src0 descriptor:desc];
  303. MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc];
  304. MPSMatrixSoftMax * softmax = [[MPSMatrixSoftMax alloc] initWithDevice:ctx->device];
  305. [softmax encodeToCommandBuffer:command_buffer inputMatrix:mat_src resultMatrix:mat_dst];
  306. #else
  307. if (encoder == nil) {
  308. encoder = [command_buffer computeCommandEncoder];
  309. }
  310. id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
  311. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
  312. [encoder setComputePipelineState:ctx->pipeline_soft_max];
  313. [encoder setBuffer:id_src offset:offs_src0 atIndex:0];
  314. [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
  315. [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
  316. #endif
  317. } break;
  318. case GGML_OP_MUL_MAT:
  319. {
  320. if (encoder != nil) {
  321. [encoder endEncoding];
  322. encoder = nil;
  323. }
  324. // use MPSMatrixMultiplication
  325. id<MTLBuffer> id_src0 = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
  326. id<MTLBuffer> id_src1 = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[1], &offs_src1);
  327. id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
  328. const int64_t ncols0 = gf->nodes[i]->src[0]->ne[0];
  329. const int64_t nrows0 = gf->nodes[i]->src[0]->ne[1];
  330. const int64_t ncols1 = gf->nodes[i]->src[1]->ne[0];
  331. const int64_t nrows1 = gf->nodes[i]->src[1]->ne[1];
  332. const int64_t ncols2 = gf->nodes[i]->ne[0];
  333. const int64_t nrows2 = gf->nodes[i]->ne[1];
  334. GGML_ASSERT(ncols0 == ncols1);
  335. MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
  336. matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src[0]->nb[1] dataType:MPSDataTypeFloat32];
  337. MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
  338. matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src[1]->nb[1] dataType:MPSDataTypeFloat32];
  339. MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor
  340. matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
  341. MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0];
  342. MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1];
  343. MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2];
  344. MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device
  345. transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
  346. [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
  347. } break;
  348. default:
  349. {
  350. fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
  351. GGML_ASSERT(false);
  352. return -1;
  353. }
  354. }
  355. }
  356. // extract results from the GPU
  357. {
  358. if (encoder != nil) {
  359. [encoder endEncoding];
  360. encoder = nil;
  361. }
  362. struct ggml_tensor * out = gf->nodes[gf->n_nodes - 1];
  363. id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, out, &offs_src0);
  364. id<MTLBuffer> id_dst = ctx->out;
  365. id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
  366. [encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
  367. [encoder_blit endEncoding];
  368. }
  369. [command_buffer commit];
  370. [command_buffer waitUntilCompleted];
  371. {
  372. const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
  373. fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
  374. }
  375. // select the most probable digit
  376. int result = -1;
  377. {
  378. const float * probs = ctx->out.contents;
  379. float prob = probs[0];
  380. for (int i = 0; i < 10; ++i) {
  381. fprintf(stderr, "%s: probs[%2d] = %f\n", __func__, i, probs[i]);
  382. if (probs[i] > prob) {
  383. result = i;
  384. prob = probs[i];
  385. }
  386. }
  387. }
  388. return result;
  389. }