Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const uint32_t K = (uint32_t) src5->ne[1];
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
Expand All @@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,

(uint32_t) src0->ne[1],
(uint32_t) (src2->ne[3] / src0->ne[3]),
K,
scale_u32,
};

Expand Down
24 changes: 20 additions & 4 deletions ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct Params {

neq1: u32,
rq3: u32,
K: u32,
scale: f32,
};

Expand All @@ -62,11 +63,14 @@ fn main(
let iq3 = seq_id / params.rq3;

let state_size = S_V * S_V;
let state_base = (seq_id * params.h + head_id) * state_size;
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
let state_out_base = (seq_id * params.h + head_id) * state_size;
let state_size_per_snap = state_size * params.h * params.n_seqs;
let shift = i32(params.n_tokens) - i32(params.K);

var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
state[i] = src_state[state_base + col * S_V + i];
state[i] = src_state[state_in_base + col * S_V + i];
}

var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
Expand Down Expand Up @@ -123,10 +127,22 @@ fn main(
dst[attn_off + col] = attn_col * params.scale;
attn_off += S_V * params.h;

if (params.K > 1u) {
let target_slot = i32(t) - shift;
if (target_slot >= 0 && target_slot < i32(params.K)) {
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
for (var i = 0u; i < S_V; i++) {
dst[slot_base + col * S_V + i] = state[i];
}
}
}

workgroupBarrier();
}

for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_base + col * S_V + i] = state[i];
if (params.K == 1u) {
for (var i = 0u; i < S_V; i++) {
dst[params.s_off + state_out_base + col * S_V + i] = state[i];
}
}
}