Browse Source

Remove mbart support for now, unity supports only nllb text decoder.

Kaushik Ram Sadagopan 2 years ago
parent
commit
9b343874f8

+ 13 - 25
src/seamless_communication/models/unity/builder.py

@@ -5,11 +5,10 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Optional, Union
+from typing import Optional
 
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
-from fairseq2.models.mbart import mBartBuilder, mBartConfig
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
@@ -49,7 +48,7 @@ class UnitYConfig:
     w2v2_encoder_config: Wav2Vec2EncoderConfig
     """The configuration of the underlying wav2vec 2.0 encoder."""
 
-    mt_model_config: Union[NllbConfig, mBartConfig]
+    mt_model_config: NllbConfig
     """The configuration of the underlying MT text encoder-decoder."""
 
     t2u_config: Optional[UnitYT2UConfig]
@@ -172,7 +171,7 @@ class UnitYBuilder:
 
     config: UnitYConfig
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
-    mt_model_builder: Union[NllbBuilder, mBartBuilder]
+    mt_model_builder: NllbBuilder
     t2u_builder: Optional["UnitYT2UBuilder"]
     device: Optional[Device]
     dtype: Optional[DataType]
@@ -181,7 +180,7 @@ class UnitYBuilder:
         self,
         config: UnitYConfig,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
-        mt_model_builder: Union[NllbBuilder, mBartBuilder],
+        mt_model_builder: NllbBuilder,
         t2u_builder: Optional["UnitYT2UBuilder"],
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
@@ -384,25 +383,14 @@ def create_unity_model(
     else:
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
-    if isinstance(config.mt_model_config, NllbConfig):
-        nllb_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
-        unity_builder = UnitYBuilder(
-            config,
-            w2v2_encoder_builder,
-            nllb_builder,
-            t2u_builder,
-            device=device,
-            dtype=dtype,
-        )
-    else:
-        mbart_builder = mBartBuilder(config.mt_model_config, device=device, dtype=dtype)
-        unity_builder = UnitYBuilder(
-            config,
-            w2v2_encoder_builder,
-            mbart_builder,
-            t2u_builder,
-            device=device,
-            dtype=dtype,
-        )
+    nllb_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
+    unity_builder = UnitYBuilder(
+        config,
+        w2v2_encoder_builder,
+        nllb_builder,
+        t2u_builder,
+        device=device,
+        dtype=dtype,
+    )
 
     return unity_builder.build_model()

+ 2 - 3
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -4,14 +4,13 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import List, Optional, Tuple, Union, final
+from typing import List, Optional, Tuple, final
 
 from torch import Tensor
 from torch.nn import Dropout, Module, Parameter
 
 from fairseq2.data import VocabularyInfo
 from fairseq2.data.text import TextTokenizer
-from fairseq2.models.mbart.tokenizer import mBartTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.embedding import Embedding
 from fairseq2.nn.normalization import LayerNorm
@@ -88,7 +87,7 @@ class NARDecoderFrontend(Module):
         self,
         embed: Embedding,
         embed_char: Embedding,
-        text_tokenizer: Union[NllbTokenizer, mBartTokenizer],
+        text_tokenizer: NllbTokenizer,
         char_tokenizer: CharTokenizer,
         unit_pos_encoder: PositionEncoder,
         char_pos_encoder: PositionEncoder,

+ 31 - 70
src/seamless_communication/models/unity/t2u_builder.py

@@ -30,7 +30,6 @@ from fairseq2.nn.transformer import (
     create_default_sdpa,
 )
 from fairseq2.typing import DataType, Device
-from fairseq2.models.mbart.loader import mBartTokenizerLoader
 from fairseq2.models.transformer import (
     TransformerEmbeddingFrontend,
     TransformerFrontend,
@@ -69,7 +68,6 @@ class NARDecoderFrontendConfig:
 
 @dataclass
 class NARDecoderConfig:
-    text_tokenizer_type: Literal["nllb", "mbart"]
     model_name_or_card: Union[str, AssetCard]
     char_vocabulary_size: int
     char_max_seq_len: int
@@ -182,7 +180,6 @@ def _nar_multilingual_t2u() -> UnitYT2UConfig:
     )
 
     nar_decoder_config = NARDecoderConfig(
-        text_tokenizer_type="nllb",
         model_name_or_card="unity_nar_multilingual",
         char_vocabulary_size=10904,
         char_max_seq_len=4096,
@@ -367,76 +364,40 @@ class UnitYT2UBuilder:
             self.config.nar_decoder_frontend_config
         )
 
-        if self.config.nar_decoder_config.text_tokenizer_type == "nllb":
-            nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
-                self.config.nar_decoder_config.model_name_or_card
-            )
-            text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
-
-            char_pos_encoder = SinusoidalPositionEncoder(
-                self.config.model_dim,
-                self.config.nar_decoder_config.char_max_seq_len,
-                _legacy_pad_idx=text_pad_idx,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-            embed_char = Embedding(
-                num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
-                embedding_dim=self.config.model_dim,
-                pad_idx=text_pad_idx,
-                scaled=True,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-            return NARDecoderFrontend(
-                embed_unit,
-                embed_char,
-                nllb_tokenizer,
-                char_tokenizer,
-                unit_pos_encoder,
-                char_pos_encoder,
-                variance_adaptor,
-                dropout_p=self.config.dropout_p,
-                device=self.device,
-                dtype=self.dtype,
-            )
+        nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
+            self.config.nar_decoder_config.model_name_or_card
+        )
+        text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
 
-        else:
-            mbart_tokenizer = mBartTokenizerLoader(asset_store, download_manager)(
-                self.config.nar_decoder_config.model_name_or_card
-            )
-            text_pad_idx = mbart_tokenizer.vocab_info.pad_idx
+        char_pos_encoder = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.nar_decoder_config.char_max_seq_len,
+            _legacy_pad_idx=text_pad_idx,
+            device=self.device,
+            dtype=self.dtype,
+        )
 
-            char_pos_encoder = SinusoidalPositionEncoder(
-                self.config.model_dim,
-                self.config.nar_decoder_config.char_max_seq_len,
-                _legacy_pad_idx=text_pad_idx,
-                device=self.device,
-                dtype=self.dtype,
-            )
+        embed_char = Embedding(
+            num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
+            embedding_dim=self.config.model_dim,
+            pad_idx=text_pad_idx,
+            scaled=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
 
-            embed_char = Embedding(
-                num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
-                embedding_dim=self.config.model_dim,
-                pad_idx=text_pad_idx,
-                scaled=True,
-                device=self.device,
-                dtype=self.dtype,
-            )
-            return NARDecoderFrontend(
-                embed_unit,
-                embed_char,
-                mbart_tokenizer,
-                char_tokenizer,
-                unit_pos_encoder,
-                char_pos_encoder,
-                variance_adaptor,
-                dropout_p=self.config.dropout_p,
-                device=self.device,
-                dtype=self.dtype,
-            )
+        return NARDecoderFrontend(
+            embed_unit,
+            embed_char,
+            nllb_tokenizer,
+            char_tokenizer,
+            unit_pos_encoder,
+            char_pos_encoder,
+            variance_adaptor,
+            dropout_p=self.config.dropout_p,
+            device=self.device,
+            dtype=self.dtype,
+        )
 
     def build_decoder(self) -> TransformerDecoder:
         """Build a Transformer decoder."""