test3.zig 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. const std = @import("std");
  2. const Thread = std.Thread;
  3. const c = @cImport({
  4. @cInclude("stdlib.h");
  5. @cInclude("ggml/ggml.h");
  6. });
  7. fn is_close(a: f32, b: f32, epsilon: f32) bool {
  8. return std.math.fabs(a - b) < epsilon;
  9. }
  10. pub fn main() !void {
  11. const params = .{
  12. .mem_size = 128*1024*1024,
  13. .mem_buffer = null,
  14. .no_alloc = false,
  15. };
  16. var opt_params = c.ggml_opt_default_params(c.GGML_OPT_LBFGS);
  17. const nthreads = try Thread.getCpuCount();
  18. opt_params.n_threads = @intCast(nthreads);
  19. const NP = 1 << 12;
  20. const NF = 1 << 8;
  21. const ctx0 = c.ggml_init(params);
  22. defer c.ggml_free(ctx0);
  23. const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP);
  24. const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP);
  25. // regularization weight
  26. const lambda = c.ggml_new_f32(ctx0, 1e-5);
  27. c.srand(0);
  28. const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data));
  29. const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data));
  30. for (0..NP) |j| {
  31. const ll = if (j < NP/2) @as(f32, 1.0) else @as(f32, -1.0);
  32. l_data_pointer[j] = ll;
  33. for (0..NF) |i| {
  34. const c_rand: f32 = @floatFromInt(c.rand());
  35. f_data_pointer[j*NF + i] =
  36. ((if (ll > 0 and i < NF/2) @as(f32, 1.0) else
  37. if (ll < 0 and i >= NF/2) @as(f32, 1.0) else @as(f32, 0.0)) +
  38. (c_rand/c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF);
  39. }
  40. }
  41. {
  42. // initial guess
  43. const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0);
  44. c.ggml_set_param(ctx0, x);
  45. // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
  46. const f =
  47. c.ggml_add(ctx0,
  48. c.ggml_div(ctx0,
  49. c.ggml_sum(ctx0,
  50. c.ggml_sqr(ctx0,
  51. c.ggml_sub(ctx0,
  52. c.ggml_mul_mat(ctx0, F, x),
  53. l)
  54. )
  55. ),
  56. c.ggml_new_f32(ctx0, @as(f32, NP))
  57. ),
  58. c.ggml_mul(ctx0,
  59. c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)),
  60. lambda)
  61. );
  62. const res = c.ggml_opt(null, opt_params, f);
  63. try std.testing.expect(res == c.GGML_OPT_OK);
  64. const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
  65. // print results
  66. for (0..16) |i| {
  67. std.debug.print("x[{d:3}] = {d:.6}\n", .{i, x_data_pointer[i]});
  68. }
  69. std.debug.print("...\n", .{});
  70. for (NF - 16..NF) |i| {
  71. std.debug.print("x[{d:3}] = {d:.6}\n", .{i, x_data_pointer[i]});
  72. }
  73. std.debug.print("\n", .{});
  74. for (0..NF) |i| {
  75. if (i < NF/2) {
  76. try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2));
  77. } else {
  78. try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2));
  79. }
  80. }
  81. }
  82. _ = try std.io.getStdIn().reader().readByte();
  83. }