test3.c 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #include "ggml/ggml.h"
  2. #include <math.h>
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. bool is_close(float a, float b, float epsilon) {
  6. return fabs(a - b) < epsilon;
  7. }
  8. int main(int argc, const char ** argv) {
  9. struct ggml_init_params params = {
  10. .mem_size = 1024*1024*1024,
  11. .mem_buffer = NULL,
  12. .no_alloc = false,
  13. };
  14. //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
  15. struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_LBFGS);
  16. opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8;
  17. const int NP = 1 << 12;
  18. const int NF = 1 << 8;
  19. struct ggml_context * ctx0 = ggml_init(params);
  20. struct ggml_tensor * F = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, NF, NP);
  21. struct ggml_tensor * l = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NP);
  22. // regularization weight
  23. struct ggml_tensor * lambda = ggml_new_f32(ctx0, 1e-5f);
  24. srand(0);
  25. for (int j = 0; j < NP; j++) {
  26. const float ll = j < NP/2 ? 1.0f : -1.0f;
  27. ((float *)l->data)[j] = ll;
  28. for (int i = 0; i < NF; i++) {
  29. ((float *)F->data)[j*NF + i] = ((ll > 0 && i < NF/2 ? 1.0f : ll < 0 && i >= NF/2 ? 1.0f : 0.0f) + ((float)rand()/(float)RAND_MAX - 0.5f)*0.1f)/(0.5f*NF);
  30. }
  31. }
  32. {
  33. // initial guess
  34. struct ggml_tensor * x = ggml_set_f32(ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NF), 0.0f);
  35. ggml_set_param(ctx0, x);
  36. // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
  37. struct ggml_tensor * f =
  38. ggml_add(ctx0,
  39. ggml_div(ctx0,
  40. ggml_sum(ctx0,
  41. ggml_sqr(ctx0,
  42. ggml_sub(ctx0,
  43. ggml_mul_mat(ctx0, F, x),
  44. l)
  45. )
  46. ),
  47. ggml_new_f32(ctx0, (float)NP)
  48. ),
  49. ggml_mul(ctx0,
  50. ggml_sum(ctx0, ggml_sqr(ctx0, x)),
  51. lambda)
  52. );
  53. enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
  54. GGML_ASSERT(res == GGML_OPT_OK);
  55. // print results
  56. for (int i = 0; i < 16; i++) {
  57. printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
  58. }
  59. printf("...\n");
  60. for (int i = NF - 16; i < NF; i++) {
  61. printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
  62. }
  63. printf("\n");
  64. for (int i = 0; i < NF; ++i) {
  65. if (i < NF/2) {
  66. GGML_ASSERT(is_close(((float *)x->data)[i], 1.0f, 1e-2f));
  67. } else {
  68. GGML_ASSERT(is_close(((float *)x->data)[i], -1.0f, 1e-2f));
  69. }
  70. }
  71. }
  72. ggml_free(ctx0);
  73. return 0;
  74. }