Эх сурвалжийг харах

Adding Seamless Nano transcription models

Ruslan Mavlyutov 1 жил өмнө
parent
commit
c320b7f4df

+ 145 - 86
src/seamless_communication/inference/transcriber.py

@@ -3,34 +3,36 @@
 # This source code is licensed under the license found in the
 # MIT_LICENSE file in the root directory of this source tree.
 
+from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Tuple, Union, Optional
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
+import numpy as np
+import torch
+import torch.nn as nn
+from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater
-from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
-from fairseq2.generation import (
-    BeamSearchSeq2SeqGenerator,
-    SequenceGeneratorOutput,
+from fairseq2.data.audio import (
+    AudioDecoder,
+    AudioDecoderOutput,
+    WaveformToFbankConverter,
 )
+from fairseq2.generation import BeamSearchSeq2SeqGenerator, Seq2SeqGeneratorOutput
 from fairseq2.memory import MemoryBlock
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.transformer.multihead_attention import AttentionWeightHook
 from fairseq2.typing import DataType, Device
-
-import numpy as np
 from scipy.signal import medfilt2d
-from argparse import Namespace
-
-import torch
-import torch.nn as nn
 from torch import Tensor
 
+from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
+from seamless_communication.models.tokenizer import SPMTokenizer
 from seamless_communication.models.unity import (
     UnitYX2TModel,
     load_unity_model,
     load_unity_text_tokenizer,
 )
-from seamless_communication.denoise.demucs import Demucs, DenoisingConfig
 from seamless_communication.segment.silero_vad import SileroVADSegmenter
 
 
@@ -58,16 +60,19 @@ class EncDecAttentionsCollect(AttentionWeightHook):
         self.attn_scores = []
 
 
+@dataclass
+class TranscriptionTokenStats:
+    text: str
+    time_s: float
+    scores: List[float]
+
+
+@dataclass
 class TranscriptionToken:
     text: str
     time_s: float
     prob: float
 
-    def __init__(self, text: str, time_s: float, prob: float):
-        self.text = text
-        self.time_s = time_s
-        self.prob = prob
-
 
 class Transcription:
     text: str
@@ -77,6 +82,11 @@ class Transcription:
         self.text = " ".join([t.text for t in tokens])
         self.tokens = tokens
 
+    def __add__(self, other: "Transcription") -> "Transcription":
+        self.text += " " + other.text
+        self.tokens += other.tokens
+        return self
+
     def __str__(self):
         return self.text
 
@@ -90,39 +100,34 @@ class Transcriber(nn.Module):
         model_name_or_card: Union[str, AssetCard],
         device: torch.device = torch.device("cpu"),
         dtype: torch.dtype = torch.float32,
-        encoder_layers: int = 6,
-        decoder_layers: int = 3,
-        embed_dim: int = 512,
-        depthwise_conv_kernel_size: int = 31,
     ):
         super().__init__()
+
         self.device = device
         self.dtype = dtype
-        self.embed_dim = embed_dim
-        self.encoder_layers = encoder_layers
-        self.decoder_layers = decoder_layers
-        self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
-        self.tokenizer = load_unity_text_tokenizer(model_name_or_card)
-        self.decoder_vocab_info = self.tokenizer.vocab_info
-        self.langs = self.tokenizer.langs
+
+        self.tokenizer = self.load_tokenizer(model_name_or_card)
 
         model = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )
+
         self.s2t = UnitYX2TModel(
             encoder_frontend=model.speech_encoder_frontend,
             encoder=model.speech_encoder,
             decoder_frontend=model.text_decoder_frontend,
             decoder=model.text_decoder,
             final_proj=model.final_proj,
-            target_vocab_info=self.decoder_vocab_info,
+            target_vocab_info=self.tokenizer.vocab_info,
         )
+
         self.enc_dec_attn_collector = EncDecAttentionsCollect()
         self.s2t.decoder.layers[-1].encoder_decoder_attn.register_attn_weight_hook(
             self.enc_dec_attn_collector
         )
 
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+
         self.convert_to_fbank = WaveformToFbankConverter(
             num_mel_bins=80,
             waveform_scale=2**15,
@@ -131,10 +136,34 @@ class Transcriber(nn.Module):
             device=device,
             dtype=dtype,
         )
+
         self.collate = Collater(
             pad_value=self.tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
 
+    @staticmethod
+    def load_tokenizer(
+        model_name_or_card: Union[AssetCard, str]
+    ) -> Union[SPMTokenizer, NllbTokenizer]:
+        if isinstance(model_name_or_card, AssetCard):
+            model_card = model_name_or_card
+        else:
+            model_card = asset_store.retrieve_card(model_name_or_card)
+
+        tokenizer_type = model_card.field("tokenizer_type").as_(str)
+
+        if tokenizer_type == "nllb":
+            return load_unity_text_tokenizer(model_name_or_card)
+
+        if tokenizer_type == "plain_spm":
+            tokenizer_uri = model_card.field("tokenizer").as_(str)
+            tokenizer_langs = model_card.field("langs").as_(list)
+            tokenizer_path = download_manager.download_tokenizer(
+                tokenizer_uri, model_name=""
+            )
+            return SPMTokenizer(pathname=tokenizer_path, langs=tokenizer_langs)
+        raise NotImplementedError(f"Unknow tokenizer type '{tokenizer_type}'")
+
     @staticmethod
     def load_model_for_inference(
         load_model_fn: Callable[..., nn.Module],
@@ -168,7 +197,7 @@ class Transcriber(nn.Module):
         while idx != prev[idx]:
             idx = prev[idx]
             seq.append(arr[idx])
-        return (maximum, reversed(seq))
+        return (maximum, list(reversed(seq)))
 
     @classmethod
     def _extract_timestamps(
@@ -212,24 +241,30 @@ class Transcriber(nn.Module):
         assert len(pieces) == len(token_timestamps) and len(token_timestamps) == len(
             step_scores
         )
-        word_stats: List[List[Any]] = []
+        word_stats: List[TranscriptionTokenStats] = []
         for (
             time_s,
             token,
             score,
         ) in zip(token_timestamps, pieces, step_scores):
-            if not word_stats or token.startswith("▁") and time_s > word_stats[-1][1]:
+            if (
+                not word_stats
+                or token.startswith("▁")
+                and time_s > word_stats[-1].time_s
+            ):
                 word_stats.append(
-                    [token.replace("▁", " ").strip(), time_s, [np.exp(score)]]
+                    TranscriptionTokenStats(
+                        token.replace("▁", " ").strip(), time_s, [np.exp(score)]
+                    )
                 )
             else:
-                word_stats[-1][0] += token.replace("▁", " ")
-                word_stats[-1][2].append(np.exp(score))
-        word_stats = [
-            TranscriptionToken(word, start, np.mean(probs))
-            for word, start, probs in word_stats
+                word_stats[-1].text += token.replace("▁", " ")
+                word_stats[-1].scores.append(np.exp(score))
+        words = [
+            TranscriptionToken(token.text, token.time_s, np.mean(token.scores).item())
+            for token in word_stats
         ]
-        return word_stats
+        return words
 
     def run_inference(
         self,
@@ -251,15 +286,19 @@ class Transcriber(nn.Module):
         )
 
         self.enc_dec_attn_collector.reset()
-        output: SequenceGeneratorOutput = generator(
+        assert prefix is not None
+        output: Seq2SeqGeneratorOutput = generator(
             source_seqs=fbanks.unsqueeze(0),
             source_padding_mask=None,
             prompt_seqs=prefix.unsqueeze(0),
             prompt_padding_mask=None,
         )
-
-        token_ids = output.hypotheses[0][0].seq.squeeze(0).tolist()[:-1]
-        step_scores = output.hypotheses[0][0].step_scores.tolist()[:-1]
+        highest_prob_hypo = output.hypotheses[0][0]
+        token_tensor = highest_prob_hypo.seq.squeeze(0)
+        token_ids = token_tensor.tolist()[:-1]
+        step_scores_tensor = highest_prob_hypo.step_scores
+        assert step_scores_tensor is not None
+        step_scores = step_scores_tensor.tolist()[:-1]
         enc_dec_attn_scores = self.enc_dec_attn_collector.attn_scores[:-1]
         token_timestamps = self._extract_timestamps(
             enc_dec_attn_scores,
@@ -275,15 +314,13 @@ class Transcriber(nn.Module):
             step_scores=step_scores,
         )
         return Transcription(stats)
-    
+
     def denoise_audio(
-            self, 
-            audio: Union[str, Tensor], 
-            denoise_config: Optional[DenoisingConfig]
-            ) -> Dict:
-        demucs = Demucs(
-            denoise_config=denoise_config)
+        self, audio: Union[str, Tensor], denoise_config: Optional[DenoisingConfig]
+    ) -> AudioDecoderOutput:
+        demucs = Demucs(denoise_config=denoise_config)
         audio = demucs.denoise(audio)
+        assert isinstance(audio, MemoryBlock)
         return self.decode_audio(audio)
 
     @torch.inference_mode()
@@ -298,7 +335,7 @@ class Transcriber(nn.Module):
         chunk_size_sec: int = 20,
         pause_length_sec: float = 1,
         **sequence_generator_options: Dict,
-    ) -> Transcription:
+    ) -> Optional[Transcription]:
         """
         The main method used to perform transcription.
 
@@ -324,16 +361,16 @@ class Transcriber(nn.Module):
             Configuration for denoising.
 
         :returns:
-            - List of Tokens with timestamps.
+            - Transcription: list of tokens with timestamps and joined text
         """
 
         if denoise:
             decoded_audio = self.denoise_audio(audio, denoise_config)
-        else:            
+        else:
             if isinstance(audio, str):
-                    with Path(audio).open("rb") as fb:
-                        block = MemoryBlock(fb.read())
-                    decoded_audio = self.decode_audio(block)
+                with Path(audio).open("rb") as fb:
+                    block = MemoryBlock(fb.read())
+                decoded_audio = self.decode_audio(block)
             else:
                 decoded_audio = {
                     "waveform": audio,
@@ -341,37 +378,59 @@ class Transcriber(nn.Module):
                     "format": -1,
                 }
 
-            length_seconds = (
-                decoded_audio["waveform"].size(0) / decoded_audio["sample_rate"]
-            )
+        wav = decoded_audio.get("waveform")
+        assert wav is not None
+
+        decoded_sample_rate = decoded_audio.get("sample_rate")
+        assert decoded_sample_rate is not None
+        assert int(decoded_sample_rate) == sample_rate
+
+        length_seconds = wav.size(0) / sample_rate
 
-            waveform_2d = decoded_audio.get("waveform")
-            waveform_1d = decoded_audio.get("waveform").view(-1)
-            segmenter = SileroVADSegmenter(
-                sample_rate=sample_rate,
-                chunk_size_sec=chunk_size_sec,
-                pause_length=pause_length_sec,
+        waveform_2d = wav
+        waveform_1d = wav.view(-1)
+        segmenter = SileroVADSegmenter(
+            sample_rate=sample_rate,
+            chunk_size_sec=chunk_size_sec,
+            pause_length=pause_length_sec,
+        )
+
+        if length_seconds > chunk_size_sec:
+            src_segments = segmenter.segment_long_input(waveform_1d)  # type: ignore
+        else:
+            src_segments = [(0, waveform_1d.size(0))]
+
+        transcriptions: List[Transcription] = []
+        for start, end in src_segments:
+            segment = waveform_2d[start:end, :]
+            src_segment = self.convert_to_fbank(
+                {
+                    "waveform": segment,
+                    "sample_rate": sample_rate,
+                }
+            )["fbank"]
+            length_seconds_segment = segment.size(0) / sample_rate
+            transcription_segment = self.run_inference(
+                src_segment,
+                src_lang,
+                length_seconds_segment,
+                filter_width,
+                sequence_generator_options,
             )
+            transcriptions.append(transcription_segment)
 
-            if length_seconds > chunk_size_sec:
-                src_segments = segmenter.segment_long_input(waveform_1d)
-            else:
-                src_segments = [(0, waveform_1d.size(0))]
-
-            transcriptions = []
-            for start, end in src_segments:
-                segment = waveform_2d[start:end, :]
-                src_segment = self.convert_to_fbank(
-                    {"waveform": segment, "sample_rate": decoded_audio.get("sample_rate"), 
-                     "format": decoded_audio.get("format")})["fbank"]
-                length_seconds_segment = segment.size(0) / sample_rate
-                transcription_segment = self.run_inference(
-                    src_segment,
-                    src_lang,
-                    length_seconds_segment,
-                    filter_width,
-                    sequence_generator_options,
-                )
-                transcriptions.append(str(transcription_segment))
+        if not transcriptions:
+            return None
+
+        for idx in range(1, len(transcriptions)):
+            transcriptions[0] = transcriptions[idx]
+
+        return transcriptions[0]
+
+
+if __name__ == "__main__":
+    transcriber = Transcriber("seamless_nano")
+    print(transcriber.transcribe("/private/home/mavlyutov/input.wav", src_lang="eng"))
 
-            return " ".join(transcriptions)
+    transcriber = Transcriber("seamless_micro")
+    print(transcriber.transcribe("/private/home/mavlyutov/input.wav", src_lang="eng"))

+ 144 - 8
src/seamless_communication/models/unity/builder.py

@@ -7,6 +7,7 @@
 from dataclasses import dataclass
 from typing import Optional, Union
 
+from fairseq2.data.vocabulary_info import VocabularyInfo
 from fairseq2.models.conformer import ConformerBlock, ConformerConvolution
 from fairseq2.models.nllb import NllbBuilder, NllbConfig, nllb_archs
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
@@ -26,6 +27,11 @@ from fairseq2.nn.transformer import (
 from fairseq2.typing import DataType, Device, override
 from torch.nn import GELU, ReLU
 
+from seamless_communication.models.conformer_shaw import (
+    ConformerShawEncoderBuilder,
+    ConformerShawEncoderConfig,
+    conformer_shaw_archs,
+)
 from seamless_communication.models.generator.ecapa_tdnn_builder import (
     EcapaTDNNBuilder,
     EcapaTDNNConfig,
@@ -43,11 +49,6 @@ from seamless_communication.models.unity.t2u_builder import (
     UnitYT2UConfig,
     unity_t2u_archs,
 )
-from seamless_communication.models.conformer_shaw import (
-    ConformerShawEncoderBuilder,
-    ConformerShawEncoderConfig,
-    conformer_shaw_archs,
-)
 
 
 @dataclass
@@ -223,6 +224,138 @@ def _expressivity_v2() -> UnitYConfig:
     )
 
 
+def _build_model_config(
+    model_embed_dim: int,
+    ffn_emb_dim_mult: int,
+    feature_stride: int,
+    text_decoder_layers: int,
+    text_dict_size: int,
+    unit_dict_size: int,
+):
+    num_fbank_channels = 80
+    fbank_stride = feature_stride
+    nllb_ffn_inner_dim = model_embed_dim * ffn_emb_dim_mult
+    w2v2_ffn_inner_dim = model_embed_dim * 4
+    w2v2_encoder_layers_layernorm_features: bool = False
+    w2v2_pos_encoder_type = "relative"
+    w2v2_pos_encoder_depth: int = 0
+    w2v2_pos_conv_kernel_size: int = 0
+    w2v2_num_pos_conv_groups: int = 0
+    w2v2_encoder_layers: int = 6
+    w2v2_encoder_layers_use_conformer: bool = True
+    nllb_encoder_layers: int = 1
+    nllb_decoder_layers: int = text_decoder_layers
+    text_vocab_info = VocabularyInfo(
+        size=text_dict_size,
+        unk_idx=3,
+        bos_idx=0,
+        eos_idx=2,
+        pad_idx=1,
+    )
+    unit_vocab_info = VocabularyInfo(
+        size=unit_dict_size,
+        unk_idx=0,
+        bos_idx=0,
+        eos_idx=0,
+        pad_idx=0,  # not used
+    )
+
+    model_config = UnitYConfig(
+        use_gelu=False,
+        use_text_decoder=True,
+        prosody_encoder_config=None,
+        model_dim=model_embed_dim,
+        w2v2_encoder_config=Wav2Vec2EncoderConfig(
+            model_dim=model_embed_dim,
+            max_seq_len=4096,
+            feature_dim=num_fbank_channels * fbank_stride,
+            use_fbank=True,
+            first_pass_dropout_p=0.0,
+            layer_norm_features=w2v2_encoder_layers_layernorm_features,
+            feature_extractor_layer_descs=[],
+            feature_extractor_bias=False,
+            feature_extractor_layer_norm_convs=False,
+            feature_grad_scale=0,
+            num_fbank_channels=num_fbank_channels,
+            fbank_stride=fbank_stride,
+            sample_fbank_every_k=1,
+            pos_encoder_type=w2v2_pos_encoder_type,
+            pos_encoder_depth=w2v2_pos_encoder_depth,
+            pos_conv_kernel_size=w2v2_pos_conv_kernel_size,
+            num_pos_conv_groups=w2v2_num_pos_conv_groups,
+            use_conformer=w2v2_encoder_layers_use_conformer,
+            num_encoder_layers=w2v2_encoder_layers,
+            num_encoder_attn_heads=16,
+            ffn_inner_dim=w2v2_ffn_inner_dim,
+            dropout_p=0.0,
+            attn_dropout_p=0.0,
+            layer_drop_p=0.0,
+            norm_order=TransformerNormOrder.POST,
+            depthwise_conv_kernel_size=31,
+        ),
+        mt_model_config=NllbConfig(
+            model_dim=model_embed_dim,
+            max_seq_len=1024,
+            vocab_info=text_vocab_info,
+            num_encoder_layers=nllb_encoder_layers,
+            num_decoder_layers=nllb_decoder_layers,
+            num_encoder_attn_heads=16,
+            num_decoder_attn_heads=16,
+            ffn_inner_dim=nllb_ffn_inner_dim,
+            dropout_p=0.1,
+        ),
+        t2u_config=UnitYT2UConfig(
+            use_gelu=False,
+            char_pad_idx=0,
+            use_prosody_proj=False,
+            prosody_encoder_dim=0,
+            nar_decoder_frontend_config=None,
+            nar_decoder_config=None,
+            model_dim=model_embed_dim,
+            unit_max_seq_len=2048,
+            target_vocab_info=unit_vocab_info,  # dummy
+            num_encoder_layers=1,
+            num_decoder_layers=1,
+            num_encoder_attn_heads=16,
+            num_decoder_attn_heads=16,
+            ffn_inner_dim=model_embed_dim * 8,
+            dropout_p=0.1,
+        ),
+        use_text_encoder=True,
+        use_conformer_adaptor=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.1,
+    )
+    return model_config
+
+
+@unity_arch("seamless_micro")
+def _seamless_micro() -> UnitYConfig:
+    return _build_model_config(
+        model_embed_dim=512,
+        ffn_emb_dim_mult=8,
+        feature_stride=4,
+        text_decoder_layers=3,
+        text_dict_size=20010,
+        unit_dict_size=10082,
+    )
+
+
+@unity_arch("seamless_nano")
+def _seamless_nano() -> UnitYConfig:
+    return _build_model_config(
+        model_embed_dim=256,
+        ffn_emb_dim_mult=8,
+        feature_stride=4,
+        text_decoder_layers=3,
+        text_dict_size=20010,
+        unit_dict_size=10082,
+    )
+
+
 class UnitYBuilder:
     """Builds modules of a UnitY model.
 
@@ -265,17 +398,20 @@ class UnitYBuilder:
         """
         if w2v2_encoder_builder.config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `w2v2_encoder_builder.config` must be equal, but are {config.model_dim} and {w2v2_encoder_builder.config.model_dim} instead."
+                "`model_dim` and `model_dim` of `w2v2_encoder_builder.config` must be equal, "
+                f"but are {config.model_dim} and {w2v2_encoder_builder.config.model_dim} instead."
             )
 
         if mt_model_builder.config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `mt_model_builder.config` must be equal, but are {config.model_dim} and {mt_model_builder.config.model_dim} instead."
+                "`model_dim` and `model_dim` of `mt_model_builder.config` must be equal, "
+                f"but are {config.model_dim} and {mt_model_builder.config.model_dim} instead."
             )
 
         if t2u_builder is not None and t2u_builder.config.model_dim != config.model_dim:
             raise ValueError(
-                f"`model_dim` and `model_dim` of `t2u_builder.config` must be equal, but are {config.model_dim} and {t2u_builder.config.model_dim} instead."
+                "`model_dim` and `model_dim` of `t2u_builder.config` must be equal, "
+                f"but are {config.model_dim} and {t2u_builder.config.model_dim} instead."
             )
 
         self.config = config