浏览代码

Merge pull request #37 from fairinternal/revert_interleave_change

Revert back interleave change since this isn't a single dimension repeat.
Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
e42276c223

+ 2 - 3
src/seamless_communication/models/unity/length_regulator.py

@@ -13,7 +13,6 @@ from typing import Optional, Tuple
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 from fairseq2.nn.transformer import create_default_layer_norm
 from fairseq2.nn.transformer import create_default_layer_norm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.ops import repeat_interleave
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.utils.mask import apply_padding_mask
 from fairseq2.nn.utils.mask import apply_padding_mask
 
 
@@ -32,8 +31,8 @@ class HardUpsampling(Module):
         upsampled_seqs = seqs.new_zeros((N, max_len, M))
         upsampled_seqs = seqs.new_zeros((N, max_len, M))
 
 
         for b in range(N):
         for b in range(N):
-            upsampled_seqs[b, : upsampled_seq_lens[b]] = repeat_interleave(
-                seqs[b], dim=0, repeat=durations[b]
+            upsampled_seqs[b, : upsampled_seq_lens[b]] = seqs[b].repeat_interleave(
+                durations[b], dim=0
             )
             )
 
 
         return upsampled_seqs, upsampled_seq_lens
         return upsampled_seqs, upsampled_seq_lens

+ 1 - 3
src/seamless_communication/models/vocoder/codehifigan.py

@@ -9,8 +9,6 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch import Tensor
 from torch import Tensor
 
 
-from fairseq2.nn.ops import repeat_interleave
-
 from seamless_communication.models.vocoder.hifigan import Generator
 from seamless_communication.models.vocoder.hifigan import Generator
 from seamless_communication.models.unity import VariancePredictor
 from seamless_communication.models.unity import VariancePredictor
 
 
@@ -85,7 +83,7 @@ class CodeGenerator(Generator):
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
             )
             )
             # B x C x T
             # B x C x T
-            x = repeat_interleave(x, dim=2, repeat=dur_out.view(-1))
+            x = torch.repeat_interleave(x, dur_out.view(-1), dim=2)
 
 
         spkr = self.spkr(sample["spkr"].to(self.spkr.weight.device)).transpose(1, 2)
         spkr = self.spkr(sample["spkr"].to(self.spkr.weight.device)).transpose(1, 2)
         spkr = self._upsample(spkr, x.shape[-1])
         spkr = self._upsample(spkr, x.shape[-1])