|
|
@@ -13,6 +13,7 @@ from typing import Optional, Tuple
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from fairseq2.nn.transformer import create_default_layer_norm
|
|
|
from fairseq2.nn.normalization import LayerNorm
|
|
|
+from fairseq2.nn.ops import repeat_interleave
|
|
|
from fairseq2.nn.projection import Linear
|
|
|
from fairseq2.nn.utils.mask import apply_padding_mask
|
|
|
|
|
|
@@ -31,8 +32,8 @@ class HardUpsampling(Module):
|
|
|
upsampled_seqs = seqs.new_zeros((N, max_len, M))
|
|
|
|
|
|
for b in range(N):
|
|
|
- upsampled_seqs[b, : upsampled_seq_lens[b]] = seqs[b].repeat_interleave(
|
|
|
- durations[b], dim=0
|
|
|
+ upsampled_seqs[b, : upsampled_seq_lens[b]] = repeat_interleave(
|
|
|
+ seqs[b], dim=0, repeat=durations[b]
|
|
|
)
|
|
|
|
|
|
return upsampled_seqs, upsampled_seq_lens
|