Browse Source

has_layer + transformer decoder

Guillaume Wenzek 1 year ago
parent
commit
f1f33dbec1
4 changed files with 227 additions and 25 deletions
  1. 152 21
      ggml/examples/unity/fairseq2.cpp
  2. 15 0
      ggml/examples/unity/model_loader.cpp
  3. 6 0
      ggml/ggml.py
  4. 54 4
      ggml/test_unity_cpp.py

+ 152 - 21
ggml/examples/unity/fairseq2.cpp

@@ -32,6 +32,9 @@ extern "C" void std_string_free(std::string* str) {
     delete str;
     delete str;
 }
 }
 
 
+bool has_layer(fairseq2_model& model, const std::string& name) {
+    return model.tensors.find(name) != model.tensors.end();
+}
 
 
 extern "C" ggml_tensor* Linear_forward(
 extern "C" ggml_tensor* Linear_forward(
     fairseq2_model& model,
     fairseq2_model& model,
@@ -80,13 +83,10 @@ extern "C" ggml_tensor* StandardFeedForwardNetwork_forward(
     // inner_activation = ReLu // TODO: allow other activation
     // inner_activation = ReLu // TODO: allow other activation
     seqs = ggml_relu(model.ctx, seqs);
     seqs = ggml_relu(model.ctx, seqs);
 
 
-    if (model.tensors.find(prefix + ".inner_layer_norm.weight") != model.tensors.end()) {
+    if (has_layer(model, prefix + ".inner_layer_norm")) {
         seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
         seqs = LayerNorm_forward(model, prefix + ".inner_layer_norm", seqs);
     }
     }
 
 
-    // TODO: inference dropout
-    // if self.inner_dropout is not None:
-    //     seqs = self.inner_dropout(seqs)
     seqs = Linear_forward(model, prefix + ".output_proj", seqs);
     seqs = Linear_forward(model, prefix + ".output_proj", seqs);
     return seqs;
     return seqs;
 }
 }
@@ -167,11 +167,6 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
 }
 }
 
 
 
 
-bool has_layer(fairseq2_model& model, const std::string& name) {
-    return model.tensors.find(name) != model.tensors.end();
-}
-
-
 extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
 extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
     fairseq2_model& model,
     fairseq2_model& model,
     const std::string& prefix,
     const std::string& prefix,
@@ -198,11 +193,9 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
         /*attention masks=*/nullptr
         /*attention masks=*/nullptr
     );
     );
 
 
-    if (has_layer(model, prefix + ".self_attn_norm.weight"))
+    if (has_layer(model, prefix + ".self_attn_norm"))
         seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
         seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
 
 
-    // TODO: seqs = self.self_attn_dropout(seqs)
-
     seqs = ggml_add(ctx, seqs, residual);
     seqs = ggml_add(ctx, seqs, residual);
 
 
     if (norm_order == TRANSFORMER_NORM_ORDER_POST)
     if (norm_order == TRANSFORMER_NORM_ORDER_POST)
@@ -216,9 +209,7 @@ extern "C" ggml_tensor* StandardTransformerEncoderLayer_forward(
 
 
     seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
     seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
 
 
-    // TODO:
-    // seqs = self.ffn_dropout(seqs)
-    // if self.residual_scale is not None:
+    // TODO: if self.residual_scale is not None:
     // residual = self.residual_scale * residual
     // residual = self.residual_scale * residual
 
 
     seqs = ggml_add(ctx, seqs, residual);
     seqs = ggml_add(ctx, seqs, residual);
@@ -237,17 +228,157 @@ extern "C" ggml_tensor* StandardTransformerEncoder_forward(
     ggml_tensor* padding_mask
     ggml_tensor* padding_mask
 ) {
 ) {
     int layer_idx = 0;
     int layer_idx = 0;
-    // TODO: this isn't nice.
-    // When loading model we should add nullptr for the module key to avoid those concatenation.
-    while (has_layer(model, prefix + ".layers." + std::to_string(layer_idx)  + ".self_attn_layer_norm.weight")) {
+    std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
+    while (has_layer(model, layer_name)) {
         seqs = StandardTransformerEncoderLayer_forward(
         seqs = StandardTransformerEncoderLayer_forward(
-            model, prefix + ".layers." + std::to_string(layer_idx), seqs, padding_mask
+            model, layer_name, seqs, padding_mask
+        );
+
+        ggml_set_name(seqs, ("x_enc_" + std::to_string(layer_idx)).c_str());
+        layer_idx += 1;
+        layer_name = prefix + ".layers." + std::to_string(layer_idx);
+    }
+
+    if (has_layer(model, prefix + ".layer_norm"))
+        seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
+
+    return seqs;
+}
+
+extern "C" ggml_tensor* StandardTransformerDecoderLayer_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* self_attn_mask,
+    ggml_tensor* encoder_output,
+    ggml_tensor* encoder_padding_mask
+) {
+    ggml_context* ctx = model.ctx;
+    // TODO: read norm_order from model
+    auto norm_order = TRANSFORMER_NORM_ORDER_PRE;
+
+    // _forward_self_attn(seqs, padding_mask)
+    auto residual = seqs;
+    if (norm_order != TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
+
+    seqs = MultiheadAttention_forward(
+        model,
+        prefix + ".self_attn",
+        seqs,
+        seqs,
+        seqs,
+        /*attention masks=*/self_attn_mask
+    );
+
+    if (has_layer(model, prefix + ".self_attn_norm"))
+        seqs = LayerNorm_forward(model, prefix + ".self_attn_norm", seqs);
+
+    seqs = ggml_add(ctx, seqs, residual);
+
+    if (norm_order == TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".self_attn_layer_norm", seqs);
+
+    // _forward_encoder_decoder_attn
+    if (! has_layer(model, prefix + ".encoder_decoder_attn")) {
+        // `encoder_output` must be `None` for decoder-only attention.
+        GGML_ASSERT(encoder_output == nullptr);
+        return seqs;
+    }
+
+    // `encoder_output` must not be `None` for encoder-decoder attention.
+    GGML_ASSERT(encoder_output != nullptr);
+
+    residual = seqs;
+
+    if (norm_order != TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
+
+
+    seqs = MultiheadAttention_forward(
+        model,
+        prefix + ".encoder_decoder_attn",
+        seqs,
+        encoder_output,
+        encoder_output,
+        /*attention masks=*/encoder_padding_mask
+    );
+
+    seqs = ggml_add(ctx, seqs, residual);
+
+    if (norm_order == TRANSFORMER_NORM_ORDER_POST)
+        seqs =  LayerNorm_forward(model, prefix + ".encoder_decoder_attn_layer_norm", seqs);
+
+    // _forward_ffn(seqs)
+    residual = seqs;
+
+    if (norm_order != TRANSFORMER_NORM_ORDER_POST)
+        seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
+
+    seqs = StandardFeedForwardNetwork_forward(model, prefix + ".ffn", seqs);
+
+    // TODO:
+    // if self.residual_scale is not None:
+    // residual = self.residual_scale * residual
+
+    seqs = ggml_add(ctx, seqs, residual);
+
+    if (norm_order == TRANSFORMER_NORM_ORDER_POST)
+        seqs = LayerNorm_forward(model, prefix + ".ffn_layer_norm", seqs);
+
+    return seqs;
+}
+
+ggml_tensor* causal_mask_cache = nullptr;
+
+extern "C" ggml_tensor* causal_attention_mask(ggml_context* ctx, ggml_tensor* seqs) {
+    auto seq_len = seqs->ne[0];
+    auto mask = causal_mask_cache;
+    // TODO: this cache only works as long as we don't change the size/device too often
+    // TODO: allow other ggml_type
+    if (mask == nullptr || mask->backend != seqs->backend || mask->ne[0] < seq_len) {
+        printf("new causal_mask (%ld, %ld) created\n", seq_len, seq_len);
+        mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, seq_len, seq_len);
+        char* data = (char*)mask->data;
+
+        // tensor([[0., -inf, -inf, -inf],
+        //         [0.,   0., -inf, -inf],
+        //         [0.,   0.,   0., -inf],
+        //         [0.,   0.,   0.,   0.]])
+        for (int i = 0; i < seq_len; ++i) {
+            char* row = data + i * mask->nb[1];
+            for (int j = 0; j <= i; ++j) {*(float*)(row + j * mask->nb[0]) = 0;}
+            for (int j = i + 1; j < seq_len; ++j) {*(float*)(row + j * mask->nb[0]) = -INFINITY;}
+        }
+
+        causal_mask_cache = mask;
+    }
+
+    return ggml_view_2d(ctx, mask, seq_len, seq_len, mask->nb[1], 0);
+}
+
+extern "C" ggml_tensor* StandardTransformerDecoder_forward(
+    fairseq2_model& model,
+    const std::string& prefix,
+    ggml_tensor* seqs,
+    ggml_tensor* padding_mask,
+    ggml_tensor* encoder_output,
+    ggml_tensor* encoder_padding_mask
+) {
+    int layer_idx = 0;
+    std::string layer_name = prefix + ".layers." + std::to_string(layer_idx);
+    ggml_tensor* self_attn_mask = causal_attention_mask(model.ctx, seqs);
+    while (has_layer(model, layer_name)) {
+        seqs = StandardTransformerDecoderLayer_forward(
+            model, layer_name, seqs, self_attn_mask, encoder_output, encoder_padding_mask
         );
         );
-        ggml_set_name(seqs, ("x_" + std::to_string(layer_idx)).c_str());
+
+        ggml_set_name(seqs, ("x_dec_" + std::to_string(layer_idx)).c_str());
         layer_idx += 1;
         layer_idx += 1;
+        layer_name = prefix + ".layers." + std::to_string(layer_idx);
     }
     }
 
 
-    if (has_layer(model, prefix + ".layer_norm.weight"))
+    if (has_layer(model, prefix + ".layer_norm"))
         seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
         seqs = LayerNorm_forward(model, prefix + ".layer_norm", seqs);
 
 
     return seqs;
     return seqs;

+ 15 - 0
ggml/examples/unity/model_loader.cpp

@@ -21,6 +21,20 @@ std::ifstream open_ggml_file(const char* fname) {
     return fin;
     return fin;
 }
 }
 
 
+void register_prefix(fairseq2_model &model, const std::string& name) {
+    std::size_t i = name.find_last_of('.');
+    while(i != std::string::npos && i > 0) {
+        std::string prefix = name.substr(0, i);
+        auto prev_tensor = model.tensors.find(prefix);
+        if (prev_tensor != model.tensors.end()) {
+            GGML_ASSERT(prev_tensor->second == nullptr);
+        }
+        model.tensors[prefix] = nullptr;
+        i = name.find_last_of('.', i - 1);
+    }
+}
+
+
 int
 int
 model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
 model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
 {
 {
@@ -35,6 +49,7 @@ model_loader::load_model_weights(fairseq2_model &model, std::ifstream &fin)
             printf("Error while reading tensor %s\n", name.c_str() );
             printf("Error while reading tensor %s\n", name.c_str() );
             return 1;
             return 1;
         }
         }
+        register_prefix(model, name);
         ggml_set_name(tensor, name.c_str());
         ggml_set_name(tensor, name.c_str());
         model.tensors[name] = tensor;
         model.tensors[name] = tensor;
         if (DEBUG_MODEL_LOAD) {
         if (DEBUG_MODEL_LOAD) {

+ 6 - 0
ggml/ggml.py

@@ -287,3 +287,9 @@ def forward(
 
 
     with CppStr(prefix) as std_prefix:
     with CppStr(prefix) as std_prefix:
         return fwd(model, std_prefix, *inputs)  # ignore: type[no-any-return]
         return fwd(model, std_prefix, *inputs)  # ignore: type[no-any-return]
+
+lib.causal_attention_mask.argtypes = [ggml_context_p, ctypes.POINTER(ggml_tensor)]
+lib.causal_attention_mask.restype = ctypes.POINTER(ggml_tensor)
+
+def causal_attention_mask(ctx: ggml_context_p, seqs: ggml_tensor_p) -> ggml_tensor_p:
+    return lib.causal_attention_mask(ctx, seqs)  # type: ignore[no-any-return]

+ 54 - 4
ggml/test_unity_cpp.py

@@ -427,7 +427,7 @@ def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     gf = ggml.ggml_build_forward(gy)
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
 
-    q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
+    # q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
 
 
     y = ggml.to_numpy(gy)
     y = ggml.to_numpy(gy)
     nodes = {}
     nodes = {}
@@ -444,9 +444,9 @@ def test_forward_self_attn(ctx: Ctx, g_model: c_void_p, pt_model: Any) -> None:
     y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
     y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
     y_exp = y_exp.squeeze(0)  # remove batch dimension
     y_exp = y_exp.squeeze(0)  # remove batch dimension
 
 
-    q = nodes[b"q"]
-    assert q.shape == q_exp.shape
-    assert np.allclose(q_exp, q, atol=1e-5)
+    # q = nodes[b"q"]
+    # assert q.shape == q_exp.shape
+    # assert np.allclose(q_exp, q, atol=1e-5)
 
 
     attn_exp, attn_weights_exp = map(
     attn_exp, attn_weights_exp = map(
         lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0]
         lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0]
@@ -535,3 +535,53 @@ def test_StandardTransformerEncoder_forward(
 
 
     assert y.shape == y_exp.shape
     assert y.shape == y_exp.shape
     assert np.allclose(y_exp, y, atol=1e-4)
     assert np.allclose(y_exp, y, atol=1e-4)
+
+
+def test_causal_attention_mask(ctx: Ctx):
+    x = torch.zeros((5, 10))
+    generator = fairseq2.nn.transformer.CausalAttentionMaskGenerator()
+    mask_exp = generator(x)
+
+    gx = ggml.from_numpy(ctx, x)
+    gmask = ggml.causal_attention_mask(ctx, gx)
+    mask = ggml.to_numpy(gmask)
+
+    assert mask_exp.shape == (10, 10)
+    assert mask.shape == (10, 10)
+    assert np.allclose(mask, mask_exp)
+
+
+
+def test_StandardTransformerDecoder_forward(
+    ctx: Ctx, g_model: c_void_p, pt_model: Any
+) -> None:
+    x = torch.empty((1, 13, 1024))
+    encoder_out = torch.empty((1, 21, 1024))
+    padding_mask = torch.ones((1, 13))
+    torch.random.manual_seed(0)
+    torch.nn.init.uniform_(x, -1, 1)
+    torch.nn.init.uniform_(encoder_out, -1, 1)
+    gx = ggml.from_numpy(ctx, x[0])
+    ggml.ggml_set_name(gx, b"x")
+    gpad = ggml.from_numpy(ctx, padding_mask[0])
+    ggml.ggml_set_name(gpad, b"padding_mask")
+    genc = ggml.from_numpy(ctx, encoder_out[0])
+    gy = ggml.forward(
+        "StandardTransformerDecoder",
+        g_model,
+        "text_decoder",
+        gx,
+        None,  # TODO support padding mask,
+        genc,
+        None,
+    )
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    y = ggml.to_numpy(gy)
+
+    y_exp, _ = pt_model.text_decoder(x, padding_mask, encoder_out, None)
+    y_exp = y_exp.squeeze(0).numpy()  # remove batch dimension
+
+    assert y.shape == y_exp.shape
+    assert np.allclose(y_exp, y, atol=1e-4)