|
@@ -20,53 +20,11 @@ from typing import List
|
|
from fairseq2.assets import AssetCard
|
|
from fairseq2.assets import AssetCard
|
|
from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
|
|
from fairseq2.models.transformer.frontend import TransformerEmbeddingFrontend
|
|
from fairseq2.nn import SinusoidalPositionEncoder
|
|
from fairseq2.nn import SinusoidalPositionEncoder
|
|
|
|
+from fairseq2.nn.transformer import RelativePositionalEncoding
|
|
from seamless_communication.models.unity import load_unity_config, load_unity_model
|
|
from seamless_communication.models.unity import load_unity_config, load_unity_model
|
|
|
|
|
|
Preprocessor = Callable[[Any], Any]
|
|
Preprocessor = Callable[[Any], Any]
|
|
|
|
|
|
-def pos_enc(max_seq_len=4096, encoding_dim=1024):
|
|
|
|
- weight = torch.empty(
|
|
|
|
- ((max_seq_len * 2) - 1, encoding_dim), dtype=torch.float32
|
|
|
|
- )
|
|
|
|
- # copied from https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/nn/transformer/relative_attention.py#L22
|
|
|
|
- dtype = torch.float32
|
|
|
|
- weight = weight.to(dtype)
|
|
|
|
-
|
|
|
|
- positive_w = weight[: max_seq_len]
|
|
|
|
- negative_w = weight[max_seq_len :]
|
|
|
|
-
|
|
|
|
- device = weight.device
|
|
|
|
-
|
|
|
|
- # (E / 2)
|
|
|
|
- indices = torch.arange(0, encoding_dim, step=2, device=device, dtype=dtype)
|
|
|
|
-
|
|
|
|
- # (1, E / 2)
|
|
|
|
- indices = indices.unsqueeze(0)
|
|
|
|
-
|
|
|
|
- # (S)
|
|
|
|
- steps = torch.arange(max_seq_len, device=device, dtype=dtype)
|
|
|
|
-
|
|
|
|
- # (S, 1)
|
|
|
|
- steps = steps.unsqueeze(1)
|
|
|
|
-
|
|
|
|
- factors = torch.exp(indices * -math.log(10000) / encoding_dim)
|
|
|
|
-
|
|
|
|
- # (S, 1) x (1, E / 2) -> (S, E / 2)
|
|
|
|
- factors = torch.matmul(steps, factors)
|
|
|
|
-
|
|
|
|
- flipped_factors = factors.flip([0])
|
|
|
|
-
|
|
|
|
- # A mirrored matrix of sinusoidal positive and negative positional
|
|
|
|
- # encodings to use in shift trick.
|
|
|
|
- #
|
|
|
|
- # [max, ..., 3, 2, 1, 0, -1, -2, -3, ..., min]
|
|
|
|
- torch.sin(flipped_factors, out=positive_w[:, 0::2])
|
|
|
|
- torch.cos(flipped_factors, out=positive_w[:, 1::2])
|
|
|
|
-
|
|
|
|
- torch.sin(-1 * factors[1:], out=negative_w[:, 0::2])
|
|
|
|
- torch.cos(-1 * factors[1:], out=negative_w[:, 1::2])
|
|
|
|
-
|
|
|
|
- return weight
|
|
|
|
|
|
|
|
def convert_model(
|
|
def convert_model(
|
|
model_name: Union[str, torch.nn.Module],
|
|
model_name: Union[str, torch.nn.Module],
|
|
@@ -82,14 +40,18 @@ def convert_model(
|
|
if "unity" in model_name or "seamlessM4T" in model_name:
|
|
if "unity" in model_name or "seamlessM4T" in model_name:
|
|
if hparams is None:
|
|
if hparams is None:
|
|
model_config = load_unity_config(model_name)
|
|
model_config = load_unity_config(model_name)
|
|
- hparams = flatten_config(dataclasses.asdict(model_config), separator="__")
|
|
|
|
|
|
+ hparams = flatten_config(
|
|
|
|
+ dataclasses.asdict(model_config), separator="__"
|
|
|
|
+ )
|
|
print(hparams)
|
|
print(hparams)
|
|
model = load_unity_model(model_name)
|
|
model = load_unity_model(model_name)
|
|
else:
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {model_name}")
|
|
raise ValueError(f"Unsupported model type: {model_name}")
|
|
else:
|
|
else:
|
|
# Use the model passed explicitly
|
|
# Use the model passed explicitly
|
|
- assert out is not None, "output path is required when explicitly passing a module"
|
|
|
|
|
|
+ assert (
|
|
|
|
+ out is not None
|
|
|
|
+ ), "output path is required when explicitly passing a module"
|
|
hparams = hparams or {}
|
|
hparams = hparams or {}
|
|
model = model_name
|
|
model = model_name
|
|
|
|
|
|
@@ -149,7 +111,16 @@ def fixup_model(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]) ->
|
|
assert name not in state_dict
|
|
assert name not in state_dict
|
|
state_dict[name] = pos_encoder.weight
|
|
state_dict[name] = pos_encoder.weight
|
|
|
|
|
|
- state_dict["speech_encoder.pos_enc"] = pos_enc()
|
|
|
|
|
|
+
|
|
|
|
+ relative_pos_encs = find_children(model, RelativePositionalEncoding)
|
|
|
|
+ # speech_encoder has several copies of the relative_pos_enc module.
|
|
|
|
+ # For efficiency reasons we only make one copy of it to GGML.
|
|
|
|
+ if relative_pos_encs:
|
|
|
|
+ print("Merging all speech_encoder RelativePositionalEncoding into one.")
|
|
|
|
+ _, rel_pos_enc = relative_pos_encs[0]
|
|
|
|
+ assert isinstance(rel_pos_enc.weight, torch.Tensor)
|
|
|
|
+ state_dict["speech_encoder.pos_enc"] = rel_pos_enc.weight
|
|
|
|
+
|
|
|
|
|
|
def write_ggml_file(
|
|
def write_ggml_file(
|
|
out: BufferedWriter,
|
|
out: BufferedWriter,
|
|
@@ -213,7 +184,7 @@ def write_state_dict(out: BufferedWriter, state_dict: Dict[str, torch.Tensor]) -
|
|
if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
|
|
if key.endswith(".bias") and value.ndim == 1 and "adaptor" not in key:
|
|
# GGML broadcasting isn't as strong as numpy
|
|
# GGML broadcasting isn't as strong as numpy
|
|
value = value.reshape(1, -1)
|
|
value = value.reshape(1, -1)
|
|
- if "pointwise_conv" in key: # pointwise_conv / depthwise_conv
|
|
|
|
|
|
+ if "pointwise_conv" in key: # pointwise_conv / depthwise_conv
|
|
value = value.squeeze(-1)
|
|
value = value.squeeze(-1)
|
|
if "depthwise_conv" in key:
|
|
if "depthwise_conv" in key:
|
|
value = value.squeeze(1)
|
|
value = value.squeeze(1)
|