Parcourir la source

Fixes and perf improvements from latest fairseq2 HEAD (#36)

Can Balioglu il y a 2 ans
Parent
commit
785cf2add8

+ 3 - 2
src/seamless_communication/models/unity/length_regulator.py

@@ -13,6 +13,7 @@ from typing import Optional, Tuple
 from fairseq2.typing import DataType, Device
 from fairseq2.nn.transformer import create_default_layer_norm
 from fairseq2.nn.normalization import LayerNorm
+from fairseq2.nn.ops import repeat_interleave
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.utils.mask import apply_padding_mask
 
@@ -31,8 +32,8 @@ class HardUpsampling(Module):
         upsampled_seqs = seqs.new_zeros((N, max_len, M))
 
         for b in range(N):
-            upsampled_seqs[b, : upsampled_seq_lens[b]] = seqs[b].repeat_interleave(
-                durations[b], dim=0
+            upsampled_seqs[b, : upsampled_seq_lens[b]] = repeat_interleave(
+                seqs[b], dim=0, repeat=durations[b]
             )
 
         return upsampled_seqs, upsampled_seq_lens

+ 0 - 3
src/seamless_communication/models/unity/t2u_builder.py

@@ -333,7 +333,6 @@ class UnitYT2UBuilder:
             self.config.unit_max_seq_len,
             _legacy_pad_idx=self.config.unit_pad_idx,
             device=self.device,
-            dtype=self.dtype,
         )
         return TransformerEmbeddingFrontend(
             embed_unit,
@@ -353,7 +352,6 @@ class UnitYT2UBuilder:
             self.config.unit_max_seq_len,
             _legacy_pad_idx=self.config.unit_pad_idx,
             device=self.device,
-            dtype=self.dtype,
         )
 
         char_tokenizer = load_unity_char_tokenizer(
@@ -374,7 +372,6 @@ class UnitYT2UBuilder:
             self.config.nar_decoder_config.char_max_seq_len,
             _legacy_pad_idx=text_pad_idx,
             device=self.device,
-            dtype=self.dtype,
         )
 
         embed_char = Embedding(

+ 3 - 1
src/seamless_communication/models/vocoder/codehifigan.py

@@ -9,6 +9,8 @@ import torch
 import torch.nn as nn
 from torch import Tensor
 
+from fairseq2.nn.ops import repeat_interleave
+
 from seamless_communication.models.vocoder.hifigan import Generator
 from seamless_communication.models.unity import VariancePredictor
 
@@ -83,7 +85,7 @@ class CodeGenerator(Generator):
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
             )
             # B x C x T
-            x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
+            x = repeat_interleave(x, dim=2, repeat=dur_out.view(-1))
 
         spkr = self.spkr(sample["spkr"].to(self.spkr.weight.device)).transpose(1, 2)
         spkr = self._upsample(spkr, x.shape[-1])