|
@@ -325,10 +325,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent): # type: ignore
|
|
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
|
|
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
|
|
decoder_features_out = None
|
|
decoder_features_out = None
|
|
|
|
|
|
- while (
|
|
|
|
- len(states.target_indices + pred_indices) < self.max_len(states)
|
|
|
|
- and len(pred_indices) < self.max_consecutive_writes
|
|
|
|
- ):
|
|
|
|
|
|
+ while True:
|
|
index, prob, decoder_features = self.run_decoder(states, pred_indices)
|
|
index, prob, decoder_features = self.run_decoder(states, pred_indices)
|
|
|
|
|
|
if decoder_features_out is None:
|
|
if decoder_features_out is None:
|
|
@@ -361,6 +358,12 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent): # type: ignore
|
|
if prob < self.decision_threshold and not states.source_finished:
|
|
if prob < self.decision_threshold and not states.source_finished:
|
|
break
|
|
break
|
|
|
|
|
|
|
|
+ if (
|
|
|
|
+ len(states.target_indices + pred_indices) >= self.max_len(states)
|
|
|
|
+ or len(pred_indices) >= self.max_consecutive_writes
|
|
|
|
+ ):
|
|
|
|
+ break
|
|
|
|
+
|
|
pred_indices.append(index)
|
|
pred_indices.append(index)
|
|
if self.state_bag.step_nr == 0:
|
|
if self.state_bag.step_nr == 0:
|
|
self.state_bag.increment_step_nr(
|
|
self.state_bag.increment_step_nr(
|