main-cnn.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include "ggml/ggml.h"
  2. #include "common.h"
  3. #include <cmath>
  4. #include <cstdio>
  5. #include <cstring>
  6. #include <ctime>
  7. #include <fstream>
  8. #include <string>
  9. #include <vector>
  10. #include <algorithm>
  11. #if defined(_MSC_VER)
  12. #pragma warning(disable: 4244 4267) // possible loss of data
  13. #endif
  14. struct mnist_model {
  15. struct ggml_tensor * conv2d_1_kernel;
  16. struct ggml_tensor * conv2d_1_bias;
  17. struct ggml_tensor * conv2d_2_kernel;
  18. struct ggml_tensor * conv2d_2_bias;
  19. struct ggml_tensor * dense_weight;
  20. struct ggml_tensor * dense_bias;
  21. struct ggml_context * ctx;
  22. };
  23. bool mnist_model_load(const std::string & fname, mnist_model & model) {
  24. struct gguf_init_params params = {
  25. /*.no_alloc =*/ false,
  26. /*.ctx =*/ &model.ctx,
  27. };
  28. gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
  29. if (!ctx) {
  30. fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
  31. return false;
  32. }
  33. model.conv2d_1_kernel = ggml_get_tensor(model.ctx, "kernel1");
  34. model.conv2d_1_bias = ggml_get_tensor(model.ctx, "bias1");
  35. model.conv2d_2_kernel = ggml_get_tensor(model.ctx, "kernel2");
  36. model.conv2d_2_bias = ggml_get_tensor(model.ctx, "bias2");
  37. model.dense_weight = ggml_get_tensor(model.ctx, "dense_w");
  38. model.dense_bias = ggml_get_tensor(model.ctx, "dense_b");
  39. return true;
  40. }
  41. int mnist_eval(
  42. const mnist_model & model,
  43. const int n_threads,
  44. std::vector<float> digit,
  45. const char * fname_cgraph
  46. )
  47. {
  48. static size_t buf_size = 100000 * sizeof(float) * 4;
  49. static void * buf = malloc(buf_size);
  50. struct ggml_init_params params = {
  51. /*.mem_size =*/ buf_size,
  52. /*.mem_buffer =*/ buf,
  53. /*.no_alloc =*/ false,
  54. };
  55. struct ggml_context * ctx0 = ggml_init(params);
  56. struct ggml_cgraph gf = {};
  57. struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 28, 28, 1, 1);
  58. memcpy(input->data, digit.data(), ggml_nbytes(input));
  59. ggml_set_name(input, "input");
  60. ggml_tensor * cur = ggml_conv_2d(ctx0, model.conv2d_1_kernel, input, 1, 1, 0, 0, 1, 1);
  61. cur = ggml_add(ctx0, cur, model.conv2d_1_bias);
  62. cur = ggml_relu(ctx0, cur);
  63. // Output shape after Conv2D: (26 26 32 1)
  64. cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
  65. // Output shape after MaxPooling2D: (13 13 32 1)
  66. cur = ggml_conv_2d(ctx0, model.conv2d_2_kernel, cur, 1, 1, 0, 0, 1, 1);
  67. cur = ggml_add(ctx0, cur, model.conv2d_2_bias);
  68. cur = ggml_relu(ctx0, cur);
  69. // Output shape after Conv2D: (11 11 64 1)
  70. cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
  71. // Output shape after MaxPooling2D: (5 5 64 1)
  72. cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
  73. // Output shape after permute: (64 5 5 1)
  74. cur = ggml_reshape_2d(ctx0, cur, 1600, 1);
  75. // Final Dense layer
  76. cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.dense_weight, cur), model.dense_bias);
  77. ggml_tensor * probs = ggml_soft_max(ctx0, cur);
  78. ggml_set_name(probs, "probs");
  79. ggml_build_forward_expand(&gf, probs);
  80. ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
  81. //ggml_graph_print(&gf);
  82. ggml_graph_dump_dot(&gf, NULL, "mnist-cnn.dot");
  83. if (fname_cgraph) {
  84. // export the compute graph for later use
  85. // see the "mnist-cpu" example
  86. ggml_graph_export(&gf, fname_cgraph);
  87. fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
  88. }
  89. const float * probs_data = ggml_get_data_f32(probs);
  90. const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
  91. ggml_free(ctx0);
  92. return prediction;
  93. }
  94. int main(int argc, char ** argv) {
  95. srand(time(NULL));
  96. ggml_time_init();
  97. if (argc != 3) {
  98. fprintf(stderr, "Usage: %s models/mnist/mnist-cnn.gguf models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
  99. exit(0);
  100. }
  101. uint8_t buf[784];
  102. mnist_model model;
  103. std::vector<float> digit;
  104. // load the model
  105. {
  106. const int64_t t_start_us = ggml_time_us();
  107. if (!mnist_model_load(argv[1], model)) {
  108. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, argv[1]);
  109. return 1;
  110. }
  111. const int64_t t_load_us = ggml_time_us() - t_start_us;
  112. fprintf(stdout, "%s: loaded model in %8.2f ms\n", __func__, t_load_us / 1000.0f);
  113. }
  114. // read a random digit from the test set
  115. {
  116. std::ifstream fin(argv[2], std::ios::binary);
  117. if (!fin) {
  118. fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
  119. return 1;
  120. }
  121. // seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
  122. fin.seekg(16 + 784 * (rand() % 10000));
  123. fin.read((char *) &buf, sizeof(buf));
  124. }
  125. // render the digit in ASCII
  126. {
  127. digit.resize(sizeof(buf));
  128. for (int row = 0; row < 28; row++) {
  129. for (int col = 0; col < 28; col++) {
  130. fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
  131. digit[row*28 + col] = ((float)buf[row*28 + col] / 255.0f);
  132. }
  133. fprintf(stderr, "\n");
  134. }
  135. fprintf(stderr, "\n");
  136. }
  137. const int prediction = mnist_eval(model, 1, digit, nullptr);
  138. fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);
  139. ggml_free(model.ctx);
  140. return 0;
  141. }