123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import dataclasses
- from pathlib import Path
- from typing import Any, Callable, Optional, Union
- from fairseq2.assets import AssetCard
- from ggml.examples.unity.buffered_ggml_writer import BufferedGGMLWriter
- from ggml.examples.unity.type_utils import get_cpp_type
- from seamless_communication.models.unity import (
- load_unity_config,
- load_unity_model
- )
- Preprocessor = Callable[[Any], Any]
- class Fairseq2ToGGMLConverter:
- """Converter from fairseq2 format to GGML format"""
- config_preprocessor: Preprocessor
- nested_params_separtor: str
- def __init__(
- self,
- nested_params_separtor: str = ".",
- config_preprocessor: Optional[Preprocessor] = None,
- ) -> None:
- """
- :param nested_params_separtor:
- string separator used when flattening nested hparams
- :param config_preprocessor:
- Preprocessor used for config/hparams values
- """
- self.config_preprocessor = config_preprocessor or (lambda v: v)
- self.nested_params_separtor = nested_params_separtor
- def convert_to_ggml(
- self,
- model_name_or_card: Union[str, AssetCard],
- output_file: Path
- ) -> None:
- """Load model from card, convert to ggml format and save result.
- :param model_name_or_card:
- The name or asset card of the model to load.
- :param output_file:
- File path to store binary output.
- """
- hparams = self._load_config(model_name_or_card)
- state_dict = self._load_state_dict(model_name_or_card)
- buffer = output_file.open("wb")
- ggml_writer = BufferedGGMLWriter(buffer)
- ggml_writer.write_magic_hex()
- ggml_writer.write_hparams(hparams)
- ggml_writer.write_state_dict(state_dict)
- buffer.close()
- def generate_hparams_struct(
- self,
- model_name_or_card: Union[str, AssetCard],
- struct_name: str,
- ) -> str:
- """Transform config to c++ struct
- :param model_name_or_card:
- The name or asset card of the model to load.
- :param output_file:
- File path to store binary output.
- """
- hparams = self._load_config(model_name_or_card)
- result = f"struct {struct_name} {{\n"
- for key, value in hparams.items():
- result = f"{result}\t{get_cpp_type(value)} {key};\n"
- result = f"{result}}};"
- return result
- def _load_config(
- self,
- model_name_or_card: Union[str, AssetCard]
- ) -> dict:
- """Load model config and transform it to flattened dict.
- :param model_name_or_card:
- The name or asset card of the model to load.
- :returns:
- Flat dictionnary containing all hyper parameters.
- """
- model_config = load_unity_config(model_name_or_card)
- model_config_dict = dataclasses.asdict(model_config)
- flattened = self.__flatten(model_config_dict)
- return flattened
- def _load_state_dict(
- self,
- model_name_or_card: Union[str, AssetCard]
- ) -> dict:
- """Load model and return state dict.
- :param model_name_or_card:
- The name or asset card of the model to load.
- :returns:
- State dict returned by pytorch model.
- """
- model = load_unity_model(model_name_or_card)
- return model.state_dict()
- def __flatten(
- self,
- config: dict
- ) -> dict:
- """Flatten nested dictionnary
- :param config:
- nested dictionnary containing model config.
- :returns:
- flat dictionnary
- """
- return self.__flatten_recursive(config, '')
- def __flatten_recursive(
- self,
- config: dict,
- prefix: str
- ) -> dict:
- """Recursive method used to flatten nested dictionnary"""
- result = {}
- for key in config:
- new_key = f"{prefix}{key}"
- if isinstance(config[key], dict):
- nested_result = self.__flatten_recursive(
- config[key],
- f"{new_key}{self.nested_params_separtor}"
- )
- result.update(nested_result)
- else:
- new_config = self.config_preprocessor(config[key])
- if new_config is not None:
- result[new_key] = config[key]
- return result
|