|
|
@@ -34,24 +34,12 @@ SPACE = "▁"
|
|
|
|
|
|
|
|
|
class TagManager:
|
|
|
- def __init__(self, text_tokenizer: TextTokenizer):
|
|
|
- self.vocab_info: VocabularyInfo = text_tokenizer.vocab_info
|
|
|
- token_encoder = text_tokenizer.create_encoder(mode="target")
|
|
|
- self.prefix_len: int = (
|
|
|
- token_encoder.prefix_indices.shape[0]
|
|
|
- if token_encoder.prefix_indices is not None
|
|
|
- else 0
|
|
|
- )
|
|
|
- self.suffix_len: int = (
|
|
|
- token_encoder.suffix_indices.shape[0]
|
|
|
- if token_encoder.suffix_indices is not None
|
|
|
- else 0
|
|
|
- )
|
|
|
+ def __init__(self, vocab_info: VocabularyInfo):
|
|
|
+ self.vocab_info = vocab_info
|
|
|
|
|
|
def preprocess_text_seqs(self, text_seqs: Tensor) -> Tensor:
|
|
|
- """Remove the prefix, suffix tokens."""
|
|
|
- seq_len = text_seqs.shape[1]
|
|
|
- text_seqs = text_seqs[:, self.prefix_len : seq_len - self.suffix_len]
|
|
|
+ # Remove EOS, lang tokens as per NLLB "target" tokenizer mode.
|
|
|
+ text_seqs = text_seqs[:, 2:]
|
|
|
assert self.vocab_info.pad_idx is not None
|
|
|
text_seqs.masked_fill_(
|
|
|
text_seqs == self.vocab_info.eos_idx, self.vocab_info.pad_idx
|
|
|
@@ -59,16 +47,10 @@ class TagManager:
|
|
|
return text_seqs
|
|
|
|
|
|
def postprocess_dur_or_len(self, dur_or_len: Tensor) -> Tensor:
|
|
|
- """Add back 0s in place of the prefix, suffix tokens."""
|
|
|
N = dur_or_len.shape[0]
|
|
|
-
|
|
|
- # prefix = dur_or_len.new_zeros((N, self.prefix_len))
|
|
|
- # suffix = dur_or_len.new_zeros((N, self.suffix_len))
|
|
|
-
|
|
|
- prefix = dur_or_len.new_zeros((N, 1))
|
|
|
- suffix = dur_or_len.new_zeros((N, 1))
|
|
|
-
|
|
|
- dur_or_len = torch.cat([prefix, dur_or_len, suffix], dim=1)
|
|
|
+ pad_zero = dur_or_len.new_zeros((N, 1))
|
|
|
+ # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode.
|
|
|
+ dur_or_len = torch.cat([pad_zero, dur_or_len, pad_zero], dim=1)
|
|
|
return dur_or_len
|
|
|
|
|
|
|
|
|
@@ -109,7 +91,7 @@ class NARDecoderFrontend(Module):
|
|
|
self.embed_char = embed_char
|
|
|
self.text_tokenizer = text_tokenizer
|
|
|
self.char_tokenizer = char_tokenizer
|
|
|
- self.tag_manager = TagManager(text_tokenizer)
|
|
|
+ self.tag_manager = TagManager(text_tokenizer.vocab_info)
|
|
|
|
|
|
self.unk_idx = self.text_tokenizer.vocab_info.unk_idx
|
|
|
self.pad_idx = self.text_tokenizer.vocab_info.pad_idx
|