Explorar el Código

Start using vocab_info instead of pad_idx (#76)

Can Balioglu hace 1 año
padre
commit
e83a4de3af

+ 5 - 4
src/seamless_communication/models/unity/builder.py

@@ -7,6 +7,7 @@
 from dataclasses import dataclass
 from typing import Union, Optional
 
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
@@ -95,7 +96,7 @@ def _base() -> UnitYConfig:
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    mt_model_config.vocabulary_size = 256102  # NLLB-100
+    mt_model_config.vocab_info.size = 256102  # NLLB-100
 
     t2u_config = unity_t2u_archs.get_config("base")
 
@@ -120,7 +121,7 @@ def _medium() -> UnitYConfig:
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_600m")
 
-    mt_model_config.vocabulary_size = 256206  # NLLB-200
+    mt_model_config.vocab_info.size = 256102  # NLLB-100
 
     t2u_config = unity_t2u_archs.get_config("medium")
 
@@ -145,7 +146,7 @@ def _base_v2() -> UnitYConfig:
 
     mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
 
-    mt_model_config.vocabulary_size = 256102  # NLLB-100
+    mt_model_config.vocab_info.size = 256102  # NLLB-100
 
     mt_model_config.max_seq_len = 4096
 
@@ -261,7 +262,7 @@ class UnitYBuilder:
             text_decoder,
             final_proj,
             t2u_model,
-            self.config.mt_model_config.pad_idx,
+            self.config.mt_model_config.vocab_info,
         )
 
     def build_speech_encoder(self) -> TransformerEncoder:

+ 4 - 3
src/seamless_communication/models/unity/generator.py

@@ -10,6 +10,7 @@ from typing import Optional, Tuple, List
 import torch
 
 from torch import Tensor
+from fairseq2.data import VocabularyInfo
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
     Seq2SeqGenerator,
@@ -94,7 +95,7 @@ class UnitYGenerator:
             decoder_frontend=model.text_decoder_frontend,
             decoder=model.text_decoder,
             final_proj=model.final_proj,
-            pad_idx=model.pad_idx,
+            target_vocab_info=model.target_vocab_info,
         )
         self.s2t_generator = SequenceToTextGenerator(
             s2t_model, text_tokenizer, target_lang, text_opts
@@ -111,7 +112,7 @@ class UnitYGenerator:
                 decoder_frontend=model.text_decoder_frontend,
                 decoder=model.text_decoder,
                 final_proj=model.final_proj,
-                pad_idx=model.pad_idx,
+                target_vocab_info=model.target_vocab_info,
             )
             self.t2t_generator = SequenceToTextGenerator(
                 t2t_model, text_tokenizer, target_lang, text_opts
@@ -237,7 +238,7 @@ class UnitYGenerator:
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
             unit_seqs = apply_padding_mask(
-                unit_seqs, decoder_padding_mask, pad_value=unit_decoder_output.pad_idx
+                unit_seqs, decoder_padding_mask, unit_decoder_output.vocab_info.pad_idx
             )
 
         # Convert to speech units.

+ 21 - 18
src/seamless_communication/models/unity/model.py

@@ -7,6 +7,7 @@
 from dataclasses import dataclass
 from typing import Optional, Tuple, Union, final
 
+from fairseq2.data import VocabularyInfo
 from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.transformer.frontend import TransformerFrontend
@@ -41,7 +42,6 @@ class UnitYModel(EncoderDecoderModel):
     text_decoder: TransformerDecoder
     final_proj: Projection
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
-    pad_idx: Optional[int]
 
     def __init__(
         self,
@@ -53,12 +53,12 @@ class UnitYModel(EncoderDecoderModel):
         text_decoder: TransformerDecoder,
         final_proj: Projection,
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
-        pad_idx: Optional[int],
+        target_vocab_info: VocabularyInfo,
         input_modality: str = "speech",
     ) -> None:
         model_dim = speech_encoder.model_dim
 
-        super().__init__(model_dim)
+        super().__init__(model_dim, target_vocab_info)
 
         self.input_modality = input_modality
 
@@ -92,7 +92,7 @@ class UnitYModel(EncoderDecoderModel):
         else:
             self.register_module("t2u_model", None)
 
-        self.pad_idx = pad_idx
+        self.target_vocab_info = target_vocab_info
 
     @finaloverride
     def encode(
@@ -136,6 +136,7 @@ class UnitYModel(EncoderDecoderModel):
         padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
+        *,
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self.text_decoder_frontend(
@@ -156,7 +157,7 @@ class UnitYModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.pad_idx)
+        return SequenceModelOutput(logits, self.target_vocab_info)
 
 
 @final
@@ -167,7 +168,6 @@ class UnitYX2TModel(EncoderDecoderModel):
     decoder_frontend: TransformerFrontend
     decoder: TransformerDecoder
     final_proj: Projection
-    pad_idx: Optional[int]
 
     def __init__(
         self,
@@ -176,17 +176,18 @@ class UnitYX2TModel(EncoderDecoderModel):
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         final_proj: Projection,
-        pad_idx: Optional[int],
+        target_vocab_info: VocabularyInfo,
     ) -> None:
         model_dim = encoder.model_dim
-        super().__init__(model_dim)
+
+        super().__init__(model_dim, target_vocab_info)
 
         self.encoder_frontend = encoder_frontend
         self.encoder = encoder
         self.decoder_frontend = decoder_frontend
         self.decoder = decoder
         self.final_proj = final_proj
-        self.pad_idx = pad_idx
+        self.target_vocab_info = target_vocab_info
 
     @finaloverride
     def encode(
@@ -202,6 +203,7 @@ class UnitYX2TModel(EncoderDecoderModel):
         padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
+        *,
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self.decoder_frontend(
@@ -222,7 +224,7 @@ class UnitYX2TModel(EncoderDecoderModel):
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.pad_idx)
+        return SequenceModelOutput(logits, self.target_vocab_info)
 
 
 @final
@@ -235,7 +237,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
     decoder_frontend: TransformerFrontend
     decoder: TransformerDecoder
     final_proj: Projection
-    pad_idx: Optional[int]
+    target_vocab_info: VocabularyInfo
 
     def __init__(
         self,
@@ -243,7 +245,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         decoder_frontend: TransformerFrontend,
         decoder: TransformerDecoder,
         final_proj: Projection,
-        pad_idx: Optional[int],
+        target_vocab_info: VocabularyInfo,
     ) -> None:
         super().__init__()
 
@@ -269,7 +271,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
 
         self.final_proj = final_proj
 
-        self.pad_idx = pad_idx
+        self.target_vocab_info = target_vocab_info
 
     def forward(
         self,
@@ -307,6 +309,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         padding_mask: Optional[PaddingMask],
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
+        *,
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs, padding_mask = self.decoder_frontend(
@@ -326,7 +329,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
     ) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.pad_idx)
+        return SequenceModelOutput(logits, self.target_vocab_info)
 
 
 @final
@@ -338,7 +341,7 @@ class UnitYNART2UModel(Module):
     decoder_frontend: NARDecoderFrontend
     decoder: NARTransformerDecoder
     final_proj: Projection
-    pad_idx: Optional[int]
+    target_vocab_info: VocabularyInfo
 
     def __init__(
         self,
@@ -346,7 +349,7 @@ class UnitYNART2UModel(Module):
         decoder_frontend: NARDecoderFrontend,
         decoder: NARTransformerDecoder,
         final_proj: Projection,
-        pad_idx: Optional[int],
+        target_vocab_info: VocabularyInfo,
     ) -> None:
         super().__init__()
 
@@ -372,7 +375,7 @@ class UnitYNART2UModel(Module):
 
         self.final_proj = final_proj
 
-        self.pad_idx = pad_idx
+        self.target_vocab_info = target_vocab_info
 
     def forward(
         self,
@@ -419,7 +422,7 @@ class UnitYNART2UModel(Module):
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
-        return SequenceModelOutput(logits, self.pad_idx)
+        return SequenceModelOutput(logits, self.target_vocab_info)
 
 
 @dataclass

+ 19 - 19
src/seamless_communication/models/unity/t2u_builder.py

@@ -88,11 +88,8 @@ class UnitYT2UConfig:
     unit_max_seq_len: int
     """The expected maximum unit sequence length."""
 
-    unit_vocabulary_size: int
-    """The size of the unit vocabulary."""
-
-    unit_pad_idx: Optional[int]
-    """The index of the pad symbol in the unit vocabulary."""
+    target_vocab_info: VocabularyInfo
+    """The target vocabulary information."""
 
     num_encoder_layers: int
     """The number of Transformer encoder layers."""
@@ -134,8 +131,9 @@ def _base_t2u() -> UnitYT2UConfig:
     return UnitYT2UConfig(
         model_dim=1024,
         unit_max_seq_len=2048,
-        unit_vocabulary_size=10082,
-        unit_pad_idx=1,
+        target_vocab_info=VocabularyInfo(
+            size=10082, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
         num_encoder_layers=6,
         num_decoder_layers=6,
         nar_decoder_frontend_config=None,
@@ -152,8 +150,9 @@ def _medium_t2u() -> UnitYT2UConfig:
     return UnitYT2UConfig(
         model_dim=1024,
         unit_max_seq_len=2048,
-        unit_vocabulary_size=10082,
-        unit_pad_idx=1,
+        target_vocab_info=VocabularyInfo(
+            size=10082, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
         num_encoder_layers=4,
         num_decoder_layers=4,
         nar_decoder_frontend_config=None,
@@ -192,8 +191,9 @@ def _base_nar() -> UnitYT2UConfig:
     return UnitYT2UConfig(
         model_dim=1024,
         unit_max_seq_len=4096,
-        unit_vocabulary_size=10082,
-        unit_pad_idx=1,
+        target_vocab_info=VocabularyInfo(
+            size=10082, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
         num_encoder_layers=6,
         num_decoder_layers=6,
         nar_decoder_frontend_config=nar_decoder_frontend_config,
@@ -253,16 +253,16 @@ class UnitYT2UBuilder:
             decoder_frontend,
             decoder,
             final_proj,
-            self.config.unit_pad_idx,
+            self.config.target_vocab_info,
         )
 
     def build_unit_embedding(self) -> StandardEmbedding:
         """Build a unit embedding table."""
 
         return StandardEmbedding(
-            num_embeddings=self.config.unit_vocabulary_size,
+            num_embeddings=self.config.target_vocab_info.size,
             embedding_dim=self.config.model_dim,
-            pad_idx=self.config.unit_pad_idx,
+            pad_idx=self.config.target_vocab_info.pad_idx,
             init_fn=init_scaled_embedding,
             device=self.device,
             dtype=self.dtype,
@@ -306,7 +306,7 @@ class UnitYT2UBuilder:
         pos_encoder = SinusoidalPositionEncoder(
             self.config.model_dim,
             self.config.unit_max_seq_len,
-            _legacy_pad_idx=self.config.unit_pad_idx,
+            _legacy_pad_idx=self.config.target_vocab_info.pad_idx,
             device=self.device,
         )
         return TransformerEmbeddingFrontend(
@@ -424,16 +424,16 @@ class UnitYNART2UBuilder:
             decoder_frontend,
             decoder,
             final_proj,
-            self.config.unit_pad_idx,
+            self.config.target_vocab_info,
         )
 
     def build_unit_embedding(self) -> StandardEmbedding:
         """Build a unit embedding table."""
 
         return StandardEmbedding(
-            num_embeddings=self.config.unit_vocabulary_size,
+            num_embeddings=self.config.target_vocab_info.size,
             embedding_dim=self.config.model_dim,
-            pad_idx=self.config.unit_pad_idx,
+            pad_idx=self.config.target_vocab_info.pad_idx,
             init_fn=init_scaled_embedding,
             device=self.device,
             dtype=self.dtype,
@@ -505,7 +505,7 @@ class UnitYNART2UBuilder:
         unit_pos_encoder = SinusoidalPositionEncoder(
             self.config.model_dim,
             self.config.unit_max_seq_len,
-            _legacy_pad_idx=self.config.unit_pad_idx,
+            _legacy_pad_idx=self.config.target_vocab_info.pad_idx,
             device=self.device,
         )