ソースを参照

Add support for `seamlessM4T_v2_large`. (#48)

* Add support for m4t_unity_v2, deprecate nar_multlingual.

* Re-orient assets, update the final m4t_v2 asset.

* Replace the SPM hack of reading the index mapping from a file with the logic behind the mapping.
Kaushik Ram Sadagopan 1 年間 前
コミット
8db669e967

+ 0 - 10
src/seamless_communication/assets/cards/s2t_chunk_conformer.yaml

@@ -1,10 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-name: s2t_chunk_conformer
-base: unity_nllb-200
-model_arch: s2t_chunk_conformer
-checkpoint: "file://checkpoint/andyyuan/ckpt_from_rsc/w2vbert-2.0/S2T/avg_last_5_checkpoint.pt"

+ 2 - 2
src/seamless_communication/assets/cards/m4t_v2_s2t.yaml → src/seamless_communication/assets/cards/s2t_m4t_v2.yaml

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-name: m4t_v2_s2t
+name: s2t_m4t_v2
 base: unity_nllb-100
-model_arch: m4t_v2_s2t
+model_arch: s2t_base_v2
 checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_s2t.pt"

+ 51 - 0
src/seamless_communication/assets/cards/seamlessM4T_v2_large.yaml

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: seamlessM4T_v2_large
+base: unity_nllb-100
+model_arch: base_v2
+char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file://checkpoint/lpw/m4t_v2_final.pt"
+num_units: 10000
+unit_langs:
+  - arb
+  - ben
+  - cat
+  - ces
+  - cmn
+  - cym
+  - dan
+  - deu
+  - eng
+  - est
+  - fin
+  - fra
+  - hin
+  - ind
+  - ita
+  - jpn
+  - kan
+  - kor
+  - mlt
+  - nld
+  - pes
+  - pol
+  - por
+  - ron
+  - rus
+  - slk
+  - spa
+  - swe
+  - swh
+  - tam
+  - tel
+  - tgl
+  - tha
+  - tur
+  - ukr
+  - urd
+  - uzn
+  - vie

+ 0 - 13
src/seamless_communication/assets/cards/unity_nar_multilingual.yaml

@@ -1,13 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-name: unity_nar_multilingual
-base: unity_nllb-100
-model_arch: nar_multilingual
-char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
-checkpoint: "file://large_experiments/seamless/ust/lpw/M4T_UNITY2/ckpt/checkpoint_9_80000.pt"
-num_units: 10000
-unit_langs: [arb, ben, hin, ind, ita, jpn, por, rus, swh, tha, tur, urd, vie, spa, eng]

+ 2 - 2
src/seamless_communication/assets/cards/m4t_v2_x2t.yaml → src/seamless_communication/assets/cards/x2t_m4t_v2.yaml

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-name: m4t_v2_x2t
+name: x2t_m4t_v2
 base: unity_nllb-100
-model_arch: m4t_v2_x2t
+model_arch: x2t_base_v2
 checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_x2t.pt"

+ 1 - 1
src/seamless_communication/models/inference/translator.py

@@ -103,7 +103,7 @@ class Translator(nn.Module):
         cls,
         model: UnitYModel,
         text_tokenizer: TextTokenizer,
-        unit_tokenizer: UnitTokenizer,
+        unit_tokenizer: Optional[UnitTokenizer],
         src: SequenceData,
         input_modality: Modality,
         output_modality: Modality,

+ 14 - 39
src/seamless_communication/models/unity/builder.py

@@ -139,8 +139,8 @@ def _medium() -> UnitYConfig:
     )
 
 
-@unity_arch("m4t_v2_x2t")
-def _m4t_v2_x2t() -> UnitYConfig:
+@unity_arch("base_v2")
+def _base_v2() -> UnitYConfig:
     w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
@@ -149,23 +149,25 @@ def _m4t_v2_x2t() -> UnitYConfig:
 
     mt_model_config.max_seq_len = 4096
 
+    t2u_config = unity_t2u_archs.get_config("base_nar")
+
     return UnitYConfig(
         model_dim=1024,
         w2v2_encoder_config=w2v2_chunk_encoder_config,
         mt_model_config=mt_model_config,
-        t2u_config=None,
-        use_text_encoder=True,
+        t2u_config=t2u_config,
+        use_text_encoder=False,
         use_conformer_adaptor=False,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_stride=8,
         adaptor_layer_norm=True,
-        adaptor_dropout_p=0.0,
+        adaptor_dropout_p=0.1,
     )
 
 
-@unity_arch("m4t_v2_s2t")
-def _m4t_v2_s2t() -> UnitYConfig:
+@unity_arch("x2t_base_v2")
+def _x2t_base_v2() -> UnitYConfig:
     w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
@@ -179,7 +181,7 @@ def _m4t_v2_s2t() -> UnitYConfig:
         w2v2_encoder_config=w2v2_chunk_encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=None,
-        use_text_encoder=False,
+        use_text_encoder=True,
         use_conformer_adaptor=False,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
@@ -189,12 +191,14 @@ def _m4t_v2_s2t() -> UnitYConfig:
     )
 
 
-@unity_arch("s2t_chunk_conformer")
-def _s2t_chunk_conformer() -> UnitYConfig:
+@unity_arch("s2t_base_v2")
+def _s2t_base_v2() -> UnitYConfig:
     w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
+    mt_model_config.vocabulary_size = 256102  # NLLB-100
+
     mt_model_config.max_seq_len = 4096
 
     return UnitYConfig(
@@ -212,35 +216,6 @@ def _s2t_chunk_conformer() -> UnitYConfig:
     )
 
 
-@unity_arch("nar_multilingual")
-def _nar_multilingual() -> UnitYConfig:
-    w2vbert_config = w2vbert_archs.get_config("600m")
-    w2v2_encoder_config = w2vbert_config.w2v2_config.encoder_config
-    w2v2_encoder_config.pos_encoder_depth = 1
-    w2v2_encoder_config.pos_conv_kernel_size = 128
-    w2v2_encoder_config.num_pos_conv_groups = 16
-
-    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
-
-    mt_model_config.vocabulary_size = 256102  # NLLB-100
-
-    t2u_config = unity_t2u_archs.get_config("nar_multilingual")
-
-    return UnitYConfig(
-        model_dim=1024,
-        w2v2_encoder_config=w2v2_encoder_config,
-        mt_model_config=mt_model_config,
-        t2u_config=t2u_config,
-        use_text_encoder=False,
-        use_conformer_adaptor=False,
-        num_adaptor_layers=1,
-        adaptor_kernel_size=8,
-        adaptor_stride=8,
-        adaptor_layer_norm=True,
-        adaptor_dropout_p=0.1,
-    )
-
-
 class UnitYBuilder:
     """Builds modules of a UnitY model.
 

+ 30 - 7
src/seamless_communication/models/unity/loader.py

@@ -4,9 +4,8 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, Mapping, Union, final
+from typing import Any, Dict, List, Mapping, Union, final
 
-import numpy as np
 import torch
 from fairseq2.assets import AssetStore, download_manager
 from fairseq2.assets.card import AssetCard
@@ -17,6 +16,7 @@ from seamless_communication.models.unity.builder import (
     create_unity_model,
     unity_archs,
 )
+from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.model import UnitYModel
 from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
 from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
@@ -61,7 +61,6 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             encoder_key = "encoder"
             decoder_key = "decoder"
 
-        # Use the built-in version attribute of `torch.Module`.
         keys_to_delete.append(f"{decoder_key}.version")
         keys_to_delete.append(f"{decoder_key}.embed_positions._float_tensor")
 
@@ -72,6 +71,9 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
+        keys_to_delete.append(f"decoder.char_upsampler.embed_positions._float_tensor")
+        keys_to_delete.append(f"decoder.char_upsampler.embed_tokens_char.weight")
+
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
             key for key in state_dict if key.startswith("decoder.alignment_encoder.")
@@ -108,13 +110,12 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         if config.use_text_encoder:
             state_dict["text_encoder_frontend.embed.weight"] = embeds
 
-        # TODO: Remove this hack once we get the correct char SPM .model file.
         char_embeds = state_dict.get(
             "t2u_model.decoder_frontend.embed_char.weight", None
         )
         if char_embeds is not None:
-            vocab_size = char_embeds.shape[0]
-            index_mapping = np.load("/checkpoint/krs/unity2/char_dict_mapping.npy")
+            index_mapping = self._get_char_index_mapping(config)
+            vocab_size = len(index_mapping)
             char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
 
         # The embedding positions of the control symbols in fairseq's dict do
@@ -133,6 +134,28 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
         return checkpoint
 
+    def _get_char_index_mapping(self, config: UnitYConfig) -> List[int]:
+        assert config.t2u_config is not None
+        assert config.t2u_config.nar_decoder_config is not None
+        char_tokenizer = load_unity_char_tokenizer(
+            config.t2u_config.nar_decoder_config.model_name_or_card
+        )
+        spm_order = [
+            char_tokenizer.model.index_to_token(i)
+            for i in range(char_tokenizer.model.vocabulary_size)
+        ][4:]
+        spm_to_dict_mapping = {
+            ch: idx
+            for (idx, ch) in zip(
+                range(4, char_tokenizer.model.vocabulary_size),
+                sorted(spm_order),
+            )
+        }
+        model_to_dict_mapping = [0, 1, 2, 3] + [
+            spm_to_dict_mapping[ch] for ch in spm_order
+        ]
+        return model_to_dict_mapping
+
     @staticmethod
     def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
         # X2T/S2T + T2U model.
@@ -301,7 +324,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                     r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
                     r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
                     r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
-                    r"^decoder\.dec_pos_emb_alpha_char":                        r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+                    r"^decoder\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
 
                     # T2U Decoder
                     r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",

+ 5 - 6
src/seamless_communication/models/unity/t2u_builder.py

@@ -4,7 +4,6 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 from dataclasses import dataclass
-from torch.nn import Parameter
 from typing import Literal, Optional, Union
 
 from fairseq2.assets import download_manager
@@ -165,8 +164,8 @@ def _medium_t2u() -> UnitYT2UConfig:
     )
 
 
-@unity_t2u_arch("nar_multilingual")
-def _nar_multilingual_t2u() -> UnitYT2UConfig:
+@unity_t2u_arch("base_nar")
+def _base_nar() -> UnitYT2UConfig:
     duration_predictor_config = VariancePredictorConfig(
         var_pred_hidden_dim=256,
         var_pred_kernel_size=3,
@@ -181,8 +180,8 @@ def _nar_multilingual_t2u() -> UnitYT2UConfig:
     )
 
     nar_decoder_config = NARDecoderConfig(
-        model_name_or_card="unity_nar_multilingual",
-        char_vocabulary_size=10904,
+        model_name_or_card="seamlessM4T_v2_large",
+        char_vocabulary_size=10943,
         char_max_seq_len=4096,
         conv1d_kernel_size=7,
         conv1d_inner_dim=1024,
@@ -192,7 +191,7 @@ def _nar_multilingual_t2u() -> UnitYT2UConfig:
     return UnitYT2UConfig(
         model_dim=1024,
         unit_max_seq_len=2048,
-        unit_vocabulary_size=10020,
+        unit_vocabulary_size=10082,
         unit_pad_idx=1,
         num_encoder_layers=6,
         num_decoder_layers=6,

+ 17 - 13
src/seamless_communication/models/unity/unit_tokenizer.py

@@ -32,15 +32,16 @@ class UnitTokenizer:
 
         self.langs = langs
 
-        self.model_arch = model_arch
-
         self.lang_map = {lang: idx for idx, lang in enumerate(langs)}
 
-        if self.model_arch == "nar_multilingual":
+        # The "_v2" unity architectures have a non-autoregressive decoder.
+        if model_arch.split("_")[-1] == "v2":
+            self.is_nar_decoder = True
             self.lang_symbol_repititions = 1
         else:
+            self.is_nar_decoder = False
             # For legacy reasons, we have to repeat the language symbols twice,
-            # along with a placeholder `<mask>` token for UnitY AR models.
+            # along with a placeholder `<mask>` token for UnitY autoregressive models.
             self.lang_symbol_repititions = 2
 
         vocab_size = num_units + self.lang_symbol_repititions * (len(langs) + 1) + 4
@@ -95,7 +96,7 @@ class UnitTokenizer:
 
     def create_decoder(self) -> "UnitTokenDecoder":
         """Create a token decoder."""
-        return UnitTokenDecoder(self, self.model_arch)
+        return UnitTokenDecoder(self, self.is_nar_decoder)
 
 
 class UnitTokenEncoder:
@@ -177,19 +178,20 @@ class UnitTokenDecoder:
     eos_idx: int
     pad_idx: int
 
-    def __init__(self, tokenizer: UnitTokenizer, model_arch: str) -> None:
+    def __init__(self, tokenizer: UnitTokenizer, is_nar_decoder: bool) -> None:
         """
         :param tokenizer:
             The unit tokenizer to use.
-        :param model_arch:
-            The type of UnitY model architecture.
+        :param is_nar_decoder:
+            If True, the unit decoder is non-autoregressive.
         """
         assert tokenizer.vocab_info.eos_idx is not None
         assert tokenizer.vocab_info.pad_idx is not None
 
         self.eos_idx = tokenizer.vocab_info.eos_idx
         self.pad_idx = tokenizer.vocab_info.pad_idx
-        self.model_arch = model_arch
+
+        self.is_nar_decoder = is_nar_decoder
 
     def __call__(self, token_indices: Tensor) -> Tensor:
         """Decode ``token_indices`` to speech units.
@@ -208,8 +210,9 @@ class UnitTokenDecoder:
 
         units = token_indices.clone().detach()
 
-        # Remove the prefix EOS symbol from the decoded output for AR UnitY.
-        if self.model_arch != "nar_multilingual":
+        # Remove the prefix EOS symbol from the decoded output for
+        # autoregressive UnitY.
+        if not self.is_nar_decoder:
             units = units[:, 1:]
 
         # Also, replace EOS with PAD at sequence ends.
@@ -217,10 +220,11 @@ class UnitTokenDecoder:
 
         units[units == self.pad_idx] = self.pad_idx + 4
 
-        # Remove offset of control symbols (exclude language symbol for AR UnitY).
-        if self.model_arch == "nar_multilingual":
+        # Remove offset of control symbols.
+        if self.is_nar_decoder:
             units -= 4
         else:
+            # Exclude language symbol for autoregressive UnitY.
             units[:, 1:] -= 4
 
         return units