|
@@ -3,19 +3,17 @@
|
|
|
#
|
|
|
# This source code is licensed under the BSD-style license found in the
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
-from fairseq2.nn.transformer import TransformerNormOrder
|
|
|
+from fairseq2.nn.transformer import TransformerEncoder, TransformerNormOrder
|
|
|
from fairseq2.models.wav2vec2 import (
|
|
|
Wav2Vec2EncoderConfig,
|
|
|
Wav2Vec2Config,
|
|
|
wav2vec2_arch,
|
|
|
Wav2Vec2Model,
|
|
|
- Wav2Vec2Builder,
|
|
|
- Wav2Vec2EncoderBuilder,
|
|
|
+ create_wav2vec2_model,
|
|
|
+ Wav2Vec2Frontend,
|
|
|
)
|
|
|
from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
|
|
|
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
|
|
|
-from fairseq2.models.utils.model_loader import ModelConfigLoader
|
|
|
-from fairseq2.typing import DataType, Device
|
|
|
from fairseq2.models.sequence import SequenceBatch
|
|
|
|
|
|
|
|
@@ -26,6 +24,8 @@ import torch
|
|
|
from typing import Optional
|
|
|
|
|
|
from torch import Tensor
|
|
|
+import torch.nn as nn
|
|
|
+
|
|
|
|
|
|
wav2vec2_archs = ArchitectureRegistry[Wav2Vec2Config]("wav2vec2")
|
|
|
wav2vec2_arch = wav2vec2_archs.marker
|
|
@@ -86,14 +86,31 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:
|
|
|
)
|
|
|
|
|
|
|
|
|
-class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
|
|
|
- @torch.no_grad()
|
|
|
+load_wav2vec2_model = Wav2Vec2Loader(
|
|
|
+ asset_store,
|
|
|
+ download_manager,
|
|
|
+ create_wav2vec2_model,
|
|
|
+ wav2vec2_archs,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class Wav2Vec2LayerOutputModel(nn.Module):
|
|
|
+ encoder_frontend: Wav2Vec2Frontend
|
|
|
+ encoder: TransformerEncoder
|
|
|
+
|
|
|
+ def __init__(self, w2v2: Wav2Vec2Model):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.encoder_frontend = w2v2.encoder_frontend
|
|
|
+ self.encoder = w2v2.encoder
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
def forward(self, batch: SequenceBatch, out_layer_idx: int):
|
|
|
"""
|
|
|
:param batch:
|
|
|
The batch of sequences to process.
|
|
|
"""
|
|
|
- seqs, padding_mask, _, _ = self.run_frontend(batch.seqs, batch.seq_lens)
|
|
|
+ seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.seq_lens)
|
|
|
w2v2_layer_output = None
|
|
|
|
|
|
def layer_output_hook(
|
|
@@ -107,70 +124,7 @@ class Wav2Vec2LayerOutputModel(Wav2Vec2Model):
|
|
|
if layer_idx == out_layer_idx:
|
|
|
w2v2_layer_output = layer_output
|
|
|
|
|
|
- # TODO: Should pad for fp16?
|
|
|
_, _ = self.encoder(seqs, padding_mask, layer_output_hook)
|
|
|
|
|
|
assert w2v2_layer_output is not None
|
|
|
return w2v2_layer_output
|
|
|
-
|
|
|
-
|
|
|
-class Wav2Vec2LayerOutputBuilder(Wav2Vec2Builder):
|
|
|
- def build_model(self) -> Wav2Vec2LayerOutputModel:
|
|
|
- """Build a model."""
|
|
|
- encoder_frontend = self.encoder_builder.build_frontend()
|
|
|
-
|
|
|
- encoder = self.encoder_builder.build_encoder()
|
|
|
-
|
|
|
- masker = self.build_masker()
|
|
|
-
|
|
|
- quantizer = self.build_quantizer()
|
|
|
-
|
|
|
- return Wav2Vec2LayerOutputModel(
|
|
|
- encoder_frontend,
|
|
|
- encoder,
|
|
|
- masker,
|
|
|
- quantizer,
|
|
|
- self.config.final_dim,
|
|
|
- self.config.final_proj_bias,
|
|
|
- self.config.num_distractors,
|
|
|
- self.config.logit_temp,
|
|
|
- self.config.diversity_loss_weight,
|
|
|
- device=self.device,
|
|
|
- dtype=self.dtype,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def create_wav2vec2_layer_output_model(
|
|
|
- config: Wav2Vec2Config,
|
|
|
- device: Optional[Device] = None,
|
|
|
- dtype: Optional[DataType] = None,
|
|
|
-) -> Wav2Vec2Model:
|
|
|
- """Create a wav2vec 2.0 model.
|
|
|
-
|
|
|
- :param config:
|
|
|
- The configuration to use.
|
|
|
- :param device:
|
|
|
- The device on which to initialize modules.
|
|
|
- :param dtype:
|
|
|
- The data type of module parameters and buffers.
|
|
|
- """
|
|
|
- encoder_builder = Wav2Vec2EncoderBuilder(config.encoder_config, device, dtype)
|
|
|
-
|
|
|
- return Wav2Vec2LayerOutputBuilder(
|
|
|
- config, encoder_builder, device, dtype
|
|
|
- ).build_model()
|
|
|
-
|
|
|
-
|
|
|
-load_wav2vec2_layer_output_config = ModelConfigLoader[Wav2Vec2Config](
|
|
|
- asset_store, wav2vec2_archs
|
|
|
-)
|
|
|
-
|
|
|
-load_wav2vec2_layer_output_model = Wav2Vec2Loader(
|
|
|
- asset_store,
|
|
|
- download_manager,
|
|
|
- create_wav2vec2_layer_output_model,
|
|
|
- wav2vec2_archs,
|
|
|
- # `weight_norm` used in `Wav2Vec2PositionEncoder` does not support meta
|
|
|
- # initialization.
|
|
|
- use_meta=False,
|
|
|
-)
|