diff --git a/examples/scripts/benchmark/benchmark_fused_sample.py b/examples/scripts/benchmark/benchmark_fused_sample.py index ba4ac53..7fc10ed 100644 --- a/examples/scripts/benchmark/benchmark_fused_sample.py +++ b/examples/scripts/benchmark/benchmark_fused_sample.py @@ -197,6 +197,12 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" default=K_TOP, help="k for top-k benchmarks (default: %(default)s)", ) + parser.add_argument( + "--tag", + type=str, + default=None, + help="Tag appended to output filenames (e.g. --tag=before / --tag=after)", + ) args = parser.parse_args() assert torch.cuda.is_available(), "CUDA required" @@ -217,7 +223,8 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" sparsity=args.sparsity, k_top=args.topk, ) - csv_path = "out/bench_fused_sample.csv" + suffix = f"_{args.tag}" if args.tag else "" + csv_path = f"out/bench_fused_sample{suffix}.csv" df.to_csv(csv_path, index=False) print(f"\nSaved {csv_path}\n") @@ -228,7 +235,7 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" df, value_col="speedup_fused_vs_sparse_pytorch_sample", title=f"Fused sample speedup vs compile(sparse_linear_pytorch)+multinomial (K={K})", - filename="out/heatmap_fused_sample_vs_sparse_pytorch.jpg", + filename=f"out/heatmap_fused_sample_vs_sparse_pytorch{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) if "speedup_fused_topk_vs_sparse_pytorch_topk" in df.columns: @@ -236,6 +243,6 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" df, value_col="speedup_fused_topk_vs_sparse_pytorch_topk", title=f"Fused top-k speedup vs compile(sparse_linear_pytorch)+topk (K={K}, k={args.topk})", - filename="out/heatmap_fused_topk_vs_sparse_pytorch_topk.jpg", + filename=f"out/heatmap_fused_topk_vs_sparse_pytorch_topk{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) diff --git a/examples/scripts/benchmark/benchmark_vtnk.py b/examples/scripts/benchmark/benchmark_vtnk.py index 6c7c182..6418b1c 100644 --- a/examples/scripts/benchmark/benchmark_vtnk.py +++ b/examples/scripts/benchmark/benchmark_vtnk.py @@ -235,6 +235,12 @@ def plot_heatmap( default=DEFAULT_SPARSITY, help="Fraction of vocab used as max branches (default: %(default)s)", ) + parser.add_argument( + "--tag", + type=str, + default=None, + help="Tag appended to output filenames (e.g. --tag=before / --tag=after)", + ) args = parser.parse_args() assert torch.cuda.is_available(), "CUDA required" @@ -251,7 +257,8 @@ def plot_heatmap( df = benchmark_grid( B_vals, N_vals, algorithms=args.algorithms, sparsity=args.sparsity ) - csv_path = "out/bench_vtnk.csv" + suffix = f"_{args.tag}" if args.tag else "" + csv_path = f"out/bench_vtnk{suffix}.csv" df.to_csv(csv_path, index=False) print(f"\nSaved {csv_path}\n") @@ -262,7 +269,7 @@ def plot_heatmap( df, value_col="speedup_fused_vs_kernel", title=f"Fused speedup vs compiled_linear+constrained_kernel (K={K})", - filename="out/heatmap_fused_vs_kernel.jpg", + filename=f"out/heatmap_fused_vs_kernel{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) if "speedup_fused_vs_pytorch" in df.columns: @@ -270,7 +277,7 @@ def plot_heatmap( df, value_col="speedup_fused_vs_pytorch", title=f"Fused speedup vs compiled_linear+vtnk_pytorch (K={K})", - filename="out/heatmap_fused_vs_pytorch.jpg", + filename=f"out/heatmap_fused_vs_pytorch{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) if "speedup_fused_vs_sparse_pytorch" in df.columns: @@ -278,7 +285,7 @@ def plot_heatmap( df, value_col="speedup_fused_vs_sparse_pytorch", title=f"Fused speedup vs sparse_linear_pytorch (K={K})", - filename="out/heatmap_fused_vs_sparse_pytorch.jpg", + filename=f"out/heatmap_fused_vs_sparse_pytorch{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) if "speedup_fused_vs_trie_cpu" in df.columns: @@ -286,6 +293,6 @@ def plot_heatmap( df, value_col="speedup_fused_vs_trie_cpu", title=f"Fused speedup vs CPU trie traversal (K={K})", - filename="out/heatmap_fused_vs_trie_cpu.jpg", + filename=f"out/heatmap_fused_vs_trie_cpu{suffix}.jpg", cbar_label="Speedup (>1 = fused faster)", ) diff --git a/examples/scripts/benchmark/compare_benchmarks.py b/examples/scripts/benchmark/compare_benchmarks.py new file mode 100644 index 0000000..725e8f5 --- /dev/null +++ b/examples/scripts/benchmark/compare_benchmarks.py @@ -0,0 +1,55 @@ +""" +Compare two benchmark CSV files (e.g. before vs after an optimisation) and +report per-algorithm timing improvements. + +Usage: + python -m examples.scripts.benchmark.compare_benchmarks \ + --before out/bench_vtnk_before.csv \ + --after out/bench_vtnk_after.csv + +Produces a summary table with absolute timings and percentage speedups. +""" + +import argparse +import pandas as pd + + +def compare(before_path: str, after_path: str) -> pd.DataFrame: + before = pd.read_csv(before_path) + after = pd.read_csv(after_path) + + key_cols = [c for c in ("B", "N", "D") if c in before.columns and c in after.columns] + ms_cols = [c for c in before.columns if c.startswith("ms_") and c in after.columns] + + if not ms_cols: + raise ValueError("No common ms_* timing columns found in both CSVs.") + + merged = before[key_cols + ms_cols].merge( + after[key_cols + ms_cols], + on=key_cols, + suffixes=("_before", "_after"), + ) + + rows = [] + for _, r in merged.iterrows(): + row = {k: r[k] for k in key_cols} + for col in ms_cols: + bval = r[f"{col}_before"] + aval = r[f"{col}_after"] + row[f"{col}_before"] = round(bval, 4) + row[f"{col}_after"] = round(aval, 4) + if bval > 0: + row[f"{col}_speedup"] = round(bval / aval, 3) + rows.append(row) + + return pd.DataFrame(rows) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compare two benchmark CSV files.") + parser.add_argument("--before", required=True, help="Path to baseline CSV") + parser.add_argument("--after", required=True, help="Path to optimised CSV") + args = parser.parse_args() + + df = compare(args.before, args.after) + print(df.to_string(index=False)) diff --git a/rectokens/kernels/constrained_node_transition.py b/rectokens/kernels/constrained_node_transition.py index 0808f4f..f8467f3 100644 --- a/rectokens/kernels/constrained_node_transition.py +++ b/rectokens/kernels/constrained_node_transition.py @@ -63,11 +63,16 @@ def _constrained_node_transition_op( @triton.autotune( configs=[ - triton.Config({"BLOCK_B": 32, "BLOCK_N": 128, "GROUP_SIZE_M": 4}), - triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}), - triton.Config({"BLOCK_B": 64, "BLOCK_N": 128, "GROUP_SIZE_M": 4}), - triton.Config({"BLOCK_B": 128, "BLOCK_N": 64, "GROUP_SIZE_M": 4}), - triton.Config({"BLOCK_B": 128, "BLOCK_N": 128, "GROUP_SIZE_M": 8}), + triton.Config({"BLOCK_B": 32, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_B": 128, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_N": 128, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_B": 16, "BLOCK_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 32, "BLOCK_N": 256, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_B": 64, "BLOCK_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_B": 64, "BLOCK_N": 64, "GROUP_SIZE_M": 4}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_B": 128, "BLOCK_N": 256, "GROUP_SIZE_M": 8}, num_warps=8, num_stages=4), ], key=["B", "N"], restore_value=["corrected_logits_ptr", "next_node_ptr", "valid_idxs_ptr"], @@ -177,13 +182,29 @@ def _constrained_node_transition_kernel( _FUSED_AUTOTUNE_CONFIGS = [ - triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}), - triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}), + # ── original block shapes with explicit warp / pipeline tuning ── + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=8, num_stages=2), + # ── small-batch configs ── + triton.Config({"BLOCK_B": 32, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_B": 32, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=3), + # ── larger K-tiles for better compute-to-load ratio ── + triton.Config({"BLOCK_B": 64, "BLOCK_K": 256, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 256, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=4), + # ── wider batch with medium K ── + triton.Config({"BLOCK_B": 256, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=3), + # ── high-branch configs for wide trie fan-out / reduced branch-block contention ── + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 32}, num_warps=8, num_stages=2), + # ── extra pipeline depth ── + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 8}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 4}, num_warps=4, num_stages=4), ] @@ -343,7 +364,7 @@ def _compute_branch_logits( other=0.0, ) # [BLOCK_B, BLOCK_K] dot = tl.sum(a_chunk * b_chunk, axis=1) # [BLOCK_B] - logits = tl.where(br_sel[None, :], logits + dot[:, None], logits) + logits += tl.where(br_sel[None, :], dot[:, None], 0.0) if HAS_BIAS: for local_br in tl.static_range(BLOCK_BRANCHES): @@ -351,7 +372,7 @@ def _compute_branch_logits( local_br, branch_cols, branch_valid, BLOCK_BRANCHES ) bias_k = tl.load(bias_ptr + col_k, mask=c_mask, other=0.0) - logits = tl.where(br_sel[None, :], logits + bias_k[:, None], logits) + logits += tl.where(br_sel[None, :], bias_k[:, None], 0.0) return branch_cols, branch_vals, branch_valid, logits