|
|
@@ -30,7 +30,6 @@ from fairseq2.nn.transformer import (
|
|
|
create_default_sdpa,
|
|
|
)
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
-from fairseq2.models.mbart.loader import mBartTokenizerLoader
|
|
|
from fairseq2.models.transformer import (
|
|
|
TransformerEmbeddingFrontend,
|
|
|
TransformerFrontend,
|
|
|
@@ -69,7 +68,6 @@ class NARDecoderFrontendConfig:
|
|
|
|
|
|
@dataclass
|
|
|
class NARDecoderConfig:
|
|
|
- text_tokenizer_type: Literal["nllb", "mbart"]
|
|
|
model_name_or_card: Union[str, AssetCard]
|
|
|
char_vocabulary_size: int
|
|
|
char_max_seq_len: int
|
|
|
@@ -182,7 +180,6 @@ def _nar_multilingual_t2u() -> UnitYT2UConfig:
|
|
|
)
|
|
|
|
|
|
nar_decoder_config = NARDecoderConfig(
|
|
|
- text_tokenizer_type="nllb",
|
|
|
model_name_or_card="unity_nar_multilingual",
|
|
|
char_vocabulary_size=10904,
|
|
|
char_max_seq_len=4096,
|
|
|
@@ -367,76 +364,40 @@ class UnitYT2UBuilder:
|
|
|
self.config.nar_decoder_frontend_config
|
|
|
)
|
|
|
|
|
|
- if self.config.nar_decoder_config.text_tokenizer_type == "nllb":
|
|
|
- nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
|
|
|
- self.config.nar_decoder_config.model_name_or_card
|
|
|
- )
|
|
|
- text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
|
|
|
-
|
|
|
- char_pos_encoder = SinusoidalPositionEncoder(
|
|
|
- self.config.model_dim,
|
|
|
- self.config.nar_decoder_config.char_max_seq_len,
|
|
|
- _legacy_pad_idx=text_pad_idx,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
-
|
|
|
- embed_char = Embedding(
|
|
|
- num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
|
|
|
- embedding_dim=self.config.model_dim,
|
|
|
- pad_idx=text_pad_idx,
|
|
|
- scaled=True,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
-
|
|
|
- return NARDecoderFrontend(
|
|
|
- embed_unit,
|
|
|
- embed_char,
|
|
|
- nllb_tokenizer,
|
|
|
- char_tokenizer,
|
|
|
- unit_pos_encoder,
|
|
|
- char_pos_encoder,
|
|
|
- variance_adaptor,
|
|
|
- dropout_p=self.config.dropout_p,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
+ nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
|
|
|
+ self.config.nar_decoder_config.model_name_or_card
|
|
|
+ )
|
|
|
+ text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
|
|
|
|
|
|
- else:
|
|
|
- mbart_tokenizer = mBartTokenizerLoader(asset_store, download_manager)(
|
|
|
- self.config.nar_decoder_config.model_name_or_card
|
|
|
- )
|
|
|
- text_pad_idx = mbart_tokenizer.vocab_info.pad_idx
|
|
|
+ char_pos_encoder = SinusoidalPositionEncoder(
|
|
|
+ self.config.model_dim,
|
|
|
+ self.config.nar_decoder_config.char_max_seq_len,
|
|
|
+ _legacy_pad_idx=text_pad_idx,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
|
|
|
- char_pos_encoder = SinusoidalPositionEncoder(
|
|
|
- self.config.model_dim,
|
|
|
- self.config.nar_decoder_config.char_max_seq_len,
|
|
|
- _legacy_pad_idx=text_pad_idx,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
+ embed_char = Embedding(
|
|
|
+ num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
|
|
|
+ embedding_dim=self.config.model_dim,
|
|
|
+ pad_idx=text_pad_idx,
|
|
|
+ scaled=True,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
|
|
|
- embed_char = Embedding(
|
|
|
- num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
|
|
|
- embedding_dim=self.config.model_dim,
|
|
|
- pad_idx=text_pad_idx,
|
|
|
- scaled=True,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
- return NARDecoderFrontend(
|
|
|
- embed_unit,
|
|
|
- embed_char,
|
|
|
- mbart_tokenizer,
|
|
|
- char_tokenizer,
|
|
|
- unit_pos_encoder,
|
|
|
- char_pos_encoder,
|
|
|
- variance_adaptor,
|
|
|
- dropout_p=self.config.dropout_p,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
+ return NARDecoderFrontend(
|
|
|
+ embed_unit,
|
|
|
+ embed_char,
|
|
|
+ nllb_tokenizer,
|
|
|
+ char_tokenizer,
|
|
|
+ unit_pos_encoder,
|
|
|
+ char_pos_encoder,
|
|
|
+ variance_adaptor,
|
|
|
+ dropout_p=self.config.dropout_p,
|
|
|
+ device=self.device,
|
|
|
+ dtype=self.dtype,
|
|
|
+ )
|
|
|
|
|
|
def build_decoder(self) -> TransformerDecoder:
|
|
|
"""Build a Transformer decoder."""
|