Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/scripts/benchmark/benchmark_fused_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"
default=K_TOP,
help="k for top-k benchmarks (default: %(default)s)",
)
parser.add_argument(
"--tag",
type=str,
default=None,
help="Tag appended to output filenames (e.g. --tag=before / --tag=after)",
)
args = parser.parse_args()

assert torch.cuda.is_available(), "CUDA required"
Expand All @@ -217,7 +223,8 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"
sparsity=args.sparsity,
k_top=args.topk,
)
csv_path = "out/bench_fused_sample.csv"
suffix = f"_{args.tag}" if args.tag else ""
csv_path = f"out/bench_fused_sample{suffix}.csv"
df.to_csv(csv_path, index=False)
print(f"\nSaved {csv_path}\n")

Expand All @@ -228,14 +235,14 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"
df,
value_col="speedup_fused_vs_sparse_pytorch_sample",
title=f"Fused sample speedup vs compile(sparse_linear_pytorch)+multinomial (K={K})",
filename="out/heatmap_fused_sample_vs_sparse_pytorch.jpg",
filename=f"out/heatmap_fused_sample_vs_sparse_pytorch{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
if "speedup_fused_topk_vs_sparse_pytorch_topk" in df.columns:
plot_heatmap(
df,
value_col="speedup_fused_topk_vs_sparse_pytorch_topk",
title=f"Fused top-k speedup vs compile(sparse_linear_pytorch)+topk (K={K}, k={args.topk})",
filename="out/heatmap_fused_topk_vs_sparse_pytorch_topk.jpg",
filename=f"out/heatmap_fused_topk_vs_sparse_pytorch_topk{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
17 changes: 12 additions & 5 deletions examples/scripts/benchmark/benchmark_vtnk.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def plot_heatmap(
default=DEFAULT_SPARSITY,
help="Fraction of vocab used as max branches (default: %(default)s)",
)
parser.add_argument(
"--tag",
type=str,
default=None,
help="Tag appended to output filenames (e.g. --tag=before / --tag=after)",
)
args = parser.parse_args()

assert torch.cuda.is_available(), "CUDA required"
Expand All @@ -251,7 +257,8 @@ def plot_heatmap(
df = benchmark_grid(
B_vals, N_vals, algorithms=args.algorithms, sparsity=args.sparsity
)
csv_path = "out/bench_vtnk.csv"
suffix = f"_{args.tag}" if args.tag else ""
csv_path = f"out/bench_vtnk{suffix}.csv"
df.to_csv(csv_path, index=False)
print(f"\nSaved {csv_path}\n")

Expand All @@ -262,30 +269,30 @@ def plot_heatmap(
df,
value_col="speedup_fused_vs_kernel",
title=f"Fused speedup vs compiled_linear+constrained_kernel (K={K})",
filename="out/heatmap_fused_vs_kernel.jpg",
filename=f"out/heatmap_fused_vs_kernel{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
if "speedup_fused_vs_pytorch" in df.columns:
plot_heatmap(
df,
value_col="speedup_fused_vs_pytorch",
title=f"Fused speedup vs compiled_linear+vtnk_pytorch (K={K})",
filename="out/heatmap_fused_vs_pytorch.jpg",
filename=f"out/heatmap_fused_vs_pytorch{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
if "speedup_fused_vs_sparse_pytorch" in df.columns:
plot_heatmap(
df,
value_col="speedup_fused_vs_sparse_pytorch",
title=f"Fused speedup vs sparse_linear_pytorch (K={K})",
filename="out/heatmap_fused_vs_sparse_pytorch.jpg",
filename=f"out/heatmap_fused_vs_sparse_pytorch{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
if "speedup_fused_vs_trie_cpu" in df.columns:
plot_heatmap(
df,
value_col="speedup_fused_vs_trie_cpu",
title=f"Fused speedup vs CPU trie traversal (K={K})",
filename="out/heatmap_fused_vs_trie_cpu.jpg",
filename=f"out/heatmap_fused_vs_trie_cpu{suffix}.jpg",
cbar_label="Speedup (>1 = fused faster)",
)
55 changes: 55 additions & 0 deletions examples/scripts/benchmark/compare_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Compare two benchmark CSV files (e.g. before vs after an optimisation) and
report per-algorithm timing improvements.

Usage:
python -m examples.scripts.benchmark.compare_benchmarks \
--before out/bench_vtnk_before.csv \
--after out/bench_vtnk_after.csv

Produces a summary table with absolute timings and percentage speedups.
"""

import argparse
import pandas as pd


def compare(before_path: str, after_path: str) -> pd.DataFrame:
before = pd.read_csv(before_path)
after = pd.read_csv(after_path)

key_cols = [c for c in ("B", "N", "D") if c in before.columns and c in after.columns]
ms_cols = [c for c in before.columns if c.startswith("ms_") and c in after.columns]

if not ms_cols:
raise ValueError("No common ms_* timing columns found in both CSVs.")

merged = before[key_cols + ms_cols].merge(
after[key_cols + ms_cols],
on=key_cols,
suffixes=("_before", "_after"),
)

rows = []
for _, r in merged.iterrows():
row = {k: r[k] for k in key_cols}
for col in ms_cols:
bval = r[f"{col}_before"]
aval = r[f"{col}_after"]
row[f"{col}_before"] = round(bval, 4)
row[f"{col}_after"] = round(aval, 4)
if bval > 0:
row[f"{col}_speedup"] = round(bval / aval, 3)
rows.append(row)

return pd.DataFrame(rows)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare two benchmark CSV files.")
parser.add_argument("--before", required=True, help="Path to baseline CSV")
parser.add_argument("--after", required=True, help="Path to optimised CSV")
args = parser.parse_args()

df = compare(args.before, args.after)
print(df.to_string(index=False))
49 changes: 35 additions & 14 deletions rectokens/kernels/constrained_node_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ def _constrained_node_transition_op(

@triton.autotune(
configs=[
triton.Config({"BLOCK_B": 32, "BLOCK_N": 128, "GROUP_SIZE_M": 4}),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 128, "GROUP_SIZE_M": 4}),
triton.Config({"BLOCK_B": 128, "BLOCK_N": 64, "GROUP_SIZE_M": 4}),
triton.Config({"BLOCK_B": 128, "BLOCK_N": 128, "GROUP_SIZE_M": 8}),
triton.Config({"BLOCK_B": 32, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_B": 128, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 128, "BLOCK_N": 128, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_B": 16, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 32, "BLOCK_N": 256, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=2, num_stages=3),
triton.Config({"BLOCK_B": 128, "BLOCK_N": 256, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=4),
],
key=["B", "N"],
restore_value=["corrected_logits_ptr", "next_node_ptr", "valid_idxs_ptr"],
Expand Down Expand Up @@ -177,13 +182,29 @@ def _constrained_node_transition_kernel(


_FUSED_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}),
triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}),
triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}),
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}),
# ── original block shapes with explicit warp / pipeline tuning ──
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=8, num_stages=2),
# ── small-batch configs ──
triton.Config({"BLOCK_B": 32, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=2, num_stages=2),
triton.Config({"BLOCK_B": 32, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=3),
# ── larger K-tiles for better compute-to-load ratio ──
triton.Config({"BLOCK_B": 64, "BLOCK_K": 256, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 256, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=4),
# ── wider batch with medium K ──
triton.Config({"BLOCK_B": 256, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=3),
# ── high-branch configs for wide trie fan-out / reduced branch-block contention ──
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 32}, num_warps=8, num_stages=2),
# ── extra pipeline depth ──
triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=4),
]


Expand Down Expand Up @@ -343,15 +364,15 @@ def _compute_branch_logits(
other=0.0,
) # [BLOCK_B, BLOCK_K]
dot = tl.sum(a_chunk * b_chunk, axis=1) # [BLOCK_B]
logits = tl.where(br_sel[None, :], logits + dot[:, None], logits)
logits += tl.where(br_sel[None, :], dot[:, None], 0.0)

if HAS_BIAS:
for local_br in tl.static_range(BLOCK_BRANCHES):
br_sel, col_k, c_mask = _select_branch(
local_br, branch_cols, branch_valid, BLOCK_BRANCHES
)
bias_k = tl.load(bias_ptr + col_k, mask=c_mask, other=0.0)
logits = tl.where(br_sel[None, :], logits + bias_k[:, None], logits)
logits += tl.where(br_sel[None, :], bias_k[:, None], 0.0)

return branch_cols, branch_vals, branch_valid, logits

Expand Down