瀏覽代碼

Fix S2S pipeline edge case error (#243)

Abinesh Ramakrishnan 1 年之前
父節點
當前提交
85be8ba9dd
共有 1 個文件被更改,包括 7 次插入4 次删除
  1. 7 4
      src/seamless_communication/streaming/agents/online_text_decoder.py

+ 7 - 4
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -325,10 +325,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):  # type: ignore
         blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
         blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
         decoder_features_out = None
         decoder_features_out = None
 
 
-        while (
-            len(states.target_indices + pred_indices) < self.max_len(states)
-            and len(pred_indices) < self.max_consecutive_writes
-        ):
+        while True:
             index, prob, decoder_features = self.run_decoder(states, pred_indices)
             index, prob, decoder_features = self.run_decoder(states, pred_indices)
 
 
             if decoder_features_out is None:
             if decoder_features_out is None:
@@ -361,6 +358,12 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):  # type: ignore
             if prob < self.decision_threshold and not states.source_finished:
             if prob < self.decision_threshold and not states.source_finished:
                 break
                 break
 
 
+            if (
+                len(states.target_indices + pred_indices) >= self.max_len(states)
+                or len(pred_indices) >= self.max_consecutive_writes
+            ):
+                break
+
             pred_indices.append(index)
             pred_indices.append(index)
             if self.state_bag.step_nr == 0:
             if self.state_bag.step_nr == 0:
                 self.state_bag.increment_step_nr(
                 self.state_bag.increment_step_nr(