|
@@ -82,7 +82,7 @@ def test_convert_linear(tmp_path: Path) -> None:
|
|
|
module = fairseq2.nn.Linear(16, 24, True)
|
|
|
|
|
|
layer_config = read_layer_config(module)
|
|
|
- assert layer_config == {"input_dim": 16, "output_dim": 24, "skip_init": False}
|
|
|
+ assert layer_config == {"input_dim": 16, "output_dim": 24}
|
|
|
|
|
|
module_file = Path("module.ggml")
|
|
|
convert_model(module, module_file)
|
|
@@ -96,8 +96,8 @@ def test_convert_linear(tmp_path: Path) -> None:
|
|
|
|
|
|
def test_causal_attention_mask(ctx: Ctx):
|
|
|
x = torch.zeros((1, 10, 32))
|
|
|
- generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
|
|
|
- mask_exp = generator(x).numpy()
|
|
|
+ generator = fairseq2.nn.transformer.CausalAttentionMaskFactory()
|
|
|
+ mask_exp = generator(x, x).materialize().numpy()
|
|
|
|
|
|
gx = ggml.from_numpy(ctx, x)
|
|
|
gmask = ggml.causal_attention_mask(ctx, gx)
|
|
@@ -111,7 +111,7 @@ def test_causal_attention_mask(ctx: Ctx):
|
|
|
assert np.all(mask == mask_exp)
|
|
|
|
|
|
x = x[:, :8, :]
|
|
|
- mask_exp = generator(x).numpy()
|
|
|
+ mask_exp = generator(x, x).materialize().numpy()
|
|
|
gx = ggml.from_numpy(ctx, x)
|
|
|
gmask = ggml.causal_attention_mask(ctx, gx)
|
|
|
mask = ggml.to_numpy(gmask)
|
|
@@ -209,10 +209,10 @@ def test_MultiheadAttention_forward(
|
|
|
y = ggml.to_numpy(gy)
|
|
|
nodes = ggml.nodes(gf)
|
|
|
|
|
|
- attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
|
|
|
+ attn_weights_hook = fairseq2.nn.transformer.AttentionWeightStoreHook([])
|
|
|
self_attn.register_attn_weight_hook(attn_weights_hook)
|
|
|
|
|
|
- y_exp = self_attn(xq, None, xk, xk).numpy()
|
|
|
+ y_exp = self_attn(xq, None, xk, None, xk).numpy()
|
|
|
|
|
|
q = ggml.to_numpy(nodes[b"q"])
|
|
|
assert q.shape == q_exp.shape
|
|
@@ -221,15 +221,15 @@ def test_MultiheadAttention_forward(
|
|
|
# with flash_attn we don't have attn_weights
|
|
|
naive_attn = b"attn_weights" in nodes
|
|
|
if naive_attn:
|
|
|
- attn_weights = ggml.to_numpy(nodes[b"attn_weights"])
|
|
|
- [attn_weights_exp] = attn_weights_hook._storage
|
|
|
+ attn_weights = ggml.to_numpy(nodes[b"attn_weights"]).reshape(-1, 16, qlen, klen)
|
|
|
+ [(_, attn_weights_exp)] = attn_weights_hook._storage
|
|
|
attn_weights_exp = attn_weights_exp.numpy()
|
|
|
assert attn_weights_exp.shape == attn_weights.shape
|
|
|
# GGML is very agressively reducing small softmax weights to 0,
|
|
|
# so the error isn't that small
|
|
|
assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
|
|
|
# But the sums should be close to 1
|
|
|
- assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, qlen)))
|
|
|
+ assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2, 16, qlen)))
|
|
|
# And the maximum index should match the original ones.
|
|
|
assert np.allclose(
|
|
|
np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
|
|
@@ -248,13 +248,13 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
|
|
|
- state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
+ state_bag = fairseq2.nn.IncrementalStateBag(100)
|
|
|
|
|
|
with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding
|
|
|
for t in range(3):
|
|
|
xq = x[:, t : t + 1]
|
|
|
- y_exp = attn(xq, None, xq, xq, state_bag=state_bag).numpy()
|
|
|
+ y_exp = attn(xq, None, xq, None, xq, state_bag=state_bag).numpy()
|
|
|
assert y_exp.shape == (2, 1, 1024)
|
|
|
|
|
|
gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
@@ -272,13 +272,11 @@ def test_MultiheadAttention_forward_self_attn_with_cache(
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
nodes = ggml.nodes(gf)
|
|
|
- state = state_bag.get_state(
|
|
|
- attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
- )
|
|
|
+ state = state_bag.get_state(attn, fairseq2.nn.transformer.AttentionState)
|
|
|
state_bag.increment_step()
|
|
|
assert state is not None
|
|
|
assert np.allclose(
|
|
|
- state.prev_k.numpy(),
|
|
|
+ state.get()[0].transpose(1, 2).reshape(2, t + 1, -1).numpy(),
|
|
|
ggml.to_numpy(
|
|
|
nodes[b"text_decoder.layers.0.self_attn.k_cache (step=%d)" % t]
|
|
|
),
|
|
@@ -299,7 +297,7 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
|
|
|
- state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
+ state_bag = fairseq2.nn.IncrementalStateBag(100)
|
|
|
|
|
|
with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding, the keys come from the encoder, and don't change during decoding
|
|
@@ -329,11 +327,11 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
if t > 0:
|
|
|
# the cache only appear in the graph during the second call
|
|
|
state = state_bag.get_state(
|
|
|
- attn, fairseq2.nn.transformer.MultiheadAttentionState
|
|
|
+ attn, fairseq2.nn.transformer.AttentionState
|
|
|
)
|
|
|
assert state is not None
|
|
|
assert np.allclose(
|
|
|
- state.prev_k.numpy(),
|
|
|
+ state.get()[0].transpose(1, 2).reshape(2, 11, -1).numpy(),
|
|
|
ggml.to_numpy(
|
|
|
nodes[b"text_decoder.layers.0.encoder_decoder_attn.k_cache"]
|
|
|
),
|
|
@@ -341,14 +339,13 @@ def test_MultiheadAttention_forward_cross_attn_with_cache(
|
|
|
)
|
|
|
|
|
|
state_bag.increment_step()
|
|
|
- y_exp = attn(xq, None, xk, xk, state_bag=state_bag).numpy()
|
|
|
+ y_exp = attn(xq, None, xk, None, xk, state_bag=state_bag).numpy()
|
|
|
assert y_exp.shape == (2, 1, 1024)
|
|
|
assert np.allclose(y, y_exp, atol=1e-2)
|
|
|
|
|
|
|
|
|
def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
- padding_mask = torch.ones((2, 21))
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
|
|
@@ -357,7 +354,8 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
|
|
|
|
|
|
gx = ggml.from_numpy(ctx, x)
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
- gpad = ggml.from_numpy(ctx, padding_mask)
|
|
|
+ padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
|
|
|
+ gpad = ggml.from_numpy(ctx, padding_mask.materialize())
|
|
|
ggml.ggml_set_name(gpad, b"padding_mask")
|
|
|
gy = ggml.forward(
|
|
|
"StandardTransformerEncoderLayer",
|
|
@@ -371,7 +369,7 @@ def test_StandardTransformerEncoderLayer_forward(ctx: Ctx, g_model: c_void_p) ->
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
|
|
|
|
- y_exp, _ = layer(x, padding_mask)
|
|
|
+ y_exp, _ = layer(x, padding_mask=None)
|
|
|
y_exp = y_exp.numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
@@ -440,13 +438,13 @@ def test_StandardConformerEncoderAdaptorLayer_forward(
|
|
|
|
|
|
def test_StandardTransformerEncoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
- padding_mask = torch.ones((2, 21))
|
|
|
+ padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([21, 21]), 21)
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
|
|
|
gx = ggml.from_numpy(ctx, x)
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
- gpad = ggml.from_numpy(ctx, padding_mask)
|
|
|
+ gpad = ggml.from_numpy(ctx, padding_mask.materialize())
|
|
|
ggml.ggml_set_name(gpad, b"padding_mask")
|
|
|
gy = ggml.forward(
|
|
|
"StandardTransformerEncoder",
|
|
@@ -532,8 +530,8 @@ def test_WaveformToFbank_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
"sample_rate": 16000.0,
|
|
|
"format": -1,
|
|
|
}
|
|
|
- y_exp = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)[0]
|
|
|
- y_exp = y_exp.numpy()
|
|
|
+ y_exp, _ = extractor(converter(converter_input)["fbank"].unsqueeze(0), None)
|
|
|
+ y_exp = y_exp.squeeze(0).numpy()
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
assert np.allclose(y_exp, y, atol=4e-3) # reduce? error is from standardization
|
|
@@ -563,7 +561,7 @@ def test_PositionalEmbedding_forward_with_cache(ctx: Ctx, g_model: c_void_p) ->
|
|
|
seq = torch.zeros((4, 20, 1024), dtype=torch.float32)
|
|
|
pos_encoder = fairseq2.nn.SinusoidalPositionEncoder(1024, 55, _legacy_pad_idx=0)
|
|
|
pos_encoder.eval()
|
|
|
- state_bag = fairseq2.nn.IncrementalStateBag()
|
|
|
+ state_bag = fairseq2.nn.IncrementalStateBag(100)
|
|
|
|
|
|
with ggml.fairseq2_kv_cache_alloc(g_model, 2, 21):
|
|
|
# Incremental decoding
|
|
@@ -607,13 +605,13 @@ def test_TransformerEmbeddingFrontend_forward(ctx: Ctx, g_model: c_void_p) -> No
|
|
|
def test_StandardTransformerDecoder_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
x = torch.empty((2, 13, 1024))
|
|
|
encoder_out = torch.empty((2, 21, 1024))
|
|
|
- padding_mask = torch.ones((2, 13))
|
|
|
+ padding_mask = fairseq2.nn.padding.PaddingMask(torch.tensor([13, 13]), 13)
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
|
torch.nn.init.uniform_(encoder_out, -1, 1)
|
|
|
gx = ggml.from_numpy(ctx, x)
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
- gpad = ggml.from_numpy(ctx, padding_mask)
|
|
|
+ gpad = ggml.from_numpy(ctx, padding_mask.materialize())
|
|
|
ggml.ggml_set_name(gpad, b"padding_mask")
|
|
|
genc = ggml.from_numpy(ctx, encoder_out)
|
|
|
gy = ggml.forward(
|
|
@@ -655,11 +653,13 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
translator.model,
|
|
|
translator.text_tokenizer,
|
|
|
translator.unit_tokenizer,
|
|
|
- src,
|
|
|
+ src["seqs"],
|
|
|
+ None,
|
|
|
input_modality=Modality.TEXT,
|
|
|
output_modality=Modality.TEXT,
|
|
|
tgt_lang=tgt_lang,
|
|
|
- beam_size=beam_size,
|
|
|
+ text_generation_opts=SequenceGeneratorOptions(beam_size=beam_size),
|
|
|
+ unit_generation_opts=None,
|
|
|
)
|
|
|
|
|
|
tgt_text = str(text_out.sentences[0])
|
|
@@ -675,14 +675,12 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
np.savez(
|
|
|
sample_file,
|
|
|
encoder_output=text_out.encoder_output.numpy(),
|
|
|
- encoder_padding_mask=text_out.encoder_padding_mask.numpy(),
|
|
|
hypotheses=hypotheses,
|
|
|
)
|
|
|
|
|
|
# allow_pickle to load the hyp dicts
|
|
|
text_out = np.load(sample_file, allow_pickle=True)
|
|
|
encoder_out = ggml.from_numpy(ctx, text_out["encoder_output"])
|
|
|
- encoder_padding_mask = ggml.from_numpy(ctx, text_out["encoder_padding_mask"])
|
|
|
prefix_seq = np.array(text_out["hypotheses"][0]["seq"][:2]).astype(np.int32)
|
|
|
max_seq_len = max(len(h["seq"]) for h in text_out["hypotheses"])
|
|
|
|
|
@@ -705,9 +703,7 @@ def test_t2tt(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
eos_idx=3,
|
|
|
)
|
|
|
|
|
|
- result_ptr = ggml.generate_sequence(
|
|
|
- g_model, job, encoder_out, encoder_padding_mask, ctx
|
|
|
- )
|
|
|
+ result_ptr = ggml.generate_sequence(g_model, job, encoder_out, NULLPTR, ctx)
|
|
|
results = [result_ptr[i] for i in range(beam_size) if result_ptr[i].seq != None]
|
|
|
|
|
|
# The step score error is big, this may negatively impact the beam search.
|