Jelajahi Sumber

Clean up M4T v2 and vocoder v2 checkpoints. (#94)

Kaushik Ram Sadagopan 1 tahun lalu
induk
melakukan
fcaf953981

+ 1 - 1
src/seamless_communication/cards/seamlessM4T_v2_large.yaml

@@ -8,7 +8,7 @@ name: seamlessM4T_v2_large
 base: unity_nllb-100
 model_arch: base_v2
 char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
-checkpoint: "file://large_experiments/seamless/ust/elbayadm/multitasking_models/m4t_v2_multitask_unity2.pt"
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamlessM4T_v2_large.pt"
 num_units: 10000
 unit_langs:
   - arb

+ 1 - 1
src/seamless_communication/cards/vocoder_v2.yaml

@@ -7,7 +7,7 @@
 name: vocoder_v2
 model_type: vocoder_code_hifigan
 model_arch: base
-checkpoint: "file://large_experiments/seamless/ust/krs/M4T_Vocoder/lang_36_commercial/km_10000/seed_1/g_00600000"
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/vocoder_v2.pt"
 model_config: {
   "lang_spkr_idx_map": {
       "multilingual": {

+ 1 - 1
src/seamless_communication/models/unity/loader.py

@@ -36,7 +36,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         state_dict = checkpoint["model"]
 
         # Check if we have a fairseq2 checkpoint.
-        if "decoder_frontend.embed.weight" in state_dict:
+        if "speech_encoder.inner.layers.0.self_attn_layer_norm.weight" in state_dict:
             return checkpoint
 
         key_map = self._fairseq_key_map(config)

+ 6 - 1
src/seamless_communication/models/vocoder/loader.py

@@ -3,7 +3,6 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-
 from typing import Any, Mapping, final
 
 from fairseq2.assets import asset_store, download_manager
@@ -29,6 +28,12 @@ class VocoderLoader(ModelLoader[Vocoder, VocoderConfig]):
     def _convert_checkpoint(
         self, checkpoint: Mapping[str, Any], config: VocoderConfig
     ) -> Mapping[str, Any]:
+        if (
+            "model" in checkpoint
+            and "code_generator.resblocks.0.convs1.0.weight_g" in checkpoint["model"]
+        ):
+            return checkpoint
+
         old_state_dict = checkpoint["generator"]
         new_state_dict = {}
         for key in old_state_dict: