|
@@ -10,14 +10,8 @@ from torch import Tensor
|
|
|
from torch.nn import Conv1d, Dropout, Module, ReLU
|
|
|
|
|
|
from fairseq2.nn.normalization import LayerNorm
|
|
|
-from fairseq2.nn.transformer import (
|
|
|
- AttentionMask,
|
|
|
- TransformerDecoderLayer,
|
|
|
- MultiheadAttention,
|
|
|
-)
|
|
|
-from fairseq2.nn.incremental_state import IncrementalStateBag
|
|
|
+from fairseq2.nn.transformer import MultiheadAttention, create_standard_layer_norm
|
|
|
from fairseq2.nn.padding import PaddingMask, apply_padding_mask
|
|
|
-from fairseq2.nn.transformer import create_standard_layer_norm
|
|
|
from fairseq2.typing import DataType, Device, finaloverride
|
|
|
|
|
|
|
|
@@ -107,10 +101,11 @@ class Conv1dBlock(Module):
|
|
|
|
|
|
|
|
|
@final
|
|
|
-class NARTransformerDecoderLayer(TransformerDecoderLayer):
|
|
|
+class NARTransformerDecoderLayer(Module):
|
|
|
"""Represents the FFT Block as described in
|
|
|
:cite:t:`https://arxiv.org/pdf/1905.09263.pdf`."""
|
|
|
|
|
|
+ model_dim: int
|
|
|
self_attn: MultiheadAttention
|
|
|
self_attn_dropout: Optional[Dropout]
|
|
|
self_attn_layer_norm: LayerNorm
|
|
@@ -137,9 +132,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
|
|
|
:param conv1d_dropout_p:
|
|
|
The dropout probability on the outputs of the conv1d block.
|
|
|
"""
|
|
|
- model_dim = self_attn.model_dim
|
|
|
+ super().__init__()
|
|
|
|
|
|
- super().__init__(model_dim)
|
|
|
+ self.model_dim = self_attn.model_dim
|
|
|
|
|
|
self.self_attn = self_attn
|
|
|
|
|
@@ -151,7 +146,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
|
|
|
layer_norm_factory = create_standard_layer_norm
|
|
|
|
|
|
self.self_attn_layer_norm = layer_norm_factory(
|
|
|
- model_dim, device=device, dtype=dtype
|
|
|
+ self.model_dim, device=device, dtype=dtype
|
|
|
)
|
|
|
|
|
|
self.conv1d = conv1d
|
|
@@ -162,7 +157,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
|
|
|
self.register_module("conv1d_dropout", None)
|
|
|
|
|
|
self.conv1d_layer_norm = layer_norm_factory(
|
|
|
- model_dim, device=device, dtype=dtype
|
|
|
+ self.model_dim, device=device, dtype=dtype
|
|
|
)
|
|
|
|
|
|
@finaloverride
|
|
@@ -170,10 +165,6 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
|
|
|
self,
|
|
|
seqs: Tensor,
|
|
|
padding_mask: Optional[PaddingMask],
|
|
|
- self_attn_mask: Optional[AttentionMask] = None,
|
|
|
- encoder_output: Optional[Tensor] = None,
|
|
|
- encoder_padding_mask: Optional[PaddingMask] = None,
|
|
|
- state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
seqs = self._forward_self_attn(seqs, padding_mask)
|
|
|
|