Prechádzať zdrojové kódy

Fix bug in unit ngram filtering. (#175)

* Fix bug in unit ngram filtering.

* Disable ngram filtering for NAR unit decoder.
Kaushik Ram Sadagopan 1 rok pred
rodič
commit
383ef96b5c

+ 8 - 3
src/seamless_communication/inference/generator.py

@@ -338,8 +338,13 @@ class UnitYGenerator:
         # Convert to speech units.
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)
         units = self.unit_decoder(unit_seqs)
 
 
-        if ngram_filtering:
-            arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
-            units = torch.tensor(arr)
+        # ngram-filtering doesn't apply to NAR unit decoding.
+        if ngram_filtering and isinstance(self.model.t2u_model, UnitYT2UModel):
+            if units.size(0) > 1:
+                raise NotImplementedError(
+                    "unit ngram_filtering is not implemented for batch_size > 1."
+                )
+            arr = remove_consecutive_repeated_ngrams(units[0].tolist())
+            units = torch.tensor(arr).to(units).unsqueeze(0)
 
 
         return texts, units
         return texts, units