Prechádzať zdrojové kódy

Return durations from the variance adaptor. (#134)

Kaushik Ram Sadagopan 1 rok pred
rodič
commit
7537081d50

+ 1 - 1
src/seamless_communication/inference/generator.py

@@ -252,7 +252,7 @@ class UnitYGenerator:
             )
             unit_seqs, _ = unit_gen_output.collate()
         else:
-            unit_decoder_output, decoder_padding_mask = self.model.t2u_model(
+            unit_decoder_output, decoder_padding_mask, _ = self.model.t2u_model(
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_seqs=text_seqs,

+ 2 - 4
src/seamless_communication/models/pretssel/builder.py

@@ -5,13 +5,12 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import List, Literal, Optional, Union
+from typing import Literal, Optional
 
 from fairseq2.assets import asset_store
-from fairseq2.assets.card import AssetCard
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
-from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
+from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.transformer import (
@@ -40,7 +39,6 @@ from seamless_communication.models.unity.fft_decoder_layer import (
     FeedForwardTransformerLayer,
 )
 from seamless_communication.models.unity.length_regulator import (
-    HardUpsampling,
     VarianceAdaptor,
     VariancePredictor,
 )

+ 1 - 1
src/seamless_communication/models/pretssel/pretssel_model.py

@@ -136,7 +136,7 @@ class PretsselDecoderFrontend(Module):
         min_duration: int = 0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
-        seqs, padding_mask = self.variance_adaptor(
+        seqs, padding_mask, _ = self.variance_adaptor(
             seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb
         )
 

+ 2 - 4
src/seamless_communication/models/unity/length_regulator.py

@@ -238,8 +238,6 @@ class VarianceAdaptor(Module):
         embed_energy: Optional[Conv1d] = None,
         add_variance_parallel: bool = True,
         upsampling_type: Literal["gaussian", "hard"] = "hard",
-        use_film: bool = False,
-        film_cond_dim: Optional[int] = None,
     ):
         super().__init__()
 
@@ -282,7 +280,7 @@ class VarianceAdaptor(Module):
         duration_factor: float = 1.0,
         min_duration: int = 0,
         film_cond_emb: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, PaddingMask]:
+    ) -> Tuple[Tensor, PaddingMask, Tensor]:
         if self.duration_predictor is not None:
             log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
             durations = torch.clamp(
@@ -320,4 +318,4 @@ class VarianceAdaptor(Module):
         else:
             seqs, seq_lens = self.length_regulator(seqs, durations)
 
-        return seqs, PaddingMask(seq_lens, batch_seq_len=seqs.size(1))
+        return seqs, PaddingMask(seq_lens, batch_seq_len=seqs.size(1)), durations

+ 10 - 6
src/seamless_communication/models/unity/model.py

@@ -383,7 +383,7 @@ class UnitYNART2UModel(Module):
         text_seqs: Optional[Tensor],
         duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
-    ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
+    ) -> Tuple[SequenceModelOutput, Optional[PaddingMask], Tensor]:
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
         )
@@ -391,7 +391,7 @@ class UnitYNART2UModel(Module):
         if self.prosody_proj is not None and film_cond_emb is not None:
             encoder_output = encoder_output + self.prosody_proj(film_cond_emb)
 
-        decoder_output, decoder_padding_mask = self.decode(
+        decoder_output, decoder_padding_mask, durations = self.decode(
             encoder_output,
             encoder_padding_mask,
             text_seqs,
@@ -399,7 +399,7 @@ class UnitYNART2UModel(Module):
             film_cond_emb,
         )
 
-        return self.project(decoder_output), decoder_padding_mask
+        return self.project(decoder_output), decoder_padding_mask, durations
 
     def encode(
         self,
@@ -418,10 +418,10 @@ class UnitYNART2UModel(Module):
         text_seqs: Optional[Tensor],
         duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
-        seqs, padding_mask = self.decoder_frontend(
+        seqs, padding_mask, durations = self.decoder_frontend(
             encoder_output,
             encoder_padding_mask,
             text_seqs,
@@ -429,7 +429,11 @@ class UnitYNART2UModel(Module):
             film_cond_emb,
         )
 
-        return self.decoder(seqs, padding_mask, film_cond_emb=film_cond_emb)  # type: ignore[no-any-return]
+        seqs, padding_mask = self.decoder(
+            seqs, padding_mask, film_cond_emb=film_cond_emb
+        )
+
+        return seqs, padding_mask, durations  # type: ignore[no-any-return]
 
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)

+ 3 - 3
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -304,7 +304,7 @@ class NARDecoderFrontend(Module):
         text_seqs: Optional[Tensor],
         duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+    ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]:
         assert text_seqs is not None
 
         # text_seqs: (N, S_text)
@@ -321,7 +321,7 @@ class NARDecoderFrontend(Module):
         )
 
         # (N, S_char, M) -> (N, S_unit, M)
-        seqs, padding_mask = self.variance_adaptor(
+        seqs, padding_mask, durations = self.variance_adaptor(
             seqs,
             encoder_padding_mask,
             duration_factor=duration_factor,
@@ -331,4 +331,4 @@ class NARDecoderFrontend(Module):
 
         seqs = self.forward_unit_pos_embedding(seqs, padding_mask)
 
-        return seqs, padding_mask
+        return seqs, padding_mask, durations