|
@@ -4,9 +4,8 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-from typing import Any, Dict, Mapping, Union, final
|
|
|
+from typing import Any, Dict, List, Mapping, Union, final
|
|
|
|
|
|
-import numpy as np
|
|
|
import torch
|
|
|
from fairseq2.assets import AssetStore, download_manager
|
|
|
from fairseq2.assets.card import AssetCard
|
|
@@ -17,6 +16,7 @@ from seamless_communication.models.unity.builder import (
|
|
|
create_unity_model,
|
|
|
unity_archs,
|
|
|
)
|
|
|
+from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
|
|
|
from seamless_communication.models.unity.model import UnitYModel
|
|
|
from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
|
|
|
from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
|
|
@@ -61,7 +61,6 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
encoder_key = "encoder"
|
|
|
decoder_key = "decoder"
|
|
|
|
|
|
- # Use the built-in version attribute of `torch.Module`.
|
|
|
keys_to_delete.append(f"{decoder_key}.version")
|
|
|
keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
|
|
|
|
|
@@ -72,6 +71,9 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
# Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
|
|
|
keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
|
|
|
|
|
|
+ keys_to_delete.append(f"decoder.char_upsampler.embed_positions._float_tensor")
|
|
|
+ keys_to_delete.append(f"decoder.char_upsampler.embed_tokens_char.weight")
|
|
|
+
|
|
|
# Delete AlignmentEncoder keys for inference.
|
|
|
alignment_encoder_keys = [
|
|
|
key for key in state_dict if key.startswith("decoder.alignment_encoder.")
|
|
@@ -108,13 +110,12 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
if config.use_text_encoder:
|
|
|
state_dict["text_encoder_frontend.embed.weight"] = embeds
|
|
|
|
|
|
- # TODO: Remove this hack once we get the correct char SPM .model file.
|
|
|
char_embeds = state_dict.get(
|
|
|
"t2u_model.decoder_frontend.embed_char.weight", None
|
|
|
)
|
|
|
if char_embeds is not None:
|
|
|
- vocab_size = char_embeds.shape[0]
|
|
|
- index_mapping = np.load("/checkpoint/krs/unity2/char_dict_mapping.npy")
|
|
|
+ index_mapping = self._get_char_index_mapping(config)
|
|
|
+ vocab_size = len(index_mapping)
|
|
|
char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
|
|
|
|
|
|
# The embedding positions of the control symbols in fairseq's dict do
|
|
@@ -133,6 +134,28 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
|
|
|
return checkpoint
|
|
|
|
|
|
+ def _get_char_index_mapping(self, config: UnitYConfig) -> List[int]:
|
|
|
+ assert config.t2u_config is not None
|
|
|
+ assert config.t2u_config.nar_decoder_config is not None
|
|
|
+ char_tokenizer = load_unity_char_tokenizer(
|
|
|
+ config.t2u_config.nar_decoder_config.model_name_or_card
|
|
|
+ )
|
|
|
+ spm_order = [
|
|
|
+ char_tokenizer.model.index_to_token(i)
|
|
|
+ for i in range(char_tokenizer.model.vocabulary_size)
|
|
|
+ ][4:]
|
|
|
+ spm_to_dict_mapping = {
|
|
|
+ ch: idx
|
|
|
+ for (idx, ch) in zip(
|
|
|
+ range(4, char_tokenizer.model.vocabulary_size),
|
|
|
+ sorted(spm_order),
|
|
|
+ )
|
|
|
+ }
|
|
|
+ model_to_dict_mapping = [0, 1, 2, 3] + [
|
|
|
+ spm_to_dict_mapping[ch] for ch in spm_order
|
|
|
+ ]
|
|
|
+ return model_to_dict_mapping
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
|
|
|
# X2T/S2T + T2U model.
|
|
@@ -301,7 +324,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
|
|
|
r"^decoder\.embed_tokens\.": r"t2u_model.decoder_frontend.embed.",
|
|
|
r"^decoder\.var_adaptor\.duration_predictor\.": r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
|
|
|
r"^decoder\.dec_pos_emb_alpha": r"t2u_model.decoder_frontend.pos_emb_alpha",
|
|
|
- r"^decoder\.dec_pos_emb_alpha_char": r"t2u_model.decoder_frontend.pos_emb_alpha_char",
|
|
|
+ r"^decoder\.char_upsampler\.pos_emb_alpha": r"t2u_model.decoder_frontend.pos_emb_alpha_char",
|
|
|
|
|
|
# T2U Decoder
|
|
|
r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
|