浏览代码

Fix unit_tokenizer to be take in model_arch.

Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
210a4ebe38

+ 0 - 2
src/seamless_communication/models/unity/generator.py

@@ -16,12 +16,10 @@ from fairseq2.generation import (
     SequenceToTextGenerator,
     SequenceToTextOutput,
 )
-from fairseq2.models.seq2seq import Seq2SeqBatch
 from seamless_communication.models.unity.model import (
     UnitYModel,
     UnitYX2TModel,
     UnitYT2UModel,
-    UnitYNART2UModel,
 )
 from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder,

+ 3 - 1
src/seamless_communication/models/unity/loader.py

@@ -311,7 +311,9 @@ class UnitYUnitTokenizerLoader:
             card = self.asset_store.retrieve_card(model_name_or_card)
 
         return UnitTokenizer(
-            card.field("num_units").as_(int), card.field("unit_langs").as_list(str)
+            card.field("num_units").as_(int),
+            card.field("unit_langs").as_list(str),
+            card.field("model_arch").as_(str),
         )
 
 

+ 2 - 0
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -148,6 +148,8 @@ class NARDecoderFrontend(Module):
             self.register_module("dropout", None)
 
     def indices_to_subwords(self, text_seqs: Tensor) -> List[List[str]]:
+        # TODO: To be replaced with fairseq2's indices_to_tokens SPM model method
+        # once implemented.
         N, seq_len = text_seqs.shape
         subwords_batch = []
         for b in range(N):

+ 33 - 11
src/seamless_communication/models/unity/unit_tokenizer.py

@@ -19,22 +19,31 @@ class UnitTokenizer:
     langs: Sequence[str]
     lang_map: Dict[str, int]
 
-    def __init__(self, num_units: int, langs: Sequence[str]) -> None:
+    def __init__(self, num_units: int, langs: Sequence[str], model_arch: str) -> None:
         """
         :param num_units:
             The number of speech units.
         :param langs:
             The list of supported languages.
+        :param model_arch:
+            The type of UnitY model architecture.
         """
         self.num_units = num_units
 
         self.langs = langs
 
+        self.model_arch = model_arch
+
         self.lang_map = {lang: idx for idx, lang in enumerate(langs)}
 
-        # For legacy reasons, we have to repeat the language symbols twice,
-        # along with a placeholder `<mask>` token.
-        vocab_size = num_units + (len(langs) + 1) + 4
+        if self.model_arch == "nar_multilingual":
+            self.lang_symbol_repititions = 1
+        else:
+            # For legacy reasons, we have to repeat the language symbols twice,
+            # along with a placeholder `<mask>` token for UnitY AR models.
+            self.lang_symbol_repititions = 2
+
+        vocab_size = num_units + self.lang_symbol_repititions * (len(langs) + 1) + 4
 
         # We use fairseq's control symbol order.
         self.vocab_info = VocabularyInfo(
@@ -45,7 +54,12 @@ class UnitTokenizer:
         """Return the symbol index of the specified language."""
         # +4 for PAD/EOS/BOS/UNK, and +1 for the `<mask>` token.
         try:
-            return self.num_units + self.lang_map[lang] + 5
+            return (
+                self.num_units
+                + (self.lang_symbol_repititions - 1) * len(self.langs)
+                + self.lang_map[lang]
+                + 5
+            )
         except KeyError:
             langs = ", ".join(self.langs)
 
@@ -76,7 +90,7 @@ class UnitTokenizer:
 
     def create_decoder(self) -> "UnitTokenDecoder":
         """Create a token decoder."""
-        return UnitTokenDecoder(self)
+        return UnitTokenDecoder(self, self.model_arch)
 
 
 class UnitTokenEncoder:
@@ -158,16 +172,19 @@ class UnitTokenDecoder:
     eos_idx: int
     pad_idx: int
 
-    def __init__(self, tokenizer: UnitTokenizer) -> None:
+    def __init__(self, tokenizer: UnitTokenizer, model_arch: str) -> None:
         """
         :param tokenizer:
             The unit tokenizer to use.
+        :param model_arch:
+            The type of UnitY model architecture.
         """
         assert tokenizer.vocab_info.eos_idx is not None
         assert tokenizer.vocab_info.pad_idx is not None
 
         self.eos_idx = tokenizer.vocab_info.eos_idx
         self.pad_idx = tokenizer.vocab_info.pad_idx
+        self.model_arch = model_arch
 
     def __call__(self, token_indices: Tensor) -> Tensor:
         """Decode ``token_indices`` to speech units.
@@ -184,16 +201,21 @@ class UnitTokenDecoder:
         if token_indices.size(1) == 0:
             return token_indices
 
-        # Remove the prefix EOS symbol. The language symbol is still expected to
-        # be part of the decoded output.
         units = token_indices.clone().detach()
 
+        # Remove the prefix EOS symbol from the decoded output for AR UnitY.
+        if self.model_arch != "nar_multilingual":
+            units = units[:, 1:]
+
         # Also, replace EOS with PAD at sequence ends.
         units[units == self.eos_idx] = self.pad_idx
 
         units[units == self.pad_idx] = self.pad_idx + 4
 
-        # Remove offset of control symbols (exclude language symbol).
-        units -= 4
+        # Remove offset of control symbols (exclude language symbol for AR UnitY).
+        if self.model_arch == "nar_multilingual":
+            units -= 4
+        else:
+            units[:, 1:] -= 4
 
         return units