|
@@ -16,9 +16,9 @@ from seamless_communication.models.monotonic_decoder import (
|
|
|
MonotonicDecoderConfig,
|
|
|
MonotonicDecoderModel,
|
|
|
)
|
|
|
+from seamless_communication.streaming.agents.common import AgentStates
|
|
|
from simuleval.agents import GenericAgent
|
|
|
from simuleval.agents.actions import Action, ReadAction, WriteAction
|
|
|
-from seamless_communication.streaming.agents.common import AgentStates
|
|
|
from simuleval.data.segments import Segment, TextSegment
|
|
|
from torch import Tensor
|
|
|
|
|
@@ -78,16 +78,8 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
|
self.device = args.device
|
|
|
self.dtype = args.dtype
|
|
|
self.eos_idx = text_tokenizer.vocab_info.eos_idx
|
|
|
- if (
|
|
|
- hasattr(args, "tgt_lang")
|
|
|
- and hasattr(args, "prefix_tgt_lang")
|
|
|
- and args.tgt_lang is not None
|
|
|
- and args.prefix_tgt_lang is not None
|
|
|
- ):
|
|
|
- assert args.tgt_lang == args.prefix_tgt_lang
|
|
|
- tgt_lang = getattr(args, "tgt_lang", None) or getattr(
|
|
|
- args, "prefix_tgt_lang", None
|
|
|
- )
|
|
|
+
|
|
|
+ tgt_lang = getattr(args, "tgt_lang", None)
|
|
|
assert tgt_lang is not None
|
|
|
self.token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
|
|
|
prefix_indices = self.token_encoder.prefix_indices
|
|
@@ -132,7 +124,7 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
|
default=False,
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--prefix-tgt-lang",
|
|
|
+ "--tgt-lang",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
)
|
|
@@ -265,7 +257,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
tgt_lang=states.tgt_lang,
|
|
|
)
|
|
|
|
|
|
- def get_blocked_ngrams(self, target_indices: List[int]):
|
|
|
+ def get_blocked_ngrams(self, target_indices: List[int]) -> Optional[Set[str]]:
|
|
|
# TODO: make it configurable and use itertools
|
|
|
if not self.block_ngrams:
|
|
|
return None
|
|
@@ -285,25 +277,26 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
self,
|
|
|
states: DecoderAgentStates,
|
|
|
pred_indices: List[int],
|
|
|
- blocked_ngrams: Set[int],
|
|
|
+ decoder_features_out: Tensor,
|
|
|
+ blocked_ngrams: Set[str],
|
|
|
index: int,
|
|
|
- ):
|
|
|
+ ) -> bool:
|
|
|
"""
|
|
|
This check is used to force a READ decision when n-gram repeat
|
|
|
happens before source_finished
|
|
|
"""
|
|
|
if not self.block_ngrams or states.source_finished:
|
|
|
- return False
|
|
|
+ return False, decoder_features_out
|
|
|
all_indices = states.target_indices + pred_indices + [index]
|
|
|
for n in [3, 2]: # TODO: make it configurable
|
|
|
if len(all_indices) >= n and states.ngram_block_count <= 4:
|
|
|
if str(all_indices[-n:]) in blocked_ngrams:
|
|
|
states.ngram_block_count += 1
|
|
|
pred_indices[:] = pred_indices[: -(n - 1)]
|
|
|
- # decoder_features_out = decoder_features_out[:, : -(n - 1)]
|
|
|
- return True
|
|
|
+ decoder_features_out = decoder_features_out[:, : -(n - 1)]
|
|
|
+ return True, decoder_features_out
|
|
|
blocked_ngrams.add(str(all_indices[-n:]))
|
|
|
- return False
|
|
|
+ return False, decoder_features_out
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def policy(self, states: DecoderAgentStates) -> Action:
|
|
@@ -350,9 +343,9 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
if prob == 1.0:
|
|
|
pred_indices = []
|
|
|
break
|
|
|
- block_ngram = self.maybe_block_ngrams(
|
|
|
- states, pred_indices, blocked_ngrams, index
|
|
|
- ) # TODO: add back decoder_features_out processing for unity2
|
|
|
+ block_ngram, decoder_features_out = self.maybe_block_ngrams(
|
|
|
+ states, pred_indices, decoder_features_out, blocked_ngrams, index
|
|
|
+ )
|
|
|
if block_ngram:
|
|
|
break
|
|
|
if (
|