|
@@ -12,9 +12,9 @@ from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
-import numpy as np
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
+from fairseq2.assets import asset_store
|
|
|
from fairseq2.data import Collater, CString, DataPipeline, FileMapper
|
|
|
from fairseq2.data.audio import (
|
|
|
AudioDecoder,
|
|
@@ -24,17 +24,26 @@ from fairseq2.data.audio import (
|
|
|
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
|
|
|
from fairseq2.data.typing import PathLike, StringLike
|
|
|
from fairseq2.generation import SequenceGeneratorOptions
|
|
|
+from fairseq2.nn.padding import get_seqs_and_padding_mask
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
|
|
|
from torch import Tensor
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
+from seamless_communication.cli.m4t.evaluate.evaluate import (
|
|
|
+ adjust_output_for_corrupted_inputs,
|
|
|
+ count_lines,
|
|
|
+)
|
|
|
from seamless_communication.cli.m4t.predict import (
|
|
|
add_inference_arguments,
|
|
|
set_generation_opts,
|
|
|
)
|
|
|
from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
|
|
|
-from seamless_communication.models.unity import load_unity_text_tokenizer
|
|
|
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
|
|
|
+from seamless_communication.models.unity import (
|
|
|
+ load_unity_text_tokenizer,
|
|
|
+ load_gcmvn_stats,
|
|
|
+)
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
@@ -97,13 +106,20 @@ class EvalContext:
|
|
|
"""If True, removes consecutive repeating ngrams
|
|
|
from the decoded unit output."""
|
|
|
|
|
|
- gcmvn_stats: Optional[PathLike] = None
|
|
|
- """the stats for gcmvn, used by Prosody Encoder"""
|
|
|
+ pretssel_model: str
|
|
|
+ """The name of the PretsselModel"""
|
|
|
+
|
|
|
+ vocoder_name: str
|
|
|
+ """The name of the Vocoder Model"""
|
|
|
+
|
|
|
+ gcmvn_mean: Optional[Tensor]
|
|
|
+ """The mean stats for global-normalized fbank"""
|
|
|
|
|
|
+ gcmvn_std: Optional[Tensor]
|
|
|
+ """The std stats for global-normalized fbank"""
|
|
|
|
|
|
-def count_lines(filename: Path) -> int:
|
|
|
- result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
|
|
|
- return int(result.stdout.decode().split()[0])
|
|
|
+ duration_factor: float = 1.1
|
|
|
+ """The duration factor for NAR T2U model. The Expressivity model uses 1.1"""
|
|
|
|
|
|
|
|
|
def build_data_pipeline(
|
|
@@ -118,15 +134,6 @@ def build_data_pipeline(
|
|
|
|
|
|
split_tsv = StrSplitter(names=header)
|
|
|
|
|
|
- if ctx.gcmvn_stats is not None:
|
|
|
- if isinstance(ctx.gcmvn_stats, CString):
|
|
|
- ctx.gcmvn_stats = str(ctx.gcmvn_stats)
|
|
|
- gcmvn_stats: Dict[str, np.ndarray] = np.load(ctx.gcmvn_stats) # type: ignore[type-arg]
|
|
|
- gcmvn_mean = torch.tensor(
|
|
|
- gcmvn_stats["mean"], device=ctx.device, dtype=ctx.dtype
|
|
|
- )
|
|
|
- gcmvn_std = torch.tensor(gcmvn_stats["std"], device=ctx.device, dtype=ctx.dtype)
|
|
|
-
|
|
|
pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
|
|
|
|
|
|
assert ctx.audio_root_dir is not None
|
|
@@ -150,8 +157,8 @@ def build_data_pipeline(
|
|
|
fbank = data["fbank"]
|
|
|
std, mean = torch.std_mean(fbank, dim=0)
|
|
|
data["fbank"] = fbank.subtract(mean).divide(std)
|
|
|
- if ctx.gcmvn_stats is not None:
|
|
|
- data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
|
|
|
+ if ctx.gcmvn_mean is not None and ctx.gcmvn_std is not None:
|
|
|
+ data["gcmvn_fbank"] = fbank.subtract(ctx.gcmvn_mean).divide(ctx.gcmvn_std)
|
|
|
return data
|
|
|
|
|
|
pipeline_builder.map(
|
|
@@ -171,52 +178,19 @@ def build_data_pipeline(
|
|
|
return pipeline_builder.and_return()
|
|
|
|
|
|
|
|
|
-def adjust_output_for_corrupted_inputs(
|
|
|
- valid_sequences: Tensor,
|
|
|
- text_output: List[StringLike],
|
|
|
- speech_output: Optional[BatchedSpeechOutput],
|
|
|
-) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
|
|
|
- adjusted_text_output: List[StringLike] = []
|
|
|
- adjusted_speech_output: Optional[BatchedSpeechOutput] = None
|
|
|
-
|
|
|
- if speech_output is not None:
|
|
|
- assert (
|
|
|
- len(text_output)
|
|
|
- == len(speech_output.units)
|
|
|
- == len(speech_output.audio_wavs)
|
|
|
- )
|
|
|
- adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
|
|
|
-
|
|
|
- batch_counter = 0
|
|
|
- for is_valid in valid_sequences:
|
|
|
- if is_valid:
|
|
|
- adjusted_text_output.append(text_output[batch_counter])
|
|
|
- if speech_output is not None:
|
|
|
- assert adjusted_speech_output is not None
|
|
|
- adjusted_speech_output.units.append(speech_output.units[batch_counter])
|
|
|
- adjusted_speech_output.audio_wavs.append(
|
|
|
- speech_output.audio_wavs[batch_counter]
|
|
|
- )
|
|
|
- batch_counter += 1
|
|
|
- else:
|
|
|
- # For the corrupted inputs, we save the following dummy outputs:
|
|
|
- # empty string for text, empty list for units, 1 second of silence for audio.
|
|
|
- adjusted_text_output.append("")
|
|
|
- if adjusted_speech_output is not None:
|
|
|
- sample_rate = adjusted_speech_output.sample_rate
|
|
|
- adjusted_speech_output.units.append([])
|
|
|
- adjusted_speech_output.audio_wavs.append(
|
|
|
- torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
|
|
|
- )
|
|
|
- return (
|
|
|
- adjusted_text_output,
|
|
|
- adjusted_speech_output,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
def run_eval(
|
|
|
translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
|
|
|
) -> None:
|
|
|
+ pretssel_generator = PretsselGenerator(
|
|
|
+ ctx.model_name,
|
|
|
+ ctx.vocoder_name,
|
|
|
+ ctx.pretssel_model,
|
|
|
+ ctx.device,
|
|
|
+ ctx.gcmvn_mean,
|
|
|
+ ctx.gcmvn_std,
|
|
|
+ ctx.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
pipeline = build_data_pipeline(ctx, text_tokenizer)
|
|
|
|
|
|
total_steps = count_lines(ctx.data_file) - 1
|
|
@@ -226,7 +200,7 @@ def run_eval(
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
if ctx.output_modality == Modality.SPEECH:
|
|
|
- waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
|
|
|
+ waveforms_dir = output_path / f"waveform"
|
|
|
waveforms_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
hyps = []
|
|
@@ -258,10 +232,10 @@ def run_eval(
|
|
|
|
|
|
# Skip performing inference when the input is entirely corrupted.
|
|
|
if src["seqs"].numel() > 0:
|
|
|
- (
|
|
|
- text_output,
|
|
|
- speech_output,
|
|
|
- ) = translator.predict(
|
|
|
+ gcmvn_fbank, padding_mask = get_seqs_and_padding_mask(
|
|
|
+ example["audio"]["data"]["gcmvn_fbank"]
|
|
|
+ )
|
|
|
+ text_output, unit_output = translator.predict(
|
|
|
src,
|
|
|
ctx.task,
|
|
|
ctx.target_lang,
|
|
@@ -269,8 +243,18 @@ def run_eval(
|
|
|
text_generation_opts=ctx.text_generation_opts,
|
|
|
unit_generation_opts=ctx.unit_generation_opts,
|
|
|
unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
|
|
|
- gcmvn_fbank=example["audio"]["data"].get("gcmvn_fbank", None),
|
|
|
+ duration_factor=ctx.duration_factor,
|
|
|
+ gcmvn_fbank=gcmvn_fbank,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert unit_output is not None
|
|
|
+ speech_output = pretssel_generator.predict(
|
|
|
+ unit_output.units,
|
|
|
+ tgt_lang=ctx.target_lang,
|
|
|
+ padding_mask=padding_mask,
|
|
|
+ gcmvn_fbank=gcmvn_fbank,
|
|
|
)
|
|
|
+
|
|
|
else:
|
|
|
text_output = []
|
|
|
if ctx.output_modality == Modality.SPEECH:
|
|
@@ -279,10 +263,7 @@ def run_eval(
|
|
|
speech_output = None
|
|
|
|
|
|
if valid_sequences is not None and not valid_sequences.all():
|
|
|
- (
|
|
|
- text_output,
|
|
|
- speech_output,
|
|
|
- ) = adjust_output_for_corrupted_inputs(
|
|
|
+ (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
|
|
|
valid_sequences,
|
|
|
text_output,
|
|
|
speech_output,
|
|
@@ -293,6 +274,7 @@ def run_eval(
|
|
|
|
|
|
for i in range(len(text_output)):
|
|
|
t = text_output[i]
|
|
|
+ idx = str(example["id"][i])
|
|
|
hyp_file.write(f"{t}\n")
|
|
|
|
|
|
if ctx.output_modality == Modality.SPEECH:
|
|
@@ -301,8 +283,8 @@ def run_eval(
|
|
|
str_units = [str(i) for i in u]
|
|
|
unit_file.write(" ".join(str_units) + "\n")
|
|
|
torchaudio.save(
|
|
|
- waveforms_dir / f"{sample_id}_pred.wav",
|
|
|
- speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
|
|
|
+ waveforms_dir / f"{idx}_pred.wav",
|
|
|
+ speech_output.audio_wavs[i].to(torch.float32).cpu(),
|
|
|
sample_rate=speech_output.sample_rate,
|
|
|
)
|
|
|
|
|
@@ -353,9 +335,9 @@ def main() -> None:
|
|
|
default="tgt_text",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--gcmvn_stats",
|
|
|
+ "--pretssel_model",
|
|
|
type=str,
|
|
|
- help="The path to gcmvn fbank stats, if provided, the DataPipeline'd have another copy of gcmvn fbank features (for P2V enc)",
|
|
|
+ help="Model card name for PretsselModel",
|
|
|
default=None,
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
@@ -369,19 +351,19 @@ def main() -> None:
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda:0")
|
|
|
- dtype = torch.float32
|
|
|
+ dtype = torch.float16
|
|
|
else:
|
|
|
device = torch.device("cpu")
|
|
|
dtype = torch.float32
|
|
|
|
|
|
text_tokenizer = load_unity_text_tokenizer(args.model_name)
|
|
|
|
|
|
- # TODO: Avoid loading the T2U model, vocoder when the output
|
|
|
- # modality is text.
|
|
|
+ gcmvn_mean, gcmvn_std = load_gcmvn_stats(args.pretssel_model)
|
|
|
+
|
|
|
translator = Translator(
|
|
|
args.model_name,
|
|
|
- args.vocoder_name,
|
|
|
- device,
|
|
|
+ vocoder_name_or_card=None,
|
|
|
+ device=device,
|
|
|
text_tokenizer=text_tokenizer,
|
|
|
dtype=dtype,
|
|
|
)
|
|
@@ -411,7 +393,10 @@ def main() -> None:
|
|
|
unit_generation_opts=unit_generation_opts,
|
|
|
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
|
|
|
output_path=Path(args.output_path),
|
|
|
- gcmvn_stats=args.gcmvn_stats,
|
|
|
+ gcmvn_mean=torch.tensor(gcmvn_mean, device=device, dtype=dtype),
|
|
|
+ gcmvn_std=torch.tensor(gcmvn_std, device=device, dtype=dtype),
|
|
|
+ pretssel_model=args.pretssel_model,
|
|
|
+ vocoder_name=args.vocoder_name,
|
|
|
)
|
|
|
# fmt: on
|
|
|
logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
|