ggml_convert.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # MIT_LICENSE file in the root directory of this source tree.
  5. import dataclasses
  6. import logging
  7. import math
  8. import struct
  9. from enum import Enum
  10. from io import BufferedWriter
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set, final
  13. import torch
  14. from fairseq2.assets import AssetCard
  15. from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
  16. from fairseq2.nn import SinusoidalPositionEncoder
  17. from fairseq2.nn.transformer import RelativePositionalEncoding
  18. from seamless_communication.models import unity
  19. from fairseq2.data.text import SentencePieceTokenizerBase
  20. from fairseq2.data.typing import PathLike
  21. from typing import Sequence
  22. from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizerBase
  23. from fairseq2.typing import Device, finaloverride
  24. from fairseq2.models.utils import TokenizerLoaderBase
  25. from fairseq2.assets import asset_store, download_manager
  26. from seamless_communication.models.unity.builder import UnitYConfig, create_unity_model
  27. from fairseq2.models.utils import ModelLoader
  28. from seamless_communication.models.unity.model import UnitYModel
  29. import ggml
  30. Preprocessor = Callable[[Any], Any]
  31. SMALLER_MODELS = [
  32. "unity_nano",
  33. "unity_micro",
  34. ] # Trained with fairseq2, with custom dict (not original NLLB ones)
  35. @final
  36. class NllbLikeTokenizer(SentencePieceTokenizerBase):
  37. """The only difference between this class and NllbTokenizer is it doesn't add a <pad> to control symbol list.
  38. Since NllbTokenizer is defined as final, we couldn't inherit from it directly. So copying ~everything"""
  39. langs: Set[str]
  40. default_lang: str
  41. def __init__(
  42. self, pathname: PathLike, langs: Sequence[str], default_lang: str
  43. ) -> None:
  44. """
  45. :param pathname:
  46. The pathname of the SentencePiece model file.
  47. :param langs:
  48. The list of supported languages.
  49. :param default_lang:
  50. The fall-back language if no language is specified.
  51. """
  52. # Each language is represented by a `__lang__` control symbol.
  53. control_symbols = [f"__{lang}__" for lang in langs]
  54. # Internal control symbols that are not relevant for eval use.
  55. control_symbols.extend(["<MINED_DATA>", "<MMT_BT_DATA>", "<SMT_BT_DATA>"])
  56. super().__init__(pathname, control_symbols)
  57. self.langs = set(langs)
  58. self.default_lang = default_lang
  59. @finaloverride
  60. def create_encoder(
  61. self,
  62. *,
  63. task: Optional[str] = None,
  64. lang: Optional[str] = None,
  65. mode: Optional[str] = None,
  66. device: Optional[Device] = None,
  67. pin_memory: bool = False,
  68. ) -> SentencePieceEncoder:
  69. """Create a token encoder.
  70. :param task:
  71. Must be 'translation'. If ``None``, defaults to 'translation'.
  72. :param lang:
  73. A language from :attr:`langs`. If ``None``, defaults to
  74. :attr:`default_lang`.
  75. :param mode:
  76. Must be 'source' or 'target'. Set to 'source' if ``lang`` is the
  77. source language; set to 'target' if ``lang`` is the target language.
  78. If ``None``, defaults to 'source'.
  79. :param device:
  80. The device on which to construct tensors.
  81. :param pin_memory:
  82. If ``True``, uses pinned memory while constructing tensors.
  83. """
  84. if task is not None and task != "translation":
  85. raise ValueError(f"`task` must be 'translation', but is '{task}' instead.")
  86. if lang is None:
  87. lang = self.default_lang
  88. if lang not in self.langs:
  89. raise ValueError(
  90. f"`lang` must be a supported language, but is '{lang}' instead."
  91. )
  92. if mode is None or mode == "source":
  93. # NLLB models expect a language token in place of BOS in source
  94. # sequences.
  95. prefix_tokens = [f"__{lang}__"]
  96. suffix_tokens = ["</s>"]
  97. elif mode == "source_mining":
  98. prefix_tokens = [f"__{lang}__", "<MINED_DATA>"]
  99. suffix_tokens = ["</s>"]
  100. elif mode == "source_mmt_bt":
  101. prefix_tokens = [f"__{lang}__", "<MMT_BT_DATA>"]
  102. suffix_tokens = ["</s>"]
  103. elif mode == "source_smt_bt":
  104. prefix_tokens = [f"__{lang}__", "<SMT_BT_DATA>"]
  105. suffix_tokens = ["</s>"]
  106. elif mode == "target":
  107. # Target sequences are expected to start with an EOS, followed by
  108. # the language token.
  109. prefix_tokens = ["</s>", f"__{lang}__"]
  110. suffix_tokens = []
  111. else:
  112. raise ValueError(
  113. f"`mode` must be 'source' or 'target', but is '{mode}' instead."
  114. )
  115. return SentencePieceEncoder(
  116. self.model,
  117. prefix_tokens=prefix_tokens,
  118. suffix_tokens=suffix_tokens,
  119. device=device,
  120. pin_memory=pin_memory,
  121. )
  122. load_unity_model_without_conversion = ModelLoader[UnitYModel, UnitYConfig](
  123. asset_store,
  124. download_manager,
  125. unity.load_unity_config,
  126. create_unity_model,
  127. None,
  128. restrict_checkpoints=False,
  129. )
  130. @final
  131. class NllbLikeTokenizerLoader(TokenizerLoaderBase[NllbLikeTokenizer]):
  132. """Loads tokenizers used by NLLB models."""
  133. @finaloverride
  134. def _load(self, pathname: Path, card: AssetCard) -> NllbLikeTokenizer:
  135. langs = card.field("langs").as_list(str)
  136. default_lang = card.field("default_lang").as_(str)
  137. return NllbLikeTokenizer(pathname, langs, default_lang)
  138. def convert_model(
  139. model_name: Union[str, torch.nn.Module],
  140. out: Optional[Path] = None,
  141. hparams: Optional[Dict[str, Any]] = None,
  142. vocab: Optional[List[Tuple[str, float]]] = None,
  143. ) -> None:
  144. if isinstance(model_name, str):
  145. # Load the corresponding fairseq2 model
  146. if out is None:
  147. out = Path(model_name).with_suffix(".ggml")
  148. # The type of model depends on the name
  149. if "unity" in model_name or "seamlessM4T" in model_name:
  150. if hparams is None:
  151. model_config = unity.load_unity_config(model_name)
  152. hparams = flatten_config(
  153. dataclasses.asdict(model_config), separator="__"
  154. )
  155. print(hparams)
  156. # Need the diverge here because current default in SC is to convert from fairseq1 ckpt format
  157. if model_name in SMALLER_MODELS:
  158. model = load_unity_model_without_conversion(model_name)
  159. else:
  160. model = unity.load_unity_model(model_name)
  161. if vocab is None:
  162. # Need the diverge here because current default in SC is to add a separate <pad>
  163. # as control symbol in NllbTokenizer
  164. if model_name in SMALLER_MODELS:
  165. tokenizer = NllbLikeTokenizerLoader(asset_store, download_manager)(
  166. model_name
  167. )
  168. else:
  169. tokenizer = unity.load_unity_text_tokenizer(model_name)
  170. vocab = read_vocab(tokenizer)
  171. else:
  172. raise ValueError(f"Unsupported model type: {model_name}")
  173. else:
  174. # Use the model passed explicitly
  175. assert (
  176. out is not None
  177. ), "output path is required when explicitly passing a module"
  178. hparams = hparams or {}
  179. model = model_name
  180. state_dict = model.state_dict()
  181. fixup_model(model, state_dict)
  182. layer_config = read_layer_config(model)
  183. vocab = vocab or []
  184. write_ggml_file(out, hparams, layer_config, vocab, state_dict)
  185. def _nested_getattr(model: Any, name: str) -> Any:
  186. parts = name.split(".")
  187. node = model
  188. for part in parts:
  189. node = getattr(node, part)
  190. if node is None:
  191. return None
  192. return node
  193. def find_children(model: torch.nn.Module, t: type) -> List[Tuple[str, torch.nn.Module]]:
  194. queue = list(model._modules.items())
  195. modules = []
  196. while queue:
  197. name, node = queue.pop()
  198. if node is None:
  199. continue
  200. if isinstance(node, t):
  201. modules.append((name, node))
  202. for child_name, child_node in node._modules.items():
  203. queue.append((".".join((name, child_name)), child_node))
  204. return modules
  205. def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
  206. # Bake the embedding scaling into the weights
  207. frontends = find_children(model, TransformerEmbeddingFrontend)
  208. print(
  209. "Upgrading the following TransformerEmbeddingFrontend:",
  210. [x[0] for x in frontends],
  211. )
  212. for name, frontend in frontends:
  213. embed_weights = state_dict[name + ".embed.weight"]
  214. state_dict[name + ".embed.weight"] = embed_weights * frontend.scale
  215. # Sinusoidal embeddings are typically not saved since they are easily recomputed,
  216. # but this allows to avoid porting the sinusoidal logic to GGML
  217. pos_encoders = find_children(model, SinusoidalPositionEncoder)
  218. print(
  219. "Upgrading the following SinusoidalPositionEncoder:",
  220. [x[0] for x in pos_encoders],
  221. )
  222. for name, pos_encoder in pos_encoders:
  223. assert isinstance(pos_encoder.freqs, torch.Tensor)
  224. assert name not in state_dict
  225. state_dict[name] = pos_encoder.freqs
  226. relative_pos_encs = find_children(model, RelativePositionalEncoding)
  227. # speech_encoder has several copies of the relative_pos_enc module.
  228. # For efficiency reasons we only make one copy of it to GGML.
  229. if relative_pos_encs:
  230. print("Merging all speech_encoder RelativePositionalEncoding into one.")
  231. _, rel_pos_enc = relative_pos_encs[0]
  232. assert isinstance(rel_pos_enc.freqs, torch.Tensor)
  233. state_dict["speech_encoder.pos_enc"] = rel_pos_enc.freqs
  234. def read_vocab(tokenizer: Any) -> List[Tuple[str, float]]:
  235. vocab_info = tokenizer.vocab_info
  236. vocab = [
  237. (tokenizer.model.index_to_token(i).replace("▁", " "), -i)
  238. for i in range(vocab_info.size)
  239. ]
  240. return vocab # type: ignore[return-value]
  241. def write_ggml_file(
  242. out: Path,
  243. hparams: Dict[str, Any],
  244. layer_config: Dict[str, Any],
  245. vocab: List[Tuple[str, float]],
  246. state_dict: Dict[str, torch.Tensor],
  247. ) -> None:
  248. with out.open("wb") as o:
  249. write_ggml_header(o)
  250. write_hparams(o, hparams)
  251. write_hparams(o, layer_config)
  252. write_vocab(o, vocab)
  253. write_state_dict(o, state_dict)
  254. def write_ggml_header(out: BufferedWriter) -> None:
  255. """Write GGML header (in reverse cause big-endian)"""
  256. out.write(b"ggml"[::-1])
  257. def write_hparams(out: BufferedWriter, hparams: Dict[str, Any]) -> None:
  258. """Write hyper parameters.
  259. :params hparams:
  260. flattened dict containing model's hyper parameters.
  261. """
  262. simple_vals = {}
  263. for key, value in hparams.items():
  264. try:
  265. simple_vals[key] = to_ctype(value)
  266. except ValueError:
  267. logging.warning(f"Skipping config for key {key}={value!r}")
  268. continue
  269. out.write(struct.pack("<q", len(simple_vals)))
  270. for key, (ctype, cvalue) in simple_vals.items():
  271. write_string(out, key)
  272. b = struct.pack(ctype, cvalue)
  273. assert len(b) == 8
  274. out.write(b)
  275. logging.info(f"Saved {len(simple_vals)} params.")
  276. def write_vocab(out: BufferedWriter, vocab: List[Tuple[str, float]]) -> None:
  277. out.write(struct.pack("<q", len(vocab)))
  278. # Write all words concatenated in a buffer
  279. words = [bytes(w, "utf8") for w, score in vocab]
  280. packed_words = b"\0".join(words)
  281. # We use i32 to allow reusing the string loading codes
  282. packed_len = struct.pack("<i", len(packed_words))
  283. out.write(packed_len)
  284. out.write(packed_words)
  285. lengths = torch.tensor([len(w) for w in words], dtype=torch.int8)
  286. write_tensor(out, lengths)
  287. scores = torch.tensor([score for w, score in vocab], dtype=torch.float32)
  288. write_tensor(out, scores)
  289. def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -> None:
  290. """Write pytorch state dict.
  291. :paras state_dict:
  292. state dict returned by pytorch model
  293. """
  294. out.write(struct.pack("<q", len(state_dict)))
  295. # Size of each tensor
  296. byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
  297. # + tensor overhead
  298. byte_size += ggml.ggml_tensor_overhead() * (len(state_dict) + 10)
  299. out.write(struct.pack("<q", byte_size))
  300. logging.warning(
  301. f"Saving a ggml file with {len(state_dict)} tensors, for an estimated amount of {byte_size / (1024**3):.3f} GGML Gb"
  302. )
  303. for key, value in state_dict.items():
  304. write_string(out, key)
  305. if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
  306. # GGML broadcasting isn't as strong as numpy
  307. value = value.reshape(1, -1)
  308. if "pointwise_conv" in key: # pointwise_conv / depthwise_conv
  309. value = value.squeeze(-1)
  310. if "depthwise_conv" in key:
  311. value = value.squeeze(1)
  312. write_tensor(out, value.contiguous())
  313. def write_string(out: BufferedWriter, value: str) -> None:
  314. """Write string in utf-8 format.
  315. :params value:
  316. string value to dump.
  317. """
  318. str_ = value.encode("utf-8")
  319. packed_len = struct.pack("<i", len(str_))
  320. assert len(packed_len) == 4
  321. out.write(packed_len)
  322. out.write(str_)
  323. def write_tensor(out: BufferedWriter, value: torch.Tensor) -> None:
  324. """Write torch tensor in ggml format.
  325. First we save the number of dimensions and the dtype.
  326. Then we save the data as numpy array.
  327. :params value:
  328. Tensor to dump.
  329. """
  330. if value.dtype is torch.int64:
  331. # GGML doesn't have int64, downcast it
  332. value = value.to(dtype=torch.int32)
  333. if value.ndim == 0:
  334. # GGML doesn't support scalar as tensors.
  335. value = value.reshape(1)
  336. data = value.numpy()
  337. n_dims = data.ndim
  338. assert n_dims < 5, "ggml doesn't support 5 dims tensors"
  339. assert n_dims >= 1, "ggml doesn't support 0 dim tensors"
  340. ftype = torch_to_ggml_type(value.dtype)
  341. out.write(struct.pack("<i", n_dims))
  342. out.write(struct.pack("<i", ftype))
  343. for i in range(n_dims):
  344. # ggml uses long for shape
  345. out.write(struct.pack("<q", data.shape[n_dims - 1 - i]))
  346. data.tofile(out)
  347. def torch_to_ggml_type(dtype: torch.dtype) -> int:
  348. if dtype is torch.float32:
  349. return ggml.GGML_TYPE_F32
  350. elif dtype is torch.float16:
  351. return ggml.GGML_TYPE_F16
  352. elif dtype is torch.int32:
  353. return ggml.GGML_TYPE_I32
  354. elif dtype is torch.int8:
  355. return ggml.GGML_TYPE_I8
  356. else:
  357. raise NotImplementedError(f"{dtype} is not mapped to a GGML_TYPE")
  358. def flatten_config(
  359. config: Dict[str, Any],
  360. separator: str,
  361. config_preprocessor: Optional[Preprocessor] = None,
  362. ) -> Dict[str, Any]:
  363. """Flatten nested dictionnary
  364. :param config:
  365. nested dictionnary containing model config.
  366. :param separator:
  367. string separator used when flattening nested hparams
  368. :param config_preprocessor:
  369. Preprocessor used for config/hparams values
  370. :returns:
  371. flat dictionnary
  372. """
  373. if config_preprocessor is None:
  374. config_preprocessor = lambda x: x
  375. def __flatten(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
  376. result = {}
  377. for key in config:
  378. new_key = f"{prefix}{key}"
  379. if isinstance(config[key], dict):
  380. nested_result = __flatten(config[key], f"{new_key}{separator}")
  381. result.update(nested_result)
  382. else:
  383. new_config = config_preprocessor(config[key])
  384. if new_config is not None:
  385. result[new_key] = config[key]
  386. return result
  387. return __flatten(config)
  388. def read_layer_config(model: torch.nn.Module) -> Dict[str, Any]:
  389. layer_config = {}
  390. def _append_node_config(node: Any, prefix: str) -> None:
  391. for k, v in node.__dict__.items():
  392. # Skip special members. In particular all children module and tensors
  393. # will be hidden in special dicts `_parameters` and `_modules`
  394. if k.startswith("_"):
  395. continue
  396. # All modules have a "training" flag
  397. if k in ("training", "init_fn"):
  398. continue
  399. if v is None:
  400. continue
  401. try:
  402. to_ctype(v)
  403. except ValueError:
  404. logging.warning(f"Skipping layer config {k}={v!r}")
  405. continue
  406. layer_config[prefix + k] = v
  407. _append_node_config(model, "")
  408. for name, node in find_children(model, torch.nn.Module):
  409. _append_node_config(node, name + ".")
  410. return layer_config
  411. def to_ctype(value: Any) -> Tuple[str, Any]:
  412. """Transform python type to ctype.
  413. Note: we always use little-endian and 8-byte types.
  414. This make the format independent of the current platform.
  415. :params value:
  416. value to cast into ctype
  417. :returns:
  418. A tuple of ctype and cvalue.
  419. """
  420. if isinstance(value, int):
  421. return ("<q", value)
  422. if isinstance(value, float):
  423. return ("<d", value)
  424. if isinstance(value, bool):
  425. return ("<q", value)
  426. if isinstance(value, Enum):
  427. return ("<q", value.value)
  428. if isinstance(value, tuple) and len(value) == 1:
  429. return to_ctype(value[0])
  430. if isinstance(value, str) and len(value) < 8:
  431. value = bytes(value, "ascii")
  432. if len(value) < 8:
  433. value = value + (8 - len(value)) * b"\0"
  434. return ("8s", value)
  435. raise ValueError(f"Unsupported type {type(value)}")
  436. def get_cpp_type(value: Any) -> str:
  437. """Return equivalent cpp type in string format
  438. :params value:
  439. value to cast into ctype
  440. :returns:
  441. str containing cpp type
  442. """
  443. # used to have compatibility between types
  444. try:
  445. ctype, _ = to_ctype(value)
  446. except ValueError as e:
  447. return f"// Error: {e}"
  448. if ctype == "i":
  449. return "std::int32_t"
  450. if ctype == "l":
  451. return "std::int64_t"
  452. if ctype == "f":
  453. return "float"
  454. if ctype == "d":
  455. return "double"
  456. if ctype == "?":
  457. return "bool"
  458. raise RuntimeError(
  459. f"Should not have reached this part." f"Missing cpp translation for {ctype}"
  460. )
  461. def generate_hparams_struct(
  462. hparams: Dict[str, Any],
  463. struct_name: str,
  464. ) -> str:
  465. """Generate a c++ struct to hold the model hyper-parameters.
  466. :param hparams:
  467. Flattened config of the model.
  468. :param struct_name:
  469. Name of the generated struct.
  470. """
  471. struct = f"struct {struct_name} {{"
  472. fields = [f" {get_cpp_type(value)} {key};" for key, value in hparams.items()]
  473. struct = "\n".join([struct] + fields + ["};\n"])
  474. valid_fields = [
  475. key for key, value in hparams.items() if "Error" not in get_cpp_type(value)
  476. ]
  477. read_struct = f"void read_{struct_name}({struct_name}& out, std::ifstream &fin) {{"
  478. read_fields = [
  479. f" fin.read((char*) &out.{field}, sizeof(out.{field}));"
  480. for field in valid_fields
  481. ]
  482. read_struct = "\n".join([read_struct] + read_fields + ["};\n"])
  483. return "\n".join([struct, read_struct])
  484. if __name__ == "__main__":
  485. import func_argparse
  486. func_argparse.single_main(convert_model)