Ver código fonte

Fix bug in unit ngram filtering. (#175)

* Fix bug in unit ngram filtering.

* Disable ngram filtering for NAR unit decoder.
Kaushik Ram Sadagopan 1 ano atrás
pai
commit
383ef96b5c
1 arquivos alterados com 8 adições e 3 exclusões
  1. 8 3
      src/seamless_communication/inference/generator.py

+ 8 - 3
src/seamless_communication/inference/generator.py

@@ -338,8 +338,13 @@ class UnitYGenerator:
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)
 
-        if ngram_filtering:
-            arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
-            units = torch.tensor(arr)
+        # ngram-filtering doesn't apply to NAR unit decoding.
+        if ngram_filtering and isinstance(self.model.t2u_model, UnitYT2UModel):
+            if units.size(0) > 1:
+                raise NotImplementedError(
+                    "unit ngram_filtering is not implemented for batch_size > 1."
+                )
+            arr = remove_consecutive_repeated_ngrams(units[0].tolist())
+            units = torch.tensor(arr).to(units).unsqueeze(0)
 
         return texts, units