A GPU library for classical machine-learning operators — kmeans, knn,
ivf-flat, pca, svd, dbscan, hdbscan, umap, t-sne, regression,
GEMM, and more — built on Triton and CuteDSL.
See the blog post for motivation, design, and benchmarks.
Install with pip:
pip install flashlibFrom source:
git clone https://github.com/FlashML-org/flashlib.git
cd flashlib
pip install -e .import torch
from flashlib import flash_kmeans
x = torch.randn(1_000_000, 128, device="cuda", dtype=torch.float32)
labels, centroids, n_iter = flash_kmeans(x, n_clusters=1024, max_iters=20)Every primitive is exposed as a top-level flash_* function and as a
sklearn-style class (KMeans, PCA, HDBSCAN, …).
Index-based primitives like IVF-Flat (GPU approximate nearest neighbours) build an index once and query it many times:
import torch
from flashlib import IVFFlat
db = torch.randn(1_000_000, 128, device="cuda")
queries = torch.randn(10_000, 128, device="cuda")
index = IVFFlat(nlist=1024, nprobe=16).fit(db)
distances, indices = index.kneighbors(queries, n_neighbors=10) # squared L2nprobe is the recall knob: at a fixed (nlist, nprobe) the probed
candidate set — and thus recall — matches a reference IVF-Flat (FAISS /
cuVS), so raising it trades speed for recall without changing the kernel.
The flashlib.info submodule predicts runtime, FLOPs, and HBM bytes for any
primitive in ~5 µs on pure CPU — useful for budgeting a pipeline before
launching it, and small enough for an LLM agent to call in a GPU-less
environment. It does not import torch, triton, or cutlass.
import flashlib.info as info
est = info.estimate("kmeans",
shape=(100_000, 64),
params={"K": 256, "max_iters": 20},
device="H200")
print(est.summary_line())See the blog post for the full API, the tolerance-driven dispatch, and per-primitive benchmarks.
The current release ships 16 high-level primitives across the following families:
| family | primitives |
|---|---|
| Clustering | flash_kmeans, flash_dbscan, flash_hdbscan, flash_spectral_clustering |
| Nearest nbrs | flash_knn, flash_ivf_flat (IVF-Flat ANN) |
| Decomposition | flash_pca, flash_truncated_svd |
| Manifold | flash_umap, flash_tsne |
| Regression | flash_linear_regression, flash_ridge, flash_logistic_regression |
| Classification | flash_multinomial_nb, flash_random_forest |
| Preprocessing | flash_standard_scaler |
Plus low-level linear-algebra primitives (cov_gemm, gram_gemm, ab_gemm,
eigh, polar, msign, cholqr2, split_basis) and a Pareto-frontier set
of multi-precision GEMM variants (gemm, gemm_tf32, gemm_3xtf32,
gemm_bf16, gemm_fp16, gemm_fp16_x9, gemm_fp16_x3_kahan,
gemm_ozaki2_int8, …).
@misc{yang2026flashlib,
title = {FlashLib: Bringing Flash Magic to Classical Machine Learning Operators},
author = {Yang, Shuo and Xi, Haocheng and Zhao, Yilong and Mang, Qiuyang and
Wang, Zhe and Sun, Shanlin and Keutzer, Kurt and Gonzalez, Joseph E. and
Han, Song and Xu, Chenfeng and Stoica, Ion},
year = {2026},
url = {https://flashml-org.github.io/},
}