|
@@ -11,7 +11,7 @@ from fairseq2.assets import download_manager
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import VocabularyInfo
|
|
|
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
|
|
|
-from fairseq2.nn.embedding import Embedding, StandardEmbedding
|
|
|
+from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
|
|
|
from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
|
|
|
from fairseq2.nn.projection import TiedProjection
|
|
|
from fairseq2.nn.transformer import (
|
|
@@ -242,8 +242,6 @@ class UnitYT2UBuilder:
|
|
|
|
|
|
decoder = self.build_decoder()
|
|
|
|
|
|
- assert isinstance(embed_unit.weight, Parameter)
|
|
|
-
|
|
|
final_proj = TiedProjection(embed_unit.weight, bias=None)
|
|
|
|
|
|
if self.config.nar_decoder_config is None:
|
|
@@ -265,13 +263,13 @@ class UnitYT2UBuilder:
|
|
|
self.config.unit_pad_idx,
|
|
|
)
|
|
|
|
|
|
- def build_unit_embedding(self) -> Embedding:
|
|
|
+ def build_unit_embedding(self) -> StandardEmbedding:
|
|
|
"""Build a unit embedding table."""
|
|
|
return StandardEmbedding(
|
|
|
num_embeddings=self.config.unit_vocabulary_size,
|
|
|
embedding_dim=self.config.model_dim,
|
|
|
pad_idx=self.config.unit_pad_idx,
|
|
|
- scaled=True,
|
|
|
+ init_fn=init_scaled_embedding,
|
|
|
device=self.device,
|
|
|
dtype=self.dtype,
|
|
|
)
|
|
@@ -381,7 +379,7 @@ class UnitYT2UBuilder:
|
|
|
num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
|
|
|
embedding_dim=self.config.model_dim,
|
|
|
pad_idx=text_pad_idx,
|
|
|
- scaled=True,
|
|
|
+ init_fn=init_scaled_embedding,
|
|
|
device=self.device,
|
|
|
dtype=self.dtype,
|
|
|
)
|