main-mtl.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Use a pre-generated MNIST compute graph for inference on the M1 GPU via MPS
  2. //
  3. // You can generate a compute graph using the "mnist" tool:
  4. //
  5. // $ ./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
  6. //
  7. // This command creates the "mnist.ggml" file, which contains the generated compute graph.
  8. // Now, you can re-use the compute graph on the GPU with the "mnist-mtl" tool:
  9. //
  10. // $ ./bin/mnist-mtl ./models/mnist/mnist.ggml ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
  11. //
  12. #include "ggml/ggml.h"
  13. #include "main-mtl.h"
  14. #include <cmath>
  15. #include <cstdio>
  16. #include <cstring>
  17. #include <ctime>
  18. #include <fstream>
  19. #include <vector>
  20. // evaluate the MNIST compute graph
  21. //
  22. // - fname_cgraph: path to the compute graph
  23. // - digit: 784 pixel values
  24. //
  25. // returns 0 - 9 prediction
  26. int mnist_eval(
  27. const char * fname_cgraph,
  28. std::vector<float> digit
  29. ) {
  30. // load the compute graph
  31. struct ggml_context * ctx_data = NULL;
  32. struct ggml_context * ctx_eval = NULL;
  33. struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
  34. // allocate work context
  35. static size_t buf_size = 128ull*1024*1024; // TODO
  36. static void * buf = malloc(buf_size);
  37. struct ggml_init_params params = {
  38. /*.mem_size =*/ buf_size,
  39. /*.mem_buffer =*/ buf,
  40. /*.no_alloc =*/ false,
  41. };
  42. struct ggml_context * ctx_work = ggml_init(params);
  43. // this allocates all Metal resources and memory buffers
  44. auto ctx_mtl = mnist_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
  45. int prediction = -1;
  46. for (int i = 0; i < 1; ++i) {
  47. struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "input");
  48. if (i % 2 == 0) {
  49. memcpy(input->data, digit.data(), ggml_nbytes(input));
  50. } else {
  51. memset(input->data, 0, ggml_nbytes(input));
  52. }
  53. // the actual inference happens here
  54. prediction = mnist_mtl_eval(ctx_mtl, &gf);
  55. }
  56. mnist_mtl_free(ctx_mtl);
  57. ggml_free(ctx_work);
  58. ggml_free(ctx_data);
  59. ggml_free(ctx_eval);
  60. return prediction;
  61. }
  62. int main(int argc, char ** argv) {
  63. srand(time(NULL));
  64. ggml_time_init();
  65. if (argc != 3) {
  66. fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
  67. exit(0);
  68. }
  69. uint8_t buf[784];
  70. std::vector<float> digit;
  71. // read a random digit from the test set
  72. {
  73. std::ifstream fin(argv[2], std::ios::binary);
  74. if (!fin) {
  75. fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
  76. return 1;
  77. }
  78. // seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
  79. fin.seekg(16 + 784 * (rand() % 10000));
  80. fin.read((char *) &buf, sizeof(buf));
  81. }
  82. // render the digit in ASCII
  83. {
  84. digit.resize(sizeof(buf));
  85. for (int row = 0; row < 28; row++) {
  86. for (int col = 0; col < 28; col++) {
  87. fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
  88. digit[row*28 + col] = ((float)buf[row*28 + col]);
  89. }
  90. fprintf(stderr, "\n");
  91. }
  92. fprintf(stderr, "\n");
  93. }
  94. const int prediction = mnist_eval(argv[1], digit);
  95. fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);
  96. return 0;
  97. }