diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index da7a9295561c..4cc4a4a16a1c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); } - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live + if (self_kq_mask && self_kq_mask->buffer) { + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); } - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + if (self_kq_mask && self_kq_mask->buffer) { + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; }