main.cpp 93 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201
  1. #define _USE_MATH_DEFINES // for M_PI
  2. #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
  3. #include "ggml.h"
  4. #include "ggml-alloc.h"
  5. #define STB_IMAGE_IMPLEMENTATION
  6. #include "stb_image.h"
  7. #define STB_IMAGE_WRITE_IMPLEMENTATION
  8. #include "stb_image_write.h"
  9. #include <cassert>
  10. #include <cmath>
  11. #include <cstddef>
  12. #include <cstdio>
  13. #include <cstring>
  14. #include <fstream>
  15. #include <map>
  16. #include <string>
  17. #include <vector>
  18. #include <thread>
  19. #if defined(_MSC_VER)
  20. #pragma warning(disable: 4244 4267) // possible loss of data
  21. #endif
  22. // default hparams (ViT-B SAM)
  23. struct sam_hparams {
  24. int32_t n_enc_state = 768;
  25. int32_t n_enc_layer = 12;
  26. int32_t n_enc_head = 12;
  27. int32_t n_enc_out_chans = 256;
  28. int32_t n_pt_embd = 4;
  29. int32_t n_dec_heads = 8;
  30. int32_t ftype = 1;
  31. float mask_threshold = 0.f;
  32. float iou_threshold = 0.88f;
  33. float stability_score_threshold = 0.95f;
  34. float stability_score_offset = 1.0f;
  35. float eps = 1e-6f;
  36. float eps_decoder_transformer = 1e-5f;
  37. int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; }
  38. int32_t n_img_size() const { return 1024; }
  39. int32_t n_window_size() const { return 14; }
  40. int32_t n_patch_size() const { return 16; }
  41. int32_t n_img_embd() const { return n_img_size() / n_patch_size(); }
  42. std::vector<int32_t> global_attn_indices() const {
  43. switch (n_enc_state) {
  44. case 768: return { 2, 5, 8, 11 };
  45. case 1024: return { 5, 11, 17, 23 };
  46. case 1280: return { 7, 15, 23, 31 };
  47. default:
  48. {
  49. fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_enc_state);
  50. } break;
  51. };
  52. return {};
  53. }
  54. bool is_global_attn(int32_t layer) const {
  55. const auto indices = global_attn_indices();
  56. for (const auto & idx : indices) {
  57. if (layer == idx) {
  58. return true;
  59. }
  60. }
  61. return false;
  62. }
  63. };
  64. struct sam_layer_enc {
  65. struct ggml_tensor * norm1_w;
  66. struct ggml_tensor * norm1_b;
  67. struct ggml_tensor * rel_pos_w;
  68. struct ggml_tensor * rel_pos_h;
  69. struct ggml_tensor * qkv_w;
  70. struct ggml_tensor * qkv_b;
  71. struct ggml_tensor * proj_w;
  72. struct ggml_tensor * proj_b;
  73. struct ggml_tensor * norm2_w;
  74. struct ggml_tensor * norm2_b;
  75. struct ggml_tensor * mlp_lin1_w;
  76. struct ggml_tensor * mlp_lin1_b;
  77. struct ggml_tensor * mlp_lin2_w;
  78. struct ggml_tensor * mlp_lin2_b;
  79. };
  80. struct sam_encoder_image {
  81. struct ggml_tensor * pe;
  82. struct ggml_tensor * proj_w;
  83. struct ggml_tensor * proj_b;
  84. struct ggml_tensor * neck_conv_0;
  85. struct ggml_tensor * neck_norm_0_w;
  86. struct ggml_tensor * neck_norm_0_b;
  87. struct ggml_tensor * neck_conv_1;
  88. struct ggml_tensor * neck_norm_1_w;
  89. struct ggml_tensor * neck_norm_1_b;
  90. std::vector<sam_layer_enc> layers;
  91. };
  92. struct sam_encoder_prompt {
  93. struct ggml_tensor * pe;
  94. struct ggml_tensor * not_a_pt_embd_w;
  95. std::vector<struct ggml_tensor *> pt_embd;
  96. struct ggml_tensor * no_mask_embd_w;
  97. //std::vector<struct ggml_tensor *> mask_down_w;
  98. //std::vector<struct ggml_tensor *> mask_down_b;
  99. };
  100. struct sam_layer_dec_transformer_attn {
  101. // q_proj
  102. struct ggml_tensor * q_w;
  103. struct ggml_tensor * q_b;
  104. // k_proj
  105. struct ggml_tensor * k_w;
  106. struct ggml_tensor * k_b;
  107. // v_proj
  108. struct ggml_tensor * v_w;
  109. struct ggml_tensor * v_b;
  110. // out_proj
  111. struct ggml_tensor * out_w;
  112. struct ggml_tensor * out_b;
  113. };
  114. struct sam_layer_dec_transformer {
  115. sam_layer_dec_transformer_attn self_attn;
  116. // norm1
  117. struct ggml_tensor * norm1_w;
  118. struct ggml_tensor * norm1_b;
  119. sam_layer_dec_transformer_attn cross_attn_token_to_img;
  120. // norm2
  121. struct ggml_tensor * norm2_w;
  122. struct ggml_tensor * norm2_b;
  123. // mlp.lin1
  124. struct ggml_tensor * mlp_lin1_w;
  125. struct ggml_tensor * mlp_lin1_b;
  126. // mlp.lin2
  127. struct ggml_tensor * mlp_lin2_w;
  128. struct ggml_tensor * mlp_lin2_b;
  129. // norm3
  130. struct ggml_tensor * norm3_w;
  131. struct ggml_tensor * norm3_b;
  132. // norm4
  133. struct ggml_tensor * norm4_w;
  134. struct ggml_tensor * norm4_b;
  135. sam_layer_dec_transformer_attn cross_attn_img_to_token;
  136. };
  137. struct sam_layer_dec_output_hypernet_mlps {
  138. // mlps_*.layers.0
  139. struct ggml_tensor * w_0;
  140. struct ggml_tensor * b_0;
  141. // mlps_*.layers.1
  142. struct ggml_tensor * w_1;
  143. struct ggml_tensor * b_1;
  144. // mlps_*.layers.2
  145. struct ggml_tensor * w_2;
  146. struct ggml_tensor * b_2;
  147. };
  148. struct sam_decoder_mask {
  149. std::vector<sam_layer_dec_transformer> transformer_layers;
  150. // trasnformer.final_attn_token_to_image
  151. sam_layer_dec_transformer_attn transformer_final_attn_token_to_img;
  152. // transformer.norm_final
  153. struct ggml_tensor * transformer_norm_final_w;
  154. struct ggml_tensor * transformer_norm_final_b;
  155. // output_upscaling.0
  156. struct ggml_tensor * output_upscaling_0_w;
  157. struct ggml_tensor * output_upscaling_0_b;
  158. // output_upscaling.1
  159. struct ggml_tensor * output_upscaling_1_w;
  160. struct ggml_tensor * output_upscaling_1_b;
  161. // output_upscaling.3
  162. struct ggml_tensor * output_upscaling_3_w;
  163. struct ggml_tensor * output_upscaling_3_b;
  164. // output_hypernetworks_mlps
  165. std::vector<sam_layer_dec_output_hypernet_mlps> output_hypernet_mlps;
  166. // iou_prediction_head.0
  167. struct ggml_tensor * iou_prediction_head_0_w;
  168. struct ggml_tensor * iou_prediction_head_0_b;
  169. // iou_prediction_head.1
  170. struct ggml_tensor * iou_prediction_head_1_w;
  171. struct ggml_tensor * iou_prediction_head_1_b;
  172. // iou_prediction_head.2
  173. struct ggml_tensor * iou_prediction_head_2_w;
  174. struct ggml_tensor * iou_prediction_head_2_b;
  175. // iou_token.weight
  176. struct ggml_tensor * iou_token_w;
  177. // mask_tokens.weight
  178. struct ggml_tensor * mask_tokens_w;
  179. };
  180. struct sam_state {
  181. struct ggml_tensor * embd_img;
  182. struct ggml_tensor * low_res_masks;
  183. struct ggml_tensor * iou_predictions;
  184. //struct ggml_tensor * tmp_save = {};
  185. struct ggml_context * ctx;
  186. // buffer for `ggml_graph_plan.work_data`
  187. std::vector<uint8_t> work_buffer;
  188. // buffers to evaluate the model
  189. std::vector<uint8_t> buf_alloc_img_enc;
  190. std::vector<uint8_t> buf_compute_img_enc;
  191. std::vector<uint8_t> buf_alloc_fast;
  192. std::vector<uint8_t> buf_compute_fast;
  193. struct ggml_allocr * allocr = {};
  194. };
  195. // void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) {
  196. // if (!state.tmp_save) {
  197. // state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne);
  198. // }
  199. // struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save);
  200. // ggml_build_forward_expand(gf, tmp0);
  201. // }
  202. struct sam_model {
  203. sam_hparams hparams;
  204. sam_encoder_image enc_img;
  205. sam_encoder_prompt enc_prompt;
  206. sam_decoder_mask dec;
  207. //
  208. struct ggml_context * ctx;
  209. std::map<std::string, struct ggml_tensor *> tensors;
  210. };
  211. struct sam_point {
  212. float x;
  213. float y;
  214. };
  215. // RGB uint8 image
  216. struct sam_image_u8 {
  217. int nx;
  218. int ny;
  219. std::vector<uint8_t> data;
  220. };
  221. // RGB float32 image
  222. // Memory layout: RGBRGBRGB...
  223. struct sam_image_f32 {
  224. int nx;
  225. int ny;
  226. std::vector<float> data;
  227. };
  228. void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
  229. printf("%s\n", title);
  230. float * data = (float *)t->data;
  231. printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
  232. printf("First & Last %d elements:\n", n);
  233. for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
  234. printf("%.5f ", data[i]);
  235. if (i != 0 && i % t->ne[0] == 0) {
  236. printf("\n");
  237. }
  238. }
  239. printf("\n");
  240. for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
  241. printf("%.5f ", data[ggml_nelements(t) - n + i]);
  242. if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {
  243. printf("\n");
  244. }
  245. }
  246. printf("\n");
  247. double sum = 0.0;
  248. for (int i = 0; i < ggml_nelements(t); i++) {
  249. sum += data[i];
  250. }
  251. printf("sum: %f\n\n", sum);
  252. }
  253. static void ggml_disconnect_node_from_graph(ggml_tensor * t) {
  254. t->op = GGML_OP_NONE;
  255. for (int i = 0; i < GGML_MAX_SRC; i++) {
  256. t->src[i] = NULL;
  257. }
  258. }
  259. static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
  260. struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
  261. if (plan.work_size > 0) {
  262. buf.resize(plan.work_size);
  263. plan.work_data = buf.data();
  264. }
  265. ggml_graph_compute(graph, &plan);
  266. }
  267. static void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
  268. GGML_ASSERT(userdata == NULL);
  269. GGML_ASSERT(ggml_are_same_shape(dst, src));
  270. GGML_ASSERT(ggml_is_contiguous(dst));
  271. GGML_ASSERT(ggml_is_contiguous(src));
  272. const float * src_data = ggml_get_data_f32(src);
  273. float * dst_data = ggml_get_data_f32(dst);
  274. const int ne = (int)ggml_nelements(dst);
  275. const int dr = (ne + nth - 1) / nth;
  276. const int ie0 = dr * ith;
  277. const int ie1 = std::min(ie0 + dr, ne);
  278. for (int i = ie0; i < ie1; ++i) {
  279. dst_data[i] = sinf(src_data[i]);
  280. }
  281. }
  282. static void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
  283. GGML_ASSERT(userdata == NULL);
  284. GGML_ASSERT(ggml_are_same_shape(dst, src));
  285. GGML_ASSERT(ggml_is_contiguous(dst));
  286. GGML_ASSERT(ggml_is_contiguous(src));
  287. const float * src_data = ggml_get_data_f32(src);
  288. float * dst_data = ggml_get_data_f32(dst);
  289. const int ne = (int)ggml_nelements(dst);
  290. const int dr = (ne + nth - 1) / nth;
  291. const int ie0 = dr * ith;
  292. const int ie1 = std::min(ie0 + dr, ne);
  293. for (int i = ie0; i < ie1; ++i) {
  294. dst_data[i] = cosf(src_data[i]);
  295. }
  296. }
  297. bool sam_image_load_from_file(const std::string & fname, sam_image_u8 & img) {
  298. int nx, ny, nc;
  299. auto data = stbi_load(fname.c_str(), &nx, &ny, &nc, 3);
  300. if (!data) {
  301. fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str());
  302. return false;
  303. }
  304. img.nx = nx;
  305. img.ny = ny;
  306. img.data.resize(nx * ny * 3);
  307. memcpy(img.data.data(), data, nx * ny * 3);
  308. stbi_image_free(data);
  309. return true;
  310. }
  311. // ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L164
  312. // resize largest dimension to 1024
  313. // normalize: x = (x - mean) / std
  314. // mean = [123.675, 116.28, 103.53]
  315. // std = [58.395, 57.12, 57.375]
  316. // TODO: why are these hardcoded !?
  317. // pad to 1024x1024
  318. // TODO: for some reason, this is not numerically identical to pytorch's interpolation
  319. bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) {
  320. const int nx = img.nx;
  321. const int ny = img.ny;
  322. const int nx2 = 1024;
  323. const int ny2 = 1024;
  324. res.nx = nx2;
  325. res.ny = ny2;
  326. res.data.resize(3*nx2*ny2);
  327. const float scale = std::max(nx, ny) / 1024.0f;
  328. fprintf(stderr, "%s: scale = %f\n", __func__, scale);
  329. const int nx3 = int(nx/scale + 0.5f);
  330. const int ny3 = int(ny/scale + 0.5f);
  331. const float m3[3] = { 123.675f, 116.280f, 103.530f };
  332. const float s3[3] = { 58.395f, 57.120f, 57.375f };
  333. for (int y = 0; y < ny3; y++) {
  334. for (int x = 0; x < nx3; x++) {
  335. for (int c = 0; c < 3; c++) {
  336. // linear interpolation
  337. const float sx = (x + 0.5f)*scale - 0.5f;
  338. const float sy = (y + 0.5f)*scale - 0.5f;
  339. const int x0 = std::max(0, (int) std::floor(sx));
  340. const int y0 = std::max(0, (int) std::floor(sy));
  341. const int x1 = std::min(x0 + 1, nx - 1);
  342. const int y1 = std::min(y0 + 1, ny - 1);
  343. const float dx = sx - x0;
  344. const float dy = sy - y0;
  345. const int j00 = 3*(y0*nx + x0) + c;
  346. const int j01 = 3*(y0*nx + x1) + c;
  347. const int j10 = 3*(y1*nx + x0) + c;
  348. const int j11 = 3*(y1*nx + x1) + c;
  349. const float v00 = img.data[j00];
  350. const float v01 = img.data[j01];
  351. const float v10 = img.data[j10];
  352. const float v11 = img.data[j11];
  353. const float v0 = v00*(1.0f - dx) + v01*dx;
  354. const float v1 = v10*(1.0f - dx) + v11*dx;
  355. const float v = v0*(1.0f - dy) + v1*dy;
  356. const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
  357. const int i = 3*(y*nx3 + x) + c;
  358. res.data[i] = (float(v2) - m3[c]) / s3[c];
  359. }
  360. }
  361. }
  362. return true;
  363. }
  364. // load the model's weights from a file
  365. bool sam_model_load(const std::string & fname, sam_model & model) {
  366. fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
  367. auto fin = std::ifstream(fname, std::ios::binary);
  368. if (!fin) {
  369. fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
  370. return false;
  371. }
  372. // verify magic
  373. {
  374. uint32_t magic;
  375. fin.read((char *) &magic, sizeof(magic));
  376. if (magic != 0x67676d6c) {
  377. fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
  378. return false;
  379. }
  380. }
  381. // load hparams
  382. {
  383. auto & hparams = model.hparams;
  384. fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state));
  385. fin.read((char *) &hparams.n_enc_layer, sizeof(hparams.n_enc_layer));
  386. fin.read((char *) &hparams.n_enc_head, sizeof(hparams.n_enc_head));
  387. fin.read((char *) &hparams.n_enc_out_chans, sizeof(hparams.n_enc_out_chans));
  388. fin.read((char *) &hparams.n_pt_embd, sizeof(hparams.n_pt_embd));
  389. fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
  390. const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
  391. printf("%s: n_enc_state = %d\n", __func__, hparams.n_enc_state);
  392. printf("%s: n_enc_layer = %d\n", __func__, hparams.n_enc_layer);
  393. printf("%s: n_enc_head = %d\n", __func__, hparams.n_enc_head);
  394. printf("%s: n_enc_out_chans = %d\n", __func__, hparams.n_enc_out_chans);
  395. printf("%s: n_pt_embd = %d\n", __func__, hparams.n_pt_embd);
  396. printf("%s: ftype = %d\n", __func__, hparams.ftype);
  397. printf("%s: qntvr = %d\n", __func__, qntvr);
  398. hparams.ftype %= GGML_QNT_VERSION_FACTOR;
  399. }
  400. // for the big tensors, we have the option to store the data in 16-bit floats or quantized
  401. // in order to save memory and also to speed up the computation
  402. ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
  403. if (wtype == GGML_TYPE_COUNT) {
  404. fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
  405. __func__, fname.c_str(), model.hparams.ftype);
  406. return false;
  407. }
  408. auto & ctx = model.ctx;
  409. const size_t ctx_size = [&]() {
  410. size_t ctx_size = 0;
  411. const auto & hparams = model.hparams;
  412. const int32_t n_enc_state = hparams.n_enc_state;
  413. const int32_t n_enc_layer = hparams.n_enc_layer;
  414. const int32_t n_enc_head_dim = hparams.n_enc_head_dim();
  415. const int32_t n_enc_out_chans = hparams.n_enc_out_chans;
  416. const int32_t n_pt_embd = hparams.n_pt_embd;
  417. const int32_t n_enc_layer_local = hparams.global_attn_indices().size();
  418. const int32_t n_enc_layer_global = n_enc_layer - n_enc_layer_local;
  419. const int32_t n_img_embd = hparams.n_img_embd();
  420. const int32_t n_window_size = hparams.n_window_size();
  421. const int32_t n_patch_size = hparams.n_patch_size();
  422. // image encoder
  423. {
  424. ctx_size += n_enc_state*n_img_embd*n_img_embd*ggml_type_sizef(GGML_TYPE_F32);
  425. ctx_size += n_enc_state*3*n_patch_size*n_patch_size*ggml_type_sizef(GGML_TYPE_F16);
  426. ctx_size += n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  427. ctx_size += n_enc_state*n_enc_out_chans*1*1*ggml_type_sizef(GGML_TYPE_F16);
  428. ctx_size += n_enc_out_chans*n_enc_out_chans*3*3*ggml_type_sizef(GGML_TYPE_F16);
  429. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  430. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  431. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  432. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  433. }
  434. // image encoder layers
  435. {
  436. ctx_size += n_enc_layer*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  437. ctx_size += n_enc_layer*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  438. ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_sizef(GGML_TYPE_F16);
  439. ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_sizef(GGML_TYPE_F16);
  440. ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_sizef(GGML_TYPE_F16);
  441. ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_sizef(GGML_TYPE_F16);
  442. ctx_size += n_enc_layer*3*n_enc_state*n_enc_state*ggml_type_sizef(GGML_TYPE_F16);
  443. ctx_size += n_enc_layer*3*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  444. ctx_size += n_enc_layer*n_enc_state*n_enc_state*ggml_type_sizef(GGML_TYPE_F16);
  445. ctx_size += n_enc_layer*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  446. ctx_size += n_enc_layer*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  447. ctx_size += n_enc_layer*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  448. ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_sizef(GGML_TYPE_F16);
  449. ctx_size += n_enc_layer*4*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  450. ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_sizef(GGML_TYPE_F16);
  451. ctx_size += n_enc_layer*4*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  452. }
  453. ctx_size += (8 + 14*n_enc_layer)*ggml_tensor_overhead();
  454. // prompt encoder
  455. {
  456. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16); // 2*(n_enc_out_chans/2)
  457. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  458. ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  459. }
  460. ctx_size += (2 + n_pt_embd)*ggml_tensor_overhead();
  461. // mask decoder
  462. {
  463. //transformer
  464. {
  465. const int tfm_layers_count = 2;
  466. const int qkv_count = 3;
  467. const int norm_count = 4;
  468. const int n_hypernet_mpls_count = 4;
  469. // self_attn
  470. ctx_size += tfm_layers_count*qkv_count*n_enc_state*n_enc_state*ggml_type_sizef(GGML_TYPE_F16);
  471. ctx_size += tfm_layers_count*qkv_count*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  472. ctx_size += tfm_layers_count*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  473. // all norms
  474. ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  475. ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  476. // cross_attn_token_to_img
  477. ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_sizef(GGML_TYPE_F16);
  478. ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_sizef(GGML_TYPE_F32);
  479. ctx_size += tfm_layers_count*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  480. // mlp
  481. ctx_size += tfm_layers_count*8*n_enc_out_chans*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16);
  482. ctx_size += tfm_layers_count*8*n_enc_out_chans* ggml_type_sizef(GGML_TYPE_F32);
  483. ctx_size += tfm_layers_count*n_enc_out_chans*8*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16);
  484. ctx_size += tfm_layers_count*n_enc_out_chans* ggml_type_sizef(GGML_TYPE_F32);
  485. // cross_attn_img_to_token
  486. ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_sizef(GGML_TYPE_F16);
  487. ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_sizef(GGML_TYPE_F32);
  488. ctx_size += tfm_layers_count*n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  489. // transformer_final_attn_token_to_img
  490. ctx_size += qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_sizef(GGML_TYPE_F16);
  491. ctx_size += qkv_count*(n_enc_state/2)* ggml_type_sizef(GGML_TYPE_F32);
  492. ctx_size += n_enc_state* ggml_type_sizef(GGML_TYPE_F32);
  493. // transformer_norm_final
  494. ctx_size += norm_count*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  495. ctx_size += norm_count*n_enc_state*ggml_type_sizef(GGML_TYPE_F32);
  496. // output_upscaling
  497. ctx_size += n_enc_out_chans*n_img_embd*2*2*ggml_type_sizef(GGML_TYPE_F16);
  498. ctx_size += 3*n_img_embd* ggml_type_sizef(GGML_TYPE_F32);
  499. ctx_size += n_enc_out_chans*n_img_embd*(n_img_embd/2)*2*2*ggml_type_sizef(GGML_TYPE_F16);
  500. ctx_size += (n_img_embd/2)* ggml_type_sizef(GGML_TYPE_F32);
  501. // output_hypernetworks_mlps
  502. ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16);
  503. ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans* ggml_type_sizef(GGML_TYPE_F32);
  504. ctx_size += n_hypernet_mpls_count*n_enc_out_chans*(n_img_embd/2)*ggml_type_sizef(GGML_TYPE_F16);
  505. ctx_size += n_hypernet_mpls_count*(n_img_embd/2)* ggml_type_sizef(GGML_TYPE_F32);
  506. // iou_prediction_head
  507. ctx_size += 2*n_enc_out_chans*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16);
  508. ctx_size += 2*n_enc_out_chans* ggml_type_sizef(GGML_TYPE_F32);
  509. ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F16);
  510. ctx_size += n_pt_embd* ggml_type_sizef(GGML_TYPE_F32);
  511. // iou_token_w
  512. ctx_size += n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  513. // mask_tokens_w
  514. ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_sizef(GGML_TYPE_F32);
  515. }
  516. }
  517. fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
  518. return ctx_size;
  519. }();
  520. // create the ggml context
  521. {
  522. struct ggml_init_params params = {
  523. /*.mem_size =*/ ctx_size,
  524. /*.mem_buffer =*/ NULL,
  525. /*.no_alloc =*/ false,
  526. };
  527. ctx = ggml_init(params);
  528. if (!ctx) {
  529. fprintf(stderr, "%s: ggml_init() failed\n", __func__);
  530. return false;
  531. }
  532. }
  533. // prepare memory for the weights
  534. {
  535. const auto & hparams = model.hparams;
  536. const int32_t n_enc_state = hparams.n_enc_state;
  537. const int32_t n_enc_layer = hparams.n_enc_layer;
  538. const int32_t n_enc_head_dim = hparams.n_enc_head_dim();
  539. const int32_t n_enc_out_chans = hparams.n_enc_out_chans;
  540. const int32_t n_pt_embd = hparams.n_pt_embd;
  541. const int32_t n_img_embd = hparams.n_img_embd();
  542. const int32_t n_window_size = hparams.n_window_size();
  543. const int32_t n_patch_size = hparams.n_patch_size();
  544. model.enc_img.layers.resize(n_enc_layer);
  545. // image encoder
  546. {
  547. auto & enc = model.enc_img;
  548. enc.pe = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_enc_state, n_img_embd, n_img_embd, 1);
  549. enc.proj_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, n_patch_size, n_patch_size, 3, n_enc_state);
  550. enc.proj_b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, 1, n_enc_state);
  551. enc.neck_conv_0 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, n_enc_state, n_enc_out_chans);
  552. enc.neck_conv_1 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, n_enc_out_chans, n_enc_out_chans);
  553. enc.neck_norm_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  554. enc.neck_norm_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  555. enc.neck_norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  556. enc.neck_norm_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  557. model.tensors["image_encoder.pos_embed"] = enc.pe;
  558. model.tensors["image_encoder.patch_embed.proj.weight"] = enc.proj_w;
  559. model.tensors["image_encoder.patch_embed.proj.bias"] = enc.proj_b;
  560. model.tensors["image_encoder.neck.0.weight"] = enc.neck_conv_0;
  561. model.tensors["image_encoder.neck.2.weight"] = enc.neck_conv_1;
  562. model.tensors["image_encoder.neck.1.weight"] = enc.neck_norm_0_w;
  563. model.tensors["image_encoder.neck.1.bias"] = enc.neck_norm_0_b;
  564. model.tensors["image_encoder.neck.3.weight"] = enc.neck_norm_1_w;
  565. model.tensors["image_encoder.neck.3.bias"] = enc.neck_norm_1_b;
  566. for (int i = 0; i < n_enc_layer; ++i) {
  567. auto & layer = enc.layers[i];
  568. layer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  569. layer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  570. if (hparams.is_global_attn(i)) {
  571. layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1);
  572. layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1);
  573. } else {
  574. layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1);
  575. layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1);
  576. }
  577. layer.qkv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 3*n_enc_state);
  578. layer.qkv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_enc_state);
  579. layer.proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, n_enc_state);
  580. layer.proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  581. layer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  582. layer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  583. layer.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 4*n_enc_state);
  584. layer.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_enc_state);
  585. layer.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 4*n_enc_state, n_enc_state);
  586. layer.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state);
  587. model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.weight"] = layer.norm1_w;
  588. model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.bias"] = layer.norm1_b;
  589. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_w"] = layer.rel_pos_w;
  590. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_h"] = layer.rel_pos_h;
  591. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.weight"] = layer.qkv_w;
  592. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.bias"] = layer.qkv_b;
  593. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.weight"] = layer.proj_w;
  594. model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.bias"] = layer.proj_b;
  595. model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.weight"] = layer.norm2_w;
  596. model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.bias"] = layer.norm2_b;
  597. model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.weight"] = layer.mlp_lin1_w;
  598. model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.bias"] = layer.mlp_lin1_b;
  599. model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.weight"] = layer.mlp_lin2_w;
  600. model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.bias"] = layer.mlp_lin2_b;
  601. }
  602. }
  603. // prompt encoder
  604. {
  605. auto & enc = model.enc_prompt;
  606. enc.pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans/2, 2);
  607. enc.not_a_pt_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  608. enc.no_mask_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  609. model.tensors["prompt_encoder.pe_layer.positional_encoding_gaussian_matrix"] = enc.pe;
  610. model.tensors["prompt_encoder.not_a_point_embed.weight"] = enc.not_a_pt_embd_w;
  611. model.tensors["prompt_encoder.no_mask_embed.weight"] = enc.no_mask_embd_w;
  612. enc.pt_embd.resize(n_pt_embd);
  613. for (int i = 0; i < n_pt_embd; i++) {
  614. enc.pt_embd[i] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  615. model.tensors["prompt_encoder.point_embeddings." + std::to_string(i) + ".weight"] = enc.pt_embd[i];
  616. }
  617. }
  618. // mask decoder
  619. {
  620. auto & dec = model.dec;
  621. auto & tfm_layers = dec.transformer_layers;
  622. const int tfm_layers_count = 2;
  623. tfm_layers.resize(tfm_layers_count);
  624. for (int i = 0; i < tfm_layers_count; ++i) {
  625. auto& l = tfm_layers[i];
  626. l.self_attn.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  627. l.self_attn.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  628. l.self_attn.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  629. l.self_attn.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  630. l.self_attn.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  631. l.self_attn.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  632. l.self_attn.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  633. l.self_attn.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  634. l.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  635. l.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  636. l.cross_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  637. l.cross_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  638. l.cross_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  639. l.cross_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  640. l.cross_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  641. l.cross_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  642. l.cross_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);
  643. l.cross_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  644. l.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  645. l.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  646. l.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, 8*n_enc_out_chans);
  647. l.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8*n_enc_out_chans);
  648. l.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 8*n_enc_out_chans, n_enc_out_chans);
  649. l.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  650. l.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  651. l.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  652. l.norm4_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  653. l.norm4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  654. l.cross_attn_img_to_token.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  655. l.cross_attn_img_to_token.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  656. l.cross_attn_img_to_token.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  657. l.cross_attn_img_to_token.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  658. l.cross_attn_img_to_token.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  659. l.cross_attn_img_to_token.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  660. l.cross_attn_img_to_token.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);
  661. l.cross_attn_img_to_token.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  662. const auto prefix = "mask_decoder.transformer.layers." + std::to_string(i) + ".";
  663. model.tensors[prefix + "self_attn.q_proj.weight"] = l.self_attn.q_w;
  664. model.tensors[prefix + "self_attn.q_proj.bias"] = l.self_attn.q_b;
  665. model.tensors[prefix + "self_attn.k_proj.weight"] = l.self_attn.k_w;
  666. model.tensors[prefix + "self_attn.k_proj.bias"] = l.self_attn.k_b;
  667. model.tensors[prefix + "self_attn.v_proj.weight"] = l.self_attn.v_w;
  668. model.tensors[prefix + "self_attn.v_proj.bias"] = l.self_attn.v_b;
  669. model.tensors[prefix + "self_attn.out_proj.weight"] = l.self_attn.out_w;
  670. model.tensors[prefix + "self_attn.out_proj.bias"] = l.self_attn.out_b;
  671. model.tensors[prefix + "norm1.weight"] = l.norm1_w;
  672. model.tensors[prefix + "norm1.bias"] = l.norm1_b;
  673. model.tensors[prefix + "cross_attn_token_to_image.q_proj.weight"] = l.cross_attn_token_to_img.q_w;
  674. model.tensors[prefix + "cross_attn_token_to_image.q_proj.bias"] = l.cross_attn_token_to_img.q_b;
  675. model.tensors[prefix + "cross_attn_token_to_image.k_proj.weight"] = l.cross_attn_token_to_img.k_w;
  676. model.tensors[prefix + "cross_attn_token_to_image.k_proj.bias"] = l.cross_attn_token_to_img.k_b;
  677. model.tensors[prefix + "cross_attn_token_to_image.v_proj.weight"] = l.cross_attn_token_to_img.v_w;
  678. model.tensors[prefix + "cross_attn_token_to_image.v_proj.bias"] = l.cross_attn_token_to_img.v_b;
  679. model.tensors[prefix + "cross_attn_token_to_image.out_proj.weight"] = l.cross_attn_token_to_img.out_w;
  680. model.tensors[prefix + "cross_attn_token_to_image.out_proj.bias"] = l.cross_attn_token_to_img.out_b;
  681. model.tensors[prefix + "norm2.weight"] = l.norm2_w;
  682. model.tensors[prefix + "norm2.bias"] = l.norm2_b;
  683. model.tensors[prefix + "mlp.lin1.weight"] = l.mlp_lin1_w;
  684. model.tensors[prefix + "mlp.lin1.bias"] = l.mlp_lin1_b;
  685. model.tensors[prefix + "mlp.lin2.weight"] = l.mlp_lin2_w;
  686. model.tensors[prefix + "mlp.lin2.bias"] = l.mlp_lin2_b;
  687. model.tensors[prefix + "norm3.weight"] = l.norm3_w;
  688. model.tensors[prefix + "norm3.bias"] = l.norm3_b;
  689. model.tensors[prefix + "norm4.weight"] = l.norm4_w;
  690. model.tensors[prefix + "norm4.bias"] = l.norm4_b;
  691. model.tensors[prefix + "cross_attn_image_to_token.q_proj.weight"] = l.cross_attn_img_to_token.q_w;
  692. model.tensors[prefix + "cross_attn_image_to_token.q_proj.bias"] = l.cross_attn_img_to_token.q_b;
  693. model.tensors[prefix + "cross_attn_image_to_token.k_proj.weight"] = l.cross_attn_img_to_token.k_w;
  694. model.tensors[prefix + "cross_attn_image_to_token.k_proj.bias"] = l.cross_attn_img_to_token.k_b;
  695. model.tensors[prefix + "cross_attn_image_to_token.v_proj.weight"] = l.cross_attn_img_to_token.v_w;
  696. model.tensors[prefix + "cross_attn_image_to_token.v_proj.bias"] = l.cross_attn_img_to_token.v_b;
  697. model.tensors[prefix + "cross_attn_image_to_token.out_proj.weight"] = l.cross_attn_img_to_token.out_w;
  698. model.tensors[prefix + "cross_attn_image_to_token.out_proj.bias"] = l.cross_attn_img_to_token.out_b;
  699. }
  700. dec.transformer_final_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  701. dec.transformer_final_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  702. dec.transformer_final_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  703. dec.transformer_final_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  704. dec.transformer_final_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2);
  705. dec.transformer_final_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2);
  706. dec.transformer_final_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans);
  707. dec.transformer_final_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  708. model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.weight"] = dec.transformer_final_attn_token_to_img.q_w;
  709. model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.bias"] = dec.transformer_final_attn_token_to_img.q_b;
  710. model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.weight"] = dec.transformer_final_attn_token_to_img.k_w;
  711. model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.bias"] = dec.transformer_final_attn_token_to_img.k_b;
  712. model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.weight"] = dec.transformer_final_attn_token_to_img.v_w;
  713. model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.bias"] = dec.transformer_final_attn_token_to_img.v_b;
  714. model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.weight"] = dec.transformer_final_attn_token_to_img.out_w;
  715. model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.bias"] = dec.transformer_final_attn_token_to_img.out_b;
  716. dec.transformer_norm_final_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  717. dec.transformer_norm_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  718. model.tensors["mask_decoder.transformer.norm_final_attn.weight"] = dec.transformer_norm_final_w;
  719. model.tensors["mask_decoder.transformer.norm_final_attn.bias"] = dec.transformer_norm_final_b;
  720. dec.output_upscaling_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd, n_enc_out_chans);
  721. dec.output_upscaling_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);
  722. dec.output_upscaling_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);
  723. dec.output_upscaling_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd);
  724. dec.output_upscaling_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd/2, n_img_embd);
  725. dec.output_upscaling_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2);
  726. model.tensors["mask_decoder.output_upscaling.0.weight"] = dec.output_upscaling_0_w;
  727. model.tensors["mask_decoder.output_upscaling.0.bias"] = dec.output_upscaling_0_b;
  728. model.tensors["mask_decoder.output_upscaling.1.weight"] = dec.output_upscaling_1_w;
  729. model.tensors["mask_decoder.output_upscaling.1.bias"] = dec.output_upscaling_1_b;
  730. model.tensors["mask_decoder.output_upscaling.3.weight"] = dec.output_upscaling_3_w;
  731. model.tensors["mask_decoder.output_upscaling.3.bias"] = dec.output_upscaling_3_b;
  732. const int n_hypernet_mpls_count = 4;
  733. dec.output_hypernet_mlps.resize(n_hypernet_mpls_count);
  734. for (int i = 0; i < n_hypernet_mpls_count; ++i) {
  735. auto& mlp = dec.output_hypernet_mlps[i];
  736. mlp.w_0 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  737. mlp.b_0 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  738. mlp.w_1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  739. mlp.b_1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  740. mlp.w_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_img_embd/2);
  741. mlp.b_2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2);
  742. const auto prefix = "mask_decoder.output_hypernetworks_mlps." + std::to_string(i) + ".";
  743. model.tensors[prefix + "layers.0.weight"] = mlp.w_0;
  744. model.tensors[prefix + "layers.0.bias"] = mlp.b_0;
  745. model.tensors[prefix + "layers.1.weight"] = mlp.w_1;
  746. model.tensors[prefix + "layers.1.bias"] = mlp.b_1;
  747. model.tensors[prefix + "layers.2.weight"] = mlp.w_2;
  748. model.tensors[prefix + "layers.2.bias"] = mlp.b_2;
  749. }
  750. dec.iou_prediction_head_0_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  751. dec.iou_prediction_head_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  752. dec.iou_prediction_head_1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans);
  753. dec.iou_prediction_head_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans);
  754. dec.iou_prediction_head_2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_pt_embd);
  755. dec.iou_prediction_head_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pt_embd);
  756. dec.iou_token_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, 1);
  757. dec.mask_tokens_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, n_pt_embd);
  758. model.tensors["mask_decoder.iou_prediction_head.layers.0.weight"] = dec.iou_prediction_head_0_w;
  759. model.tensors["mask_decoder.iou_prediction_head.layers.0.bias"] = dec.iou_prediction_head_0_b;
  760. model.tensors["mask_decoder.iou_prediction_head.layers.1.weight"] = dec.iou_prediction_head_1_w;
  761. model.tensors["mask_decoder.iou_prediction_head.layers.1.bias"] = dec.iou_prediction_head_1_b;
  762. model.tensors["mask_decoder.iou_prediction_head.layers.2.weight"] = dec.iou_prediction_head_2_w;
  763. model.tensors["mask_decoder.iou_prediction_head.layers.2.bias"] = dec.iou_prediction_head_2_b;
  764. model.tensors["mask_decoder.iou_token.weight"] = dec.iou_token_w;
  765. model.tensors["mask_decoder.mask_tokens.weight"] = dec.mask_tokens_w;
  766. }
  767. }
  768. // load weights
  769. {
  770. int n_tensors = 0;
  771. size_t total_size = 0;
  772. fprintf(stderr, "%s: ", __func__);
  773. while (true) {
  774. int32_t n_dims;
  775. int32_t length;
  776. int32_t ftype;
  777. fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
  778. fin.read(reinterpret_cast<char *>(&length), sizeof(length));
  779. fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
  780. if (fin.eof()) {
  781. break;
  782. }
  783. int64_t nelements = 1;
  784. int64_t ne[4] = { 1, 1, 1, 1 };
  785. for (int i = 0; i < n_dims; ++i) {
  786. int32_t ne_cur;
  787. fin.read(reinterpret_cast<char *>(&ne_cur), sizeof(ne_cur));
  788. ne[i] = ne_cur;
  789. nelements *= ne[i];
  790. }
  791. std::string name(length, 0);
  792. fin.read(&name[0], length);
  793. if (model.tensors.find(name.data()) == model.tensors.end()) {
  794. fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
  795. return false;
  796. }
  797. auto tensor = model.tensors[name.data()];
  798. //printf("ne0 = %jd, ne1 = %jd, ne2 = %jd, ne3 = %jd\n", ne[0], ne[1], ne[2], ne[3]);
  799. if (ggml_nelements(tensor) != nelements) {
  800. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %d, expected %d\n",
  801. __func__, name.data(), (int) nelements, (int) ggml_nelements(tensor));
  802. return false;
  803. }
  804. if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) {
  805. fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n",
  806. __func__, name.data(),
  807. (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3],
  808. (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3]);
  809. return false;
  810. }
  811. size_t bpe = 0;
  812. switch (ftype) {
  813. case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
  814. case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
  815. case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
  816. case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
  817. default:
  818. {
  819. fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
  820. return false;
  821. }
  822. };
  823. if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
  824. fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
  825. __func__, name.data(), ggml_nbytes(tensor), (size_t) nelements*bpe);
  826. return false;
  827. }
  828. fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
  829. total_size += ggml_nbytes(tensor);
  830. if (++n_tensors % 8 == 0) {
  831. fprintf(stderr, ".");
  832. fflush(stdout);
  833. }
  834. }
  835. if (n_tensors != int(model.tensors.size())) {
  836. fprintf(stderr, "%s: model file has %d tensors, but %d tensors were expected\n", __func__, n_tensors, (int) model.tensors.size());
  837. return false;
  838. }
  839. fprintf(stderr, " done\n");
  840. fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
  841. }
  842. fin.close();
  843. return true;
  844. }
  845. struct ggml_tensor * sam_fill_dense_pe(
  846. const sam_model & model,
  847. struct ggml_context * ctx0,
  848. struct ggml_cgraph * gf,
  849. sam_state & state) {
  850. const auto & hparams = model.hparams;
  851. const auto & enc = model.enc_prompt;
  852. const int32_t n_img_embd = hparams.n_img_embd();
  853. const float n_img_embd_inv = 1.0f / n_img_embd;
  854. struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd);
  855. ggml_allocr_alloc(state.allocr, xy_embed_stacked);
  856. if (!ggml_allocr_is_measure(state.allocr)) {
  857. float * data = (float *) ggml_get_data(xy_embed_stacked);
  858. for (int i = 0; i < n_img_embd; ++i) {
  859. const int row = 2*i*n_img_embd;
  860. const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1;
  861. for (int j = 0; j < n_img_embd; ++j) {
  862. const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1;
  863. data[row + 2*j + 0] = x_val;
  864. data[row + 2*j + 1] = y_val;
  865. }
  866. }
  867. }
  868. struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked);
  869. cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, float(2.0*M_PI)));
  870. // concat
  871. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
  872. {
  873. struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
  874. struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);
  875. cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);
  876. ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));
  877. ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));
  878. }
  879. struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
  880. ggml_build_forward_expand(gf, pe_img_dense);
  881. return pe_img_dense;
  882. }
  883. struct ggml_tensor* sam_layer_norm_2d(
  884. struct ggml_context * ctx0,
  885. struct ggml_tensor * layer,
  886. int n_channels,
  887. struct ggml_tensor * w,
  888. struct ggml_tensor * b,
  889. float eps) {
  890. // LayerNorm2d
  891. // normalize along channel dimmension
  892. // TODO: better implementation
  893. layer = ggml_permute(ctx0,
  894. ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps),
  895. 2, 0, 1, 3);
  896. layer = ggml_add(ctx0,
  897. ggml_mul(ctx0,
  898. ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer),
  899. layer),
  900. ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));
  901. return layer;
  902. }
  903. struct ggml_cgraph * sam_encode_image(
  904. const sam_model & model,
  905. sam_state & state,
  906. const sam_image_f32 & img) {
  907. const auto & hparams = model.hparams;
  908. const auto & enc = model.enc_img;
  909. const int32_t n_enc_state = hparams.n_enc_state;
  910. const int32_t n_enc_layer = hparams.n_enc_layer;
  911. const int32_t n_enc_head = hparams.n_enc_head;
  912. const int32_t n_enc_head_dim = hparams.n_enc_head_dim();
  913. const int32_t n_enc_out_chans = hparams.n_enc_out_chans;
  914. const int32_t n_img_size = hparams.n_img_size();
  915. const int32_t n_window_size = hparams.n_window_size();
  916. struct ggml_init_params ggml_params = {
  917. /*.mem_size =*/ state.buf_compute_img_enc.size(),
  918. /*.mem_buffer =*/ state.buf_compute_img_enc.data(),
  919. /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
  920. };
  921. struct ggml_context * ctx0 = ggml_init(ggml_params);
  922. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  923. struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1);
  924. ggml_allocr_alloc(state.allocr, inp);
  925. if (!ggml_allocr_is_measure(state.allocr)) {
  926. float * data = (float *) ggml_get_data(inp);
  927. const int nx = img.nx;
  928. const int ny = img.ny;
  929. const int n = nx*ny;
  930. GGML_ASSERT(nx == n_img_size && ny == n_img_size);
  931. for (int k = 0; k < 3; k++) {
  932. for (int y = 0; y < ny; y++) {
  933. for (int x = 0; x < nx; x++) {
  934. data[k*n + y*nx + x] = img.data[3*(y*nx + x) + k];
  935. }
  936. }
  937. }
  938. }
  939. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392
  940. struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp);
  941. cur = ggml_add_inplace(ctx0,
  942. cur,
  943. ggml_repeat(ctx0, enc.proj_b, cur));
  944. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394
  945. // keep in F32
  946. cur = ggml_cont(ctx0,
  947. ggml_permute(ctx0, cur, 1, 2, 0, 3));
  948. // convert to F16
  949. //cur = ggml_cpy(ctx0,
  950. // ggml_permute(ctx0, cur, 1, 2, 0, 3),
  951. // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd));
  952. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109
  953. cur = ggml_add_inplace(ctx0, cur, enc.pe);
  954. struct ggml_tensor * inpL = cur;
  955. for (int il = 0; il < n_enc_layer; ++il) {
  956. const auto & layer = enc.layers[il];
  957. // norm
  958. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
  959. {
  960. cur = ggml_norm(ctx0, inpL, hparams.eps);
  961. // cur = ln_0_w*cur + ln_0_b
  962. cur = ggml_mul(ctx0, cur, layer.norm1_w);
  963. cur = ggml_add_inplace(ctx0, cur, layer.norm1_b);
  964. }
  965. const int64_t w0 = cur->ne[1];
  966. const int64_t h0 = cur->ne[2];
  967. if (hparams.is_global_attn(il) == false) {
  968. // local attention layer - apply window partition
  969. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
  970. cur = ggml_win_part(ctx0, cur, n_window_size);
  971. }
  972. const int64_t W = cur->ne[1];
  973. const int64_t H = cur->ne[2];
  974. // self-attention
  975. {
  976. cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
  977. cur = ggml_add_inplace(ctx0, cur, layer.qkv_b);
  978. // split qkv into separate tensors
  979. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229
  980. const int B = cur->ne[3];
  981. cur = ggml_reshape_4d(ctx0, cur, n_enc_state, 3, W*H, B);
  982. cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2));
  983. struct ggml_tensor * Q;
  984. struct ggml_tensor * K;
  985. struct ggml_tensor * V;
  986. Q = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 0*cur->nb[3]);
  987. Q = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, n_enc_head, W*H, B);
  988. Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  989. Q = ggml_reshape_3d(ctx0, Q, n_enc_head_dim, W*H, B*n_enc_head);
  990. K = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 1*cur->nb[3]);
  991. K = ggml_reshape_4d(ctx0, K, n_enc_head_dim, n_enc_head, W*H, B);
  992. K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  993. K = ggml_reshape_3d(ctx0, K, n_enc_head_dim, W*H, B*n_enc_head);
  994. V = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 2*cur->nb[3]);
  995. V = ggml_reshape_4d(ctx0, V, n_enc_head_dim, n_enc_head, W*H, B);
  996. V = ggml_cont (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed
  997. V = ggml_reshape_3d(ctx0, V, W*H, n_enc_head_dim, B*n_enc_head);
  998. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  999. struct ggml_tensor * KQ_scaled =
  1000. ggml_scale_inplace(ctx0,
  1001. KQ,
  1002. ggml_new_f32(ctx0, 1.0f/sqrtf(n_enc_head_dim))
  1003. );
  1004. struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W);
  1005. struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H);
  1006. struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, W, H, B*n_enc_head);
  1007. struct ggml_tensor * rel_w = ggml_cont(ctx0, ggml_permute(ctx0,
  1008. ggml_mul_mat(ctx0,
  1009. rw,
  1010. ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),
  1011. 0, 2, 1, 3));
  1012. struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
  1013. struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
  1014. struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
  1015. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
  1016. cur =
  1017. ggml_reshape_4d(ctx0,
  1018. ggml_cont(ctx0,
  1019. ggml_permute(ctx0,
  1020. ggml_reshape_4d(ctx0, KQV, n_enc_head_dim, W*H, n_enc_head, B),
  1021. 0, 2, 1, 3)),
  1022. n_enc_state, W, H, B);
  1023. cur = ggml_mul_mat(ctx0, layer.proj_w, cur);
  1024. cur = ggml_add_inplace(ctx0, cur, layer.proj_b);
  1025. }
  1026. if (hparams.is_global_attn(il) == false) {
  1027. // local attention layer - reverse window partition
  1028. cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size);
  1029. }
  1030. cur = ggml_add_inplace(ctx0, cur, inpL);
  1031. struct ggml_tensor * inpFF = cur;
  1032. // feed-forward network
  1033. {
  1034. // norm
  1035. {
  1036. cur = ggml_norm(ctx0, inpFF, hparams.eps);
  1037. // cur = mlp_ln_w*cur + mlp_ln_b
  1038. cur = ggml_mul(ctx0, cur, layer.norm2_w);
  1039. cur = ggml_add_inplace(ctx0, cur, layer.norm2_b);
  1040. }
  1041. // fully connected
  1042. cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur);
  1043. cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b);
  1044. // GELU activation
  1045. cur = ggml_gelu(ctx0, cur);
  1046. // projection
  1047. cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur);
  1048. cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b);
  1049. }
  1050. inpL = ggml_add(ctx0, cur, inpFF);
  1051. }
  1052. cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
  1053. cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
  1054. cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps);
  1055. cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);
  1056. cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);
  1057. cur = ggml_cpy(ctx0, cur, state.embd_img);
  1058. ggml_build_forward_expand(gf, cur);
  1059. ggml_disconnect_node_from_graph(state.embd_img);
  1060. //ggml_graph_print(&gf);
  1061. ggml_free(ctx0);
  1062. return gf;
  1063. }
  1064. struct prompt_encoder_result {
  1065. struct ggml_tensor * embd_prompt_sparse = {};
  1066. struct ggml_tensor * embd_prompt_dense = {};
  1067. };
  1068. // encode a prompt
  1069. //
  1070. // - points
  1071. // - boxes
  1072. // - masks
  1073. //
  1074. // TODO: currently just encode a single point for simplicity
  1075. //
  1076. prompt_encoder_result sam_encode_prompt(
  1077. const sam_model & model,
  1078. struct ggml_context * ctx0,
  1079. struct ggml_cgraph * gf,
  1080. sam_state & state,
  1081. int nx,
  1082. int ny,
  1083. sam_point point) {
  1084. const auto & hparams = model.hparams;
  1085. const auto & enc = model.enc_prompt;
  1086. // transform points
  1087. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py#L276
  1088. {
  1089. const int nmax = std::max(nx, ny);
  1090. const float scale = hparams.n_img_size() / (float) nmax;
  1091. const int nx_new = int(nx*scale + 0.5f);
  1092. const int ny_new = int(ny*scale + 0.5f);
  1093. point.x = point.x*(float(nx_new)/nx) + 0.5f;
  1094. point.y = point.y*(float(ny_new)/ny) + 0.5f;
  1095. }
  1096. struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2);
  1097. ggml_allocr_alloc(state.allocr, inp);
  1098. if (!ggml_allocr_is_measure(state.allocr)) {
  1099. // set the input by converting the [0, 1] coordinates to [-1, 1]
  1100. float * data = (float *) inp->data;
  1101. data[0] = 2.0f*(point.x / hparams.n_img_size()) - 1.0f;
  1102. data[1] = 2.0f*(point.y / hparams.n_img_size()) - 1.0f;
  1103. // padding
  1104. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L81-L85
  1105. data[2] = 2.0f*(0.0f) - 1.0f;
  1106. data[3] = 2.0f*(0.0f) - 1.0f;
  1107. }
  1108. struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), inp);
  1109. cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, float(2.0*M_PI)));
  1110. // concat
  1111. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
  1112. {
  1113. struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
  1114. struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);
  1115. cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);
  1116. ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0)));
  1117. ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1])));
  1118. // overwrite label == -1 with not_a_point_embed.weight
  1119. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86
  1120. // TODO: extend for multiple points
  1121. ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1])));
  1122. }
  1123. // add point_embeddings[1] to label == 1
  1124. // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90
  1125. struct ggml_tensor * v = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0);
  1126. ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_add_inplace(ctx0, v, enc.pt_embd[1]), v));
  1127. struct ggml_tensor * embd_prompt_sparse = cur;
  1128. ggml_build_forward_expand(gf, embd_prompt_sparse);
  1129. struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0,
  1130. ggml_cont(ctx0,
  1131. ggml_view_3d(ctx0, enc.no_mask_embd_w,
  1132. 1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)),
  1133. ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans));
  1134. ggml_build_forward_expand(gf, embd_prompt_dense);
  1135. //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
  1136. prompt_encoder_result res;
  1137. res.embd_prompt_sparse = embd_prompt_sparse;
  1138. res.embd_prompt_dense = embd_prompt_dense;
  1139. return res;
  1140. }
  1141. struct ggml_tensor* sam_decode_mask_transformer_attn(
  1142. const sam_layer_dec_transformer_attn & attn,
  1143. struct ggml_tensor * queries,
  1144. struct ggml_tensor * keys,
  1145. struct ggml_tensor * values,
  1146. struct ggml_context * ctx0,
  1147. const sam_model & model) {
  1148. const auto & hparams = model.hparams;
  1149. const int n_head = hparams.n_dec_heads;
  1150. struct ggml_tensor * Qcur = {};
  1151. struct ggml_tensor * Kcur = {};
  1152. struct ggml_tensor * Vcur = {};
  1153. Qcur = ggml_mul_mat(ctx0, attn.q_w, queries);
  1154. Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b);
  1155. Kcur = ggml_mul_mat(ctx0, attn.k_w, keys);
  1156. Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b);
  1157. Vcur = ggml_mul_mat(ctx0, attn.v_w, values);
  1158. Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b);
  1159. struct ggml_tensor * Q = {};
  1160. struct ggml_tensor * K = {};
  1161. struct ggml_tensor * V = {};
  1162. Q = ggml_reshape_4d(ctx0, Qcur, Qcur->ne[0]/n_head, n_head, Qcur->ne[1], Qcur->ne[2]);
  1163. Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
  1164. K = ggml_reshape_4d(ctx0, Kcur, Kcur->ne[0]/n_head, n_head, Kcur->ne[1], Kcur->ne[2]);
  1165. K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
  1166. V = ggml_reshape_4d(ctx0, Vcur, Vcur->ne[0]/n_head, n_head, Vcur->ne[1], Vcur->ne[2]);
  1167. V = ggml_cont(ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3));
  1168. // Q * K
  1169. struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
  1170. struct ggml_tensor * KQ_scaled =
  1171. ggml_scale_inplace(ctx0,
  1172. KQ,
  1173. ggml_new_f32(ctx0, 1.0f/sqrt(float(Q->ne[0]))));
  1174. struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled);
  1175. struct ggml_tensor * KQV = ggml_mul_mat(ctx0, KQ_soft_max, ggml_cont(ctx0, ggml_transpose(ctx0, V)));
  1176. struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV));
  1177. KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));
  1178. KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);
  1179. KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);
  1180. KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b);
  1181. return KQV_merged;
  1182. }
  1183. struct ggml_tensor * sam_decode_mask_mlp_relu_3(
  1184. struct ggml_tensor * in,
  1185. struct ggml_tensor * w_0,
  1186. struct ggml_tensor * b_0,
  1187. struct ggml_tensor * w_1,
  1188. struct ggml_tensor * b_1,
  1189. struct ggml_tensor * w_2,
  1190. struct ggml_tensor * b_2,
  1191. struct ggml_context * ctx0) {
  1192. struct ggml_tensor * cur = {};
  1193. cur = ggml_mul_mat(ctx0, w_0, in);
  1194. cur = ggml_add_inplace(ctx0, cur, b_0);
  1195. cur = ggml_relu_inplace(ctx0, cur);
  1196. cur = ggml_mul_mat(ctx0, w_1, cur);
  1197. cur = ggml_add_inplace(ctx0, cur, b_1);
  1198. cur = ggml_relu_inplace(ctx0, cur);
  1199. cur = ggml_mul_mat(ctx0, w_2, cur);
  1200. cur = ggml_add_inplace(ctx0, cur, b_2);
  1201. return cur;
  1202. }
  1203. bool sam_decode_mask(
  1204. const sam_model & model,
  1205. const prompt_encoder_result & prompt,
  1206. struct ggml_tensor * pe_img,
  1207. struct ggml_context * ctx0,
  1208. struct ggml_cgraph * gf,
  1209. sam_state & state) {
  1210. const auto & hparams = model.hparams;
  1211. const auto & dec = model.dec;
  1212. const int n_img_embd = hparams.n_img_embd();
  1213. struct ggml_tensor * tokens = {};
  1214. {
  1215. // Concatenate output tokens
  1216. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120
  1217. const auto& sparse = prompt.embd_prompt_sparse;
  1218. tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]);
  1219. const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] };
  1220. ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1], tokens->nb[1], offsets[0])));
  1221. ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1])));
  1222. ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse, ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1], tokens->nb[1], offsets[2])));
  1223. // TODO: Sparse prompt embeddings can have more than one point
  1224. }
  1225. struct ggml_tensor * src = {};
  1226. struct ggml_tensor * pos_src = {};
  1227. int srcNE[4] = { 0, 0, 0, 0 };
  1228. {
  1229. // Expand per-image data in the batch direction to be per-mask
  1230. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125
  1231. src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]);
  1232. src = ggml_add(ctx0,
  1233. ggml_repeat(ctx0,
  1234. state.embd_img,
  1235. src),
  1236. prompt.embd_prompt_dense);
  1237. srcNE[0] = src->ne[0];
  1238. srcNE[1] = src->ne[1];
  1239. srcNE[2] = src->ne[2];
  1240. srcNE[3] = src->ne[3];
  1241. // flatten & permute
  1242. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83
  1243. src = ggml_cont(ctx0, ggml_permute(ctx0,
  1244. ggml_view_3d(ctx0,
  1245. src,
  1246. src->ne[0]*src->ne[1],
  1247. src->ne[2],
  1248. src->ne[3],
  1249. src->nb[2],
  1250. src->nb[3],
  1251. 0),
  1252. 1, 0, 2, 3));
  1253. pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]);
  1254. pos_src = ggml_repeat(ctx0,
  1255. pe_img,
  1256. pos_src);
  1257. // flatten & permute
  1258. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83
  1259. pos_src = ggml_cont(ctx0, ggml_permute(ctx0,
  1260. ggml_view_3d(ctx0,
  1261. pos_src,
  1262. pos_src->ne[0]*pos_src->ne[1],
  1263. pos_src->ne[2],
  1264. pos_src->ne[3],
  1265. pos_src->nb[2],
  1266. pos_src->nb[3],
  1267. 0),
  1268. 1, 0, 2, 3));
  1269. }
  1270. struct ggml_tensor * queries = tokens;
  1271. struct ggml_tensor * keys = src;
  1272. {
  1273. // Run the transformer
  1274. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L62
  1275. for (int i = 0; i < int(model.dec.transformer_layers.size()); ++i) {
  1276. const auto& tfm_layer = model.dec.transformer_layers[i];
  1277. // Self attention block
  1278. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L154
  1279. const bool skip_first_layer_pe = i == 0;
  1280. if (skip_first_layer_pe) {
  1281. queries = sam_decode_mask_transformer_attn(tfm_layer.self_attn, queries, queries, queries, ctx0, model);
  1282. }
  1283. else {
  1284. struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens);
  1285. struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model);
  1286. queries = ggml_add(ctx0, queries, self_attn);
  1287. }
  1288. queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
  1289. queries = ggml_add_inplace(ctx0,
  1290. ggml_mul(ctx0, queries, tfm_layer.norm1_w),
  1291. tfm_layer.norm1_b);
  1292. // Cross attention block, tokens attending to image embedding
  1293. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163
  1294. struct ggml_tensor * q_1 = ggml_add(ctx0, queries, tokens);
  1295. struct ggml_tensor * k_1 = ggml_add(ctx0, keys, pos_src);
  1296. struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model);
  1297. queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img);
  1298. queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
  1299. queries = ggml_add_inplace(ctx0,
  1300. ggml_mul(ctx0, queries, tfm_layer.norm2_w),
  1301. tfm_layer.norm2_b);
  1302. // MLP block
  1303. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170
  1304. struct ggml_tensor * mlp_out = ggml_mul_mat(ctx0,
  1305. tfm_layer.mlp_lin1_w,
  1306. queries);
  1307. mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b);
  1308. // RELU activation
  1309. mlp_out = ggml_relu_inplace(ctx0, mlp_out);
  1310. mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out);
  1311. mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b);
  1312. queries = ggml_add_inplace(ctx0, queries, mlp_out);
  1313. queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
  1314. queries = ggml_add_inplace(ctx0,
  1315. ggml_mul(ctx0, queries, tfm_layer.norm3_w),
  1316. tfm_layer.norm3_b);
  1317. // Cross attention block, image embedding attending to tokens
  1318. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175
  1319. struct ggml_tensor * q_2 = ggml_add(ctx0, queries, tokens);
  1320. struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src);
  1321. struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model);
  1322. keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token);
  1323. keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer);
  1324. keys = ggml_add_inplace(ctx0,
  1325. ggml_mul(ctx0, keys, tfm_layer.norm4_w),
  1326. tfm_layer.norm4_b);
  1327. }
  1328. // Apply the final attention layer from the points to the image
  1329. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L99
  1330. struct ggml_tensor * q = ggml_add(ctx0, queries, tokens);
  1331. struct ggml_tensor * k = ggml_add(ctx0, keys, pos_src);
  1332. struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model);
  1333. queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img);
  1334. queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
  1335. queries = ggml_add_inplace(ctx0,
  1336. ggml_mul(ctx0, queries, dec.transformer_norm_final_w),
  1337. dec.transformer_norm_final_b);
  1338. }
  1339. struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0);
  1340. const int num_mask_tokens = 4; // num_multimask_outputs + 1
  1341. struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]);
  1342. // Upscale mask embeddings and predict masks using the mask tokens
  1343. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136
  1344. keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys));
  1345. keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0);
  1346. // ggml_build_forward_expand(gf, keys);
  1347. struct ggml_tensor * upscaled_embedding = {};
  1348. {
  1349. // ConvTranspose2d
  1350. keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
  1351. ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
  1352. keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0,
  1353. ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),
  1354. keys));
  1355. keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps);
  1356. // GELU activation
  1357. keys = ggml_gelu_inplace(ctx0, keys);
  1358. // ConvTranspose2d
  1359. keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);
  1360. ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
  1361. keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0,
  1362. ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),
  1363. keys), keys);
  1364. // GELU activation
  1365. keys = ggml_gelu_inplace(ctx0, keys);
  1366. upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]);
  1367. upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed
  1368. }
  1369. struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]);
  1370. for (int i = 0; i < num_mask_tokens; ++i) {
  1371. const auto& mlp = dec.output_hypernet_mlps[i];
  1372. struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]);
  1373. struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0);
  1374. ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1])));
  1375. }
  1376. struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);
  1377. masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed
  1378. masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]);
  1379. // Generate mask quality predictions
  1380. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146
  1381. iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0);
  1382. // Select the correct mask or masks for output
  1383. // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101
  1384. iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions);
  1385. masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3],
  1386. masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/);
  1387. masks = ggml_cpy(state.ctx, masks, state.low_res_masks);
  1388. ggml_build_forward_expand(gf, masks);
  1389. ggml_build_forward_expand(gf, iou_pred);
  1390. ggml_disconnect_node_from_graph(state.low_res_masks);
  1391. ggml_disconnect_node_from_graph(state.iou_predictions);
  1392. return true;
  1393. }
  1394. bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
  1395. if (state.low_res_masks->ne[2] == 0) return true;
  1396. if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
  1397. printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
  1398. return false;
  1399. }
  1400. const int n_img_size = hparams.n_img_size();
  1401. const float mask_threshold = hparams.mask_threshold;
  1402. const float iou_threshold = hparams.iou_threshold;
  1403. const float stability_score_threshold = hparams.stability_score_threshold;
  1404. const float intersection_threshold = mask_threshold + hparams.stability_score_offset;
  1405. const float union_threshold = mask_threshold - hparams.stability_score_offset;
  1406. const int ne0 = state.low_res_masks->ne[0];
  1407. const int ne1 = state.low_res_masks->ne[1];
  1408. const int ne2 = state.low_res_masks->ne[2];
  1409. // Remove padding and upscale masks to the original image size.
  1410. // ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L140
  1411. const float preprocess_scale = std::max(nx, ny) / float(n_img_size);
  1412. const int cropped_nx = int(nx / preprocess_scale + 0.5f);
  1413. const int cropped_ny = int(ny / preprocess_scale + 0.5f);
  1414. const float scale_x_1 = (float)ne0 / (float)n_img_size;
  1415. const float scale_y_1 = (float)ne1 / (float)n_img_size;
  1416. const float scale_x_2 = float(cropped_nx) / float(nx);
  1417. const float scale_y_2 = float(cropped_ny) / float(ny);
  1418. const auto iou_data = (float*)state.iou_predictions->data;
  1419. for (int i = 0; i < ne2; ++i) {
  1420. if (iou_threshold > 0.f && iou_data[i] < iou_threshold) {
  1421. printf("Skipping mask %d with iou %f below threshold %f\n", i, iou_data[i], iou_threshold);
  1422. continue; // Filtering masks with iou below the threshold
  1423. }
  1424. std::vector<float> mask_data(n_img_size*n_img_size);
  1425. {
  1426. const float* data = (float *) state.low_res_masks->data + i*ne0*ne1;
  1427. for (int iy = 0; iy < n_img_size; ++iy) {
  1428. for (int ix = 0; ix < n_img_size; ++ix) {
  1429. const float sx = std::max(scale_x_1*(ix + 0.5f) - 0.5f, 0.0f);
  1430. const float sy = std::max(scale_y_1*(iy + 0.5f) - 0.5f, 0.0f);
  1431. const int x0 = std::max(0, (int)sx);
  1432. const int y0 = std::max(0, (int)sy);
  1433. const int x1 = std::min(x0 + 1, ne0 - 1);
  1434. const int y1 = std::min(y0 + 1, ne1 - 1);
  1435. const float dx = sx - x0;
  1436. const float dy = sy - y0;
  1437. const int j00 = y0*ne0 + x0;
  1438. const int j01 = y0*ne0 + x1;
  1439. const int j10 = y1*ne0 + x0;
  1440. const int j11 = y1*ne0 + x1;
  1441. const float v00 = data[j00];
  1442. const float v01 = data[j01];
  1443. const float v10 = data[j10];
  1444. const float v11 = data[j11];
  1445. const float v0 = (1-dx)*v00 + dx*v01;
  1446. const float v1 = (1-dx)*v10 + dx*v11;
  1447. const float v = (1-dy)*v0 + dy*v1;
  1448. mask_data[iy*n_img_size + ix] = v;
  1449. }
  1450. }
  1451. }
  1452. int intersections = 0;
  1453. int unions = 0;
  1454. sam_image_u8 res;
  1455. int min_iy = ny;
  1456. int max_iy = 0;
  1457. int min_ix = nx;
  1458. int max_ix = 0;
  1459. {
  1460. const float* data = mask_data.data();
  1461. res.nx = nx;
  1462. res.ny = ny;
  1463. res.data.resize(nx*ny);
  1464. for (int iy = 0; iy < ny; ++iy) {
  1465. for (int ix = 0; ix < nx; ++ix) {
  1466. const float sx = std::max(scale_x_2*(ix + 0.5f) - 0.5f, 0.0f);
  1467. const float sy = std::max(scale_y_2*(iy + 0.5f) - 0.5f, 0.0f);
  1468. const int x0 = std::max(0, (int)sx);
  1469. const int y0 = std::max(0, (int)sy);
  1470. const int x1 = std::min(x0 + 1, cropped_nx - 1);
  1471. const int y1 = std::min(y0 + 1, cropped_ny - 1);
  1472. const float dx = sx - x0;
  1473. const float dy = sy - y0;
  1474. const int j00 = y0*n_img_size + x0;
  1475. const int j01 = y0*n_img_size + x1;
  1476. const int j10 = y1*n_img_size + x0;
  1477. const int j11 = y1*n_img_size + x1;
  1478. const float v00 = data[j00];
  1479. const float v01 = data[j01];
  1480. const float v10 = data[j10];
  1481. const float v11 = data[j11];
  1482. const float v0 = (1-dx)*v00 + dx*v01;
  1483. const float v1 = (1-dx)*v10 + dx*v11;
  1484. const float v = (1-dy)*v0 + dy*v1;
  1485. if (v > intersection_threshold) {
  1486. intersections++;
  1487. }
  1488. if (v > union_threshold) {
  1489. unions++;
  1490. }
  1491. if (v > mask_threshold) {
  1492. min_iy = std::min(min_iy, iy);
  1493. max_iy = std::max(max_iy, iy);
  1494. min_ix = std::min(min_ix, ix);
  1495. max_ix = std::max(max_ix, ix);
  1496. res.data[iy*nx + ix] = 255;
  1497. }
  1498. }
  1499. }
  1500. }
  1501. const float stability_score = float(intersections) / float(unions);
  1502. if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) {
  1503. printf("Skipping mask %d with stability score %f below threshold %f\n", i, stability_score, stability_score_threshold);
  1504. continue; // Filtering masks with stability score below the threshold
  1505. }
  1506. printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n",
  1507. i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);
  1508. std::string filename = "mask_out_" + std::to_string(i) + ".png";
  1509. if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {
  1510. printf("%s: failed to write mask %s\n", __func__, filename.c_str());
  1511. return false;
  1512. }
  1513. }
  1514. return true;
  1515. }
  1516. struct ggml_cgraph * sam_build_fast_graph(
  1517. const sam_model & model,
  1518. sam_state & state,
  1519. int nx,
  1520. int ny,
  1521. sam_point point) {
  1522. struct ggml_init_params ggml_params = {
  1523. /*.mem_size =*/ state.buf_compute_fast.size(),
  1524. /*.mem_buffer =*/ state.buf_compute_fast.data(),
  1525. /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
  1526. };
  1527. struct ggml_context * ctx0 = ggml_init(ggml_params);
  1528. struct ggml_cgraph * gf = ggml_new_graph(ctx0);
  1529. prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
  1530. if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
  1531. fprintf(stderr, "%s: failed to encode prompt\n", __func__);
  1532. return {};
  1533. }
  1534. struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state);
  1535. if (!pe_img_dense) {
  1536. fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__);
  1537. return {};
  1538. }
  1539. if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state)) {
  1540. fprintf(stderr, "%s: failed to decode mask\n", __func__);
  1541. return {};
  1542. }
  1543. ggml_free(ctx0);
  1544. return gf;
  1545. }
  1546. struct sam_params {
  1547. int32_t seed = -1; // RNG seed
  1548. int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
  1549. std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
  1550. std::string fname_inp = "img.jpg";
  1551. std::string fname_out = "img.out";
  1552. };
  1553. void sam_print_usage(int argc, char ** argv, const sam_params & params) {
  1554. fprintf(stderr, "usage: %s [options]\n", argv[0]);
  1555. fprintf(stderr, "\n");
  1556. fprintf(stderr, "options:\n");
  1557. fprintf(stderr, " -h, --help show this help message and exit\n");
  1558. fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
  1559. fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
  1560. fprintf(stderr, " -m FNAME, --model FNAME\n");
  1561. fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
  1562. fprintf(stderr, " -i FNAME, --inp FNAME\n");
  1563. fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
  1564. fprintf(stderr, " -o FNAME, --out FNAME\n");
  1565. fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
  1566. fprintf(stderr, "\n");
  1567. }
  1568. bool sam_params_parse(int argc, char ** argv, sam_params & params) {
  1569. for (int i = 1; i < argc; i++) {
  1570. std::string arg = argv[i];
  1571. if (arg == "-s" || arg == "--seed") {
  1572. params.seed = std::stoi(argv[++i]);
  1573. } else if (arg == "-t" || arg == "--threads") {
  1574. params.n_threads = std::stoi(argv[++i]);
  1575. } else if (arg == "-m" || arg == "--model") {
  1576. params.model = argv[++i];
  1577. } else if (arg == "-i" || arg == "--inp") {
  1578. params.fname_inp = argv[++i];
  1579. } else if (arg == "-o" || arg == "--out") {
  1580. params.fname_out = argv[++i];
  1581. } else if (arg == "-h" || arg == "--help") {
  1582. sam_print_usage(argc, argv, params);
  1583. exit(0);
  1584. } else {
  1585. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  1586. sam_print_usage(argc, argv, params);
  1587. exit(0);
  1588. }
  1589. }
  1590. return true;
  1591. }
  1592. int main(int argc, char ** argv) {
  1593. const int64_t t_main_start_us = ggml_time_us();
  1594. sam_params params;
  1595. params.model = "models/sam-vit-b/ggml-model-f16.bin";
  1596. sam_model model;
  1597. sam_state state;
  1598. int64_t t_load_us = 0;
  1599. if (sam_params_parse(argc, argv, params) == false) {
  1600. return 1;
  1601. }
  1602. if (params.seed < 0) {
  1603. params.seed = time(NULL);
  1604. }
  1605. fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
  1606. // load the image
  1607. sam_image_u8 img0;
  1608. if (!sam_image_load_from_file(params.fname_inp, img0)) {
  1609. fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str());
  1610. return 1;
  1611. }
  1612. fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny);
  1613. // preprocess to f32
  1614. sam_image_f32 img1;
  1615. if (!sam_image_preprocess(img0, img1)) {
  1616. fprintf(stderr, "%s: failed to preprocess image\n", __func__);
  1617. return 1;
  1618. }
  1619. fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny);
  1620. // load the model
  1621. {
  1622. const int64_t t_start_us = ggml_time_us();
  1623. if (!sam_model_load(params.model, model)) {
  1624. fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
  1625. return 1;
  1626. }
  1627. t_load_us = ggml_time_us() - t_start_us;
  1628. }
  1629. {
  1630. static size_t buf_size = 256u*1024*1024;
  1631. struct ggml_init_params ggml_params = {
  1632. /*.mem_size =*/ buf_size,
  1633. /*.mem_buffer =*/ NULL,
  1634. /*.no_alloc =*/ false,
  1635. };
  1636. state.ctx = ggml_init(ggml_params);
  1637. state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
  1638. model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
  1639. state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
  1640. model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3);
  1641. state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3);
  1642. }
  1643. static const size_t tensor_alignment = 32;
  1644. {
  1645. state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
  1646. state.allocr = ggml_allocr_new_measure(tensor_alignment);
  1647. struct ggml_cgraph * gf_measure = sam_encode_image(model, state, img1);
  1648. if (!gf_measure) {
  1649. fprintf(stderr, "%s: failed to encode image\n", __func__);
  1650. return 1;
  1651. }
  1652. size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
  1653. ggml_allocr_free(state.allocr);
  1654. // recreate allocator with exact memory requirements
  1655. state.buf_alloc_img_enc.resize(alloc_size);
  1656. state.allocr = ggml_allocr_new(state.buf_alloc_img_enc.data(), state.buf_alloc_img_enc.size(), tensor_alignment);
  1657. // compute the graph with the measured exact memory requirements from above
  1658. ggml_allocr_reset(state.allocr);
  1659. struct ggml_cgraph * gf = sam_encode_image(model, state, img1);
  1660. if (!gf) {
  1661. fprintf(stderr, "%s: failed to encode image\n", __func__);
  1662. return 1;
  1663. }
  1664. ggml_allocr_alloc_graph(state.allocr, gf);
  1665. ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
  1666. print_t_f32("embd_img", state.embd_img);
  1667. ggml_allocr_free(state.allocr);
  1668. state.allocr = NULL;
  1669. state.work_buffer.clear();
  1670. }
  1671. {
  1672. state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
  1673. state.allocr = ggml_allocr_new_measure(tensor_alignment);
  1674. // TODO: user input
  1675. const sam_point pt = { 414.375f, 162.796875f, };
  1676. // measure memory requirements for the graph
  1677. struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
  1678. if (!gf_measure) {
  1679. fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
  1680. return 1;
  1681. }
  1682. size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
  1683. ggml_allocr_free(state.allocr);
  1684. // recreate allocator with exact memory requirements
  1685. state.buf_alloc_fast.resize(alloc_size);
  1686. state.allocr = ggml_allocr_new(state.buf_alloc_fast.data(), state.buf_alloc_fast.size(), tensor_alignment);
  1687. // compute the graph with the measured exact memory requirements from above
  1688. ggml_allocr_reset(state.allocr);
  1689. struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
  1690. if (!gf) {
  1691. fprintf(stderr, "%s: failed to build fast graph\n", __func__);
  1692. return 1;
  1693. }
  1694. ggml_allocr_alloc_graph(state.allocr, gf);
  1695. ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
  1696. //print_t_f32("iou_predictions", state.iou_predictions);
  1697. //print_t_f32("low_res_masks", state.low_res_masks);
  1698. ggml_allocr_free(state.allocr);
  1699. state.allocr = NULL;
  1700. }
  1701. if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
  1702. fprintf(stderr, "%s: failed to write masks\n", __func__);
  1703. return 1;
  1704. }
  1705. // report timing
  1706. {
  1707. const int64_t t_main_end_us = ggml_time_us();
  1708. fprintf(stderr, "\n\n");
  1709. fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
  1710. fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
  1711. }
  1712. ggml_free(model.ctx);
  1713. return 0;
  1714. }