test2.zig 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. const std = @import("std");
  2. const Thread = std.Thread;
  3. const c = @cImport({
  4. @cInclude("ggml/ggml.h");
  5. });
  6. fn is_close(a: f32, b: f32, epsilon: f32) bool {
  7. return std.math.fabs(a - b) < epsilon;
  8. }
  9. pub fn main() !void {
  10. const params = .{
  11. .mem_size = 128*1024*1024,
  12. .mem_buffer = null,
  13. .no_alloc = false,
  14. };
  15. var opt_params = c.ggml_opt_default_params(c.GGML_OPT_LBFGS);
  16. const nthreads = try Thread.getCpuCount();
  17. opt_params.n_threads = @intCast(nthreads);
  18. std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads});
  19. const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 };
  20. const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 };
  21. const n = xi.len;
  22. const ctx0 = c.ggml_init(params);
  23. defer c.ggml_free(ctx0);
  24. const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
  25. const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
  26. for (0..n) |i| {
  27. const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
  28. x_data_pointer[i] = xi[i];
  29. const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data));
  30. y_data_pointer[i] = yi[i];
  31. }
  32. {
  33. const t0 = c.ggml_new_f32(ctx0, 0.0);
  34. const t1 = c.ggml_new_f32(ctx0, 0.0);
  35. // initialize auto-diff parameters:
  36. _ = c.ggml_set_param(ctx0, t0);
  37. _ = c.ggml_set_param(ctx0, t1);
  38. // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)
  39. const f =
  40. c.ggml_div(ctx0,
  41. c.ggml_sum(ctx0,
  42. c.ggml_sqr(ctx0,
  43. c.ggml_sub(ctx0,
  44. c.ggml_add(ctx0,
  45. c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)),
  46. c.ggml_repeat(ctx0, t0, x)),
  47. y)
  48. )
  49. ),
  50. c.ggml_new_f32(ctx0, @as(f32, 2.0)*n));
  51. const res = c.ggml_opt(null, opt_params, f);
  52. std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)});
  53. std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)});
  54. try std.testing.expect(res == c.GGML_OPT_OK);
  55. try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3));
  56. try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3));
  57. }
  58. {
  59. const t0 = c.ggml_new_f32(ctx0, -1.0);
  60. const t1 = c.ggml_new_f32(ctx0, 9.0);
  61. _ = c.ggml_set_param(ctx0, t0);
  62. _ = c.ggml_set_param(ctx0, t1);
  63. // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n
  64. const f =
  65. c.ggml_mul(ctx0,
  66. c.ggml_new_f32(ctx0, @as(f32, 1.0)/(2*n)),
  67. c.ggml_sum(ctx0,
  68. c.ggml_abs(ctx0,
  69. c.ggml_sub(ctx0,
  70. c.ggml_add(ctx0,
  71. c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)),
  72. c.ggml_repeat(ctx0, t0, x)),
  73. y)
  74. )
  75. )
  76. );
  77. const res = c.ggml_opt(null, opt_params, f);
  78. try std.testing.expect(res == c.GGML_OPT_OK);
  79. try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2));
  80. try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2));
  81. }
  82. {
  83. const t0 = c.ggml_new_f32(ctx0, 5.0);
  84. const t1 = c.ggml_new_f32(ctx0, -4.0);
  85. _ = c.ggml_set_param(ctx0, t0);
  86. _ = c.ggml_set_param(ctx0, t1);
  87. // f = t0^2 + t1^2
  88. const f =
  89. c.ggml_add(ctx0,
  90. c.ggml_sqr(ctx0, t0),
  91. c.ggml_sqr(ctx0, t1)
  92. );
  93. const res = c.ggml_opt(null, opt_params, f);
  94. try std.testing.expect(res == c.GGML_OPT_OK);
  95. try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
  96. try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3));
  97. try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3));
  98. }
  99. /////////////////////////////////////////
  100. {
  101. const t0 = c.ggml_new_f32(ctx0, -7.0);
  102. const t1 = c.ggml_new_f32(ctx0, 8.0);
  103. _ = c.ggml_set_param(ctx0, t0);
  104. _ = c.ggml_set_param(ctx0, t1);
  105. // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2
  106. const f =
  107. c.ggml_add(ctx0,
  108. c.ggml_sqr(ctx0,
  109. c.ggml_sub(ctx0,
  110. c.ggml_add(ctx0,
  111. t0,
  112. c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))),
  113. c.ggml_new_f32(ctx0, 7.0)
  114. )
  115. ),
  116. c.ggml_sqr(ctx0,
  117. c.ggml_sub(ctx0,
  118. c.ggml_add(ctx0,
  119. c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)),
  120. t1),
  121. c.ggml_new_f32(ctx0, 5.0)
  122. )
  123. )
  124. );
  125. const res = c.ggml_opt(null, opt_params, f);
  126. try std.testing.expect(res == c.GGML_OPT_OK);
  127. try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
  128. try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3));
  129. try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3));
  130. }
  131. _ = try std.io.getStdIn().reader().readByte();
  132. }