diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index 8ac91009..00d13e57 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -277,6 +277,11 @@ if(DFLASH27B_TESTS) target_include_directories(test_vs_oracle PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) target_link_libraries(test_vs_oracle PRIVATE dflash27b ggml ggml-cuda) endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_swa_mask_contract.cpp") + add_executable(test_draft_swa_mask_contract test/test_draft_swa_mask_contract.cpp) + target_include_directories(test_draft_swa_mask_contract PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_draft_swa_mask_contract PRIVATE dflash27b ggml ggml-cuda) + endif() if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_load_target.cpp") add_executable(smoke_load_target test/smoke_load_target.cpp) target_include_directories(smoke_load_target PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) diff --git a/dflash/src/dflash_graph.h b/dflash/src/dflash_graph.h index 304ff8e3..5b5204b6 100644 --- a/dflash/src/dflash_graph.h +++ b/dflash/src/dflash_graph.h @@ -1,6 +1,9 @@ // Shared inputs/outputs for the DFlash draft graph builder. #pragma once +#include +#include + #include "ggml.h" namespace dflash27b { @@ -13,11 +16,15 @@ struct DraftGraphInputs { ggml_tensor * target_hidden_cat;// [5*hidden, ctx_len, 1] f32 ggml_tensor * positions_q; // [q_len] i32 values [ctx_len..ctx_len+q_len-1] ggml_tensor * positions_k; // [ctx_len+q_len] i32 values [0..ctx_len+q_len-1] + // Optional SWA mask for long-context sliding-attention layers. + // Shape [kv_len, q_len] or padded [kv_pad, q_pad], type F16, values + // 0 for visible positions and -inf for masked positions. + ggml_tensor * attn_mask = nullptr; // Optional: if non-null, the graph projects final hidden states through // this LM head (shape [hidden, vocab]) and returns logits instead of // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. - ggml_tensor * lm_head; + ggml_tensor * lm_head = nullptr; }; struct DraftGraphOutputs { @@ -30,4 +37,10 @@ DraftGraphOutputs build_draft_graph( const DraftWeights & w, const DraftGraphInputs & in); +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len); +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window); + } // namespace dflash27b diff --git a/dflash/src/qwen3_dflash_graph.cpp b/dflash/src/qwen3_dflash_graph.cpp index 638bd873..2060d00e 100644 --- a/dflash/src/qwen3_dflash_graph.cpp +++ b/dflash/src/qwen3_dflash_graph.cpp @@ -31,10 +31,46 @@ #include "internal.h" #include "dflash_graph.h" +#include #include +#include namespace dflash27b { +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len) { + if (w.swa_window <= 0) { + return false; + } + const int total_k = ctx_len + DFLASH27B_DRAFT_BLOCK_SIZE; + if (total_k <= w.swa_window) { + return false; + } + for (int il = 0; il < w.n_layer; ++il) { + if (w.layers[il].is_swa) { + return true; + } + } + return false; +} + +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window) { + static constexpr uint16_t F16_ZERO = 0x0000; + static constexpr uint16_t F16_NEG_INF = 0xFC00; + + const int total_k = ctx_len + q_len; + out.assign((size_t)total_k * q_len, F16_NEG_INF); + for (int q = 0; q < q_len; ++q) { + const int abs_q = ctx_len + q; + const int min_k = std::max(0, abs_q - swa_window); + for (int k = min_k; k < total_k; ++k) { + out[(size_t)q * total_k + k] = F16_ZERO; + } + } +} + DraftGraphOutputs build_draft_graph( ggml_context * ctx, const DraftWeights & w, @@ -118,8 +154,36 @@ DraftGraphOutputs build_draft_graph( V = ggml_cont (ctx, V); // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // For SWA layers (Qwen3.6 draft): apply sliding window mask + // limiting context K/V to the last `swa_window` positions. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * attn_mask = nullptr; + if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) { + if (!in.attn_mask) { + set_last_error("build_draft_graph: SWA layer requires a non-null attn_mask"); + return {}; + } + if (in.attn_mask->type != GGML_TYPE_F16) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask must be F16, got %s", + ggml_type_name(in.attn_mask->type)); + set_last_error(buf); + return {}; + } + if (in.attn_mask->ne[0] < total_k || in.attn_mask->ne[1] < q_len) { + char buf[160]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask too small (%lld x %lld, need >= %d x %d)", + (long long)in.attn_mask->ne[0], + (long long)in.attn_mask->ne[1], + total_k, q_len); + set_last_error(buf); + return {}; + } + attn_mask = in.attn_mask; + } + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/test/smoke_draft_graph.cpp b/dflash/test/smoke_draft_graph.cpp index 16672216..8a1fa02e 100644 --- a/dflash/test/smoke_draft_graph.cpp +++ b/dflash/test/smoke_draft_graph.cpp @@ -85,6 +85,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, fc_in, ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, ctx_len + q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, ctx_len + q_len, q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -101,6 +107,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) { std::fprintf(stderr, "build_draft_graph returned null\n"); return 1; } @@ -141,6 +148,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < ctx_len + q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); } + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, ctx_len, q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // ── 7. Compute auto status = ggml_backend_graph_compute(backend, gf); diff --git a/dflash/test/test_draft_swa_mask_contract.cpp b/dflash/test/test_draft_swa_mask_contract.cpp new file mode 100644 index 00000000..460ba35b --- /dev/null +++ b/dflash/test/test_draft_swa_mask_contract.cpp @@ -0,0 +1,177 @@ +#include "dflash_graph.h" +#include "internal.h" + +#include "ggml.h" + +#include +#include + +using namespace dflash27b; + +namespace { + +struct GraphCase { + bool is_swa = false; + int swa_window = 0; + int ctx_len = 0; + bool provide_mask = false; + bool expect_mask = false; + const char * label = ""; +}; + +ggml_tensor * new_vec(ggml_context * ctx, int64_t n) { + return ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n); +} + +ggml_tensor * new_mat(ggml_context * ctx, int64_t ne0, int64_t ne1) { + return ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, ne1); +} + +bool run_case(const GraphCase & tc) { + ggml_init_params ip{}; + ip.mem_size = 2 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + std::fprintf(stderr, "FAIL %s: ggml_init failed\n", tc.label); + return false; + } + + constexpr int hidden = 8; + constexpr int n_head = 2; + constexpr int n_kv = 1; + constexpr int head_dim = 4; + constexpr int q_len = DFLASH27B_DRAFT_BLOCK_SIZE; + constexpr int inter = 12; + constexpr int fc_in = 5 * hidden; + const int total_k = tc.ctx_len + q_len; + + DraftWeights w{}; + w.n_layer = 1; + w.n_head = n_head; + w.n_head_kv = n_kv; + w.head_dim = head_dim; + w.swa_window = tc.swa_window; + w.layers.resize(1); + + w.fc = new_mat(ctx, fc_in, hidden); + w.hidden_norm = new_vec(ctx, hidden); + w.out_norm = new_vec(ctx, hidden); + + DraftLayer & layer = w.layers[0]; + layer.attn_norm = new_vec(ctx, hidden); + layer.ffn_norm = new_vec(ctx, hidden); + layer.wq = new_mat(ctx, hidden, n_head * head_dim); + layer.wk = new_mat(ctx, hidden, n_kv * head_dim); + layer.wv = new_mat(ctx, hidden, n_kv * head_dim); + layer.wo = new_mat(ctx, n_head * head_dim, hidden); + layer.q_norm = new_vec(ctx, head_dim); + layer.k_norm = new_vec(ctx, head_dim); + layer.w_gate = new_mat(ctx, hidden, inter); + layer.w_up = new_mat(ctx, hidden, inter); + layer.w_down = new_mat(ctx, inter, hidden); + layer.is_swa = tc.is_swa; + + ggml_tensor * noise_embed = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden, q_len, 1); + ggml_tensor * target_hidden_cat = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, fc_in, tc.ctx_len, 1); + ggml_tensor * positions_q = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, q_len); + ggml_tensor * positions_k = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, total_k); + ggml_tensor * attn_mask = nullptr; + if (tc.provide_mask) { + attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, total_k, q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } + + ggml_set_input(noise_embed); + ggml_set_input(target_hidden_cat); + ggml_set_input(positions_q); + ggml_set_input(positions_k); + + DraftGraphInputs in{}; + in.ctx_len = tc.ctx_len; + in.noise_embed = noise_embed; + in.target_hidden_cat = target_hidden_cat; + in.positions_q = positions_q; + in.positions_k = positions_k; + in.attn_mask = attn_mask; + + DraftGraphOutputs out = build_draft_graph(ctx, w, in); + if (!out.hidden_states) { + std::fprintf(stderr, "FAIL %s: build_draft_graph failed: %s\n", + tc.label, dflash27b_last_error()); + ggml_free(ctx); + return false; + } + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 256, false); + ggml_build_forward_expand(gf, out.hidden_states); + + ggml_tensor * flash = nullptr; + for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) { + ggml_tensor * node = ggml_graph_node(gf, i); + if (node && node->op == GGML_OP_FLASH_ATTN_EXT) { + flash = node; + break; + } + } + + if (!flash) { + std::fprintf(stderr, "FAIL %s: no flash_attn_ext node found\n", tc.label); + ggml_free(ctx); + return false; + } + + const bool got_mask = flash->src[3] != nullptr; + if (got_mask != tc.expect_mask) { + std::fprintf(stderr, "FAIL %s: expected mask=%d got mask=%d\n", + tc.label, tc.expect_mask ? 1 : 0, got_mask ? 1 : 0); + ggml_free(ctx); + return false; + } + if (tc.expect_mask && flash->src[3] != attn_mask) { + std::fprintf(stderr, "FAIL %s: flash_attn_ext did not use caller mask tensor\n", tc.label); + ggml_free(ctx); + return false; + } + + std::printf("PASS %s\n", tc.label); + ggml_free(ctx); + return true; +} + +} // namespace + +int main() { + std::vector cases(3); + cases[0].is_swa = true; + cases[0].swa_window = 8; + cases[0].ctx_len = 12; + cases[0].provide_mask = true; + cases[0].expect_mask = true; + cases[0].label = "swa-long-context-wires-mask"; + + cases[1].is_swa = false; + cases[1].swa_window = 8; + cases[1].ctx_len = 12; + cases[1].provide_mask = true; + cases[1].expect_mask = false; + cases[1].label = "non-swa-layer-ignores-mask"; + + cases[2].is_swa = true; + cases[2].swa_window = 64; + cases[2].ctx_len = 12; + cases[2].provide_mask = true; + cases[2].expect_mask = false; + cases[2].label = "short-context-keeps-full-attn"; + + int failed = 0; + for (const GraphCase & tc : cases) { + if (!run_case(tc)) { + ++failed; + } + } + + return failed == 0 ? 0 : 1; +} diff --git a/dflash/test/test_vs_oracle.cpp b/dflash/test/test_vs_oracle.cpp index b0c247d9..352139bc 100644 --- a/dflash/test/test_vs_oracle.cpp +++ b/dflash/test/test_vs_oracle.cpp @@ -117,6 +117,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, m.fc_in, m.ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.ctx_len + m.q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, m.ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, m.ctx_len + m.q_len, m.q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -132,6 +138,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) return 1; ggml_set_output(go.hidden_states); @@ -154,6 +161,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < m.ctx_len + m.q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_q, pq.data(), 0, sizeof(int32_t) * pq.size()); ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, m.ctx_len, m.q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // Compute auto status = ggml_backend_graph_compute(backend, gf);