Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3f70f44
test
berlin020 Apr 13, 2026
0e76bd8
complex128
berlin020 Apr 16, 2026
f2b0b5b
Support CONJ transpose mode in spsv
berlin020 Apr 20, 2026
c876864
test
berlin020 Apr 21, 2026
2b10313
Update spsv csr tests
berlin020 Apr 21, 2026
5bf5ee4
Guard SPSV timing and benchmark gather paths
berlin020 Apr 21, 2026
1b71727
是,已经符合要求:12 组 gather 组合都在,主路径只走 Triton,`flagsparse_gather_cupy`/`cusptr?
berlin020 Apr 21, 2026
f78c4ff
gather update
berlin020 Apr 21, 2026
fe5786a
Add files from flagsparse_new
berlin020 Apr 26, 2026
3218eee
Merge pull request #2 from berlin020/merge-flagsparse-new
berlin020 Apr 27, 2026
3bf7caf
speedup and opt
berlin020 Apr 27, 2026
655d712
pytorch
berlin020 Apr 27, 2026
0413508
pytorch
berlin020 Apr 27, 2026
92717a6
pytorch
berlin020 Apr 27, 2026
d463197
pytorch
berlin020 Apr 27, 2026
4ac0218
pytorch
berlin020 Apr 27, 2026
aa9689b
Refine SpSV PyTorch fallback notes and output
berlin020 Apr 27, 2026
9f8ef1b
spsm-opt
berlin020 Apr 27, 2026
dbdb9d0
spsm-opt
berlin020 Apr 27, 2026
852e38c
spsv&spsm_opt
berlin020 Apr 28, 2026
bf259c2
opt
berlin020 Apr 28, 2026
06053a7
opt
berlin020 Apr 28, 2026
9d3eb68
opt
berlin020 Apr 28, 2026
626aa6f
Merge remote-tracking branch 'origin/spsm' into merge_spsm_into_flags…
berlin020 Apr 29, 2026
82fc7d5
gather/spsm/spsv
berlin020 Apr 29, 2026
590256f
gather update
berlin020 Apr 29, 2026
63c8a9d
spsv coo
berlin020 May 5, 2026
e65b843
spsv
berlin020 May 7, 2026
bbc818d
modify
berlin020 May 11, 2026
7761636
alg2/3/8
berlin020 May 12, 2026
0c5c17f
merge
berlin020 May 12, 2026
328a3d1
Merge pull request #4 from berlin020/spsv
berlin020 May 12, 2026
899a3b6
Merge branch 'flagsparse_merge' into merge
berlin020 May 12, 2026
82e1c9a
Merge pull request #5 from berlin020/merge
berlin020 May 12, 2026
f08e768
Delete .DS_Store
berlin020 May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ __pycache__/
.pytest_cache/
.coverage
htmlcov/
.DS_Store
177 changes: 177 additions & 0 deletions ops_support_sort_check.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
operator,format,index_dtype,value_dtype,op,route,status
gather,index,int32,float16,non,triton,SUPPORTED
gather,index,int32,bfloat16,non,triton,SUPPORTED
gather,index,int32,float32,non,triton,SUPPORTED
gather,index,int32,float64,non,triton,SUPPORTED
gather,index,int32,complex32,non,triton,SUPPORTED
gather,index,int32,complex64,non,triton,SUPPORTED
gather,index,int32,complex128,non,triton,SUPPORTED
gather,index,int64,float16,non,triton,SUPPORTED
gather,index,int64,bfloat16,non,triton,SUPPORTED
gather,index,int64,float32,non,triton,SUPPORTED
gather,index,int64,float64,non,triton,SUPPORTED
gather,index,int64,complex32,non,triton,SUPPORTED
gather,index,int64,complex64,non,triton,SUPPORTED
gather,index,int64,complex128,non,triton,SUPPORTED
scatter,index,int32,float16,non,triton,SUPPORTED
scatter,index,int32,bfloat16,non,triton,SUPPORTED
scatter,index,int32,float32,non,triton,SUPPORTED
scatter,index,int32,float64,non,triton,SUPPORTED
scatter,index,int32,complex64,non,triton,SUPPORTED
scatter,index,int32,complex128,non,triton,SUPPORTED
scatter,index,int64,float16,non,triton,SUPPORTED
scatter,index,int64,bfloat16,non,triton,SUPPORTED
scatter,index,int64,float32,non,triton,SUPPORTED
scatter,index,int64,float64,non,triton,SUPPORTED
scatter,index,int64,complex64,non,triton,SUPPORTED
scatter,index,int64,complex128,non,triton,SUPPORTED
sddmm,CSR,int32,float32,non,triton,SUPPORTED
sddmm,CSR,int32,float64,non,triton,SUPPORTED
sddmm,CSR,int64,float32,non,triton,SUPPORTED
sddmm,CSR,int64,float64,non,triton,SUPPORTED
spgemm,CSR,int32,float32,non,triton,SUPPORTED
spgemm,CSR,int32,float64,non,triton,SUPPORTED
spgemm,CSR,int64,float32,non,triton,SUPPORTED
spgemm,CSR,int64,float64,non,triton,SUPPORTED
spmm,COO,int32,float16,non,triton,SUPPORTED
spmm,COO,int32,bfloat16,non,triton,SUPPORTED
spmm,COO,int32,float32,non,triton,SUPPORTED
spmm,COO,int32,float64,non,triton,SUPPORTED
spmm,COO,int32,complex32,non,triton,SUPPORTED
spmm,COO,int32,complex64,non,triton,SUPPORTED
spmm,COO,int32,complex128,non,triton,SUPPORTED
spmm,COO,int64,float16,non,triton,SUPPORTED
spmm,COO,int64,bfloat16,non,triton,SUPPORTED
spmm,COO,int64,float32,non,triton,SUPPORTED
spmm,COO,int64,float64,non,triton,SUPPORTED
spmm,COO,int64,complex32,non,triton,SUPPORTED
spmm,COO,int64,complex64,non,triton,SUPPORTED
spmm,COO,int64,complex128,non,triton,SUPPORTED
spmm,CSR,int32,float16,non,triton,SUPPORTED
spmm,CSR,int32,float16,non,triton_opt,SUPPORTED
spmm,CSR,int32,bfloat16,non,triton,SUPPORTED
spmm,CSR,int32,bfloat16,non,triton_opt,SUPPORTED
spmm,CSR,int32,float32,non,triton,SUPPORTED
spmm,CSR,int32,float32,non,triton_opt,SUPPORTED
spmm,CSR,int32,float64,non,triton,SUPPORTED
spmm,CSR,int32,float64,non,triton_opt,SUPPORTED
spmm,CSR,int32,complex32,non,triton,SUPPORTED
spmm,CSR,int32,complex32,non,triton_opt,SUPPORTED
spmm,CSR,int32,complex64,non,triton,SUPPORTED
spmm,CSR,int32,complex64,non,triton_opt,SUPPORTED
spmm,CSR,int32,complex128,non,triton,SUPPORTED
spmm,CSR,int32,complex128,non,triton_opt,SUPPORTED
spmm,CSR,int64,float16,non,triton,SUPPORTED
spmm,CSR,int64,float16,non,triton_opt,SUPPORTED
spmm,CSR,int64,bfloat16,non,triton,SUPPORTED
spmm,CSR,int64,bfloat16,non,triton_opt,SUPPORTED
spmm,CSR,int64,float32,non,triton,SUPPORTED
spmm,CSR,int64,float32,non,triton_opt,SUPPORTED
spmm,CSR,int64,float64,non,triton,SUPPORTED
spmm,CSR,int64,float64,non,triton_opt,SUPPORTED
spmm,CSR,int64,complex32,non,triton,SUPPORTED
spmm,CSR,int64,complex32,non,triton_opt,SUPPORTED
spmm,CSR,int64,complex64,non,triton,SUPPORTED
spmm,CSR,int64,complex64,non,triton_opt,SUPPORTED
spmm,CSR,int64,complex128,non,triton,SUPPORTED
spmm,CSR,int64,complex128,non,triton_opt,SUPPORTED
spmv,COO,int32,float32,non,triton,SUPPORTED
spmv,COO,int32,float64,non,triton,SUPPORTED
spmv,COO,int64,float32,non,triton,SUPPORTED
spmv,COO,int64,float64,non,triton,SUPPORTED
spmv,COO->CSR,int32,float16,non,triton,SUPPORTED
spmv,COO->CSR,int32,bfloat16,non,triton,SUPPORTED
spmv,COO->CSR,int32,float32,non,triton,SUPPORTED
spmv,COO->CSR,int32,float64,non,triton,SUPPORTED
spmv,COO->CSR,int32,complex64,non,triton,SUPPORTED
spmv,COO->CSR,int32,complex128,non,triton,SUPPORTED
spmv,COO->CSR,int64,float16,non,triton,SUPPORTED
spmv,COO->CSR,int64,bfloat16,non,triton,SUPPORTED
spmv,COO->CSR,int64,float32,non,triton,SUPPORTED
spmv,COO->CSR,int64,float64,non,triton,SUPPORTED
spmv,COO->CSR,int64,complex64,non,triton,SUPPORTED
spmv,COO->CSR,int64,complex128,non,triton,SUPPORTED
spmv,CSR,int32,float16,non,triton,SUPPORTED
spmv,CSR,int32,float16,trans,triton,SUPPORTED
spmv,CSR,int32,float16,conj,triton,SUPPORTED
spmv,CSR,int32,bfloat16,non,triton,SUPPORTED
spmv,CSR,int32,bfloat16,trans,triton,SUPPORTED
spmv,CSR,int32,bfloat16,conj,triton,SUPPORTED
spmv,CSR,int32,float32,non,triton,SUPPORTED
spmv,CSR,int32,float32,trans,triton,SUPPORTED
spmv,CSR,int32,float32,conj,triton,SUPPORTED
spmv,CSR,int32,float64,non,triton,SUPPORTED
spmv,CSR,int32,float64,trans,triton,SUPPORTED
spmv,CSR,int32,float64,conj,triton,SUPPORTED
spmv,CSR,int32,complex64,non,triton,SUPPORTED
spmv,CSR,int32,complex64,trans,triton,SUPPORTED
spmv,CSR,int32,complex64,conj,triton,SUPPORTED
spmv,CSR,int32,complex128,non,triton,SUPPORTED
spmv,CSR,int32,complex128,trans,triton,SUPPORTED
spmv,CSR,int32,complex128,conj,triton,SUPPORTED
spmv,CSR,int64,float16,non,triton,SUPPORTED
spmv,CSR,int64,float16,trans,triton,SUPPORTED
spmv,CSR,int64,float16,conj,triton,SUPPORTED
spmv,CSR,int64,bfloat16,non,triton,SUPPORTED
spmv,CSR,int64,bfloat16,trans,triton,SUPPORTED
spmv,CSR,int64,bfloat16,conj,triton,SUPPORTED
spmv,CSR,int64,float32,non,triton,SUPPORTED
spmv,CSR,int64,float32,trans,triton,SUPPORTED
spmv,CSR,int64,float32,conj,triton,SUPPORTED
spmv,CSR,int64,float64,non,triton,SUPPORTED
spmv,CSR,int64,float64,trans,triton,SUPPORTED
spmv,CSR,int64,float64,conj,triton,SUPPORTED
spmv,CSR,int64,complex64,non,triton,SUPPORTED
spmv,CSR,int64,complex64,trans,triton,SUPPORTED
spmv,CSR,int64,complex64,conj,triton,SUPPORTED
spmv,CSR,int64,complex128,non,triton,SUPPORTED
spmv,CSR,int64,complex128,trans,triton,SUPPORTED
spmv,CSR,int64,complex128,conj,triton,SUPPORTED
spsm,COO,int32,float32,non,triton,SUPPORTED
spsm,COO,int32,float64,non,triton,SUPPORTED
spsm,COO,int64,float32,non,triton,SUPPORTED
spsm,COO,int64,float64,non,triton,SUPPORTED
spsm,CSR,int32,float32,non,triton,SUPPORTED
spsm,CSR,int32,float64,non,triton,SUPPORTED
spsm,CSR,int64,float32,non,triton,SUPPORTED
spsm,CSR,int64,float64,non,triton,SUPPORTED
spsv,COO,int32,bfloat16,non,triton,SUPPORTED
spsv,COO,int32,bfloat16,trans,triton,SUPPORTED
spsv,COO,int32,float32,non,triton,SUPPORTED
spsv,COO,int32,float32,trans,triton,SUPPORTED
spsv,COO,int32,float64,non,triton,SUPPORTED
spsv,COO,int32,float64,trans,triton,SUPPORTED
spsv,COO,int32,complex32,non,triton,SUPPORTED
spsv,COO,int32,complex32,trans,triton,SUPPORTED
spsv,COO,int32,complex64,non,triton,SUPPORTED
spsv,COO,int32,complex64,trans,triton,SUPPORTED
spsv,COO,int64,bfloat16,non,triton,SUPPORTED
spsv,COO,int64,bfloat16,trans,triton,SUPPORTED
spsv,COO,int64,float32,non,triton,SUPPORTED
spsv,COO,int64,float32,trans,triton,SUPPORTED
spsv,COO,int64,float64,non,triton,SUPPORTED
spsv,COO,int64,float64,trans,triton,SUPPORTED
spsv,COO,int64,complex32,non,triton,SUPPORTED
spsv,COO,int64,complex32,trans,triton,SUPPORTED
spsv,COO,int64,complex64,non,triton,SUPPORTED
spsv,COO,int64,complex64,trans,triton,SUPPORTED
spsv,CSR,int32,bfloat16,non,triton,SUPPORTED
spsv,CSR,int32,bfloat16,trans,triton,SUPPORTED
spsv,CSR,int32,float32,non,triton,SUPPORTED
spsv,CSR,int32,float32,trans,triton,SUPPORTED
spsv,CSR,int32,float64,non,triton,SUPPORTED
spsv,CSR,int32,float64,trans,triton,SUPPORTED
spsv,CSR,int32,complex32,non,triton,SUPPORTED
spsv,CSR,int32,complex32,trans,triton,SUPPORTED
spsv,CSR,int32,complex64,non,triton,SUPPORTED
spsv,CSR,int32,complex64,trans,triton,SUPPORTED
spsv,CSR,int64,bfloat16,non,triton,SUPPORTED
spsv,CSR,int64,bfloat16,trans,triton,SUPPORTED
spsv,CSR,int64,float32,non,triton,SUPPORTED
spsv,CSR,int64,float32,trans,triton,SUPPORTED
spsv,CSR,int64,float64,non,triton,SUPPORTED
spsv,CSR,int64,float64,trans,triton,SUPPORTED
spsv,CSR,int64,complex32,non,triton,SUPPORTED
spsv,CSR,int64,complex32,trans,triton,SUPPORTED
spsv,CSR,int64,complex64,non,triton,SUPPORTED
spsv,CSR,int64,complex64,trans,triton,SUPPORTED
42 changes: 41 additions & 1 deletion src/flagsparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@
"flagsparse_spmm_csr_opt_alg2_symbolic",
"flagsparse_spsv_csr",
"flagsparse_spsv_coo",
"flagsparse_spsv_buffer_size",
"flagsparse_spsv_buffer_size_ex",
"flagsparse_spsv_analysis_csr",
"flagsparse_spsv_analysis_coo",
"flagsparse_spsv_analysis_ex",
"flagsparse_spsv_solve_csr",
"flagsparse_spsv_solve_coo",
"flagsparse_spsv_solve_ex",
"flagsparse_spsv_preprocess_csr",
"flagsparse_spsv_preprocess_coo",
"flagsparse_spsv_create_workspace",
"flagsparse_create_spsv_handle",
"flagsparse_create_spmat_csr",
"flagsparse_create_spmat_coo",
"flagsparse_create_dnvec",
"FlagSparseSpSVDescr",
"FlagSparseSpSVHandle",
"FlagSparseSpSVWorkspace",
"FlagSparseSpMatDescr",
"FlagSparseDnVecDescr",
"flagsparse_spsm_csr",
"flagsparse_spsm_coo",
"benchmark_spsm_case",
Expand Down Expand Up @@ -126,6 +146,26 @@
"flagsparse_spmm_csr_opt_alg2_symbolic",
"flagsparse_spsv_csr",
"flagsparse_spsv_coo",
"flagsparse_spsv_buffer_size",
"flagsparse_spsv_buffer_size_ex",
"flagsparse_spsv_analysis_csr",
"flagsparse_spsv_analysis_coo",
"flagsparse_spsv_analysis_ex",
"flagsparse_spsv_solve_csr",
"flagsparse_spsv_solve_coo",
"flagsparse_spsv_solve_ex",
"flagsparse_spsv_preprocess_csr",
"flagsparse_spsv_preprocess_coo",
"flagsparse_spsv_create_workspace",
"flagsparse_create_spsv_handle",
"flagsparse_create_spmat_csr",
"flagsparse_create_spmat_coo",
"flagsparse_create_dnvec",
"FlagSparseSpSVDescr",
"FlagSparseSpSVHandle",
"FlagSparseSpSVWorkspace",
"FlagSparseSpMatDescr",
"FlagSparseDnVecDescr",
"flagsparse_spsm_csr",
"flagsparse_spsm_coo",
"benchmark_spsm_case",
Expand All @@ -141,8 +181,8 @@
"benchmark_spgemm_case",
"benchmark_sddmm_case",
"comprehensive_spmm_test",
"benchmark_spmv_case",
"comprehensive_spsm_test",
"benchmark_spmv_case",
}

_FORMAT_EXPORTS = {
Expand Down
63 changes: 63 additions & 0 deletions src/flagsparse/sparse_operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
is_alpha_spmm_alg1_tle_available,
prepare_alpha_spmm_alg1,
prepare_alpha_spmm_alg1_tle,
from ._common import SUPPORTED_INDEX_DTYPES, SUPPORTED_VALUE_DTYPES
from .benchmarks import (
benchmark_gather_case,
benchmark_performance,
benchmark_scatter_case,
benchmark_sddmm_case,
benchmark_spgemm_case,
benchmark_spmm_case,
benchmark_spmm_opt_case,
benchmark_spmv_case,
benchmark_spsm_case,
comprehensive_gather_test,
comprehensive_scatter_test,
comprehensive_spmm_test,
comprehensive_spsm_test,
)
from .gather_scatter import (
cusparse_spmv_gather,
Expand Down Expand Up @@ -64,6 +79,34 @@
"comprehensive_scatter_test",
"comprehensive_spsm_test",
}
from .spmm_coo import flagsparse_spmm_coo
from .spgemm_csr import SpGEMMPrepared, flagsparse_spgemm_csr, prepare_spgemm_csr
from .sddmm_csr import SDDMMPrepared, flagsparse_sddmm_csr, prepare_sddmm_csr
from .spsv import (
FlagSparseDnVecDescr,
FlagSparseSpMatDescr,
FlagSparseSpSVDescr,
FlagSparseSpSVHandle,
FlagSparseSpSVWorkspace,
flagsparse_create_dnvec,
flagsparse_create_spmat_coo,
flagsparse_create_spmat_csr,
flagsparse_create_spsv_handle,
flagsparse_spsv_analysis_coo,
flagsparse_spsv_analysis_csr,
flagsparse_spsv_analysis_ex,
flagsparse_spsv_buffer_size,
flagsparse_spsv_buffer_size_ex,
flagsparse_spsv_coo,
flagsparse_spsv_create_workspace,
flagsparse_spsv_csr,
flagsparse_spsv_preprocess_coo,
flagsparse_spsv_preprocess_csr,
flagsparse_spsv_solve_ex,
flagsparse_spsv_solve_coo,
flagsparse_spsv_solve_csr,
)
from .spsm import flagsparse_spsm_coo, flagsparse_spsm_csr

__all__ = [
"PreparedCoo",
Expand Down Expand Up @@ -107,8 +150,28 @@
"flagsparse_spmv_csr",
"flagsparse_spsm_coo",
"flagsparse_spsm_csr",
"FlagSparseDnVecDescr",
"FlagSparseSpMatDescr",
"FlagSparseSpSVDescr",
"FlagSparseSpSVHandle",
"FlagSparseSpSVWorkspace",
"flagsparse_create_dnvec",
"flagsparse_create_spmat_coo",
"flagsparse_create_spmat_csr",
"flagsparse_create_spsv_handle",
"flagsparse_spsv_analysis_coo",
"flagsparse_spsv_analysis_csr",
"flagsparse_spsv_analysis_ex",
"flagsparse_spsv_buffer_size",
"flagsparse_spsv_buffer_size_ex",
"flagsparse_spsv_coo",
"flagsparse_spsv_create_workspace",
"flagsparse_spsv_csr",
"flagsparse_spsv_preprocess_coo",
"flagsparse_spsv_preprocess_csr",
"flagsparse_spsv_solve_ex",
"flagsparse_spsv_solve_coo",
"flagsparse_spsv_solve_csr",
"prepare_sddmm_csr",
"prepare_alpha_spmm_alg1",
"prepare_spgemm_csr",
Expand Down
Loading