Ver Fonte

Online feature extractor SimulEval agent. (#107)

* Initial commit adding online_feature_extractor.

* Fix args issue, tokenizer downloaded path.

* Fix mypy issues in dataloader.py

* Fix model names argparsing, set unity use_text_decoder to False.

* Import the dataloader to register it, remove unnecessary args, fix dataloader issues.

* Simplify script and args (#108)

* Fix bug in convert_to_fbank waveform_scale.

* Dataloader next iterator explicit try-catch (#109)

* Next iterator explicit try catch

* Type hints - mypy fixes

* Update src/seamless_communication/agents/dataloader.py

Co-authored-by: Abinesh Ramakrishnan <3632454+ibanesh@users.noreply.github.com>

* Addressing comments - moving to streaming.agents, fixing dtype, device in unity pipeline.

* Remove remnant of refactor.

* Enforce start and end index in iterable dataloader (#117)

* Remove eval file in cli/streaming, create a new dataloaders dir under streaming/

* Simplify __next__ of SimulEvalSpeechToTextDataloader, and set range [start_index, end_index).

* Fix bug in dataloader, self.data_file not declared.

* Fix bug in s2tt dataloader when end_index is not set.

* Dataloader improvements

* moving skip to appropriate position

---------

Co-authored-by: Abinesh Ramakrishnan <3632454+ibanesh@users.noreply.github.com>
Kaushik Ram Sadagopan há 1 ano atrás
pai
commit
521a374213

+ 1 - 0
setup.py

@@ -25,6 +25,7 @@ setup(
         "fairseq2==0.2.*",
         "librosa",
         "openai-whisper",
+        "simuleval",
         "soundfile",
         "torchaudio",
         "tqdm",

+ 5 - 0
src/seamless_communication/cli/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 16 - 0
src/seamless_communication/cli/eval_utils/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from seamless_communication.cli.eval_utils.compute_metrics import (
+    compute_quality_metrics as compute_quality_metrics,
+)
+from seamless_communication.cli.eval_utils.compute_metrics import (
+    get_tokenizer as get_tokenizer,
+)
+from seamless_communication.cli.eval_utils.lang_mapping import (
+    LANG2_LANG3 as LANG2_LANG3,
+)

+ 2 - 6
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -7,24 +7,20 @@
 import argparse
 import contextlib
 import logging
-import subprocess
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Optional
 
 import torch
 import torchaudio
-from fairseq2.assets import asset_store
-from fairseq2.data import Collater, CString, DataPipeline, FileMapper
+from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import (
     AudioDecoder,
     WaveformToFbankConverter,
     WaveformToFbankOutput,
 )
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
-from fairseq2.data.typing import PathLike, StringLike
 from fairseq2.generation import SequenceGeneratorOptions
-from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor

+ 4 - 1
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -25,7 +25,7 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from tqdm import tqdm
 
-from seamless_communication.cli.eval_utils.compute_metrics import (
+from seamless_communication.cli.eval_utils import (
     compute_quality_metrics,
 )
 from seamless_communication.cli.m4t.predict import (
@@ -368,6 +368,9 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
             "Please provide required arguments for evaluation - data_file, task, tgt_lang"
         )
 
+    if not Path(args.data_file).exists():
+        raise ValueError(f"Invalid data_file to be evaluated: {args.data_file}")
+
     input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
 
     if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():

+ 5 - 0
src/seamless_communication/streaming/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 9 - 0
src/seamless_communication/streaming/agents/__init__.py

@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.streaming.agents.mma_m4t_s2t import (
+    MonotonicM4TS2TSPMAgent as MonotonicM4TS2TSPMAgent,
+)

+ 18 - 0
src/seamless_communication/streaming/agents/mixins.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Mixins for fairseq2 simuleval agents
+"""
+
+
+class EarlyStoppingMixin:
+    def reset_early(self) -> None:
+        """
+        Implement to override for different behavior on a reset that
+        happens before EOS
+        """
+        raise NotImplementedError()

+ 16 - 0
src/seamless_communication/streaming/agents/mma_m4t_s2t.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.streaming.agents.online_feature_extractor import (
+    OnlineFeatureExtractorAgent,
+)
+from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
+from simuleval.utils import entrypoint
+
+
+@entrypoint
+class MonotonicM4TS2TSPMAgent(UnitYAgentPipeline):
+    pipeline = [OnlineFeatureExtractorAgent]

+ 152 - 0
src/seamless_communication/streaming/agents/online_feature_extractor.py

@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import math
+import torch
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, List
+
+from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
+
+from simuleval.agents import SpeechToSpeechAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.agents.states import AgentStates
+from simuleval.data.segments import Segment, SpeechSegment
+
+
+SHIFT_SIZE = 10
+WINDOW_SIZE = 25
+SAMPLE_RATE = 16000
+FEATURE_DIM = 80
+
+
+class FeatureStates(AgentStates):
+    def reset(self) -> None:
+        super().reset()
+        self.previous_residual_samples: List[float] = []
+        self.tgt_lang = None
+
+    def update_source(self, segment: Segment) -> None:
+        """
+        Update states from input segment
+        Args:
+            segment (~simuleval.agents.segments.Segment): input segment
+        """
+        self.source_finished = segment.finished
+        if self.tgt_lang is None and segment.tgt_lang is not None:
+            self.tgt_lang = segment.tgt_lang
+        if not segment.is_empty:
+            self.source.append(segment.content)
+
+
+class OnlineFeatureExtractorAgent(SpeechToSpeechAgent):
+    """
+    Extract speech features on the fly.
+    """
+
+    def __init__(self, args: Namespace):
+        super().__init__(args)
+        self.shift_size = args.shift_size
+        self.window_size = args.window_size
+        assert self.window_size >= self.shift_size
+
+        self.sample_rate = args.sample_rate
+        self.feature_dim = args.feature_dim
+        self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
+        self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
+        self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
+
+        self.convert_to_fbank = WaveformToFbankConverter(
+            num_mel_bins=80,
+            waveform_scale=2**15 if args.denormalize else 1.0,
+            standardize=False,
+            device=args.device,
+            dtype=args.dtype,
+        )
+
+    def build_states(self) -> FeatureStates:
+        return FeatureStates()
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--shift-size",
+            type=int,
+            default=SHIFT_SIZE,
+            help="Shift size of feature extraction window.",
+        )
+        parser.add_argument(
+            "--window-size",
+            type=int,
+            default=WINDOW_SIZE,
+            help="Window size of feature extraction window.",
+        )
+        parser.add_argument(
+            "--feature-dim",
+            type=int,
+            default=FEATURE_DIM,
+            help="Acoustic feature dimension.",
+        )
+        parser.add_argument(
+            "--denormalize",
+            action="store_true",
+            help="denormalized to 16-bit signed integers",
+        )
+
+    def policy(self, states: FeatureStates) -> Action:
+        if len(states.source) == 0:
+            if states.source_finished:
+                return WriteAction({}, finished=states.source_finished)
+            else:
+                return ReadAction()
+
+        samples = states.source[-1]
+
+        samples = states.previous_residual_samples + samples
+        if len(samples) < self.num_samples_per_window:
+            states.previous_residual_samples = samples
+            return ReadAction()
+
+        # num_frames is the number of frames from the new segment
+        num_frames = math.floor(
+            (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size))
+            / self.num_samples_per_shift
+        )
+
+        # the number of frames used for feature extraction
+        # including some part of the previous segment
+        effective_num_samples = int(
+            num_frames * self.len_ms_to_samples(self.shift_size)
+            + self.len_ms_to_samples(self.window_size - self.shift_size)
+        )
+
+        input_samples = samples[:effective_num_samples]
+        states.previous_residual_samples = samples[
+            num_frames * self.num_samples_per_shift :
+        ]
+
+        data: WaveformToFbankInput = {
+            "waveform": torch.tensor(input_samples).unsqueeze(0),
+            "sample_rate": self.sample_rate,
+        }
+
+        output = self.convert_to_fbank(data)["fbank"]
+
+        return WriteAction(
+            SpeechSegment(
+                content=output,
+                tgt_lang=states.tgt_lang,
+                finished=states.source_finished,
+            ),
+            finished=states.source_finished,
+        )
+
+    @classmethod
+    def from_args(cls, args: Any, **kwargs: Any) -> OnlineFeatureExtractorAgent:
+        return cls(args)

+ 169 - 0
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+from simuleval.agents.agent import GenericAgent
+
+import logging
+import torch
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, List, Optional
+
+from fairseq2.assets import asset_store
+from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
+from seamless_communication.inference.translator import Modality, Translator
+from seamless_communication.models.unity import (
+    load_unity_config,
+    load_unity_model,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+from seamless_communication.models.monotonic_decoder import load_monotonic_decoder_model
+
+from simuleval.agents import AgentPipeline, AgentStates
+from simuleval.data.segments import Segment
+
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def maybe_reset_states(states: Optional[List[Optional[AgentStates]]]) -> None:
+    for s in states:
+        if s is not None:
+            if isinstance(s, EarlyStoppingMixin):
+                s.reset_early()
+            else:
+                s.reset()
+
+
+class UnitYPipelineMixin:
+    """
+    Mixin for fairseq pipeline which works with both AgentPipeline
+    and TreeAgentPipeline
+    """
+
+    @classmethod
+    def add_args(cls, parser: ArgumentParser) -> None:
+        super().add_args(parser)
+        parser.add_argument("--task", type=str, help="Task type")
+        parser.add_argument(
+            "--unity-model-name",
+            type=str,
+            help="Unity model name.",
+            default="unity_sans_decoder",
+        )
+        parser.add_argument(
+            "--monotonic-decoder-model-name",
+            type=str,
+            help="Monotonic decoder model name.",
+            default="monotonic_decoder",
+        )
+        parser.add_argument(
+            "--sample-rate",
+            default=16000,
+            type=float,
+        )
+        parser.add_argument(
+            "--dtype",
+            default="fp16",
+            type=str,
+        )
+
+    @classmethod
+    def from_args(cls, args: Any) -> UnitYPipelineMixin:
+        return cls(args)
+
+
+class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
+    pipeline: List[GenericAgent] = []
+
+    def __init__(self, args: Namespace):
+
+        if not torch.cuda.is_available() and "cuda" in args.device:
+            raise ValueError("CUDA not available, use CPU.")
+
+        args.device = torch.device(args.device)
+        if (args.fp16 or args.dtype == "fp16") and args.device != torch.device("cpu"):
+            args.dtype = torch.float16
+        else:
+            args.dtype = torch.float32
+
+        input_modality, output_modality = Translator.get_modalities_from_task_str(
+            args.task
+        )
+
+        if input_modality != Modality.SPEECH:
+            raise ValueError("`UnitYAgentPipeline` only supports speech input.")
+
+        unity_config = load_unity_config(args.unity_model_name)
+        unity_config.use_text_decoder = False
+        unity_config.use_text_encoder = False
+
+        text_tokenizer = load_unity_text_tokenizer(args.unity_model_name)
+
+        # Skip loading the T2U model.
+        if output_modality == Modality.TEXT:
+            unity_config.t2u_config = None
+            unit_tokenizer = None
+        else:
+            unit_tokenizer = load_unity_unit_tokenizer(args.unity_model_name)
+
+        asset_card = asset_store.retrieve_card(args.unity_model_name)
+        asset_card.field("model_config").set(unity_config)
+
+        logger.info(
+            f"Loading the UnitY model: {args.unity_model_name} on device={args.device}, dtype={args.dtype}"
+        )
+        unity_model = load_unity_model(asset_card, device=args.device, dtype=args.dtype)
+        unity_model.eval()
+
+        logger.info(
+            f"Loading the Monotonic Decoder model: {args.monotonic_decoder_model_name} on device={args.device}, dtype={args.dtype}"
+        )
+        monotonic_decoder_model = load_monotonic_decoder_model(
+            args.monotonic_decoder_model_name, device=args.device, dtype=args.dtype
+        )
+        monotonic_decoder_model.eval()
+
+        module_list = []
+        for p in self.pipeline:
+            module_list.append(
+                p.from_args(
+                    args,
+                    unity_model=unity_model,
+                    unity_config=unity_config,
+                    monotonic_decoder_model=monotonic_decoder_model,
+                    text_tokenizer=text_tokenizer,
+                    unit_tokenizer=unit_tokenizer,
+                )
+            )
+
+        super().__init__(module_list)
+
+    def pop(self, states: Optional[List[Optional[AgentStates]]] = None) -> Segment:
+        output_segment = super().pop(states)
+        if states is None:
+            # Not stateless
+            first_states = self.module_list[0].states
+        else:
+            assert len(states) == len(self.module_list)
+            first_states = states[0]
+
+        if not first_states.source_finished and output_segment.finished:
+            # An early stop.
+            # The temporary solution is to start over
+            if states is not None:
+                maybe_reset_states(states)
+            else:
+                self.reset()
+            output_segment.finished = False
+
+        return output_segment

+ 9 - 0
src/seamless_communication/streaming/dataloaders/__init__.py

@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.streaming.dataloaders.s2tt import (
+    SimulEvalSpeechToTextDataloader as SimulEvalSpeechToTextDataloader,
+)

+ 169 - 0
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import subprocess
+from argparse import ArgumentParser, Namespace
+from dataclasses import dataclass
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from fairseq2.data.audio import AudioDecoder
+from fairseq2.data.data_pipeline import Collater, DataPipeline, FileMapper
+from fairseq2.data.text.converters import StrSplitter
+from fairseq2.data.text.text_reader import read_text
+from simuleval.data.dataloader import register_dataloader
+from simuleval.data.dataloader.dataloader import IterableDataloader
+from simuleval.data.dataloader.s2t_dataloader import SpeechToTextDataloader
+
+
+@dataclass
+class SoundFileInfo:
+    samplerate: float
+    path: str
+
+    def __repr__(self) -> str:
+        return "\n".join([f"samplerate: {str(self.samplerate)}", f"path: {self.path}"])
+
+
+def count_lines(filename: Path) -> int:
+    result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
+    return int(result.stdout.decode().split()[0]) - 1
+
+
+@register_dataloader("fairseq2_s2tt")
+class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):
+    def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
+        self.args = args
+        self.data_file: Path = Path(getattr(self.args, "data_file", ""))
+        if not self.data_file.exists():
+            raise ValueError(f"data_file: {self.data_file} does not exist.")
+        self.start_index: int = getattr(self.args, "start_index", 0)
+        self.end_index: int = getattr(self.args, "end_index", -1)
+        self.data_pipeline = data_pipeline
+        self.data_itr = iter(self.data_pipeline)
+        self.cur_index = self.start_index - 1
+        self.item = None
+
+    def __iter__(self) -> SimulEvalSpeechToTextDataloader:
+        return self
+
+    def __next__(self) -> SimulEvalSpeechToTextDataloader:
+        if self.cur_index >= self.end_index - 1:
+            raise StopIteration
+        self.item = next(self.data_itr)
+        self.cur_index += 1
+        return self
+
+    def reset(self) -> None:
+        self.cur_index = 0
+        self.data_pipeline.reset()
+
+    def __len__(self) -> int:
+        if self.end_index > 0:
+            return self.end_index - self.start_index
+        self.end_index = count_lines(self.data_file)
+        return self.end_index - self.start_index
+
+    def get_source(self, index: Optional[int] = None) -> List[float]:
+        source: List[float] = (
+            self.item["audio"]["data"]["waveform"]["seqs"].squeeze().tolist()
+        )
+        return source
+
+    def get_target(self, index: Optional[int] = None) -> str:
+        return str(self.item[self.args.ref_field][0])
+
+    def get_tgt_lang(self, index: Optional[int] = None) -> Optional[str]:
+        if self.args.tgt_lang:
+            tgt_lang: str = self.args.tgt_lang
+            return tgt_lang
+
+        tgt_lang = self.item.get("tgt_lang")
+        return str(tgt_lang[0]) if tgt_lang else None
+
+    def get_source_audio_info(self, index: Optional[int] = None) -> SoundFileInfo:
+        samplerate = self.item["audio"]["data"]["sample_rate"][0]
+        path = f'{self.args.audio_root_dir}/{str(self.item["audio"]["path"][0])}'
+        return SoundFileInfo(samplerate, path)
+
+    def get_source_audio_path(self, index: Optional[int] = None) -> str:
+        return str(self.item["audio"]["path"][0])
+
+    @classmethod
+    def from_args(cls, args: Namespace) -> SimulEvalSpeechToTextDataloader:
+        with open(args.data_file, "r") as f:
+            header = f.readline().strip("\n").split("\t")
+
+        split_tsv = StrSplitter(names=header)
+
+        start_index: int = getattr(args, "start_index", 0)
+
+        pipeline_builder = (
+            read_text(args.data_file, rtrim=True).skip(1 + start_index).map(split_tsv)
+        )
+
+        map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
+
+        pipeline_builder.map(map_file, selector="audio")
+
+        device = getattr(args, "device", None)
+        assert device is not None
+
+        decode_audio = AudioDecoder(dtype=torch.float32, device=torch.device(device))
+
+        pipeline_builder.map(
+            decode_audio,
+            selector="audio.data",
+        )
+
+        pipeline_builder.map(
+            lambda x: F.layer_norm(x, x.shape),
+            selector="audio.data.waveform",
+        )
+
+        collate = Collater(pad_value=0, pad_to_multiple=1)
+
+        pipeline_builder.map(collate)
+
+        pipeline_builder.prefetch(1)
+
+        data_pipeline = pipeline_builder.and_return()
+
+        return cls(data_pipeline, args)
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--data-file",
+            type=str,
+            required=True,
+            help="Data file (.tsv) to be evaluated.",
+        )
+        parser.add_argument(
+            "--audio-root-dir",
+            type=str,
+            help="Root directory for the audio filenames in the data file.",
+            default="",
+        )
+        parser.add_argument(
+            "--ref-field",
+            type=str,
+            help="Reference target text field to compute the BLEU score against.",
+            default="tgt_text",
+        )
+        parser.add_argument(
+            "--source-segment-size",
+            type=int,
+            default=1,
+            help="Source segment size, For text the unit is # token, for speech is ms",
+        )
+        parser.add_argument(
+            "--tgt-lang", type=str, help="Target language to translate/transcribe into."
+        )