main-cpu.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. // Use a pre-generated MNIST compute graph for inference on the CPU
  2. //
  3. // You can generate a compute graph using the "mnist" tool:
  4. //
  5. // $ ./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
  6. //
  7. // This command creates the "mnist.ggml" file, which contains the generated compute graph.
  8. // Now, you can re-use the compute graph with the "mnist-cpu" tool:
  9. //
  10. // $ ./bin/mnist-cpu ./models/mnist/mnist.ggml ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
  11. //
  12. #include "ggml/ggml.h"
  13. #include <algorithm>
  14. #include <cmath>
  15. #include <cstdio>
  16. #include <cstring>
  17. #include <ctime>
  18. #include <fstream>
  19. #include <vector>
  20. #if defined(_MSC_VER)
  21. #pragma warning(disable: 4244 4267) // possible loss of data
  22. #endif
  23. // evaluate the MNIST compute graph
  24. //
  25. // - fname_cgraph: path to the compute graph
  26. // - n_threads: number of threads to use
  27. // - digit: 784 pixel values
  28. //
  29. // returns 0 - 9 prediction
  30. int mnist_eval(
  31. const char * fname_cgraph,
  32. const int n_threads,
  33. std::vector<float> digit) {
  34. // load the compute graph
  35. struct ggml_context * ctx_data = NULL;
  36. struct ggml_context * ctx_eval = NULL;
  37. struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
  38. // param export/import test
  39. GGML_ASSERT(ggml_graph_get_tensor(&gfi, "fc1_bias")->op_params[0] == int(0xdeadbeef));
  40. // allocate work context
  41. // needed during ggml_graph_compute() to allocate a work tensor
  42. static size_t buf_size = 128ull*1024*1024; // TODO
  43. static void * buf = malloc(buf_size);
  44. struct ggml_init_params params = {
  45. /*.mem_size =*/ buf_size,
  46. /*.mem_buffer =*/ buf,
  47. /*.no_alloc =*/ false,
  48. };
  49. struct ggml_context * ctx_work = ggml_init(params);
  50. struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input");
  51. memcpy(input->data, digit.data(), ggml_nbytes(input));
  52. ggml_graph_compute_with_ctx(ctx_work, &gfi, n_threads);
  53. const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs"));
  54. const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
  55. ggml_free(ctx_work);
  56. ggml_free(ctx_data);
  57. ggml_free(ctx_eval);
  58. return prediction;
  59. }
  60. int main(int argc, char ** argv) {
  61. srand(time(NULL));
  62. ggml_time_init();
  63. if (argc != 3) {
  64. fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
  65. exit(0);
  66. }
  67. uint8_t buf[784];
  68. std::vector<float> digit;
  69. // read a random digit from the test set
  70. {
  71. std::ifstream fin(argv[2], std::ios::binary);
  72. if (!fin) {
  73. fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
  74. return 1;
  75. }
  76. // seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
  77. fin.seekg(16 + 784 * (rand() % 10000));
  78. fin.read((char *) &buf, sizeof(buf));
  79. }
  80. // render the digit in ASCII
  81. {
  82. digit.resize(sizeof(buf));
  83. for (int row = 0; row < 28; row++) {
  84. for (int col = 0; col < 28; col++) {
  85. fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
  86. digit[row*28 + col] = ((float)buf[row*28 + col]);
  87. }
  88. fprintf(stderr, "\n");
  89. }
  90. fprintf(stderr, "\n");
  91. }
  92. const int prediction = mnist_eval(argv[1], 1, digit);
  93. fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);
  94. return 0;
  95. }