Selaa lähdekoodia

Merge pull request #40 from fairinternal/embedding

Rename Embedding to StandardEmbedding
Kaushik Ram Sadagopan 2 vuotta sitten
vanhempi
commit
fbeabde759
1 muutettua tiedostoa jossa 3 lisäystä ja 3 poistoa
  1. 3 3
      src/seamless_communication/models/unity/t2u_builder.py

+ 3 - 3
src/seamless_communication/models/unity/t2u_builder.py

@@ -10,7 +10,7 @@ from fairseq2.assets import download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.nn.embedding import Embedding
+from fairseq2.nn.embedding import Embedding, StandardEmbedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.transformer import (
@@ -264,7 +264,7 @@ class UnitYT2UBuilder:
 
     def build_unit_embedding(self) -> Embedding:
         """Build a unit embedding table."""
-        return Embedding(
+        return StandardEmbedding(
             num_embeddings=self.config.unit_vocabulary_size,
             embedding_dim=self.config.model_dim,
             pad_idx=self.config.unit_pad_idx,
@@ -374,7 +374,7 @@ class UnitYT2UBuilder:
             device=self.device,
         )
 
-        embed_char = Embedding(
+        embed_char = StandardEmbedding(
             num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
             embedding_dim=self.config.model_dim,
             pad_idx=text_pad_idx,