|
|
@@ -13,7 +13,6 @@ from typing import Optional, Tuple
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from fairseq2.nn.transformer import create_default_layer_norm
|
|
|
from fairseq2.nn.normalization import LayerNorm
|
|
|
-from fairseq2.nn.ops import repeat_interleave
|
|
|
from fairseq2.nn.projection import Linear
|
|
|
from fairseq2.nn.utils.mask import apply_padding_mask
|
|
|
|
|
|
@@ -32,8 +31,8 @@ class HardUpsampling(Module):
|
|
|
upsampled_seqs = seqs.new_zeros((N, max_len, M))
|
|
|
|
|
|
for b in range(N):
|
|
|
- upsampled_seqs[b, : upsampled_seq_lens[b]] = repeat_interleave(
|
|
|
- seqs[b], dim=0, repeat=durations[b]
|
|
|
+ upsampled_seqs[b, : upsampled_seq_lens[b]] = seqs[b].repeat_interleave(
|
|
|
+ durations[b], dim=0
|
|
|
)
|
|
|
|
|
|
return upsampled_seqs, upsampled_seq_lens
|