|
|
@@ -19,22 +19,31 @@ class UnitTokenizer:
|
|
|
langs: Sequence[str]
|
|
|
lang_map: Dict[str, int]
|
|
|
|
|
|
- def __init__(self, num_units: int, langs: Sequence[str]) -> None:
|
|
|
+ def __init__(self, num_units: int, langs: Sequence[str], model_arch: str) -> None:
|
|
|
"""
|
|
|
:param num_units:
|
|
|
The number of speech units.
|
|
|
:param langs:
|
|
|
The list of supported languages.
|
|
|
+ :param model_arch:
|
|
|
+ The type of UnitY model architecture.
|
|
|
"""
|
|
|
self.num_units = num_units
|
|
|
|
|
|
self.langs = langs
|
|
|
|
|
|
+ self.model_arch = model_arch
|
|
|
+
|
|
|
self.lang_map = {lang: idx for idx, lang in enumerate(langs)}
|
|
|
|
|
|
- # For legacy reasons, we have to repeat the language symbols twice,
|
|
|
- # along with a placeholder `<mask>` token.
|
|
|
- vocab_size = num_units + (len(langs) + 1) + 4
|
|
|
+ if self.model_arch == "nar_multilingual":
|
|
|
+ self.lang_symbol_repititions = 1
|
|
|
+ else:
|
|
|
+ # For legacy reasons, we have to repeat the language symbols twice,
|
|
|
+ # along with a placeholder `<mask>` token for UnitY AR models.
|
|
|
+ self.lang_symbol_repititions = 2
|
|
|
+
|
|
|
+ vocab_size = num_units + self.lang_symbol_repititions * (len(langs) + 1) + 4
|
|
|
|
|
|
# We use fairseq's control symbol order.
|
|
|
self.vocab_info = VocabularyInfo(
|
|
|
@@ -45,7 +54,12 @@ class UnitTokenizer:
|
|
|
"""Return the symbol index of the specified language."""
|
|
|
# +4 for PAD/EOS/BOS/UNK, and +1 for the `<mask>` token.
|
|
|
try:
|
|
|
- return self.num_units + self.lang_map[lang] + 5
|
|
|
+ return (
|
|
|
+ self.num_units
|
|
|
+ + (self.lang_symbol_repititions - 1) * len(self.langs)
|
|
|
+ + self.lang_map[lang]
|
|
|
+ + 5
|
|
|
+ )
|
|
|
except KeyError:
|
|
|
langs = ", ".join(self.langs)
|
|
|
|
|
|
@@ -76,7 +90,7 @@ class UnitTokenizer:
|
|
|
|
|
|
def create_decoder(self) -> "UnitTokenDecoder":
|
|
|
"""Create a token decoder."""
|
|
|
- return UnitTokenDecoder(self)
|
|
|
+ return UnitTokenDecoder(self, self.model_arch)
|
|
|
|
|
|
|
|
|
class UnitTokenEncoder:
|
|
|
@@ -158,16 +172,19 @@ class UnitTokenDecoder:
|
|
|
eos_idx: int
|
|
|
pad_idx: int
|
|
|
|
|
|
- def __init__(self, tokenizer: UnitTokenizer) -> None:
|
|
|
+ def __init__(self, tokenizer: UnitTokenizer, model_arch: str) -> None:
|
|
|
"""
|
|
|
:param tokenizer:
|
|
|
The unit tokenizer to use.
|
|
|
+ :param model_arch:
|
|
|
+ The type of UnitY model architecture.
|
|
|
"""
|
|
|
assert tokenizer.vocab_info.eos_idx is not None
|
|
|
assert tokenizer.vocab_info.pad_idx is not None
|
|
|
|
|
|
self.eos_idx = tokenizer.vocab_info.eos_idx
|
|
|
self.pad_idx = tokenizer.vocab_info.pad_idx
|
|
|
+ self.model_arch = model_arch
|
|
|
|
|
|
def __call__(self, token_indices: Tensor) -> Tensor:
|
|
|
"""Decode ``token_indices`` to speech units.
|
|
|
@@ -184,16 +201,21 @@ class UnitTokenDecoder:
|
|
|
if token_indices.size(1) == 0:
|
|
|
return token_indices
|
|
|
|
|
|
- # Remove the prefix EOS symbol. The language symbol is still expected to
|
|
|
- # be part of the decoded output.
|
|
|
units = token_indices.clone().detach()
|
|
|
|
|
|
+ # Remove the prefix EOS symbol from the decoded output for AR UnitY.
|
|
|
+ if self.model_arch != "nar_multilingual":
|
|
|
+ units = units[:, 1:]
|
|
|
+
|
|
|
# Also, replace EOS with PAD at sequence ends.
|
|
|
units[units == self.eos_idx] = self.pad_idx
|
|
|
|
|
|
units[units == self.pad_idx] = self.pad_idx + 4
|
|
|
|
|
|
- # Remove offset of control symbols (exclude language symbol).
|
|
|
- units -= 4
|
|
|
+ # Remove offset of control symbols (exclude language symbol for AR UnitY).
|
|
|
+ if self.model_arch == "nar_multilingual":
|
|
|
+ units -= 4
|
|
|
+ else:
|
|
|
+ units[:, 1:] -= 4
|
|
|
|
|
|
return units
|