|
@@ -3,34 +3,36 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# MIT_LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
-from typing import Any, Callable, Dict, List, Tuple, Union, Optional
|
|
|
+from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+from fairseq2.assets import asset_store, download_manager
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import Collater
|
|
|
-from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
|
|
-from fairseq2.generation import (
|
|
|
- BeamSearchSeq2SeqGenerator,
|
|
|
- SequenceGeneratorOutput,
|
|
|
+from fairseq2.data.audio import (
|
|
|
+ AudioDecoder,
|
|
|
+ AudioDecoderOutput,
|
|
|
+ WaveformToFbankConverter,
|
|
|
)
|
|
|
+from fairseq2.generation import BeamSearchSeq2SeqGenerator, Seq2SeqGeneratorOutput
|
|
|
from fairseq2.memory import MemoryBlock
|
|
|
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
|
|
|
from fairseq2.nn.transformer.multihead_attention import AttentionWeightHook
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
-
|
|
|
-import numpy as np
|
|
|
from scipy.signal import medfilt2d
|
|
|
-from argparse import Namespace
|
|
|
-
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
from torch import Tensor
|
|
|
|
|
|
+from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
|
|
|
+from seamless_communication.models.tokenizer import SPMTokenizer
|
|
|
from seamless_communication.models.unity import (
|
|
|
UnitYX2TModel,
|
|
|
load_unity_model,
|
|
|
load_unity_text_tokenizer,
|
|
|
)
|
|
|
-from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
|
|
|
from seamless_communication.segment.silero_vad import SileroVADSegmenter
|
|
|
|
|
|
|
|
@@ -58,16 +60,19 @@ class EncDecAttentionsCollect(AttentionWeightHook):
|
|
|
self.attn_scores = []
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class TranscriptionTokenStats:
|
|
|
+ text: str
|
|
|
+ time_s: float
|
|
|
+ scores: List[float]
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
class TranscriptionToken:
|
|
|
text: str
|
|
|
time_s: float
|
|
|
prob: float
|
|
|
|
|
|
- def __init__(self, text: str, time_s: float, prob: float):
|
|
|
- self.text = text
|
|
|
- self.time_s = time_s
|
|
|
- self.prob = prob
|
|
|
-
|
|
|
|
|
|
class Transcription:
|
|
|
text: str
|
|
@@ -77,6 +82,11 @@ class Transcription:
|
|
|
self.text = " ".join([t.text for t in tokens])
|
|
|
self.tokens = tokens
|
|
|
|
|
|
+ def __add__(self, other: "Transcription") -> "Transcription":
|
|
|
+ self.text += " " + other.text
|
|
|
+ self.tokens += other.tokens
|
|
|
+ return self
|
|
|
+
|
|
|
def __str__(self):
|
|
|
return self.text
|
|
|
|
|
@@ -90,39 +100,34 @@ class Transcriber(nn.Module):
|
|
|
model_name_or_card: Union[str, AssetCard],
|
|
|
device: torch.device = torch.device("cpu"),
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
- encoder_layers: int = 6,
|
|
|
- decoder_layers: int = 3,
|
|
|
- embed_dim: int = 512,
|
|
|
- depthwise_conv_kernel_size: int = 31,
|
|
|
):
|
|
|
super().__init__()
|
|
|
+
|
|
|
self.device = device
|
|
|
self.dtype = dtype
|
|
|
- self.embed_dim = embed_dim
|
|
|
- self.encoder_layers = encoder_layers
|
|
|
- self.decoder_layers = decoder_layers
|
|
|
- self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
|
|
|
- self.tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
|
- self.decoder_vocab_info = self.tokenizer.vocab_info
|
|
|
- self.langs = self.tokenizer.langs
|
|
|
+
|
|
|
+ self.tokenizer = self.load_tokenizer(model_name_or_card)
|
|
|
|
|
|
model = self.load_model_for_inference(
|
|
|
load_unity_model, model_name_or_card, device, dtype
|
|
|
)
|
|
|
+
|
|
|
self.s2t = UnitYX2TModel(
|
|
|
encoder_frontend=model.speech_encoder_frontend,
|
|
|
encoder=model.speech_encoder,
|
|
|
decoder_frontend=model.text_decoder_frontend,
|
|
|
decoder=model.text_decoder,
|
|
|
final_proj=model.final_proj,
|
|
|
- target_vocab_info=self.decoder_vocab_info,
|
|
|
+ target_vocab_info=self.tokenizer.vocab_info,
|
|
|
)
|
|
|
+
|
|
|
self.enc_dec_attn_collector = EncDecAttentionsCollect()
|
|
|
self.s2t.decoder.layers[-1].encoder_decoder_attn.register_attn_weight_hook(
|
|
|
self.enc_dec_attn_collector
|
|
|
)
|
|
|
|
|
|
self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
|
|
+
|
|
|
self.convert_to_fbank = WaveformToFbankConverter(
|
|
|
num_mel_bins=80,
|
|
|
waveform_scale=2**15,
|
|
@@ -131,10 +136,34 @@ class Transcriber(nn.Module):
|
|
|
device=device,
|
|
|
dtype=dtype,
|
|
|
)
|
|
|
+
|
|
|
self.collate = Collater(
|
|
|
pad_value=self.tokenizer.vocab_info.pad_idx, pad_to_multiple=2
|
|
|
)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def load_tokenizer(
|
|
|
+ model_name_or_card: Union[AssetCard, str]
|
|
|
+ ) -> Union[SPMTokenizer, NllbTokenizer]:
|
|
|
+ if isinstance(model_name_or_card, AssetCard):
|
|
|
+ model_card = model_name_or_card
|
|
|
+ else:
|
|
|
+ model_card = asset_store.retrieve_card(model_name_or_card)
|
|
|
+
|
|
|
+ tokenizer_type = model_card.field("tokenizer_type").as_(str)
|
|
|
+
|
|
|
+ if tokenizer_type == "nllb":
|
|
|
+ return load_unity_text_tokenizer(model_name_or_card)
|
|
|
+
|
|
|
+ if tokenizer_type == "plain_spm":
|
|
|
+ tokenizer_uri = model_card.field("tokenizer").as_(str)
|
|
|
+ tokenizer_langs = model_card.field("langs").as_(list)
|
|
|
+ tokenizer_path = download_manager.download_tokenizer(
|
|
|
+ tokenizer_uri, model_name=""
|
|
|
+ )
|
|
|
+ return SPMTokenizer(pathname=tokenizer_path, langs=tokenizer_langs)
|
|
|
+ raise NotImplementedError(f"Unknow tokenizer type '{tokenizer_type}'")
|
|
|
+
|
|
|
@staticmethod
|
|
|
def load_model_for_inference(
|
|
|
load_model_fn: Callable[..., nn.Module],
|
|
@@ -168,7 +197,7 @@ class Transcriber(nn.Module):
|
|
|
while idx != prev[idx]:
|
|
|
idx = prev[idx]
|
|
|
seq.append(arr[idx])
|
|
|
- return (maximum, reversed(seq))
|
|
|
+ return (maximum, list(reversed(seq)))
|
|
|
|
|
|
@classmethod
|
|
|
def _extract_timestamps(
|
|
@@ -212,24 +241,30 @@ class Transcriber(nn.Module):
|
|
|
assert len(pieces) == len(token_timestamps) and len(token_timestamps) == len(
|
|
|
step_scores
|
|
|
)
|
|
|
- word_stats: List[List[Any]] = []
|
|
|
+ word_stats: List[TranscriptionTokenStats] = []
|
|
|
for (
|
|
|
time_s,
|
|
|
token,
|
|
|
score,
|
|
|
) in zip(token_timestamps, pieces, step_scores):
|
|
|
- if not word_stats or token.startswith("▁") and time_s > word_stats[-1][1]:
|
|
|
+ if (
|
|
|
+ not word_stats
|
|
|
+ or token.startswith("▁")
|
|
|
+ and time_s > word_stats[-1].time_s
|
|
|
+ ):
|
|
|
word_stats.append(
|
|
|
- [token.replace("▁", " ").strip(), time_s, [np.exp(score)]]
|
|
|
+ TranscriptionTokenStats(
|
|
|
+ token.replace("▁", " ").strip(), time_s, [np.exp(score)]
|
|
|
+ )
|
|
|
)
|
|
|
else:
|
|
|
- word_stats[-1][0] += token.replace("▁", " ")
|
|
|
- word_stats[-1][2].append(np.exp(score))
|
|
|
- word_stats = [
|
|
|
- TranscriptionToken(word, start, np.mean(probs))
|
|
|
- for word, start, probs in word_stats
|
|
|
+ word_stats[-1].text += token.replace("▁", " ")
|
|
|
+ word_stats[-1].scores.append(np.exp(score))
|
|
|
+ words = [
|
|
|
+ TranscriptionToken(token.text, token.time_s, np.mean(token.scores).item())
|
|
|
+ for token in word_stats
|
|
|
]
|
|
|
- return word_stats
|
|
|
+ return words
|
|
|
|
|
|
def run_inference(
|
|
|
self,
|
|
@@ -251,15 +286,19 @@ class Transcriber(nn.Module):
|
|
|
)
|
|
|
|
|
|
self.enc_dec_attn_collector.reset()
|
|
|
- output: SequenceGeneratorOutput = generator(
|
|
|
+ assert prefix is not None
|
|
|
+ output: Seq2SeqGeneratorOutput = generator(
|
|
|
source_seqs=fbanks.unsqueeze(0),
|
|
|
source_padding_mask=None,
|
|
|
prompt_seqs=prefix.unsqueeze(0),
|
|
|
prompt_padding_mask=None,
|
|
|
)
|
|
|
-
|
|
|
- token_ids = output.hypotheses[0][0].seq.squeeze(0).tolist()[:-1]
|
|
|
- step_scores = output.hypotheses[0][0].step_scores.tolist()[:-1]
|
|
|
+ highest_prob_hypo = output.hypotheses[0][0]
|
|
|
+ token_tensor = highest_prob_hypo.seq.squeeze(0)
|
|
|
+ token_ids = token_tensor.tolist()[:-1]
|
|
|
+ step_scores_tensor = highest_prob_hypo.step_scores
|
|
|
+ assert step_scores_tensor is not None
|
|
|
+ step_scores = step_scores_tensor.tolist()[:-1]
|
|
|
enc_dec_attn_scores = self.enc_dec_attn_collector.attn_scores[:-1]
|
|
|
token_timestamps = self._extract_timestamps(
|
|
|
enc_dec_attn_scores,
|
|
@@ -275,15 +314,13 @@ class Transcriber(nn.Module):
|
|
|
step_scores=step_scores,
|
|
|
)
|
|
|
return Transcription(stats)
|
|
|
-
|
|
|
+
|
|
|
def denoise_audio(
|
|
|
- self,
|
|
|
- audio: Union[str, Tensor],
|
|
|
- denoise_config: Optional[DenoisingConfig]
|
|
|
- ) -> Dict:
|
|
|
- demucs = Demucs(
|
|
|
- denoise_config=denoise_config)
|
|
|
+ self, audio: Union[str, Tensor], denoise_config: Optional[DenoisingConfig]
|
|
|
+ ) -> AudioDecoderOutput:
|
|
|
+ demucs = Demucs(denoise_config=denoise_config)
|
|
|
audio = demucs.denoise(audio)
|
|
|
+ assert isinstance(audio, MemoryBlock)
|
|
|
return self.decode_audio(audio)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@@ -298,7 +335,7 @@ class Transcriber(nn.Module):
|
|
|
chunk_size_sec: int = 20,
|
|
|
pause_length_sec: float = 1,
|
|
|
**sequence_generator_options: Dict,
|
|
|
- ) -> Transcription:
|
|
|
+ ) -> Optional[Transcription]:
|
|
|
"""
|
|
|
The main method used to perform transcription.
|
|
|
|
|
@@ -324,16 +361,16 @@ class Transcriber(nn.Module):
|
|
|
Configuration for denoising.
|
|
|
|
|
|
:returns:
|
|
|
- - List of Tokens with timestamps.
|
|
|
+ - Transcription: list of tokens with timestamps and joined text
|
|
|
"""
|
|
|
|
|
|
if denoise:
|
|
|
decoded_audio = self.denoise_audio(audio, denoise_config)
|
|
|
- else:
|
|
|
+ else:
|
|
|
if isinstance(audio, str):
|
|
|
- with Path(audio).open("rb") as fb:
|
|
|
- block = MemoryBlock(fb.read())
|
|
|
- decoded_audio = self.decode_audio(block)
|
|
|
+ with Path(audio).open("rb") as fb:
|
|
|
+ block = MemoryBlock(fb.read())
|
|
|
+ decoded_audio = self.decode_audio(block)
|
|
|
else:
|
|
|
decoded_audio = {
|
|
|
"waveform": audio,
|
|
@@ -341,37 +378,59 @@ class Transcriber(nn.Module):
|
|
|
"format": -1,
|
|
|
}
|
|
|
|
|
|
- length_seconds = (
|
|
|
- decoded_audio["waveform"].size(0) / decoded_audio["sample_rate"]
|
|
|
- )
|
|
|
+ wav = decoded_audio.get("waveform")
|
|
|
+ assert wav is not None
|
|
|
+
|
|
|
+ decoded_sample_rate = decoded_audio.get("sample_rate")
|
|
|
+ assert decoded_sample_rate is not None
|
|
|
+ assert int(decoded_sample_rate) == sample_rate
|
|
|
+
|
|
|
+ length_seconds = wav.size(0) / sample_rate
|
|
|
|
|
|
- waveform_2d = decoded_audio.get("waveform")
|
|
|
- waveform_1d = decoded_audio.get("waveform").view(-1)
|
|
|
- segmenter = SileroVADSegmenter(
|
|
|
- sample_rate=sample_rate,
|
|
|
- chunk_size_sec=chunk_size_sec,
|
|
|
- pause_length=pause_length_sec,
|
|
|
+ waveform_2d = wav
|
|
|
+ waveform_1d = wav.view(-1)
|
|
|
+ segmenter = SileroVADSegmenter(
|
|
|
+ sample_rate=sample_rate,
|
|
|
+ chunk_size_sec=chunk_size_sec,
|
|
|
+ pause_length=pause_length_sec,
|
|
|
+ )
|
|
|
+
|
|
|
+ if length_seconds > chunk_size_sec:
|
|
|
+ src_segments = segmenter.segment_long_input(waveform_1d) # type: ignore
|
|
|
+ else:
|
|
|
+ src_segments = [(0, waveform_1d.size(0))]
|
|
|
+
|
|
|
+ transcriptions: List[Transcription] = []
|
|
|
+ for start, end in src_segments:
|
|
|
+ segment = waveform_2d[start:end, :]
|
|
|
+ src_segment = self.convert_to_fbank(
|
|
|
+ {
|
|
|
+ "waveform": segment,
|
|
|
+ "sample_rate": sample_rate,
|
|
|
+ }
|
|
|
+ )["fbank"]
|
|
|
+ length_seconds_segment = segment.size(0) / sample_rate
|
|
|
+ transcription_segment = self.run_inference(
|
|
|
+ src_segment,
|
|
|
+ src_lang,
|
|
|
+ length_seconds_segment,
|
|
|
+ filter_width,
|
|
|
+ sequence_generator_options,
|
|
|
)
|
|
|
+ transcriptions.append(transcription_segment)
|
|
|
|
|
|
- if length_seconds > chunk_size_sec:
|
|
|
- src_segments = segmenter.segment_long_input(waveform_1d)
|
|
|
- else:
|
|
|
- src_segments = [(0, waveform_1d.size(0))]
|
|
|
-
|
|
|
- transcriptions = []
|
|
|
- for start, end in src_segments:
|
|
|
- segment = waveform_2d[start:end, :]
|
|
|
- src_segment = self.convert_to_fbank(
|
|
|
- {"waveform": segment, "sample_rate": decoded_audio.get("sample_rate"),
|
|
|
- "format": decoded_audio.get("format")})["fbank"]
|
|
|
- length_seconds_segment = segment.size(0) / sample_rate
|
|
|
- transcription_segment = self.run_inference(
|
|
|
- src_segment,
|
|
|
- src_lang,
|
|
|
- length_seconds_segment,
|
|
|
- filter_width,
|
|
|
- sequence_generator_options,
|
|
|
- )
|
|
|
- transcriptions.append(str(transcription_segment))
|
|
|
+ if not transcriptions:
|
|
|
+ return None
|
|
|
+
|
|
|
+ for idx in range(1, len(transcriptions)):
|
|
|
+ transcriptions[0] = transcriptions[idx]
|
|
|
+
|
|
|
+ return transcriptions[0]
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ transcriber = Transcriber("seamless_nano")
|
|
|
+ print(transcriber.transcribe("/private/home/mavlyutov/input.wav", src_lang="eng"))
|
|
|
|
|
|
- return " ".join(transcriptions)
|
|
|
+ transcriber = Transcriber("seamless_micro")
|
|
|
+ print(transcriber.transcribe("/private/home/mavlyutov/input.wav", src_lang="eng"))
|