fairseq2_to_ggml_converter.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import dataclasses
  6. from pathlib import Path
  7. from typing import Any, Callable, Optional, Union
  8. from fairseq2.assets import AssetCard
  9. from ggml.examples.unity.buffered_ggml_writer import BufferedGGMLWriter
  10. from ggml.examples.unity.type_utils import get_cpp_type
  11. from seamless_communication.models.unity import (
  12. load_unity_config,
  13. load_unity_model
  14. )
  15. Preprocessor = Callable[[Any], Any]
  16. class Fairseq2ToGGMLConverter:
  17. """Converter from fairseq2 format to GGML format"""
  18. config_preprocessor: Preprocessor
  19. nested_params_separtor: str
  20. def __init__(
  21. self,
  22. nested_params_separtor: str = ".",
  23. config_preprocessor: Optional[Preprocessor] = None,
  24. ) -> None:
  25. """
  26. :param nested_params_separtor:
  27. string separator used when flattening nested hparams
  28. :param config_preprocessor:
  29. Preprocessor used for config/hparams values
  30. """
  31. self.config_preprocessor = config_preprocessor or (lambda v: v)
  32. self.nested_params_separtor = nested_params_separtor
  33. def convert_to_ggml(
  34. self,
  35. model_name_or_card: Union[str, AssetCard],
  36. output_file: Path
  37. ) -> None:
  38. """Load model from card, convert to ggml format and save result.
  39. :param model_name_or_card:
  40. The name or asset card of the model to load.
  41. :param output_file:
  42. File path to store binary output.
  43. """
  44. hparams = self._load_config(model_name_or_card)
  45. state_dict = self._load_state_dict(model_name_or_card)
  46. buffer = output_file.open("wb")
  47. ggml_writer = BufferedGGMLWriter(buffer)
  48. ggml_writer.write_magic_hex()
  49. ggml_writer.write_hparams(hparams)
  50. ggml_writer.write_state_dict(state_dict)
  51. buffer.close()
  52. def generate_hparams_struct(
  53. self,
  54. model_name_or_card: Union[str, AssetCard],
  55. struct_name: str,
  56. ) -> str:
  57. """Transform config to c++ struct
  58. :param model_name_or_card:
  59. The name or asset card of the model to load.
  60. :param output_file:
  61. File path to store binary output.
  62. """
  63. hparams = self._load_config(model_name_or_card)
  64. result = f"struct {struct_name} {{\n"
  65. for key, value in hparams.items():
  66. result = f"{result}\t{get_cpp_type(value)} {key};\n"
  67. result = f"{result}}};"
  68. return result
  69. def _load_config(
  70. self,
  71. model_name_or_card: Union[str, AssetCard]
  72. ) -> dict:
  73. """Load model config and transform it to flattened dict.
  74. :param model_name_or_card:
  75. The name or asset card of the model to load.
  76. :returns:
  77. Flat dictionnary containing all hyper parameters.
  78. """
  79. model_config = load_unity_config(model_name_or_card)
  80. model_config_dict = dataclasses.asdict(model_config)
  81. flattened = self.__flatten(model_config_dict)
  82. return flattened
  83. def _load_state_dict(
  84. self,
  85. model_name_or_card: Union[str, AssetCard]
  86. ) -> dict:
  87. """Load model and return state dict.
  88. :param model_name_or_card:
  89. The name or asset card of the model to load.
  90. :returns:
  91. State dict returned by pytorch model.
  92. """
  93. model = load_unity_model(model_name_or_card)
  94. return model.state_dict()
  95. def __flatten(
  96. self,
  97. config: dict
  98. ) -> dict:
  99. """Flatten nested dictionnary
  100. :param config:
  101. nested dictionnary containing model config.
  102. :returns:
  103. flat dictionnary
  104. """
  105. return self.__flatten_recursive(config, '')
  106. def __flatten_recursive(
  107. self,
  108. config: dict,
  109. prefix: str
  110. ) -> dict:
  111. """Recursive method used to flatten nested dictionnary"""
  112. result = {}
  113. for key in config:
  114. new_key = f"{prefix}{key}"
  115. if isinstance(config[key], dict):
  116. nested_result = self.__flatten_recursive(
  117. config[key],
  118. f"{new_key}{self.nested_params_separtor}"
  119. )
  120. result.update(nested_result)
  121. else:
  122. new_config = self.config_preprocessor(config[key])
  123. if new_config is not None:
  124. result[new_key] = config[key]
  125. return result