Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e3abbbf
[SM70] Add V100 dense WNA16 TurboMind linear kernel
rivetphilbot May 19, 2026
788812f
[SM70] Admit V100 (CC 7.0) in CompressedTensorsWNA16
rivetphilbot May 19, 2026
1b6a4cd
[SM70] Wire compressed-tensors MoE decode buffers for V100
rivetphilbot May 19, 2026
bd6d0c8
[Qwen3] Keep router gate and split GDN projections unquantized under …
rivetphilbot May 19, 2026
4d13a60
[Qwen3.5] Skip tuple-shard split for non-output-dim CT params
rivetphilbot May 19, 2026
d4f98f3
[SM70] Sync sm70_884_4.cu kernel registry to lmdeploy main (gs32)
rivetphilbot May 19, 2026
74b0ebf
[gemma4][WIP] Add gemma-4 backbone + register Gemma4ForCausalLM / Gem…
philbert440 Jun 5, 2026
3f5998d
[gemma4] Backbone imports clean: vendor GateLinear + expert-mapping shim
philbert440 Jun 5, 2026
cf1781b
[gemma4][WIP] Vendor MTP drafter (gemma4_mtp) + spec-decode proposer
philbert440 Jun 5, 2026
2b28dbc
[gemma4] spec-decode proposer imports clean (redirect base-class import)
philbert440 Jun 5, 2026
c554f75
[gemma4] Config wiring: arch convertors + speculative gemma4_mtp reco…
philbert440 Jun 5, 2026
7827081
[gemma4] Wire proposer dispatch + base-proposer constant-positions path
philbert440 Jun 5, 2026
a8b8614
[gemma4] Add image+audio multimodal (tower-based Gemma4ForConditional…
philbert440 Jun 5, 2026
2aeb1cb
[gemma4] Runtime serve fixes: activation, proportional RoPE, MM guard
philbert440 Jun 6, 2026
b99d0db
[gemma4] Load vision-tower std_bias/std_scale persistent buffers
philbert440 Jun 6, 2026
694a9c2
[gemma4] Fix inference crash: don't assign read-only mm_prefix_range_…
philbert440 Jun 6, 2026
a54babb
[gemma4][perf] FlashAttnV100Backend.supports_head_size override (mixe…
philbert440 Jun 6, 2026
efb414a
[gemma4] Register gemma4_assistant config so MTP drafter loads
philbert440 Jun 7, 2026
f3376f6
[gemma4][mtp] Fix target KV-cache corruption from draft layers
philbert440 Jun 7, 2026
474f985
[gemma4] Load tower clip buffers + MTP ordered-embedding buffer
philbert440 Jun 7, 2026
2b93dd9
[gemma4][FA] Sliding-window support in Volta FLASH_ATTN_V100 kernels
philbert440 Jun 8, 2026
dbbd2df
[gemma4][FA] head_dim-512 decode + prefill window tile-skip
philbert440 Jun 8, 2026
47b5bbc
[gemma4][FA] head_dim-512 prefill (split via small blocks) -> fully-FA
philbert440 Jun 8, 2026
a2f453c
[FA_V100] Guard paged prefill against V100 smem overflow at long context
philbert440 Jun 8, 2026
68d48fc
[FA_V100] Run gemma E2B/E4B KV-shared layers on FA (read target cache)
philbert440 Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flash-attention-v100/flash_attn_v100/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def forward(
if causal and (window_size_left != -1 or window_size_right != -1):
if window_size_left > 0 and window_size_right > 0:
window_size_left, window_size_right = -1, -1
elif window_size_left >= 0 and window_size_right == 0:
# Causal sliding-window: query attends to window_size_left + 1
# tokens. Supported by the Volta dense kernel.
pass
else:
raise NotImplementedError(f"Unsupported window_size={window_size} with causal=True")

Expand Down Expand Up @@ -292,6 +296,7 @@ def flash_attn_decode_paged(
kv_cache_dtype: str = "auto",
k_scale: float = 1.0,
v_scale: float = 1.0,
window: int = -1,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
Expand Down Expand Up @@ -319,6 +324,7 @@ def flash_attn_decode_paged(
kv_cache_dtype,
float(k_scale),
float(v_scale),
int(window),
)

def flash_attn_prefill_paged(
Expand All @@ -333,6 +339,7 @@ def flash_attn_prefill_paged(
k_scale: float = 1.0,
v_scale: float = 1.0,
causal: bool = True,
window: int = -1,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
Expand All @@ -357,6 +364,7 @@ def flash_attn_prefill_paged(
float(k_scale),
float(v_scale),
causal,
int(window),
)
return out_.permute(0, 2, 1, 3).contiguous()

Expand Down
6 changes: 4 additions & 2 deletions flash-attention-v100/include/fused_mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ at::Tensor flash_attention_decode_paged(
const int partition_size,
const std::string& kv_cache_dtype,
const float k_scale,
const float v_scale
const float v_scale,
const int window
);

at::Tensor flash_attention_prefill_paged(
Expand All @@ -51,7 +52,8 @@ at::Tensor flash_attention_prefill_paged(
const std::string& kv_cache_dtype,
const float k_scale,
const float v_scale,
const bool is_causal
const bool is_causal,
const int window
);

std::vector<at::Tensor> flash_attention_backward(
Expand Down
34 changes: 30 additions & 4 deletions flash-attention-v100/kernel/flash_decode_paged.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ __global__ void flash_attention_decode_partition_kernel(
const int64_t v_head_stride,
const float softmax_scale,
const float k_scale,
const float v_scale) {
const float v_scale,
const int window) {
const int batch_idx = blockIdx.x;
const int head_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
Expand All @@ -197,6 +198,19 @@ __global__ void flash_attention_decode_partition_kernel(
}

const int part_tokens = min(PARTITION_SIZE, seq_len - start_token_idx);
// Sliding-window: the decode query sits at seq_len-1 and attends only to keys
// in [seq_len-window, seq_len-1]. If the whole partition predates the window,
// it contributes nothing -- write neutral stats so the reduce step skips it.
if (window >= 0 && start_token_idx + part_tokens <= seq_len - window) {
if (threadIdx.x == 0) {
const int64_t stats_index =
static_cast<int64_t>(batch_idx) * stats_stride0 +
static_cast<int64_t>(head_idx) * stats_stride1 + partition_idx;
max_logits[stats_index] = -1.0e20f;
exp_sums[stats_index] = 0.f;
}
return;
}
const int q_per_kv = num_heads_q / num_heads_kv;
const int kv_head_idx = head_idx / q_per_kv;
const int lane = threadIdx.x % kWarpSize;
Expand Down Expand Up @@ -227,6 +241,12 @@ __global__ void flash_attention_decode_partition_kernel(
float local_max = -1.0e20f;
for (int token_local = warp_idx; token_local < part_tokens;
token_local += kWarpsPerBlock) {
// Per-token sliding-window mask (token_local is warp-uniform, so the branch
// is uniform across the warp and we can skip the dot product entirely).
if (window >= 0 && start_token_idx + token_local < seq_len - window) {
if (lane == 0) scores_shared[token_local] = -1.0e20f;
continue;
}
const int physical_block = block_idx_shared[token_local];
const int block_offset = block_offset_shared[token_local];
const int64_t k_index =
Expand Down Expand Up @@ -386,6 +406,7 @@ void launch_flash_attention_decode_paged(
const float softmax_scale,
const float k_scale,
const float v_scale,
const int window,
cudaStream_t stream) {
const int batch_size = q.size(0);
const int num_heads_q = q.size(1);
Expand Down Expand Up @@ -431,7 +452,8 @@ void launch_flash_attention_decode_paged(
v_cache.stride(2),
softmax_scale,
k_scale,
v_scale);
v_scale,
window);

flash_attention_decode_reduce_kernel<D, PARTITION_SIZE><<<reduce_grid, block, reduce_shared_mem, stream>>>(
reinterpret_cast<const __half*>(tmp_out.data_ptr<at::Half>()),
Expand Down Expand Up @@ -467,7 +489,8 @@ at::Tensor flash_attention_decode_paged(
const int partition_size,
const std::string& kv_cache_dtype,
const float k_scale,
const float v_scale) {
const float v_scale,
const int window) {
TORCH_CHECK(q.is_cuda(), "q must be on CUDA");
TORCH_CHECK(k_cache.is_cuda() && v_cache.is_cuda(), "k/v cache must be on CUDA");
TORCH_CHECK(block_table.is_cuda() && seq_lens.is_cuda(), "block_table and seq_lens must be on CUDA");
Expand Down Expand Up @@ -536,7 +559,7 @@ at::Tensor flash_attention_decode_paged(
#define LAUNCH_TYPED(HDIM, PARTITION, KV_DTYPE_CODE) \
launch_flash_attention_decode_paged<HDIM, PARTITION, KV_DTYPE_CODE>( \
q, k_cache, v_cache, out, block_table, seq_lens, tmp_out, max_logits, \
exp_sums, softmax_scale, k_scale, v_scale, stream)
exp_sums, softmax_scale, k_scale, v_scale, window, stream)

#define LAUNCH_BY_KV_DTYPE(HDIM, PARTITION) \
do { \
Expand Down Expand Up @@ -591,6 +614,9 @@ at::Tensor flash_attention_decode_paged(
case 256:
LAUNCH_BY_PARTITION(256);
break;
case 512:
LAUNCH_BY_PARTITION(512);
break;
default:
TORCH_CHECK(false, "Unsupported head_dim for paged decode: ", head_dim);
}
Expand Down
8 changes: 6 additions & 2 deletions flash-attention-v100/kernel/flash_v100_traits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ struct FlashV100Traits {
static constexpr int BLOCK_N_128 = 176;
static constexpr int BLOCK_M_256 = 32;
static constexpr int BLOCK_N_256 = 64;
static constexpr int BLOCK_M_512 = 16;
static constexpr int BLOCK_N_512 = 32;
static constexpr int WARPS_PER_BLOCK = 16;
static constexpr int THREADS_PER_WARP = 32;

static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 :
(D == 32) ? BLOCK_M_32 :
(D == 64) ? BLOCK_M_64 :
(D == 128) ? BLOCK_M_128 : BLOCK_M_256;
(D == 128) ? BLOCK_M_128 :
(D == 512) ? BLOCK_M_512 : BLOCK_M_256;

static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 :
(D == 32) ? BLOCK_N_32 :
(D == 64) ? BLOCK_N_64 :
(D == 128) ? BLOCK_N_128 : BLOCK_N_256;
(D == 128) ? BLOCK_N_128 :
(D == 512) ? BLOCK_N_512 : BLOCK_N_256;

static constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * THREADS_PER_WARP;
static constexpr int THREADS_PER_ROW = THREADS_PER_BLOCK / BLOCK_M;
Expand Down
60 changes: 45 additions & 15 deletions flash-attention-v100/kernel/fused_mha_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,18 @@ using namespace nvcuda::wmma;
#define BLOCK_N_256 64
#define WARPS_256 16

// head_dim 512 (gemma global layers). Small blocks so the 512-wide Q/K/V/O
// tiles fit in 96KB smem (q+kv+o dominate). The QK k-loop already accumulates
// over all of D in WMMA_K chunks, so no body change is needed.
#define BLOCK_M_512 16
#define BLOCK_N_512 32
#define WARPS_512 16

template<int D>
struct KernelConfig {
static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : BLOCK_M_256;
static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : BLOCK_N_256;
static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : WARPS_256;
static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : (D == 512) ? BLOCK_M_512 : BLOCK_M_256;
static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : (D == 512) ? BLOCK_N_512 : BLOCK_N_256;
static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : (D == 512) ? WARPS_512 : WARPS_256;

static constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * MAX_THREADS_PER_WARP;
static constexpr int THREADS_PER_ROW = THREADS_PER_BLOCK / BLOCK_M;
Expand Down Expand Up @@ -109,7 +116,8 @@ flash_attention_forward_kernel(
const int H_KV,
const int M,
const int N,
const float softmax_scale
const float softmax_scale,
const int window
) {
using Config = KernelConfig<D>;
constexpr int BLOCK_M = Config::BLOCK_M;
Expand Down Expand Up @@ -197,7 +205,18 @@ flash_attention_forward_kernel(
}
__syncthreads();

for (int block_n = 0; block_n < num_n_tiles; ++block_n) {
// Sliding-window lower-bound tile skip: keys older than `window` from the
// earliest query in this row-block are fully masked, so skip those tiles
// entirely (the per-element mask still guarantees correctness).
int first_n_tile = 0;
if constexpr (IS_CAUSAL) {
if (window >= 0) {
const int earliest_key = (start_row + causal_q_offset) - window + 1;
if (earliest_key > 0) first_n_tile = earliest_key / BLOCK_N;
}
}

for (int block_n = first_n_tile; block_n < num_n_tiles; ++block_n) {
const int start_col = block_n * BLOCK_N;
if (start_col >= N) break;
const int valid_k_rows = min(BLOCK_N, N - start_col);
Expand Down Expand Up @@ -269,9 +288,14 @@ flash_attention_forward_kernel(

const bool is_valid = (global_m < start_row + valid_q_rows) &&
(global_n < start_col + valid_k_rows);
// Causal + optional sliding-window: mask future keys and
// keys older than `window` tokens from the query.
const bool masked =
(global_n > global_q_pos) ||
(window >= 0 && global_q_pos - global_n >= window);

acc_frag.x[i] = is_valid
? ((global_n > global_q_pos) ? NEG_INF : acc_frag.x[i] * softmax_scale)
? (masked ? NEG_INF : acc_frag.x[i] * softmax_scale)
: NEG_INF;
}
} else {
Expand Down Expand Up @@ -508,6 +532,7 @@ void launcher_flash_attention_forward(
torch::Tensor& softmax_lse,
float softmax_scale,
bool is_causal,
int window,
cudaStream_t stream
) {
using Config = KernelConfig<D>;
Expand Down Expand Up @@ -538,7 +563,7 @@ void launcher_flash_attention_forward(
reinterpret_cast<const __half*>(V.data_ptr()),
reinterpret_cast<__half*>(Out.data_ptr()),
softmax_lse.data_ptr<float>(),
B, H, H_KV, M, N, softmax_scale
B, H, H_KV, M, N, softmax_scale, window
);
} else {
flash_attention_forward_kernel<D, false><<<grid, block, smem, stream>>>(
Expand All @@ -547,7 +572,7 @@ void launcher_flash_attention_forward(
reinterpret_cast<const __half*>(V.data_ptr()),
reinterpret_cast<__half*>(Out.data_ptr()),
softmax_lse.data_ptr<float>(),
B, H, H_KV, M, N, softmax_scale
B, H, H_KV, M, N, softmax_scale, window
);
}
}
Expand All @@ -570,8 +595,12 @@ std::vector<at::Tensor> flash_attention_forward(

TORCH_CHECK(!alibi_slopes_.has_value(), "alibi_slopes not supported");
TORCH_CHECK(p_dropout == 0.f, "dropout not supported");
TORCH_CHECK(window_size_left == -1, "window_size_left not supported");
TORCH_CHECK(window_size_left == -1 || (is_causal && window_size_left >= 0),
"window_size_left only supported with causal=True");
TORCH_CHECK(window_size_right == -1 || (is_causal && window_size_right == 0), "window not supported");
// Attended-token count for the kernel: left==-1 means unlimited, otherwise
// a query attends to window_size_left + 1 tokens (itself + left preceding).
const int window = (window_size_left < 0) ? -1 : window_size_left + 1;
TORCH_CHECK(softcap == 0.f, "softcap not supported");
TORCH_CHECK(!return_softmax, "return_softmax not supported");
TORCH_CHECK(!gen_.has_value(), "Generator not supported");
Expand All @@ -586,7 +615,7 @@ std::vector<at::Tensor> flash_attention_forward(
const int B = sizes[0], H = sizes[1], M = sizes[2], D = sizes[3];
const int H_KV = k.size(1);
const int N = k.size(2);
TORCH_CHECK(D <= 256 && D % 8 == 0 && D % 2 == 0, "D must be even, <=256, multiple of 8");
TORCH_CHECK((D <= 256 || D == 512) && D % 8 == 0 && D % 2 == 0, "D must be even, multiple of 8, and <=256 or ==512");
TORCH_CHECK(H_KV > 0, "num_kv_heads must be positive");
TORCH_CHECK(H % H_KV == 0, "num_attention_heads must be divisible by num_kv_heads");
TORCH_CHECK(k.size(0) == B && v.size(0) == B, "K/V batch size must match Q");
Expand All @@ -605,11 +634,12 @@ std::vector<at::Tensor> flash_attention_forward(
TORCH_CHECK(sm70, "Kernel supports only Volta GPUs.");

switch (D) {
case 16: launcher_flash_attention_forward<16>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break;
case 32: launcher_flash_attention_forward<32>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break;
case 64: launcher_flash_attention_forward<64>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break;
case 128: launcher_flash_attention_forward<128>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break;
case 256: launcher_flash_attention_forward<256>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break;
case 16: launcher_flash_attention_forward<16>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
case 32: launcher_flash_attention_forward<32>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
case 64: launcher_flash_attention_forward<64>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
case 128: launcher_flash_attention_forward<128>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
case 256: launcher_flash_attention_forward<256>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
case 512: launcher_flash_attention_forward<512>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break;
default: TORCH_CHECK(false, "Unsupported D: ", D);
}

Expand Down
Loading