convert-pt-to-ggml.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # Convert Whisper transformer model from PyTorch to ggml format
  2. #
  3. # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
  4. #
  5. # You need to clone the original repo in ~/path/to/repo/whisper/
  6. #
  7. # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
  8. #
  9. # It is used to various assets needed by the algorithm:
  10. #
  11. # - tokenizer
  12. # - mel filters
  13. #
  14. # Also, you need to have the original models in ~/.cache/whisper/
  15. # See the original repo for more details.
  16. #
  17. # This script loads the specified model and whisper assets and saves them in ggml format.
  18. # The output is a single binary file containing the following information:
  19. #
  20. # - hparams
  21. # - mel filters
  22. # - tokenizer vocab
  23. # - model variables
  24. #
  25. # For each variable, write the following:
  26. #
  27. # - Number of dimensions (int)
  28. # - Name length (int)
  29. # - Dimensions (int[n_dims])
  30. # - Name (char[name_length])
  31. # - Data (float[n_dims])
  32. #
  33. import io
  34. import os
  35. import sys
  36. import struct
  37. import json
  38. import code
  39. import torch
  40. import numpy as np
  41. import base64
  42. from pathlib import Path
  43. #from transformers import GPTJForCausalLM
  44. #from transformers import GPT2TokenizerFast
  45. # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
  46. #LANGUAGES = {
  47. # "en": "english",
  48. # "zh": "chinese",
  49. # "de": "german",
  50. # "es": "spanish",
  51. # "ru": "russian",
  52. # "ko": "korean",
  53. # "fr": "french",
  54. # "ja": "japanese",
  55. # "pt": "portuguese",
  56. # "tr": "turkish",
  57. # "pl": "polish",
  58. # "ca": "catalan",
  59. # "nl": "dutch",
  60. # "ar": "arabic",
  61. # "sv": "swedish",
  62. # "it": "italian",
  63. # "id": "indonesian",
  64. # "hi": "hindi",
  65. # "fi": "finnish",
  66. # "vi": "vietnamese",
  67. # "iw": "hebrew",
  68. # "uk": "ukrainian",
  69. # "el": "greek",
  70. # "ms": "malay",
  71. # "cs": "czech",
  72. # "ro": "romanian",
  73. # "da": "danish",
  74. # "hu": "hungarian",
  75. # "ta": "tamil",
  76. # "no": "norwegian",
  77. # "th": "thai",
  78. # "ur": "urdu",
  79. # "hr": "croatian",
  80. # "bg": "bulgarian",
  81. # "lt": "lithuanian",
  82. # "la": "latin",
  83. # "mi": "maori",
  84. # "ml": "malayalam",
  85. # "cy": "welsh",
  86. # "sk": "slovak",
  87. # "te": "telugu",
  88. # "fa": "persian",
  89. # "lv": "latvian",
  90. # "bn": "bengali",
  91. # "sr": "serbian",
  92. # "az": "azerbaijani",
  93. # "sl": "slovenian",
  94. # "kn": "kannada",
  95. # "et": "estonian",
  96. # "mk": "macedonian",
  97. # "br": "breton",
  98. # "eu": "basque",
  99. # "is": "icelandic",
  100. # "hy": "armenian",
  101. # "ne": "nepali",
  102. # "mn": "mongolian",
  103. # "bs": "bosnian",
  104. # "kk": "kazakh",
  105. # "sq": "albanian",
  106. # "sw": "swahili",
  107. # "gl": "galician",
  108. # "mr": "marathi",
  109. # "pa": "punjabi",
  110. # "si": "sinhala",
  111. # "km": "khmer",
  112. # "sn": "shona",
  113. # "yo": "yoruba",
  114. # "so": "somali",
  115. # "af": "afrikaans",
  116. # "oc": "occitan",
  117. # "ka": "georgian",
  118. # "be": "belarusian",
  119. # "tg": "tajik",
  120. # "sd": "sindhi",
  121. # "gu": "gujarati",
  122. # "am": "amharic",
  123. # "yi": "yiddish",
  124. # "lo": "lao",
  125. # "uz": "uzbek",
  126. # "fo": "faroese",
  127. # "ht": "haitian creole",
  128. # "ps": "pashto",
  129. # "tk": "turkmen",
  130. # "nn": "nynorsk",
  131. # "mt": "maltese",
  132. # "sa": "sanskrit",
  133. # "lb": "luxembourgish",
  134. # "my": "myanmar",
  135. # "bo": "tibetan",
  136. # "tl": "tagalog",
  137. # "mg": "malagasy",
  138. # "as": "assamese",
  139. # "tt": "tatar",
  140. # "haw": "hawaiian",
  141. # "ln": "lingala",
  142. # "ha": "hausa",
  143. # "ba": "bashkir",
  144. # "jw": "javanese",
  145. # "su": "sundanese",
  146. #}
  147. ## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
  148. #def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
  149. # os.environ["TOKENIZERS_PARALLELISM"] = "false"
  150. # path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
  151. # tokenizer = GPT2TokenizerFast.from_pretrained(path)
  152. #
  153. # specials = [
  154. # "<|startoftranscript|>",
  155. # *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  156. # "<|translate|>",
  157. # "<|transcribe|>",
  158. # "<|startoflm|>",
  159. # "<|startofprev|>",
  160. # "<|nocaptions|>",
  161. # "<|notimestamps|>",
  162. # ]
  163. #
  164. # tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
  165. # return tokenizer
  166. # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
  167. def bytes_to_unicode():
  168. """
  169. Returns list of utf-8 byte and a corresponding list of unicode strings.
  170. The reversible bpe codes work on unicode strings.
  171. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  172. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  173. This is a signficant percentage of your normal, say, 32K bpe vocab.
  174. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  175. And avoids mapping to whitespace/control characters the bpe code barfs on.
  176. """
  177. bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  178. cs = bs[:]
  179. n = 0
  180. for b in range(2**8):
  181. if b not in bs:
  182. bs.append(b)
  183. cs.append(2**8+n)
  184. n += 1
  185. cs = [chr(n) for n in cs]
  186. return dict(zip(bs, cs))
  187. if len(sys.argv) < 4:
  188. print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
  189. sys.exit(1)
  190. fname_inp = Path(sys.argv[1])
  191. dir_whisper = Path(sys.argv[2])
  192. dir_out = Path(sys.argv[3])
  193. # try to load PyTorch binary data
  194. try:
  195. model_bytes = open(fname_inp, "rb").read()
  196. with io.BytesIO(model_bytes) as fp:
  197. checkpoint = torch.load(fp, map_location="cpu")
  198. except Exception:
  199. print("Error: failed to load PyTorch model file:" , fname_inp)
  200. sys.exit(1)
  201. hparams = checkpoint["dims"]
  202. print("hparams:", hparams)
  203. list_vars = checkpoint["model_state_dict"]
  204. #print(list_vars['encoder.positional_embedding'])
  205. #print(list_vars['encoder.conv1.weight'])
  206. #print(list_vars['encoder.conv1.weight'].shape)
  207. # load mel filters
  208. n_mels = hparams["n_mels"]
  209. with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
  210. filters = torch.from_numpy(f[f"mel_{n_mels}"])
  211. #print (filters)
  212. #code.interact(local=locals())
  213. # load tokenizer
  214. # for backwards compatibility, also check for older hf_transformers format tokenizer files
  215. # old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
  216. # new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
  217. multilingual = hparams["n_vocab"] == 51865
  218. tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
  219. tokenizer_type = "tiktoken"
  220. if not tokenizer.is_file():
  221. tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
  222. tokenizer_type = "hf_transformers"
  223. if not tokenizer.is_file():
  224. print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
  225. sys.exit(1)
  226. byte_encoder = bytes_to_unicode()
  227. byte_decoder = {v:k for k, v in byte_encoder.items()}
  228. if tokenizer_type == "tiktoken":
  229. with open(tokenizer, "rb") as f:
  230. contents = f.read()
  231. tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
  232. elif tokenizer_type == "hf_transformers":
  233. with open(tokenizer, "r", encoding="utf8") as f:
  234. _tokens_raw = json.load(f)
  235. if '<|endoftext|>' in _tokens_raw:
  236. # ensures exact same model as tokenizer_type == tiktoken
  237. # details: https://github.com/ggerganov/whisper.cpp/pull/725
  238. del _tokens_raw['<|endoftext|>']
  239. tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
  240. # output in the same directory as the model
  241. fname_out = dir_out / "ggml-model.bin"
  242. # use 16-bit or 32-bit floats
  243. use_f16 = True
  244. if len(sys.argv) > 4:
  245. use_f16 = False
  246. fname_out = dir_out / "ggml-model-f32.bin"
  247. fout = fname_out.open("wb")
  248. fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
  249. fout.write(struct.pack("i", hparams["n_vocab"]))
  250. fout.write(struct.pack("i", hparams["n_audio_ctx"]))
  251. fout.write(struct.pack("i", hparams["n_audio_state"]))
  252. fout.write(struct.pack("i", hparams["n_audio_head"]))
  253. fout.write(struct.pack("i", hparams["n_audio_layer"]))
  254. fout.write(struct.pack("i", hparams["n_text_ctx"]))
  255. fout.write(struct.pack("i", hparams["n_text_state"]))
  256. fout.write(struct.pack("i", hparams["n_text_head"]))
  257. fout.write(struct.pack("i", hparams["n_text_layer"]))
  258. fout.write(struct.pack("i", hparams["n_mels"]))
  259. fout.write(struct.pack("i", use_f16))
  260. # write mel filters
  261. fout.write(struct.pack("i", filters.shape[0]))
  262. fout.write(struct.pack("i", filters.shape[1]))
  263. for i in range(filters.shape[0]):
  264. for j in range(filters.shape[1]):
  265. fout.write(struct.pack("f", filters[i][j]))
  266. # write tokenizer
  267. fout.write(struct.pack("i", len(tokens)))
  268. for key in tokens:
  269. fout.write(struct.pack("i", len(key)))
  270. fout.write(key)
  271. for name in list_vars.keys():
  272. data = list_vars[name].squeeze().numpy()
  273. print("Processing variable: " , name , " with shape: ", data.shape)
  274. # reshape conv bias from [n] to [n, 1]
  275. if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
  276. data = data.reshape(data.shape[0], 1)
  277. print(f" Reshaped variable: {name} to shape: ", data.shape)
  278. n_dims = len(data.shape)
  279. # looks like the whisper models are in f16 by default
  280. # so we need to convert the small tensors to f32 until we fully support f16 in ggml
  281. # ftype == 0 -> float32, ftype == 1 -> float16
  282. ftype = 1
  283. if use_f16:
  284. if n_dims < 2 or \
  285. name == "encoder.conv1.bias" or \
  286. name == "encoder.conv2.bias" or \
  287. name == "encoder.positional_embedding" or \
  288. name == "decoder.positional_embedding":
  289. print(" Converting to float32")
  290. data = data.astype(np.float32)
  291. ftype = 0
  292. else:
  293. data = data.astype(np.float32)
  294. ftype = 0
  295. #if name.startswith("encoder"):
  296. # if name.endswith("mlp.0.weight") or \
  297. # name.endswith("mlp.2.weight"):
  298. # print(" Transposing")
  299. # data = data.transpose()
  300. # header
  301. str_ = name.encode('utf-8')
  302. fout.write(struct.pack("iii", n_dims, len(str_), ftype))
  303. for i in range(n_dims):
  304. fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
  305. fout.write(str_)
  306. # data
  307. data.tofile(fout)
  308. fout.close()
  309. print("Done. Output file: " , fname_out)
  310. print("")