123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # MIT_LICENSE file in the root directory of this source tree.
- import os
- from typing import Any, List, Tuple, Union
- import numpy
- import torch
- import torch.nn as nn
- import torchaudio
- from fairseq2.typing import DataType, Device
- from fairseq2.data.typing import StringLike
- from torch import Tensor
- from seamless_communication.models.aligner.loader import load_unity2_alignment_model
- from seamless_communication.models.unit_extractor import UnitExtractor
- try:
- import matplotlib.pyplot as plt
- matplotlib_available = True
- except ImportError:
- matplotlib_available = False
- class AlignmentExtractor(nn.Module):
- def __init__(
- self,
- aligner_model_name_or_card: str,
- unit_extractor_model_name_or_card: Union[Any, str] = None,
- unit_extractor_output_layer: Union[Any, int] = None,
- unit_extractor_kmeans_model_uri: Union[Any, str] = None,
- device: Device = Device("cpu"),
- dtype: DataType = torch.float32,
- ):
- super().__init__()
- self.device = device
- self.dtype = dtype
- if self.dtype == torch.float16 and self.device == Device("cpu"):
- raise RuntimeError("FP16 only works on GPU, set args accordingly")
- self.alignment_model = load_unity2_alignment_model(
- aligner_model_name_or_card, device=self.device, dtype=self.dtype
- )
- self.alignment_model.eval()
- self.unit_extractor = None
- self.unit_extractor_output_layer = 0
- if unit_extractor_model_name_or_card is not None:
- self.unit_extractor = UnitExtractor(
- unit_extractor_model_name_or_card,
- unit_extractor_kmeans_model_uri,
- device=device,
- dtype=dtype,
- )
- self.unit_extractor_output_layer = unit_extractor_output_layer
- def load_audio(
- self, audio_path: str, sampling_rate: int = 16_000
- ) -> Tuple[Tensor, int]:
- assert os.path.exists(audio_path)
- audio, rate = torchaudio.load(audio_path)
- if rate != sampling_rate:
- audio = torchaudio.functional.resample(audio, rate, sampling_rate)
- rate = sampling_rate
- return audio, rate
- def prepare_audio(self, audio: Union[str, Tensor]) -> Tensor:
- # TODO: switch to fairseq2 data pipeline once it supports resampling
- if isinstance(audio, str):
- audio, _ = self.load_audio(audio, sampling_rate=16_000)
- if audio.ndim > 1:
- # averaging over channels
- assert audio.size(0) < audio.size(
- 1
- ), "Expected [Channel,Time] shape, but Channel > Time"
- audio = audio.mean(0)
- assert (
- audio.ndim == 1
- ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
- audio = audio.to(self.device, self.dtype)
- return audio
- def extract_units(self, audio: Tensor) -> Tensor:
- assert isinstance(
- self.unit_extractor, UnitExtractor
- ), "Unit extractor is required to get units from audio tensor"
- units = self.unit_extractor.predict(audio, self.unit_extractor_output_layer - 1)
- return units
- @torch.inference_mode()
- def extract_alignment(
- self,
- audio: Union[str, Tensor],
- text: str,
- plot: bool = False,
- add_trailing_silence: bool = False,
- ) -> Tuple[Tensor, Tensor, List[StringLike]]:
- if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
- # we got units as audio arg
- units = audio
- units = units.to(self.device)
- audio_tensor = None
- else:
- audio_tensor = self.prepare_audio(audio)
- units = self.extract_units(audio_tensor)
- tokenized_unit_ids = self.alignment_model.alignment_frontend.tokenize_unit(
- units
- ).unsqueeze(0)
- tokenized_text_ids = (
- self.alignment_model.alignment_frontend.tokenize_text(
- text, add_trailing_silence=add_trailing_silence
- )
- .to(self.device)
- .unsqueeze(0)
- )
- tokenized_text_tokens = (
- self.alignment_model.alignment_frontend.tokenize_text_to_tokens(
- text, add_trailing_silence=add_trailing_silence
- )
- )
- _, alignment_durations = self.alignment_model(
- tokenized_text_ids, tokenized_unit_ids
- )
- if plot and (audio_tensor is not None):
- self.plot_alignment(
- audio_tensor.cpu(), tokenized_text_tokens, alignment_durations.cpu()
- )
- return alignment_durations, tokenized_text_ids, tokenized_text_tokens
- def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
- return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
- def plot_alignment(
- self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
- ) -> None:
- if not matplotlib_available:
- raise RuntimeError(
- "Please `pip install matplotlib` in order to use plot alignment."
- )
- _, ax = plt.subplots(figsize=(22, 3.5))
- ax.plot(audio, color="gray", linewidth=0.3)
- durations_cumul = numpy.concatenate([numpy.array([0]), numpy.cumsum(durations)])
- alignment_ticks = durations_cumul * 320 # 320 is hardcoded for 20ms rate here
- ax.vlines(
- alignment_ticks,
- ymax=1,
- ymin=-1,
- color="indigo",
- linestyles="dashed",
- lw=0.5,
- )
- middle_tick_positions = (
- durations_cumul[:-1] + (durations_cumul[1:] - durations_cumul[:-1]) / 2
- )
- ax.set_xticks(middle_tick_positions * 320)
- ax.set_xticklabels(text_tokens, fontsize=13)
- ax.set_xlim(0, len(audio))
- ax.set_ylim(audio.min(), audio.max())
- ax.set_yticks([])
- plt.tight_layout()
- plt.show()
|