Browse Source

Migrates SC to the new fairseq2 sequence generator API (#171)

Can Balioglu 1 năm trước cách đây
mục cha
commit
f11026d271

+ 6 - 2
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -20,7 +20,6 @@ from fairseq2.data.audio import (
     WaveformToFbankOutput,
 )
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
-from fairseq2.generation import SequenceGeneratorOptions
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
@@ -34,7 +33,12 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
-from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.inference import (
+    BatchedSpeechOutput,
+    Modality,
+    SequenceGeneratorOptions,
+    Translator,
+)
 from seamless_communication.inference.pretssel_generator import PretsselGenerator
 from seamless_communication.models.unity import (
     load_gcmvn_stats,

+ 6 - 2
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -20,7 +20,6 @@ from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.typing import StringLike
-from fairseq2.generation import SequenceGeneratorOptions
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 from tqdm import tqdm
@@ -32,7 +31,12 @@ from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
-from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.inference import (
+    BatchedSpeechOutput,
+    Modality,
+    SequenceGeneratorOptions,
+    Translator,
+)
 from seamless_communication.models.unity import load_unity_text_tokenizer
 
 logging.basicConfig(

+ 2 - 2
src/seamless_communication/cli/m4t/predict/predict.py

@@ -11,9 +11,9 @@ from typing import Tuple
 
 import torch
 import torchaudio
-from fairseq2.generation import NGramRepeatBlockProcessor, SequenceGeneratorOptions
+from fairseq2.generation import NGramRepeatBlockProcessor
 
-from seamless_communication.inference import Translator
+from seamless_communication.inference import SequenceGeneratorOptions, Translator
 
 logging.basicConfig(
     level=logging.INFO,

+ 3 - 3
src/seamless_communication/cli/m4t/train/run_eval.py

@@ -20,7 +20,7 @@ from jiwer import wer  # type: ignore
 
 import seamless_communication.cli.m4t.train.cleaners as cleaners
 from fairseq2.data.audio import WaveformToFbankConverter
-from fairseq2.generation import NGramRepeatBlockProcessor, SequenceGeneratorOptions
+from fairseq2.generation import NGramRepeatBlockProcessor
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from seamless_communication.cli.m4t.train import model as _model
 from seamless_communication.cli.m4t.train import trainer as _trainer
@@ -28,7 +28,7 @@ from seamless_communication.cli.m4t.train.configs import (
     DataLoadingConfig,
     WorkflowParams,
 )
-from seamless_communication.inference.generator import UnitYGenerator
+from seamless_communication.inference import SequenceGeneratorOptions, UnitYGenerator
 from seamless_communication.models.tokenizer import SPMTokenizer
 from seamless_communication.models.unity import (
     UnitTokenizer,
@@ -194,7 +194,7 @@ def translate(
     ngram_filtering: bool = True,
     text_max_len_a: int = 1,
     text_max_len_b: int = 200,
-    unit_max_len_a: int = 1,
+    unit_max_len_a: int = 25,
     unit_max_len_b: int = 50,
 ) -> Tuple[str, Any]:
     """Runs S2T translation. TBD: add S2S"""

+ 3 - 0
src/seamless_communication/inference/__init__.py

@@ -4,6 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from seamless_communication.inference.generator import (
+    SequenceGeneratorOptions as SequenceGeneratorOptions,
+)
 from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
 from seamless_communication.inference.translator import (
     BatchedSpeechOutput as BatchedSpeechOutput,

+ 105 - 47
src/seamless_communication/inference/generator.py

@@ -8,19 +8,19 @@ from dataclasses import dataclass
 from typing import List, Optional, Tuple
 
 import torch
-from fairseq2.data import SequenceData
+from fairseq2.data import SequenceData, StringLike
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
+    BeamSearchSeq2SeqGenerator,
     Seq2SeqGenerator,
-    SequenceGeneratorOptions,
-    SequenceGeneratorOutput,
-    SequenceToTextGenerator,
-    SequenceToTextOutput,
+    SequenceToTextConverter,
+    StepProcessor,
 )
 from fairseq2.nn.padding import (
     PaddingMask,
     apply_padding_mask,
     get_seqs_and_padding_mask,
+    pad_seqs,
 )
 from fairseq2.nn.utils.module import infer_device
 from torch import Tensor
@@ -56,13 +56,34 @@ def remove_consecutive_repeated_ngrams(
     return [token for idx, token in enumerate(sequence) if idx not in drop_idx]
 
 
+@dataclass
+class SequenceGeneratorOptions:
+    """Holds the options to pass to a sequence generator."""
+
+    beam_size: int = 5
+    """The beam size."""
+
+    soft_max_seq_len: Tuple[int, int] = (1, 200)
+    """The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
+    sequence length. The generated sequences (including prefix sequence) will
+    have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
+    ``hard_max_seq_len``."""
+
+    hard_max_seq_len: int = 1024
+    """The hard limit on maximum length of generated sequences."""
+
+    step_processor: Optional[StepProcessor] = None
+    """The processor called at each generation step."""
+
+
 class UnitYGenerator:
     """Generates text translations and speech units from a UnitY model."""
 
     model: UnitYModel
-    s2t_generator: SequenceToTextGenerator
-    t2t_generator: Optional[SequenceToTextGenerator]
+    s2t_converter: SequenceToTextConverter
+    t2t_converter: Optional[SequenceToTextConverter]
     unit_decoder: Optional[UnitTokenDecoder]
+    unit_prefix_indices: Optional[Tensor]
     unit_generator: Optional[Seq2SeqGenerator]
 
     def __init__(
@@ -92,6 +113,9 @@ class UnitYGenerator:
 
         self.model = model
 
+        if text_opts is None:
+            text_opts = SequenceGeneratorOptions()
+
         if model.text_decoder is None:
             raise ValueError(
                 "`UnitYGenerator` requires a text decoder, but the current UnitY model does not have one."
@@ -107,8 +131,21 @@ class UnitYGenerator:
             final_proj=model.final_proj,
             target_vocab_info=model.target_vocab_info,
         )
-        self.s2t_generator = SequenceToTextGenerator(
-            s2t_model, text_tokenizer, target_lang, text_opts
+
+        step_processors = []
+        if text_opts.step_processor is not None:
+            step_processors.append(text_opts.step_processor)
+
+        generator = BeamSearchSeq2SeqGenerator(
+            s2t_model,
+            beam_size=text_opts.beam_size,
+            max_gen_len=text_opts.soft_max_seq_len,
+            max_seq_len=text_opts.hard_max_seq_len,
+            echo_prompt=True,
+            step_processors=step_processors,
+        )
+        self.s2t_converter = SequenceToTextConverter(
+            generator, text_tokenizer, "translation", target_lang
         )
 
         if model.text_encoder is None:
@@ -124,8 +161,16 @@ class UnitYGenerator:
                 final_proj=model.final_proj,
                 target_vocab_info=model.target_vocab_info,
             )
-            self.t2t_generator = SequenceToTextGenerator(
-                t2t_model, text_tokenizer, target_lang, text_opts
+            generator = BeamSearchSeq2SeqGenerator(
+                t2t_model,
+                beam_size=text_opts.beam_size,
+                max_gen_len=text_opts.soft_max_seq_len,
+                max_seq_len=text_opts.hard_max_seq_len,
+                echo_prompt=True,
+                step_processors=step_processors,
+            )
+            self.t2t_converter = SequenceToTextConverter(
+                generator, text_tokenizer, "translation", target_lang
             )
 
         self.unit_generator = None
@@ -143,18 +188,26 @@ class UnitYGenerator:
                 lang=target_lang, device=infer_device(model.t2u_model)
             )
 
+            self.unit_prefix_indices = unit_encoder.prefix_indices
+
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 if unit_opts is None:
                     # Speech sequences are typically much longer than text sequences.
                     unit_opts = SequenceGeneratorOptions(
-                        soft_max_seq_len=(1, 50), hard_max_seq_len=5000
+                        soft_max_seq_len=(25, 50), hard_max_seq_len=5000
                     )
 
-                self.unit_generator = Seq2SeqGenerator(
+                step_processors = []
+                if unit_opts.step_processor is not None:
+                    step_processors.append(unit_opts.step_processor)
+
+                self.unit_generator = BeamSearchSeq2SeqGenerator(
                     self.model.t2u_model,
-                    unit_tokenizer.vocab_info,
-                    unit_encoder.prefix_indices,
-                    unit_opts,
+                    beam_size=unit_opts.beam_size,
+                    max_gen_len=unit_opts.soft_max_seq_len,
+                    max_seq_len=unit_opts.hard_max_seq_len,
+                    echo_prompt=True,
+                    step_processors=step_processors,
                 )
 
     @torch.inference_mode()
@@ -167,7 +220,7 @@ class UnitYGenerator:
         ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
+    ) -> Tuple[List[StringLike], Optional[Tensor]]:
         """
         :param source_seqs:
             The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
@@ -191,25 +244,31 @@ class UnitYGenerator:
         """
 
         if input_modality == "speech":
-            text_output = self.s2t_generator.generate_ex(
+            texts, text_gen_output = self.s2t_converter.batch_convert(
                 source_seqs, source_padding_mask
             )
-        elif input_modality == "text" and self.t2t_generator is not None:
-            text_output = self.t2t_generator.generate_ex(
+        elif input_modality == "text":
+            if self.t2t_converter is None:
+                raise ValueError(
+                    "Please set `use_text_encoder` to `True` in your model config to encode text."
+                )
+            texts, text_gen_output = self.t2t_converter.batch_convert(
                 source_seqs, source_padding_mask
             )
-        elif input_modality == "text" and self.t2t_generator is None:
-            raise ValueError(
-                "Please set use_text_encoder to True in your model config to encode text."
-            )
         else:
             raise ValueError(f"Unsupported input_modality: {input_modality}")
 
         # We skip T2U when we only need to output text.
         if output_modality == "text":
-            return text_output, None
+            return texts, None
+
+        assert self.model.target_vocab_info.pad_idx is not None
 
-        text_seqs, text_padding_mask = text_output.generator_output.collate()
+        text_seq_list = [h[0].seq for h in text_gen_output.hypotheses]
+
+        text_seqs, text_padding_mask = pad_seqs(
+            text_seq_list, self.model.target_vocab_info.pad_idx
+        )
 
         # Manually trim the final EOS token to be consistent with fairseq.
         text_seqs = text_seqs[:, :-1]
@@ -221,8 +280,8 @@ class UnitYGenerator:
         decoder_output, decoder_padding_mask = self.model.decode(
             text_seqs,
             text_padding_mask,
-            text_output.encoder_output,
-            text_output.encoder_padding_mask,
+            text_gen_output.encoder_output,
+            text_gen_output.encoder_padding_mask,
         )
 
         assert self.model.t2u_model is not None
@@ -242,15 +301,25 @@ class UnitYGenerator:
 
         if isinstance(self.model.t2u_model, UnitYT2UModel):
             assert self.unit_generator is not None
-            t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
-                decoder_output, decoder_padding_mask
-            )
+            assert self.unit_prefix_indices is not None
+
+            # (S_pre) -> (N, S_pre)
+            prefix_seqs = self.unit_prefix_indices.expand(decoder_output.size(0), -1)
+
             unit_gen_output = self.unit_generator(
-                t2u_encoder_output,
-                t2u_encoder_padding_mask,
-                source_seq_len=source_seqs.size(1),
+                source_seqs=decoder_output,
+                source_padding_mask=decoder_padding_mask,
+                prompt_seqs=prefix_seqs,
+                prompt_padding_mask=None,
+            )
+
+            assert self.model.t2u_model.target_vocab_info.pad_idx is not None
+
+            unit_seq_list = [h[0].seq for h in unit_gen_output.hypotheses]
+
+            unit_seqs, _ = pad_seqs(
+                unit_seq_list, self.model.t2u_model.target_vocab_info.pad_idx
             )
-            unit_seqs, _ = unit_gen_output.collate()
         else:
             t2u_model_output, decoder_padding_mask, _ = self.model.t2u_model(
                 text_decoder_output=decoder_output,
@@ -273,15 +342,4 @@ class UnitYGenerator:
             arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
             units = torch.tensor(arr)
 
-        unit_output = SequenceToUnitOutput(units, unit_gen_output)
-
-        return text_output, unit_output
-
-
-@dataclass
-class SequenceToUnitOutput:
-    units: Tensor
-    """The generated units."""
-
-    generator_output: Optional[SequenceGeneratorOutput]
-    """The output of the underlying :class:`Seq2SeqGenerator`."""
+        return texts, units

+ 20 - 18
src/seamless_communication/inference/translator.py

@@ -13,18 +13,16 @@ import torch
 import torch.nn as nn
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
-from fairseq2.data import Collater, SequenceData
+from fairseq2.data import Collater, SequenceData, StringLike
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
 from fairseq2.data.text import TextTokenizer
-from fairseq2.data.typing import StringLike
-from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
 from fairseq2.memory import MemoryBlock
 from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 
 from seamless_communication.inference.generator import (
-    SequenceToUnitOutput,
+    SequenceGeneratorOptions,
     UnitYGenerator,
 )
 from seamless_communication.models.unity import (
@@ -171,7 +169,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-    ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
+    ) -> Tuple[List[StringLike], Optional[Tensor]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
             model.t2u_model, UnitYNART2UModel
@@ -228,7 +226,7 @@ class Translator(nn.Module):
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
         prosody_encoder_input: Optional[SequenceData] = None,
-        src_text: Optional[str] = None,
+        src_text: Optional[StringLike] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -318,7 +316,7 @@ class Translator(nn.Module):
                 beam_size=5, soft_max_seq_len=(25, 50)
             )
 
-        text_output, unit_output = self.get_prediction(
+        texts, units = self.get_prediction(
             self.model,
             self.text_tokenizer,
             self.unit_tokenizer,
@@ -339,7 +337,7 @@ class Translator(nn.Module):
                 if src_text is not None:
                     src_texts = [src_text]
                 else:
-                    asr_text, _, = self.predict(
+                    src_texts, _, = self.predict(
                         input=input,
                         task_str=Task.ASR.name,
                         tgt_lang=tgt_lang,
@@ -350,11 +348,16 @@ class Translator(nn.Module):
                         sample_rate=sample_rate,
                         unit_generation_ngram_filtering=unit_generation_ngram_filtering,
                     )
-                    src_texts = [str(asr_text)]
             else:
-                src_texts = [str(input)]
+                assert isinstance(input, str)
 
-            text_output, unit_output = mintox_pipeline(
+                src_texts = [input]
+
+            assert src_lang is not None
+            assert self.unit_tokenizer is not None
+            assert self.bad_word_checker is not None
+
+            texts, units = mintox_pipeline(
                 model=self.model,
                 text_tokenizer=self.text_tokenizer,
                 unit_tokenizer=self.unit_tokenizer,
@@ -365,8 +368,8 @@ class Translator(nn.Module):
                 input_modality=input_modality,
                 output_modality=output_modality,
                 src_texts=src_texts,
-                original_text_out=text_output,
-                original_unit_out=unit_output,
+                original_texts=texts,
+                original_units=units,
                 unit_generation_ngram_filtering=unit_generation_ngram_filtering,
                 text_generation_opts=text_generation_opts,
                 unit_generation_opts=unit_generation_opts,
@@ -376,17 +379,16 @@ class Translator(nn.Module):
             )
 
         if output_modality == Modality.TEXT:
-            return text_output.sentences, None
+            return texts, None
         else:
-            assert unit_output is not None
+            assert units is not None
 
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it
                 # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
-                units = unit_output.units[:, 1:]
+                units = units[:, 1:]
                 duration_prediction = True
             else:
-                units = unit_output.units
                 # Vocoder duration predictions not required since the NAR
                 # T2U model already predicts duration in the units.
                 duration_prediction = False
@@ -417,7 +419,7 @@ class Translator(nn.Module):
                     ].unsqueeze(0)
                     audio_wavs.append(padding_removed_audio_wav)
             return (
-                text_output.sentences,
+                texts,
                 BatchedSpeechOutput(
                     units=speech_units,
                     audio_wavs=audio_wavs,

+ 38 - 81
src/seamless_communication/toxicity/mintox.py

@@ -12,14 +12,11 @@ import torch
 from torch.nn import functional as F
 
 
-from seamless_communication.inference.generator import (
-    SequenceToUnitOutput,
-    SequenceGeneratorOptions,
-)
+from seamless_communication.inference import SequenceGeneratorOptions
 from seamless_communication.toxicity.etox_bad_word_checker import (
     ETOXBadWordChecker,
 )
-from fairseq2.generation import SequenceToTextOutput, BannedSequenceProcessor
+from fairseq2.generation import BannedSequenceProcessor
 from fairseq2.data.text.text_tokenizer import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.typing import Device
@@ -57,76 +54,40 @@ def _extract_bad_words_with_batch_indices(
 
 
 def _replace_with_new_text_output_in_batch(
-    original_text_out: SequenceToTextOutput,
+    original_texts: List[StringLike],
     indices_with_toxicity: List[int],
-    indices_with_toxicity_tensor: Tensor,
-    new_text_output: SequenceToTextOutput,
-    batch_size: int,
+    new_texts: List[StringLike],
 ) -> None:
-    original_text_out.encoder_output[
-        indices_with_toxicity_tensor
-    ] = new_text_output.encoder_output
-    if original_text_out.encoder_padding_mask is not None:
-        assert new_text_output.encoder_padding_mask is not None
-
-        original_text_out.encoder_padding_mask.seq_lens[
-            indices_with_toxicity_tensor
-        ] = new_text_output.encoder_padding_mask.seq_lens
-
-    new_i = 0
-    for original_i in range(batch_size):
-        if (
-            original_i in indices_with_toxicity
-        ):  # indices_with_toxicity is a small list, using list should be fast enough
-            original_text_out.sentences[original_i] = new_text_output.sentences[new_i]
-            original_text_out.generator_output.results[
-                original_i
-            ] = new_text_output.generator_output.results[new_i]
-            new_i += 1
+    new_idx = 0
+    # indices_with_toxicity is a small list, using list should be fast enough.
+    for original_idx in range(len(original_texts)):
+        if original_idx in indices_with_toxicity:
+            original_texts[original_idx] = new_texts[new_idx]
+            new_idx += 1
 
 
 def _replace_with_new_unit_output_in_batch(
     unit_tokenizer: UnitTokenizer,
-    original_unit_out: SequenceToUnitOutput,
-    indices_with_toxicity: List[int],
+    original_units: Tensor,
     indices_with_toxicity_tensor: Tensor,
-    new_unit_output: SequenceToUnitOutput,
-    batch_size: int,
+    new_units: Tensor,
 ) -> None:
-    original_units_length = original_unit_out.units.size(1)
-    new_units_length = new_unit_output.units.size(1)
+    original_units_length = original_units.size(1)
+    new_units_length = new_units.size(1)
     length_diff = abs(new_units_length - original_units_length)
     nb_pads = (0, length_diff)
     pad_idx = unit_tokenizer.vocab_info.pad_idx or 1
     if new_units_length > original_units_length:
         # pad on the original units
-        original_unit_out.units = F.pad(
-            original_unit_out.units,
-            pad=nb_pads,
-            mode="constant",
-            value=pad_idx,
+        original_units = F.pad(
+            original_units, pad=nb_pads, mode="constant", value=pad_idx
         )
     else:
         # pad on the new units
-        new_unit_output.units = F.pad(
-            new_unit_output.units,
-            pad=nb_pads,
-            mode="constant",
-            value=pad_idx,
+        new_units = F.pad(
+            new_units, pad=nb_pads, mode="constant", value=pad_idx
         )
-    original_unit_out.units[indices_with_toxicity_tensor] = new_unit_output.units
-
-    new_i = 0
-    if original_unit_out.generator_output is not None:
-        for original_i in range(batch_size):
-            if (
-                original_i in indices_with_toxicity
-                and new_unit_output.generator_output is not None
-            ):
-                original_unit_out.generator_output.results[
-                    original_i
-                ] = new_unit_output.generator_output.results[new_i]
-                new_i += 1
+    original_units[indices_with_toxicity_tensor] = new_units
 
 
 def mintox_pipeline(
@@ -140,15 +101,15 @@ def mintox_pipeline(
     input_modality: "Modality",
     output_modality: "Modality",
     src_texts: List[StringLike],
-    original_text_out: SequenceToTextOutput,
-    original_unit_out: Optional[SequenceToUnitOutput] = None,
+    original_texts: List[StringLike],
+    original_units: Optional[Tensor] = None,
     unit_generation_ngram_filtering: bool = False,
     text_generation_opts: Optional[SequenceGeneratorOptions] = None,
     unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
     bad_word_checker: ETOXBadWordChecker = None,
     duration_factor: float = 1.0,
     prosody_encoder_input: Optional[SequenceData] = None,
-) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
+) -> Tuple[List[StringLike], Optional[Tensor]]:
     """MinTox: Mitigation at INference time of added TOXicity."""
     from seamless_communication.inference.translator import Modality, Translator
 
@@ -175,7 +136,7 @@ def mintox_pipeline(
 
     bad_words, indices_with_toxicity = _extract_bad_words_with_batch_indices(
         src_texts,
-        original_text_out.sentences,
+        original_texts,
         src_lang,
         tgt_lang,
         bad_word_checker,
@@ -184,9 +145,9 @@ def mintox_pipeline(
     if len(indices_with_toxicity) == 0:
         # if no added toxicity is found, retrun the orignal output
         if output_modality == Modality.TEXT:
-            return original_text_out, None
+            return original_texts, None
         else:
-            return original_text_out, original_unit_out
+            return original_texts, original_units
     else:
         logger.info(
             "TOX src_lang=%s tgt_lang=%s added_tox=%d",
@@ -216,7 +177,7 @@ def mintox_pipeline(
             )
         seqs, padding_mask = get_seqs_and_padding_mask(model_input)
         # redo the prediction
-        new_text_output, new_unit_output = Translator.get_prediction(
+        new_texts, new_units = Translator.get_prediction(
             model=model,
             text_tokenizer=text_tokenizer,
             unit_tokenizer=unit_tokenizer,
@@ -231,34 +192,30 @@ def mintox_pipeline(
             duration_factor=duration_factor,
             prosody_encoder_input=prosody_encoder_input,
         )
-        batch_size = len(original_text_out.sentences)
+        batch_size = len(original_texts)
         if batch_size > 1:
             # reconstruct the text output by updating the original one in place
             _replace_with_new_text_output_in_batch(
-                original_text_out,
-                indices_with_toxicity,
-                indices_with_toxicity_tensor,
-                new_text_output,
-                batch_size,
+                original_texts, indices_with_toxicity, new_texts
             )
-            final_text_output = original_text_out
+            final_texts = original_texts
         else:
-            final_text_output = new_text_output
+            final_texts = new_texts
 
         if output_modality == Modality.TEXT:
-            return final_text_output, None
+            return final_texts, None
         else:
             if batch_size > 1:
+                assert original_units is not None
+                assert new_units is not None
                 # reconstruct the unit output by updating the original one in place
                 _replace_with_new_unit_output_in_batch(
                     unit_tokenizer,
-                    original_unit_out,
-                    indices_with_toxicity,
+                    original_units,
                     indices_with_toxicity_tensor,
-                    new_unit_output,
-                    batch_size,
+                    new_units,
                 )
-                final_unit_out = original_unit_out
+                final_units = original_units
             else:
-                final_unit_out = new_unit_output
-            return final_text_output, final_unit_out
+                final_units = new_units
+            return final_texts, final_units