|
@@ -48,7 +48,7 @@ class UnitExtractor(nn.Module):
|
|
self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
|
|
self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
|
|
self.device = device
|
|
self.device = device
|
|
self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
|
self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
|
- self.collate = Collater(pad_idx=2, pad_to_multiple=2)
|
|
|
|
|
|
+ self.collate = Collater(pad_value=2, pad_to_multiple=2)
|
|
self.kmeans_model = KmeansModel(kmeans_uri, device)
|
|
self.kmeans_model = KmeansModel(kmeans_uri, device)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|