Browse Source

Add seamless_streaming assets. (#106)

Kaushik Ram Sadagopan 1 year ago
parent
commit
5198e0586c

+ 10 - 0
src/seamless_communication/cards/seamless_streaming_monotonic_decoder.yaml

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: seamless_streaming_monotonic_decoder
+model_type: monotonic_decoder
+model_arch: dense_1b
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_monotonic_decoder.pt"

+ 51 - 0
src/seamless_communication/cards/seamless_streaming_unity.yaml

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: seamless_streaming_unity
+base: unity_nllb-100
+model_arch: base_v2
+char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file://large_experiments/seamless/ust/krs/fairseq2_checkpoints/seamless_streaming_unity.pt"
+num_units: 10000
+unit_langs:
+  - arb
+  - ben
+  - cat
+  - ces
+  - cmn
+  - cym
+  - dan
+  - deu
+  - eng
+  - est
+  - fin
+  - fra
+  - hin
+  - ind
+  - ita
+  - jpn
+  - kan
+  - kor
+  - mlt
+  - nld
+  - pes
+  - pol
+  - por
+  - ron
+  - rus
+  - slk
+  - spa
+  - swe
+  - swh
+  - tam
+  - tel
+  - tgl
+  - tha
+  - tur
+  - ukr
+  - urd
+  - uzn
+  - vie

+ 7 - 0
src/seamless_communication/inference/generator.py

@@ -92,6 +92,13 @@ class UnitYGenerator:
 
         self.model = model
 
+        if model.text_decoder is None:
+            raise ValueError(
+                "`UnitYGenerator` requires a text decoder, but the current UnitY model does not have one."
+            )
+        assert model.text_decoder_frontend is not None
+        assert model.final_proj is not None
+
         s2t_model = UnitYX2TModel(
             encoder_frontend=model.speech_encoder_frontend,
             encoder=model.speech_encoder,

+ 2 - 2
src/seamless_communication/models/monotonic_decoder/loader.py

@@ -28,7 +28,7 @@ from seamless_communication.models.monotonic_decoder.model import MonotonicDecod
 class MonotonicDecoderLoader(
     ModelLoader[MonotonicDecoderModel, MonotonicDecoderConfig]
 ):
-    """Loads NLLB models."""
+    """Loads Monotonic Decoder models."""
 
     @finaloverride
     def _convert_checkpoint(
@@ -37,7 +37,7 @@ class MonotonicDecoderLoader(
         state_dict = checkpoint["model"]
 
         # Check if we have a fairseq2 checkpoint.
-        if "decoder_frontend.embed_weight" in state_dict:
+        if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict:
             return checkpoint
 
         key_map = self._fairseq_key_map()

+ 3 - 0
src/seamless_communication/models/unity/__init__.py

@@ -40,6 +40,9 @@ from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoade
 from seamless_communication.models.unity.loader import (
     load_gcmvn_stats as load_gcmvn_stats,
 )
+from seamless_communication.models.unity.loader import (
+    load_unity_config as load_unity_config,
+)
 from seamless_communication.models.unity.loader import (
     load_unity_model as load_unity_model,
 )