|
@@ -23,7 +23,7 @@ from simuleval.data.segments import Segment, TextSegment
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
|
-class DecoderAgentStates(AgentStates):
|
|
|
+class DecoderAgentStates(AgentStates): # type: ignore
|
|
|
def reset(self) -> None:
|
|
|
self.source_len = 0
|
|
|
self.target_indices: List[int] = []
|
|
@@ -50,7 +50,7 @@ class DecoderAgentStates(AgentStates):
|
|
|
self.source_len = self.source.size(1)
|
|
|
|
|
|
|
|
|
-class OnlineTextDecoderAgent(GenericAgent):
|
|
|
+class OnlineTextDecoderAgent(GenericAgent): # type: ignore
|
|
|
"""
|
|
|
Online text decoder
|
|
|
"""
|
|
@@ -139,7 +139,7 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
|
self.prefix_indices[-1] = tgt_lang_tag_idx
|
|
|
|
|
|
|
|
|
-class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
+class MMATextDecoderAgent(OnlineTextDecoderAgent): # type: ignore
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: MonotonicDecoderModel,
|
|
@@ -278,15 +278,17 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
states: DecoderAgentStates,
|
|
|
pred_indices: List[int],
|
|
|
decoder_features_out: Tensor,
|
|
|
- blocked_ngrams: Set[str],
|
|
|
+ blocked_ngrams: Optional[Set[str]],
|
|
|
index: int,
|
|
|
- ) -> bool:
|
|
|
+ ) -> Tuple[bool, Tensor]:
|
|
|
"""
|
|
|
This check is used to force a READ decision when n-gram repeat
|
|
|
happens before source_finished
|
|
|
"""
|
|
|
if not self.block_ngrams or states.source_finished:
|
|
|
return False, decoder_features_out
|
|
|
+
|
|
|
+ assert blocked_ngrams is not None
|
|
|
all_indices = states.target_indices + pred_indices + [index]
|
|
|
for n in [3, 2]: # TODO: make it configurable
|
|
|
if len(all_indices) >= n and states.ngram_block_count <= 4:
|