|
@@ -338,8 +338,13 @@ class UnitYGenerator:
|
|
|
# Convert to speech units.
|
|
|
units = self.unit_decoder(unit_seqs)
|
|
|
|
|
|
- if ngram_filtering:
|
|
|
- arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
|
|
|
- units = torch.tensor(arr)
|
|
|
+ # ngram-filtering doesn't apply to NAR unit decoding.
|
|
|
+ if ngram_filtering and isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
+ if units.size(0) > 1:
|
|
|
+ raise NotImplementedError(
|
|
|
+ "unit ngram_filtering is not implemented for batch_size > 1."
|
|
|
+ )
|
|
|
+ arr = remove_consecutive_repeated_ngrams(units[0].tolist())
|
|
|
+ units = torch.tensor(arr).to(units).unsqueeze(0)
|
|
|
|
|
|
return texts, units
|