Эх сурвалжийг харах

Use embed.init_fn instead of scaled=True (#44)

Can Balioglu 1 жил өмнө
parent
commit
9f142ac30e

+ 0 - 2
src/seamless_communication/models/unity/builder.py

@@ -293,8 +293,6 @@ class UnitYBuilder:
             text_encoder_frontend = None
             text_encoder = None
 
-        assert isinstance(text_embed.weight, Parameter)
-
         final_proj = TiedProjection(text_embed.weight, bias=None)
 
         if self.t2u_builder is None:

+ 4 - 6
src/seamless_communication/models/unity/t2u_builder.py

@@ -11,7 +11,7 @@ from fairseq2.assets import download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.nn.embedding import Embedding, StandardEmbedding
+from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.transformer import (
@@ -242,8 +242,6 @@ class UnitYT2UBuilder:
 
         decoder = self.build_decoder()
 
-        assert isinstance(embed_unit.weight, Parameter)
-
         final_proj = TiedProjection(embed_unit.weight, bias=None)
 
         if self.config.nar_decoder_config is None:
@@ -265,13 +263,13 @@ class UnitYT2UBuilder:
                 self.config.unit_pad_idx,
             )
 
-    def build_unit_embedding(self) -> Embedding:
+    def build_unit_embedding(self) -> StandardEmbedding:
         """Build a unit embedding table."""
         return StandardEmbedding(
             num_embeddings=self.config.unit_vocabulary_size,
             embedding_dim=self.config.model_dim,
             pad_idx=self.config.unit_pad_idx,
-            scaled=True,
+            init_fn=init_scaled_embedding,
             device=self.device,
             dtype=self.dtype,
         )
@@ -381,7 +379,7 @@ class UnitYT2UBuilder:
             num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
             embedding_dim=self.config.model_dim,
             pad_idx=text_pad_idx,
-            scaled=True,
+            init_fn=init_scaled_embedding,
             device=self.device,
             dtype=self.dtype,
         )