|
@@ -6,7 +6,8 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
from argparse import ArgumentParser, Namespace
|
|
-from typing import Any, Dict, List, Set, Tuple
|
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
+from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from fairseq2.models.nllb.tokenizer import NllbTokenizer
|
|
from fairseq2.models.nllb.tokenizer import NllbTokenizer
|
|
@@ -77,13 +78,19 @@ class OnlineTextDecoderAgent(GenericAgent):
|
|
self.device = args.device
|
|
self.device = args.device
|
|
self.dtype = args.dtype
|
|
self.dtype = args.dtype
|
|
self.eos_idx = text_tokenizer.vocab_info.eos_idx
|
|
self.eos_idx = text_tokenizer.vocab_info.eos_idx
|
|
- if getattr(args, "tgt_lang", None) and getattr(args, "prefix_tgt_lang", None):
|
|
|
|
|
|
+ if (
|
|
|
|
+ hasattr(args, "tgt_lang")
|
|
|
|
+ and hasattr(args, "prefix_tgt_lang")
|
|
|
|
+ and args.tgt_lang is not None
|
|
|
|
+ and args.prefix_tgt_lang is not None
|
|
|
|
+ ):
|
|
assert args.tgt_lang == args.prefix_tgt_lang
|
|
assert args.tgt_lang == args.prefix_tgt_lang
|
|
tgt_lang = getattr(args, "tgt_lang", None) or getattr(
|
|
tgt_lang = getattr(args, "tgt_lang", None) or getattr(
|
|
args, "prefix_tgt_lang", None
|
|
args, "prefix_tgt_lang", None
|
|
)
|
|
)
|
|
- token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
|
|
|
|
- prefix_indices = token_encoder.prefix_indices
|
|
|
|
|
|
+ assert tgt_lang is not None
|
|
|
|
+ self.token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
|
|
|
|
+ prefix_indices = self.token_encoder.prefix_indices
|
|
assert prefix_indices is not None
|
|
assert prefix_indices is not None
|
|
self.prefix_indices: List[int] = prefix_indices.tolist()
|
|
self.prefix_indices: List[int] = prefix_indices.tolist()
|
|
|
|
|
|
@@ -218,8 +225,6 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
pred_indices[-1:], device=self.device, dtype=torch.int64
|
|
pred_indices[-1:], device=self.device, dtype=torch.int64
|
|
).unsqueeze(0)
|
|
).unsqueeze(0)
|
|
|
|
|
|
- torch.cuda.empty_cache()
|
|
|
|
-
|
|
|
|
encoder_output = states.source
|
|
encoder_output = states.source
|
|
decoder_output, _, p_choose = self.model.decode(
|
|
decoder_output, _, p_choose = self.model.decode(
|
|
target_input, None, encoder_output, None, state_bag=self.state_bag
|
|
target_input, None, encoder_output, None, state_bag=self.state_bag
|
|
@@ -246,7 +251,11 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
return index, prob, decoder_output
|
|
return index, prob, decoder_output
|
|
|
|
|
|
def postprocess(
|
|
def postprocess(
|
|
- self, states: DecoderAgentStates, pred_indices: Tensor, finished: bool
|
|
|
|
|
|
+ self,
|
|
|
|
+ states: DecoderAgentStates,
|
|
|
|
+ pred_indices: List[int],
|
|
|
|
+ finished: bool,
|
|
|
|
+ decoder_features_out: Optional[Tensor] = None,
|
|
) -> TextSegment:
|
|
) -> TextSegment:
|
|
return TextSegment(
|
|
return TextSegment(
|
|
content=" ".join(
|
|
content=" ".join(
|
|
@@ -319,12 +328,19 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
prob = None
|
|
prob = None
|
|
finished = False
|
|
finished = False
|
|
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
|
|
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
|
|
|
|
+ decoder_features_out = None
|
|
|
|
|
|
while (
|
|
while (
|
|
len(states.target_indices + pred_indices) < self.max_len(states)
|
|
len(states.target_indices + pred_indices) < self.max_len(states)
|
|
and len(pred_indices) < self.max_consecutive_writes
|
|
and len(pred_indices) < self.max_consecutive_writes
|
|
):
|
|
):
|
|
- index, prob, _ = self.run_decoder(states, pred_indices)
|
|
|
|
|
|
+ index, prob, decoder_features = self.run_decoder(states, pred_indices)
|
|
|
|
+
|
|
|
|
+ if decoder_features_out is None:
|
|
|
|
+ decoder_features_out = decoder_features.new(0)
|
|
|
|
+ decoder_features_out = torch.cat(
|
|
|
|
+ [decoder_features_out, decoder_features], dim=1
|
|
|
|
+ )
|
|
|
|
|
|
if (
|
|
if (
|
|
self.no_early_stop
|
|
self.no_early_stop
|
|
@@ -366,7 +382,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
) > self.max_len(states)
|
|
) > self.max_len(states)
|
|
states.ngram_block_count = 0
|
|
states.ngram_block_count = 0
|
|
return WriteAction(
|
|
return WriteAction(
|
|
- self.postprocess(states, torch.tensor(pred_indices), finished),
|
|
|
|
|
|
+ self.postprocess(states, pred_indices, finished, decoder_features_out),
|
|
finished=finished,
|
|
finished=finished,
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
@@ -375,3 +391,56 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
|
|
|
|
|
|
class MMASpeechToTextDecoderAgent(MMATextDecoderAgent):
|
|
class MMASpeechToTextDecoderAgent(MMATextDecoderAgent):
|
|
source_type = "speech"
|
|
source_type = "speech"
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class UnitYTextDecoderOutput:
|
|
|
|
+ decoder_features: Tensor
|
|
|
|
+ tokens: List[str]
|
|
|
|
+ target_indices: Optional[Tensor] = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class UnitYMMATextDecoderAgent(MMASpeechToTextDecoderAgent):
|
|
|
|
+ """
|
|
|
|
+ MMA UnitY text decoder agent which just prepares the decoder
|
|
|
|
+ output for the downstream agent.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def postprocess(
|
|
|
|
+ self,
|
|
|
|
+ states: DecoderAgentStates,
|
|
|
|
+ pred_indices: List[int],
|
|
|
|
+ finished: bool,
|
|
|
|
+ decoder_features_out: Optional[Tensor] = None,
|
|
|
|
+ ) -> TextSegment:
|
|
|
|
+ tokens: List[str] = [
|
|
|
|
+ self.text_tokenizer.model.index_to_token(idx) for idx in pred_indices
|
|
|
|
+ ]
|
|
|
|
+ assert decoder_features_out is not None
|
|
|
|
+ token_list = self.prefix_indices + states.target_indices
|
|
|
|
+ if (
|
|
|
|
+ len(pred_indices) > 0
|
|
|
|
+ and pred_indices[-1] != self.text_tokenizer.vocab_info.eos_idx
|
|
|
|
+ ):
|
|
|
|
+ # Append "," to make speech smooth
|
|
|
|
+ # TODO: a temporary solution.
|
|
|
|
+ ending_token_index = self.text_tokenizer.model.token_to_index(",")
|
|
|
|
+ token_list.append(ending_token_index)
|
|
|
|
+ self.state_bag.increment_step()
|
|
|
|
+
|
|
|
|
+ _, _, decoder_features = self.run_decoder(states, [ending_token_index])
|
|
|
|
+ decoder_features_out = torch.cat(
|
|
|
|
+ [decoder_features_out, decoder_features], dim=1
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ target_input = torch.tensor(
|
|
|
|
+ token_list,
|
|
|
|
+ device=self.device,
|
|
|
|
+ dtype=torch.int64,
|
|
|
|
+ ).unsqueeze(0)
|
|
|
|
+
|
|
|
|
+ return TextSegment(
|
|
|
|
+ content=UnitYTextDecoderOutput(decoder_features_out, tokens, target_input),
|
|
|
|
+ finished=finished,
|
|
|
|
+ tgt_lang=states.tgt_lang,
|
|
|
|
+ )
|