|
@@ -6,7 +6,7 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
-from typing import Any, Dict, List, Tuple
|
|
|
+from typing import Any, Dict, List, Set, Tuple
|
|
|
|
|
|
import torch
|
|
|
from fairseq2.models.nllb.tokenizer import NllbTokenizer
|
|
@@ -27,6 +27,7 @@ class DecoderAgentStates(AgentStates):
|
|
|
self.source_len = 0
|
|
|
self.target_indices: List[int] = []
|
|
|
self.tgt_lang = None
|
|
|
+ self.ngram_block_count = 0
|
|
|
super().reset()
|
|
|
|
|
|
def update_source(self, segment: Segment) -> None:
|
|
@@ -76,9 +77,11 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
|
self.device = args.device
|
|
|
self.dtype = args.dtype
|
|
|
self.eos_idx = text_tokenizer.vocab_info.eos_idx
|
|
|
- if hasattr(args, "tgt_lang") and hasattr(args, "prefix_tgt_lang"):
|
|
|
+ if getattr(args, "tgt_lang", None) and getattr(args, "prefix_tgt_lang", None):
|
|
|
assert args.tgt_lang == args.prefix_tgt_lang
|
|
|
- tgt_lang = getattr(args, "tgt_lang", None) or getattr(args, "prefix_tgt_lang", None)
|
|
|
+ tgt_lang = getattr(args, "tgt_lang", None) or getattr(
|
|
|
+ args, "prefix_tgt_lang", None
|
|
|
+ )
|
|
|
token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
|
|
|
prefix_indices = token_encoder.prefix_indices
|
|
|
assert prefix_indices is not None
|
|
@@ -116,12 +119,6 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
|
default=1,
|
|
|
help="Minimal starting waiting source steps",
|
|
|
)
|
|
|
- parser.add_argument(
|
|
|
- "--min-starting-wait-reset",
|
|
|
- type=int,
|
|
|
- default=0,
|
|
|
- help="Minimal starting waiting source steps",
|
|
|
- )
|
|
|
parser.add_argument(
|
|
|
"--no-early-stop",
|
|
|
action="store_true",
|
|
@@ -157,6 +154,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
|
|
|
self.decision_threshold = args.decision_threshold
|
|
|
self.decision_method = args.decision_method
|
|
|
+ self.block_ngrams = args.block_ngrams
|
|
|
self.p_choose_start_layer = args.p_choose_start_layer
|
|
|
|
|
|
@staticmethod
|
|
@@ -181,6 +179,10 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
default=0,
|
|
|
help="Encoder layer from which p_choose should be considered for selection.",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--block-ngrams",
|
|
|
+ action="store_true",
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def from_args(
|
|
@@ -224,6 +226,10 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
)
|
|
|
|
|
|
logits = self.model.project(decoder_output)
|
|
|
+ if self.block_ngrams and states.source_finished:
|
|
|
+ all_indices = states.target_indices + pred_indices
|
|
|
+ blocked_indices = all_indices[-4:]
|
|
|
+ logits[:, :, blocked_indices] = float("-inf")
|
|
|
|
|
|
index = int(logits[0, -1].argmax().item())
|
|
|
_, tgt_len, src_len = p_choose.size()
|
|
@@ -250,6 +256,46 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
tgt_lang=states.tgt_lang,
|
|
|
)
|
|
|
|
|
|
+ def get_blocked_ngrams(self, target_indices: List[int]):
|
|
|
+ # TODO: make it configurable and use itertools
|
|
|
+ if not self.block_ngrams:
|
|
|
+ return None
|
|
|
+ blocked_ngrams = set()
|
|
|
+ if len(target_indices) >= 4:
|
|
|
+ blocked_ngrams.add(str(target_indices[-4:]))
|
|
|
+ blocked_ngrams.add(str(target_indices[-4:-2]))
|
|
|
+ blocked_ngrams.add(str(target_indices[-4:-1]))
|
|
|
+ if len(target_indices) >= 3:
|
|
|
+ blocked_ngrams.add(str(target_indices[-3:]))
|
|
|
+ blocked_ngrams.add(str(target_indices[-3:-1]))
|
|
|
+ if len(target_indices) >= 2:
|
|
|
+ blocked_ngrams.add(str(target_indices[-2:]))
|
|
|
+ return blocked_ngrams
|
|
|
+
|
|
|
+ def maybe_block_ngrams(
|
|
|
+ self,
|
|
|
+ states: DecoderAgentStates,
|
|
|
+ pred_indices: List[int],
|
|
|
+ blocked_ngrams: Set[int],
|
|
|
+ index: int,
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ This check is used to force a READ decision when n-gram repeat
|
|
|
+ happens before source_finished
|
|
|
+ """
|
|
|
+ if not self.block_ngrams or states.source_finished:
|
|
|
+ return False
|
|
|
+ all_indices = states.target_indices + pred_indices + [index]
|
|
|
+ for n in [3, 2]: # TODO: make it configurable
|
|
|
+ if len(all_indices) >= n and states.ngram_block_count <= 4:
|
|
|
+ if str(all_indices[-n:]) in blocked_ngrams:
|
|
|
+ states.ngram_block_count += 1
|
|
|
+ pred_indices[:] = pred_indices[: -(n - 1)]
|
|
|
+ # decoder_features_out = decoder_features_out[:, : -(n - 1)]
|
|
|
+ return True
|
|
|
+ blocked_ngrams.add(str(all_indices[-n:]))
|
|
|
+ return False
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
|
def policy(self, states: DecoderAgentStates) -> Action:
|
|
|
if len(states.source) == 0:
|
|
@@ -272,6 +318,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
index = None
|
|
|
prob = None
|
|
|
finished = False
|
|
|
+ blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
|
|
|
|
|
|
while (
|
|
|
len(states.target_indices + pred_indices) < self.max_len(states)
|
|
@@ -281,18 +328,17 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
|
|
|
if (
|
|
|
self.no_early_stop
|
|
|
- and prob < self.decision_threshold
|
|
|
- and not states.source_finished
|
|
|
- ):
|
|
|
- break
|
|
|
- if (
|
|
|
- self.no_early_stop
|
|
|
- and index == self.eos_idx
|
|
|
and not states.source_finished
|
|
|
+ and (prob < self.decision_threshold or index == self.eos_idx)
|
|
|
):
|
|
|
if prob == 1.0:
|
|
|
pred_indices = []
|
|
|
break
|
|
|
+ block_ngram = self.maybe_block_ngrams(
|
|
|
+ states, pred_indices, blocked_ngrams, index
|
|
|
+ ) # TODO: add back decoder_features_out processing for unity2
|
|
|
+ if block_ngram:
|
|
|
+ break
|
|
|
if (
|
|
|
finished
|
|
|
or index == self.eos_idx
|
|
@@ -301,11 +347,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
finished = True
|
|
|
break
|
|
|
|
|
|
- if (
|
|
|
- not self.no_early_stop
|
|
|
- and prob < self.decision_threshold
|
|
|
- and not states.source_finished
|
|
|
- ):
|
|
|
+ if prob < self.decision_threshold and not states.source_finished:
|
|
|
break
|
|
|
|
|
|
pred_indices.append(index)
|
|
@@ -322,6 +364,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
finished = finished or len(
|
|
|
states.target_indices + pred_indices
|
|
|
) > self.max_len(states)
|
|
|
+ states.ngram_block_count = 0
|
|
|
return WriteAction(
|
|
|
self.postprocess(states, torch.tensor(pred_indices), finished),
|
|
|
finished=finished,
|