example_test_all_quants.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from ggml import ffi, lib
  2. from ggml.utils import init, numpy, copy
  3. import numpy as np
  4. from math import pi, cos, sin, ceil
  5. import matplotlib.pyplot as plt
  6. ctx = init(mem_size=100*1024*1024) # Will be auto-GC'd
  7. n = 256
  8. orig = np.array([
  9. [
  10. cos(j * 2 * pi / n) * (sin(i * 2 * pi / n))
  11. for j in range(n)
  12. ]
  13. for i in range(n)
  14. ], np.float32)
  15. orig_tensor = lib.ggml_new_tensor_2d(ctx, lib.GGML_TYPE_F32, n, n)
  16. copy(orig, orig_tensor)
  17. quants = [
  18. type for type in range(lib.GGML_TYPE_COUNT)
  19. if lib.ggml_is_quantized(type) and
  20. type not in [lib.GGML_TYPE_Q8_1, lib.GGML_TYPE_Q8_K] # Apparently not supported
  21. ]
  22. # quants = [lib.GGML_TYPE_Q2_K] # Test a single one
  23. def get_name(type):
  24. name = lib.ggml_type_name(type)
  25. return ffi.string(name).decode('utf-8') if name else '?'
  26. quants.sort(key=get_name)
  27. quants.insert(0, None)
  28. print(quants)
  29. ncols=4
  30. nrows = ceil(len(quants) / ncols)
  31. plt.figure(figsize=(ncols * 5, nrows * 5), layout='tight')
  32. for i, type in enumerate(quants):
  33. plt.subplot(nrows, ncols, i + 1)
  34. try:
  35. if type == None:
  36. plt.title('Original')
  37. plt.imshow(orig)
  38. else:
  39. quantized_tensor = lib.ggml_new_tensor_2d(ctx, type, n, n)
  40. copy(orig_tensor, quantized_tensor)
  41. quantized = numpy(quantized_tensor, allow_copy=True)
  42. d = quantized - orig
  43. results = {
  44. "l2": np.linalg.norm(d, 2),
  45. "linf": np.linalg.norm(d, np.inf),
  46. "compression":
  47. round(lib.ggml_nbytes(orig_tensor) /
  48. lib.ggml_nbytes(quantized_tensor), 1)
  49. }
  50. name = get_name(type)
  51. print(f'{name}: {results}')
  52. plt.title(f'{name} ({results["compression"]}x smaller)')
  53. plt.imshow(quantized, interpolation='nearest')
  54. except Exception as e:
  55. print(f'Error: {e}')
  56. plt.show()