diff --git a/.gitignore b/.gitignore index 4ca5df5..e44113e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ .pytest_cache/ .coverage htmlcov/ +.DS_Store diff --git a/ops_support_sort_check.csv b/ops_support_sort_check.csv new file mode 100644 index 0000000..61ddfaa --- /dev/null +++ b/ops_support_sort_check.csv @@ -0,0 +1,177 @@ +operator,format,index_dtype,value_dtype,op,route,status +gather,index,int32,float16,non,triton,SUPPORTED +gather,index,int32,bfloat16,non,triton,SUPPORTED +gather,index,int32,float32,non,triton,SUPPORTED +gather,index,int32,float64,non,triton,SUPPORTED +gather,index,int32,complex32,non,triton,SUPPORTED +gather,index,int32,complex64,non,triton,SUPPORTED +gather,index,int32,complex128,non,triton,SUPPORTED +gather,index,int64,float16,non,triton,SUPPORTED +gather,index,int64,bfloat16,non,triton,SUPPORTED +gather,index,int64,float32,non,triton,SUPPORTED +gather,index,int64,float64,non,triton,SUPPORTED +gather,index,int64,complex32,non,triton,SUPPORTED +gather,index,int64,complex64,non,triton,SUPPORTED +gather,index,int64,complex128,non,triton,SUPPORTED +scatter,index,int32,float16,non,triton,SUPPORTED +scatter,index,int32,bfloat16,non,triton,SUPPORTED +scatter,index,int32,float32,non,triton,SUPPORTED +scatter,index,int32,float64,non,triton,SUPPORTED +scatter,index,int32,complex64,non,triton,SUPPORTED +scatter,index,int32,complex128,non,triton,SUPPORTED +scatter,index,int64,float16,non,triton,SUPPORTED +scatter,index,int64,bfloat16,non,triton,SUPPORTED +scatter,index,int64,float32,non,triton,SUPPORTED +scatter,index,int64,float64,non,triton,SUPPORTED +scatter,index,int64,complex64,non,triton,SUPPORTED +scatter,index,int64,complex128,non,triton,SUPPORTED +sddmm,CSR,int32,float32,non,triton,SUPPORTED +sddmm,CSR,int32,float64,non,triton,SUPPORTED +sddmm,CSR,int64,float32,non,triton,SUPPORTED +sddmm,CSR,int64,float64,non,triton,SUPPORTED +spgemm,CSR,int32,float32,non,triton,SUPPORTED +spgemm,CSR,int32,float64,non,triton,SUPPORTED +spgemm,CSR,int64,float32,non,triton,SUPPORTED +spgemm,CSR,int64,float64,non,triton,SUPPORTED +spmm,COO,int32,float16,non,triton,SUPPORTED +spmm,COO,int32,bfloat16,non,triton,SUPPORTED +spmm,COO,int32,float32,non,triton,SUPPORTED +spmm,COO,int32,float64,non,triton,SUPPORTED +spmm,COO,int32,complex32,non,triton,SUPPORTED +spmm,COO,int32,complex64,non,triton,SUPPORTED +spmm,COO,int32,complex128,non,triton,SUPPORTED +spmm,COO,int64,float16,non,triton,SUPPORTED +spmm,COO,int64,bfloat16,non,triton,SUPPORTED +spmm,COO,int64,float32,non,triton,SUPPORTED +spmm,COO,int64,float64,non,triton,SUPPORTED +spmm,COO,int64,complex32,non,triton,SUPPORTED +spmm,COO,int64,complex64,non,triton,SUPPORTED +spmm,COO,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton_opt,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int32,float32,non,triton,SUPPORTED +spmm,CSR,int32,float32,non,triton_opt,SUPPORTED +spmm,CSR,int32,float64,non,triton,SUPPORTED +spmm,CSR,int32,float64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex32,non,triton,SUPPORTED +spmm,CSR,int32,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex64,non,triton,SUPPORTED +spmm,CSR,int32,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex128,non,triton,SUPPORTED +spmm,CSR,int32,complex128,non,triton_opt,SUPPORTED +spmm,CSR,int64,float16,non,triton,SUPPORTED +spmm,CSR,int64,float16,non,triton_opt,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int64,float32,non,triton,SUPPORTED +spmm,CSR,int64,float32,non,triton_opt,SUPPORTED +spmm,CSR,int64,float64,non,triton,SUPPORTED +spmm,CSR,int64,float64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex32,non,triton,SUPPORTED +spmm,CSR,int64,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex64,non,triton,SUPPORTED +spmm,CSR,int64,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int64,complex128,non,triton_opt,SUPPORTED +spmv,COO,int32,float32,non,triton,SUPPORTED +spmv,COO,int32,float64,non,triton,SUPPORTED +spmv,COO,int64,float32,non,triton,SUPPORTED +spmv,COO,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,float16,non,triton,SUPPORTED +spmv,COO->CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int32,float32,non,triton,SUPPORTED +spmv,COO->CSR,int32,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex128,non,triton,SUPPORTED +spmv,COO->CSR,int64,float16,non,triton,SUPPORTED +spmv,COO->CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int64,float32,non,triton,SUPPORTED +spmv,COO->CSR,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int32,float16,non,triton,SUPPORTED +spmv,CSR,int32,float16,trans,triton,SUPPORTED +spmv,CSR,int32,float16,conj,triton,SUPPORTED +spmv,CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int32,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int32,float32,non,triton,SUPPORTED +spmv,CSR,int32,float32,trans,triton,SUPPORTED +spmv,CSR,int32,float32,conj,triton,SUPPORTED +spmv,CSR,int32,float64,non,triton,SUPPORTED +spmv,CSR,int32,float64,trans,triton,SUPPORTED +spmv,CSR,int32,float64,conj,triton,SUPPORTED +spmv,CSR,int32,complex64,non,triton,SUPPORTED +spmv,CSR,int32,complex64,trans,triton,SUPPORTED +spmv,CSR,int32,complex64,conj,triton,SUPPORTED +spmv,CSR,int32,complex128,non,triton,SUPPORTED +spmv,CSR,int32,complex128,trans,triton,SUPPORTED +spmv,CSR,int32,complex128,conj,triton,SUPPORTED +spmv,CSR,int64,float16,non,triton,SUPPORTED +spmv,CSR,int64,float16,trans,triton,SUPPORTED +spmv,CSR,int64,float16,conj,triton,SUPPORTED +spmv,CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int64,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int64,float32,non,triton,SUPPORTED +spmv,CSR,int64,float32,trans,triton,SUPPORTED +spmv,CSR,int64,float32,conj,triton,SUPPORTED +spmv,CSR,int64,float64,non,triton,SUPPORTED +spmv,CSR,int64,float64,trans,triton,SUPPORTED +spmv,CSR,int64,float64,conj,triton,SUPPORTED +spmv,CSR,int64,complex64,non,triton,SUPPORTED +spmv,CSR,int64,complex64,trans,triton,SUPPORTED +spmv,CSR,int64,complex64,conj,triton,SUPPORTED +spmv,CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int64,complex128,trans,triton,SUPPORTED +spmv,CSR,int64,complex128,conj,triton,SUPPORTED +spsm,COO,int32,float32,non,triton,SUPPORTED +spsm,COO,int32,float64,non,triton,SUPPORTED +spsm,COO,int64,float32,non,triton,SUPPORTED +spsm,COO,int64,float64,non,triton,SUPPORTED +spsm,CSR,int32,float32,non,triton,SUPPORTED +spsm,CSR,int32,float64,non,triton,SUPPORTED +spsm,CSR,int64,float32,non,triton,SUPPORTED +spsm,CSR,int64,float64,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,trans,triton,SUPPORTED +spsv,COO,int32,float32,non,triton,SUPPORTED +spsv,COO,int32,float32,trans,triton,SUPPORTED +spsv,COO,int32,float64,non,triton,SUPPORTED +spsv,COO,int32,float64,trans,triton,SUPPORTED +spsv,COO,int32,complex32,non,triton,SUPPORTED +spsv,COO,int32,complex32,trans,triton,SUPPORTED +spsv,COO,int32,complex64,non,triton,SUPPORTED +spsv,COO,int32,complex64,trans,triton,SUPPORTED +spsv,COO,int64,bfloat16,non,triton,SUPPORTED +spsv,COO,int64,bfloat16,trans,triton,SUPPORTED +spsv,COO,int64,float32,non,triton,SUPPORTED +spsv,COO,int64,float32,trans,triton,SUPPORTED +spsv,COO,int64,float64,non,triton,SUPPORTED +spsv,COO,int64,float64,trans,triton,SUPPORTED +spsv,COO,int64,complex32,non,triton,SUPPORTED +spsv,COO,int64,complex32,trans,triton,SUPPORTED +spsv,COO,int64,complex64,non,triton,SUPPORTED +spsv,COO,int64,complex64,trans,triton,SUPPORTED +spsv,CSR,int32,bfloat16,non,triton,SUPPORTED +spsv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int32,float32,non,triton,SUPPORTED +spsv,CSR,int32,float32,trans,triton,SUPPORTED +spsv,CSR,int32,float64,non,triton,SUPPORTED +spsv,CSR,int32,float64,trans,triton,SUPPORTED +spsv,CSR,int32,complex32,non,triton,SUPPORTED +spsv,CSR,int32,complex32,trans,triton,SUPPORTED +spsv,CSR,int32,complex64,non,triton,SUPPORTED +spsv,CSR,int32,complex64,trans,triton,SUPPORTED +spsv,CSR,int64,bfloat16,non,triton,SUPPORTED +spsv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int64,float32,non,triton,SUPPORTED +spsv,CSR,int64,float32,trans,triton,SUPPORTED +spsv,CSR,int64,float64,non,triton,SUPPORTED +spsv,CSR,int64,float64,trans,triton,SUPPORTED +spsv,CSR,int64,complex32,non,triton,SUPPORTED +spsv,CSR,int64,complex32,trans,triton,SUPPORTED +spsv,CSR,int64,complex64,non,triton,SUPPORTED +spsv,CSR,int64,complex64,trans,triton,SUPPORTED diff --git a/src/flagsparse/__init__.py b/src/flagsparse/__init__.py index b138f4d..bcdb734 100644 --- a/src/flagsparse/__init__.py +++ b/src/flagsparse/__init__.py @@ -45,6 +45,26 @@ "flagsparse_spmm_csr_opt_alg2_symbolic", "flagsparse_spsv_csr", "flagsparse_spsv_coo", + "flagsparse_spsv_buffer_size", + "flagsparse_spsv_buffer_size_ex", + "flagsparse_spsv_analysis_csr", + "flagsparse_spsv_analysis_coo", + "flagsparse_spsv_analysis_ex", + "flagsparse_spsv_solve_csr", + "flagsparse_spsv_solve_coo", + "flagsparse_spsv_solve_ex", + "flagsparse_spsv_preprocess_csr", + "flagsparse_spsv_preprocess_coo", + "flagsparse_spsv_create_workspace", + "flagsparse_create_spsv_handle", + "flagsparse_create_spmat_csr", + "flagsparse_create_spmat_coo", + "flagsparse_create_dnvec", + "FlagSparseSpSVDescr", + "FlagSparseSpSVHandle", + "FlagSparseSpSVWorkspace", + "FlagSparseSpMatDescr", + "FlagSparseDnVecDescr", "flagsparse_spsm_csr", "flagsparse_spsm_coo", "benchmark_spsm_case", @@ -126,6 +146,26 @@ "flagsparse_spmm_csr_opt_alg2_symbolic", "flagsparse_spsv_csr", "flagsparse_spsv_coo", + "flagsparse_spsv_buffer_size", + "flagsparse_spsv_buffer_size_ex", + "flagsparse_spsv_analysis_csr", + "flagsparse_spsv_analysis_coo", + "flagsparse_spsv_analysis_ex", + "flagsparse_spsv_solve_csr", + "flagsparse_spsv_solve_coo", + "flagsparse_spsv_solve_ex", + "flagsparse_spsv_preprocess_csr", + "flagsparse_spsv_preprocess_coo", + "flagsparse_spsv_create_workspace", + "flagsparse_create_spsv_handle", + "flagsparse_create_spmat_csr", + "flagsparse_create_spmat_coo", + "flagsparse_create_dnvec", + "FlagSparseSpSVDescr", + "FlagSparseSpSVHandle", + "FlagSparseSpSVWorkspace", + "FlagSparseSpMatDescr", + "FlagSparseDnVecDescr", "flagsparse_spsm_csr", "flagsparse_spsm_coo", "benchmark_spsm_case", @@ -141,8 +181,8 @@ "benchmark_spgemm_case", "benchmark_sddmm_case", "comprehensive_spmm_test", - "benchmark_spmv_case", "comprehensive_spsm_test", + "benchmark_spmv_case", } _FORMAT_EXPORTS = { diff --git a/src/flagsparse/sparse_operations/__init__.py b/src/flagsparse/sparse_operations/__init__.py index 9114c46..fcb0758 100644 --- a/src/flagsparse/sparse_operations/__init__.py +++ b/src/flagsparse/sparse_operations/__init__.py @@ -9,6 +9,21 @@ is_alpha_spmm_alg1_tle_available, prepare_alpha_spmm_alg1, prepare_alpha_spmm_alg1_tle, +from ._common import SUPPORTED_INDEX_DTYPES, SUPPORTED_VALUE_DTYPES +from .benchmarks import ( + benchmark_gather_case, + benchmark_performance, + benchmark_scatter_case, + benchmark_sddmm_case, + benchmark_spgemm_case, + benchmark_spmm_case, + benchmark_spmm_opt_case, + benchmark_spmv_case, + benchmark_spsm_case, + comprehensive_gather_test, + comprehensive_scatter_test, + comprehensive_spmm_test, + comprehensive_spsm_test, ) from .gather_scatter import ( cusparse_spmv_gather, @@ -64,6 +79,34 @@ "comprehensive_scatter_test", "comprehensive_spsm_test", } +from .spmm_coo import flagsparse_spmm_coo +from .spgemm_csr import SpGEMMPrepared, flagsparse_spgemm_csr, prepare_spgemm_csr +from .sddmm_csr import SDDMMPrepared, flagsparse_sddmm_csr, prepare_sddmm_csr +from .spsv import ( + FlagSparseDnVecDescr, + FlagSparseSpMatDescr, + FlagSparseSpSVDescr, + FlagSparseSpSVHandle, + FlagSparseSpSVWorkspace, + flagsparse_create_dnvec, + flagsparse_create_spmat_coo, + flagsparse_create_spmat_csr, + flagsparse_create_spsv_handle, + flagsparse_spsv_analysis_coo, + flagsparse_spsv_analysis_csr, + flagsparse_spsv_analysis_ex, + flagsparse_spsv_buffer_size, + flagsparse_spsv_buffer_size_ex, + flagsparse_spsv_coo, + flagsparse_spsv_create_workspace, + flagsparse_spsv_csr, + flagsparse_spsv_preprocess_coo, + flagsparse_spsv_preprocess_csr, + flagsparse_spsv_solve_ex, + flagsparse_spsv_solve_coo, + flagsparse_spsv_solve_csr, +) +from .spsm import flagsparse_spsm_coo, flagsparse_spsm_csr __all__ = [ "PreparedCoo", @@ -107,8 +150,28 @@ "flagsparse_spmv_csr", "flagsparse_spsm_coo", "flagsparse_spsm_csr", + "FlagSparseDnVecDescr", + "FlagSparseSpMatDescr", + "FlagSparseSpSVDescr", + "FlagSparseSpSVHandle", + "FlagSparseSpSVWorkspace", + "flagsparse_create_dnvec", + "flagsparse_create_spmat_coo", + "flagsparse_create_spmat_csr", + "flagsparse_create_spsv_handle", + "flagsparse_spsv_analysis_coo", + "flagsparse_spsv_analysis_csr", + "flagsparse_spsv_analysis_ex", + "flagsparse_spsv_buffer_size", + "flagsparse_spsv_buffer_size_ex", "flagsparse_spsv_coo", + "flagsparse_spsv_create_workspace", "flagsparse_spsv_csr", + "flagsparse_spsv_preprocess_coo", + "flagsparse_spsv_preprocess_csr", + "flagsparse_spsv_solve_ex", + "flagsparse_spsv_solve_coo", + "flagsparse_spsv_solve_csr", "prepare_sddmm_csr", "prepare_alpha_spmm_alg1", "prepare_spgemm_csr", diff --git a/src/flagsparse/sparse_operations/_common.py b/src/flagsparse/sparse_operations/_common.py index 7e89526..6f7725d 100644 --- a/src/flagsparse/sparse_operations/_common.py +++ b/src/flagsparse/sparse_operations/_common.py @@ -18,19 +18,14 @@ cp = None cpx_sparse = None -_TORCH_COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if _TORCH_COMPLEX32_DTYPE is None: - _TORCH_COMPLEX32_DTYPE = getattr(torch, "chalf", None) - _SUPPORTED_VALUE_DTYPES = [ torch.float16, torch.bfloat16, torch.float32, torch.float64, + torch.complex64, + torch.complex128, ] -if _TORCH_COMPLEX32_DTYPE is not None: - _SUPPORTED_VALUE_DTYPES.append(_TORCH_COMPLEX32_DTYPE) -_SUPPORTED_VALUE_DTYPES.extend([torch.complex64, torch.complex128]) SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES) SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) _INDEX_LIMIT_INT32 = 2**31 - 1 @@ -40,8 +35,6 @@ "SUPPORTED_VALUE_DTYPES", "SUPPORTED_INDEX_DTYPES", "_INDEX_LIMIT_INT32", - "_torch_complex32_dtype", - "_is_complex32_dtype", "_is_complex_dtype", "_resolve_scatter_value_dtype", "_component_dtype_for_complex", @@ -68,17 +61,8 @@ "tl", ) - -def _torch_complex32_dtype(): - return _TORCH_COMPLEX32_DTYPE - - -def _is_complex32_dtype(value_dtype): - return _TORCH_COMPLEX32_DTYPE is not None and value_dtype == _TORCH_COMPLEX32_DTYPE - - def _is_complex_dtype(value_dtype): - return _is_complex32_dtype(value_dtype) or value_dtype in (torch.complex64, torch.complex128) + return value_dtype in (torch.complex64, torch.complex128) def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): @@ -95,24 +79,13 @@ def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): "complex64": torch.complex64, "complex128": torch.complex128, } - if token == "complex32": - if _TORCH_COMPLEX32_DTYPE is not None: - return _TORCH_COMPLEX32_DTYPE, False, None - if dtype_policy == "strict": - raise TypeError("complex32 is unavailable in this torch build") - return torch.complex64, True, "complex32 is unavailable; fallback to complex64" if token not in mapping: raise TypeError(f"Unsupported dtype token: {value_dtype}") value_dtype = mapping[token] - if _is_complex32_dtype(value_dtype): - # If complex32 exists in torch dtype table, keep native path. - return value_dtype, False, None return value_dtype, False, None def _component_dtype_for_complex(value_dtype): - if _is_complex32_dtype(value_dtype): - return torch.float16 if value_dtype == torch.complex64: return torch.float32 if value_dtype == torch.complex128: @@ -123,8 +96,6 @@ def _component_dtype_for_complex(value_dtype): def _tolerance_for_dtype(value_dtype): if value_dtype == torch.float16: return 2e-3, 2e-3 - if _is_complex32_dtype(value_dtype): - return 5e-3, 5e-3 if value_dtype == torch.bfloat16: return 1e-1, 1e-1 if value_dtype in (torch.float32, torch.complex64): @@ -155,9 +126,6 @@ def _cupy_dtype_from_torch(torch_dtype): torch.int32: cp.int32, torch.int64: cp.int64, } - if _TORCH_COMPLEX32_DTYPE is not None: - # CuPy has no native complex32 sparse path; use complex64 for baseline parity. - mapping[_TORCH_COMPLEX32_DTYPE] = cp.complex64 if torch_dtype not in mapping: raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}") return mapping[torch_dtype] @@ -193,8 +161,6 @@ def _to_backend_like(torch_tensor, ref_obj): def _cusparse_baseline_skip_reason(value_dtype): if value_dtype == torch.bfloat16: return "bfloat16 is not supported by the cuSPARSE baseline path; skipped" - if _is_complex32_dtype(value_dtype): - return "complex32 is not supported by the cuSPARSE baseline path; skipped" if cp is None and value_dtype == torch.float16: return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped" return None @@ -203,9 +169,6 @@ def _cusparse_baseline_skip_reason(value_dtype): def _build_random_dense(dense_size, value_dtype, device): if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): return torch.randn(dense_size, dtype=value_dtype, device=device) - if _is_complex32_dtype(value_dtype): - stacked = torch.randn((dense_size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) if _is_complex_dtype(value_dtype): component_dtype = _component_dtype_for_complex(value_dtype) real = torch.randn(dense_size, dtype=component_dtype, device=device) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 6464a37..9c6863f 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -3,6 +3,8 @@ from ._common import * from collections import OrderedDict +from contextlib import nullcontext +from dataclasses import dataclass, field import os import time import triton @@ -46,14 +48,153 @@ def _spsv_env_flag(name, default="0"): SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 = _spsv_env_flag( "FLAGSPARSE_SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128", "0" ) -SPSV_ENABLE_CSR_CW = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_CSR_CW", "0") -SPSV_ENABLE_TRANSPOSE_CW = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_TRANSPOSE_CW", "0") -SPSV_ENABLE_LEVEL_FRONTIERS = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_LEVEL_FRONTIERS", "0") -SPSV_ENABLE_REVERSE_FRONTIERS = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_REVERSE_FRONTIERS", "0") _SPSV_CSR_PREPROCESS_CACHE = OrderedDict() _SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 +@dataclass +class FlagSparseSpSVDescr: + """Host-side analysis handle for Triton SpSV. + + This is the Triton/Python equivalent of the CUDA-side SpSV descriptor: + it stores the analyzed matrix metadata, the selected solve route, and the + workspace layout needed by the current implementation. + """ + + format: str + canonical_format: str + shape: tuple + lower: bool + unit_diagonal: bool + fill_mode: str + diag_type: str + matrix_type: str + index_base: int + transpose_mode: str + value_dtype: torch.dtype + compute_dtype: torch.dtype + index_dtype: torch.dtype + solve_kind: str + route_name: str + storage_view: str + buffer_size: int + workspace_layout: tuple + data: torch.Tensor = field(repr=False) + indices: torch.Tensor = field(repr=False) + indptr: torch.Tensor = field(repr=False) + solve_plan: dict = field(repr=False) + + +@dataclass +class FlagSparseSpSVWorkspace: + """Caller-owned workspace object for Triton SpSV host APIs.""" + + buffer_size: int + layout: tuple + device: torch.device + buffers: dict = field(default_factory=dict, repr=False) + prepared_solve_kind: str = "" + prepared_signature: tuple | None = None + + +@dataclass +class FlagSparseSpSVHandle: + """Host-side execution handle for Triton SpSV.""" + + device: torch.device + stream: object = None + + +@dataclass +class FlagSparseSpMatDescr: + """Sparse matrix descriptor mirroring the CUDA SpMat inputs.""" + + format: str + shape: tuple + values: torch.Tensor = field(repr=False) + indices: torch.Tensor = field(repr=False) + indptr_or_col: torch.Tensor = field(repr=False) + lower: bool = True + unit_diagonal: bool = False + diag_type: str = "non_unit" + fill_mode: str = "lower" + matrix_type: str = "triangular" + index_base: int = 0 + + +@dataclass +class FlagSparseDnVecDescr: + """Dense vector descriptor mirroring the CUDA DnVec inputs.""" + + values: torch.Tensor = field(repr=False) + + +def flagsparse_create_spsv_handle(device=None, stream=None): + if device is None: + device = torch.device("cuda") + return FlagSparseSpSVHandle(device=torch.device(device), stream=stream) + + +def flagsparse_create_dnvec(values): + if not torch.is_tensor(values): + raise TypeError("values must be a torch.Tensor") + if values.ndim != 1: + raise ValueError("DnVec values must be 1D") + return FlagSparseDnVecDescr(values=values) + + +def flagsparse_create_spmat_csr( + values, + indices, + indptr, + shape, + *, + lower=True, + unit_diagonal=False, + matrix_type="triangular", + index_base=0, +): + return FlagSparseSpMatDescr( + format="csr", + shape=(int(shape[0]), int(shape[1])), + values=values, + indices=indices, + indptr_or_col=indptr, + lower=bool(lower), + unit_diagonal=bool(unit_diagonal), + diag_type="unit" if unit_diagonal else "non_unit", + fill_mode="lower" if lower else "upper", + matrix_type=str(matrix_type), + index_base=int(index_base), + ) + + +def flagsparse_create_spmat_coo( + values, + row, + col, + shape, + *, + lower=True, + unit_diagonal=False, + matrix_type="triangular", + index_base=0, +): + return FlagSparseSpMatDescr( + format="coo", + shape=(int(shape[0]), int(shape[1])), + values=values, + indices=row, + indptr_or_col=col, + lower=bool(lower), + unit_diagonal=bool(unit_diagonal), + diag_type="unit" if unit_diagonal else "non_unit", + fill_mode="lower" if lower else "upper", + matrix_type=str(matrix_type), + index_base=int(index_base), + ) + + def _clear_spsv_csr_preprocess_cache(): _SPSV_CSR_PREPROCESS_CACHE.clear() @@ -78,12 +219,6 @@ def _attach_spsv_complex_plan_views(plan): if kernel_data is None or not torch.is_complex(kernel_data): return plan plan["kernel_data_ri"] = _complex_interleaved_view(kernel_data) - transpose_diag = plan.get("transpose_diag") - if transpose_diag is not None and torch.is_complex(transpose_diag): - plan["transpose_diag_ri"] = _complex_interleaved_view(transpose_diag) - cw_diag = plan.get("cw_diag") - if cw_diag is not None and torch.is_complex(cw_diag): - plan["cw_diag_ri"] = _complex_interleaved_view(cw_diag) return plan @@ -132,18 +267,16 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): raise ValueError("data, indices, indptr, b must all be CUDA tensors") if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: raise ValueError("data, indices, indptr must be 1D") - if b.ndim not in (1, 2): - raise ValueError("b must be 1D or 2D (vector or multiple RHS)") + if b.ndim != 1: + raise ValueError("b must be a 1D dense vector (DnVec)") n_rows, n_cols = int(shape[0]), int(shape[1]) if indptr.numel() != n_rows + 1: raise ValueError(f"indptr length must be n_rows+1={n_rows + 1}") if data.numel() != indices.numel(): raise ValueError("data and indices must have the same length (nnz)") - if b.ndim == 1 and b.numel() != n_rows: + if b.numel() != n_rows: raise ValueError(f"b length must equal n_rows={n_rows}") - if b.ndim == 2 and b.shape[0] != n_rows: - raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( @@ -219,12 +352,251 @@ def _spsv_cache_put(cache, key, value, max_entries): cache.popitem(last=False) -def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode, unit_diagonal): +def _normalize_spsv_format(fmt): + token = str(fmt).strip().lower() + if token not in ("csr", "coo"): + raise ValueError("format must be 'csr' or 'coo'") + return token + + +def _normalize_spsv_storage_view(storage_view): + if storage_view is None: + return "csr_as_csc" + token = str(storage_view).strip().lower() + aliases = { + "csr_as_csc": "csr_as_csc", + "csc_view": "csr_as_csc", + "reuse_csr_storage": "csr_as_csc", + } + if token not in aliases: + raise ValueError( + "storage_view must be one of: csr_as_csc, csc_view, reuse_csr_storage" + ) + return aliases[token] + + +def _resolve_spsv_stream(handle, stream, device): + resolved = stream + if handle is not None: + if not isinstance(handle, FlagSparseSpSVHandle): + raise TypeError("handle must be a FlagSparseSpSVHandle or None") + if torch.device(handle.device) != torch.device(device): + raise ValueError("handle device must match the solve device") + if resolved is None: + resolved = handle.stream + return resolved + + +def _coerce_spsv_alpha(alpha, dtype, device): + if torch.is_tensor(alpha): + alpha_tensor = alpha.to(device=device, dtype=dtype).reshape(-1) + if alpha_tensor.numel() != 1: + raise ValueError("alpha must be a scalar tensor") + return alpha_tensor.reshape(()) + return torch.tensor(alpha, device=device, dtype=dtype) + + +def _workspace_entry(name, numel, dtype): + return { + "name": str(name), + "numel": int(numel), + "dtype": dtype, + "bytes": int(numel) * int(torch.empty((), dtype=dtype).element_size()), + } + + +def _workspace_size_bytes(layout): + return int(sum(int(entry["bytes"]) for entry in layout)) + + +def _spsv_effective_compute_dtype(value_dtype, trans_mode, compute_dtype=None): + if compute_dtype is not None: + if compute_dtype not in SUPPORTED_SPSV_VALUE_DTYPES: + raise TypeError( + "compute_dtype must be one of: float32, float64, complex64, complex128" + ) + return compute_dtype + if ( + value_dtype == torch.complex64 + and trans_mode in ("T", "C") + and SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 + ): + return torch.complex128 + if value_dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + return torch.float64 + if ( + value_dtype == torch.float32 + and trans_mode in ("T", "C") + and SPSV_PROMOTE_TRANSPOSE_FP32_TO_FP64 + ): + return torch.float64 + return value_dtype + + +def _build_spsv_workspace_layout(n_rows, solve_kind, value_dtype=None): + n_rows = int(n_rows) + if solve_kind == "csr_cw": + return ( + _workspace_entry("ready", n_rows, torch.int32), + _workspace_entry("row_counter", 1, torch.int32), + ) + if solve_kind == "csr_roc": + return ( + _workspace_entry("ready", n_rows, torch.int32), + ) + if solve_kind == "csr_cw_levelschd": + return ( + _workspace_entry("ready", n_rows, torch.int32), + ) + if solve_kind == "csr_nnz_balance": + if value_dtype is None: + raise ValueError("value_dtype is required for csr_nnz_balance workspace sizing") + return ( + _workspace_entry("tmp_sum", n_rows, value_dtype), + _workspace_entry("ready", n_rows, torch.int32), + _workspace_entry("indegree", n_rows, torch.int32), + ) + if solve_kind == "transpose_cw": + if value_dtype is None: + raise ValueError("value_dtype is required for transpose_cw workspace sizing") + return ( + _workspace_entry("residual", n_rows, value_dtype), + _workspace_entry("indegree", n_rows, torch.int32), + _workspace_entry("row_counter", 1, torch.int32), + ) + raise ValueError(f"Unsupported SpSV solve kind for workspace sizing: {solve_kind}") + + +def _clone_spsv_plan(plan): + cloned = dict(plan) + matrix_stats = plan.get("matrix_stats") + if matrix_stats is not None: + cloned["matrix_stats"] = dict(matrix_stats) + return cloned + + +def _alloc_spsv_workspace_buffers(layout, device): + buffers = {} + for entry in layout: + buffers[entry["name"]] = torch.empty( + int(entry["numel"]), dtype=entry["dtype"], device=device + ) + return buffers + + +def _resolve_spsv_workspace(workspace, layout, device): + if workspace is None: + return _alloc_spsv_workspace_buffers(layout, device) + if not isinstance(workspace, FlagSparseSpSVWorkspace): + raise TypeError("workspace must be a FlagSparseSpSVWorkspace or None") + if torch.device(workspace.device) != torch.device(device): + raise ValueError("workspace device must match the solve device") + if int(workspace.buffer_size) < _workspace_size_bytes(layout): + raise ValueError("workspace buffer is smaller than the required SpSV size") + required = {entry["name"]: entry for entry in layout} + for name, entry in required.items(): + buf = workspace.buffers.get(name) + if buf is None: + workspace.buffers[name] = torch.empty( + int(entry["numel"]), dtype=entry["dtype"], device=device + ) + continue + if buf.device != torch.device(device): + raise ValueError(f"workspace buffer {name!r} is on the wrong device") + if buf.dtype != entry["dtype"] or int(buf.numel()) < int(entry["numel"]): + workspace.buffers[name] = torch.empty( + int(entry["numel"]), dtype=entry["dtype"], device=device + ) + return workspace.buffers + + +def _transpose_cw_preprocess_signature( + solve_plan, n_rows, unit_diagonal, block_nnz_use, max_segments_use +): + kernel_indices32 = solve_plan["kernel_indices32"] + kernel_indptr64 = solve_plan["kernel_indptr64"] + return ( + "transpose_cw", + int(n_rows), + bool(solve_plan["lower_eff"]), + bool(unit_diagonal), + int(block_nnz_use), + int(max_segments_use), + _tensor_cache_token(kernel_indices32), + _tensor_cache_token(kernel_indptr64), + ) + + +def flagsparse_spsv_buffer_size( + shape, + value_dtype, + *, + format="csr", + transpose=False, + solve_kind=None, + compute_dtype=None, + alpha=None, + handle=None, + vecX=None, + vecY=None, + storage_view="csr_as_csc", +): + """Return the caller-managed workspace size for the current Triton SpSV route. + + This is the Triton host-side equivalent of the CUDA bufferSize query. + The returned byte count matches the scratch buffers used by the current + Triton implementation, rather than the raw CUDA ABI layout. + """ + + fmt = _normalize_spsv_format(format) + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows != n_cols: + raise ValueError(f"SpSV expects a square matrix, got shape={shape}") + if value_dtype not in SUPPORTED_SPSV_VALUE_DTYPES: + raise TypeError( + "value_dtype must be one of: float32, float64, complex64, complex128" + ) + trans_mode = _normalize_spsv_transpose_mode(transpose) + storage_view = _normalize_spsv_storage_view(storage_view) + compute_dtype = _spsv_effective_compute_dtype( + value_dtype, trans_mode, compute_dtype=compute_dtype + ) + route = _normalize_requested_spsv_route(solve_kind, trans_mode) + if route is None: + route = "transpose_cw" if trans_mode in ("T", "C") else "csr_cw" + if trans_mode in ("T", "C") and storage_view != "csr_as_csc": + raise ValueError("TRANS/CONJ SpSV only supports storage_view='csr_as_csc'") + layout = _build_spsv_workspace_layout(n_rows, route, value_dtype=compute_dtype) + return _workspace_size_bytes(layout) + + +def flagsparse_spsv_create_workspace(descr, device=None): + """Allocate a caller-owned SpSV workspace object from an analysis descriptor.""" + + if not isinstance(descr, FlagSparseSpSVDescr): + raise TypeError("descr must be a FlagSparseSpSVDescr") + if device is None: + device = descr.data.device + device = torch.device(device) + buffers = _alloc_spsv_workspace_buffers(descr.workspace_layout, device) + return FlagSparseSpSVWorkspace( + buffer_size=int(descr.buffer_size), + layout=tuple(descr.workspace_layout), + device=device, + buffers=buffers, + ) + + +def _csr_preprocess_cache_key( + data, indices, indptr, shape, lower, trans_mode, unit_diagonal, requested_route=None, storage_view="csr_as_csc" +): return ( "csr_preprocess", trans_mode, bool(lower), bool(unit_diagonal), + str(requested_route), + str(storage_view), int(shape[0]), int(shape[1]), _tensor_cache_token(data), @@ -233,121 +605,65 @@ def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode, u ) -def _build_spsv_frontiers(indptr, indices, levels, lower=True): - """Greedily merge rows from adjacent levels when they do not depend on the - currently active frontier. - - This keeps the same correctness contract as strict level scheduling while - trimming some kernel launches on matrices with narrow but not fully - serialized dependency wavefronts. - """ - if not levels: - return [] - - indptr_h = indptr.to(torch.int64).cpu() - indices_h = indices.to(torch.int64).cpu() - device = indptr.device - frontier_rows = [] - frontier_row_set = set() - merged = [] - - def _flush_frontier(): - nonlocal frontier_rows, frontier_row_set - if frontier_rows: - merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) - frontier_rows = [] - frontier_row_set = set() - - for rows_lv in levels: - for row in rows_lv.to(torch.int64).cpu().tolist(): - start = int(indptr_h[row].item()) - end = int(indptr_h[row + 1].item()) - depends_on_frontier = False - for p in range(start, end): - col = int(indices_h[p].item()) - if lower: - is_dep = col < row - else: - is_dep = col > row - if is_dep and col in frontier_row_set: - depends_on_frontier = True - break - if depends_on_frontier: - _flush_frontier() - frontier_rows.append(int(row)) - frontier_row_set.add(int(row)) - _flush_frontier() - return merged - +def _normalize_requested_spsv_route(solve_kind, trans_mode): + if solve_kind is None: + return None + token = str(solve_kind).strip().lower() + aliases = { + "csr_cw": "csr_cw", + "csr_roc": "csr_roc", + "roc": "csr_roc", + "alg3": "csr_nnz_balance", + "csr_levelschd": "csr_cw_levelschd", + "csr_cw_levelschd": "csr_cw_levelschd", + "levelschd": "csr_cw_levelschd", + "level_sched": "csr_cw_levelschd", + "alg2": "csr_cw_levelschd", + "csr_nnz_balance": "csr_nnz_balance", + "nnz_balance": "csr_nnz_balance", + "alg8": "csr_nnz_balance", + "cw": "csr_cw" if trans_mode == "N" else "transpose_cw", + "transpose_cw": "transpose_cw", + "csc_cw": "transpose_cw", + } + route = aliases.get(token) + if route is None: + raise ValueError( + "solve_kind must be one of: csr_cw, csr_roc, csr_cw_levelschd, csr_nnz_balance, transpose_cw" + ) + if trans_mode in ("T", "C") and route != "transpose_cw": + raise ValueError("TRANS/CONJ SpSV only supports solve_kind='transpose_cw'") + if trans_mode == "N" and route == "transpose_cw": + raise ValueError("NON_TRANS SpSV cannot use solve_kind='transpose_cw'") + return route -def _build_spsv_reverse_frontiers(indptr, indices, levels, lower=True): - """Greedily merge reverse-topological launch groups for transpose push. - Rows processed by the currently active reverse frontier push residual - updates into their dependency targets. A candidate row can be merged only - when no active row would update it; otherwise it must be delayed until the - current frontier completes. - """ - if not levels: - return [] - - indptr_h = indptr.to(torch.int64).cpu() - indices_h = indices.to(torch.int64).cpu() - device = indptr.device - dependency_targets = {} - for rows_lv in levels: - for row in rows_lv.to(torch.int64).cpu().tolist(): - start = int(indptr_h[row].item()) - end = int(indptr_h[row + 1].item()) - targets = set() - for p in range(start, end): - col = int(indices_h[p].item()) - is_dep = (col < row) if lower else (col > row) - if is_dep: - targets.add(col) - dependency_targets[int(row)] = targets - - frontier_rows = [] - frontier_targets = set() - merged = [] - - def _flush_frontier(): - nonlocal frontier_rows, frontier_targets - if frontier_rows: - merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) - frontier_rows = [] - frontier_targets = set() - - for rows_lv in reversed(levels): - for row in rows_lv.to(torch.int64).cpu().tolist(): - if int(row) in frontier_targets: - _flush_frontier() - frontier_rows.append(int(row)) - frontier_targets.update(dependency_targets.get(int(row), ())) - _flush_frontier() - return merged - - -def _prepare_spsv_transpose_cw_metadata(data, indices64, indptr64, n_rows, lower, unit_diagonal=False): - indegree = torch.zeros(n_rows, dtype=torch.int32, device=data.device) - diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) - if n_rows == 0: - return diag, indegree - row_ids = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr64[1:] - indptr64[:-1], - ) - dep_mask = indices64 < row_ids if lower else indices64 > row_ids - if dep_mask.numel() > 0: - dep_rows = row_ids[dep_mask] - dep_counts = torch.bincount(dep_rows, minlength=n_rows) - indegree.copy_(dep_counts.to(torch.int32)) - if not unit_diagonal: - indegree.add_(1) - diag_mask = indices64 == row_ids - if bool(torch.any(diag_mask).item()): - diag.scatter_(0, row_ids[diag_mask], data[diag_mask]) - return diag, indegree +@triton.jit +def _spsv_csc_preprocess_kernel( + indices_ptr, + indptr_ptr, + indegree_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, +): + col = tl.program_id(0) + if col >= n_rows: + return + start = tl.load(indptr_ptr + col) + end = tl.load(indptr_ptr + col + 1) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + row = tl.load(indices_ptr + offsets, mask=mask, other=0) + if LOWER: + dep_mask = mask & (row > col if UNIT_DIAG else row >= col) + else: + dep_mask = mask & (row < col if UNIT_DIAG else row <= col) + tl.atomic_add(indegree_ptr + row, 1, mask=dep_mask) def _sort_csr_rows(data, indices64, indptr64, n_rows, n_cols, lower=True): @@ -369,22 +685,6 @@ def _sort_csr_rows(data, indices64, indptr64, n_rows, n_cols, lower=True): return data[order], indices64[order], indptr64 -def _prepare_spsv_nontrans_cw_metadata(data, indices64, indptr64, n_rows, unit_diagonal=False): - diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) - if n_rows == 0: - return diag - if unit_diagonal: - return diag - row_ids = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr64[1:] - indptr64[:-1], - ) - diag_mask = indices64 == row_ids - if bool(torch.any(diag_mask).item()): - diag.scatter_(0, row_ids[diag_mask], data[diag_mask]) - return diag - - def _cw_rhs_bucket(n_rhs): if n_rhs <= 1: return 1 @@ -416,54 +716,129 @@ def _cw_worker_count(n_rows, max_frontier, avg_nnz_per_row, n_rhs): if n_rows <= 0: return 1 rhs_bucket = _cw_rhs_bucket(n_rhs) - target = max(256, min(n_rows, 4096)) + if rhs_bucket == 1: + target = min(n_rows, 32) + else: + target = max(32, min(n_rows, 512)) if max_frontier > 0: - target = min(target, max(128, min(n_rows, max_frontier * 8))) - if avg_nnz_per_row > 2048: - target = max(128, target // 2) + target = min(target, max(4, min(n_rows, max_frontier * 2))) + if avg_nnz_per_row > 8192: + target = max(4, target // 8) + elif avg_nnz_per_row > 4096: + target = max(4, target // 4) + elif avg_nnz_per_row > 2048: + target = max(4, target // 2) + elif avg_nnz_per_row > 1024: + target = max(8, (target * 2) // 3) + elif avg_nnz_per_row > 512: + target = max(8, (target * 3) // 5) if rhs_bucket >= 16: - target = max(64, target // 4) + target = max(4, target // 4) elif rhs_bucket >= 8: - target = max(64, target // 2) + target = max(4, target // 2) elif rhs_bucket >= 4: - target = max(128, (target * 3) // 4) + target = max(4, (target * 3) // 4) return _snap_cw_worker_count(target, n_rows) + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) def _resolve_cw_worker_count(n_rows, matrix_stats, n_rhs, cached_worker_count=None): rhs_bucket = _cw_rhs_bucket(n_rhs) + max_frontier = int(matrix_stats.get("max_frontier", n_rows)) + avg_frontier = float(matrix_stats.get("avg_frontier", float(max_frontier))) + frontier_ratio = float(matrix_stats.get("frontier_ratio", 1.0 if n_rows > 0 else 0.0)) + num_levels = int(matrix_stats.get("num_levels", 0)) + avg_nnz_per_row = float(matrix_stats.get("avg_nnz_per_row", 0.0)) if cached_worker_count is not None and rhs_bucket == 1: - return int(max(1, min(int(cached_worker_count), int(max(n_rows, 1))))) - return _cw_worker_count( + target = int(max(1, min(int(cached_worker_count), int(max(n_rows, 1))))) + else: + target = _cw_worker_count( + n_rows, + max_frontier, + avg_nnz_per_row, + rhs_bucket, + ) + if frontier_ratio < 0.01 or avg_frontier < 4.0: + target = min(target, max(1, min(n_rows, 4))) + elif frontier_ratio < 0.02 or avg_frontier < 8.0: + target = min(target, max(1, min(n_rows, 8))) + elif frontier_ratio < 0.05 or avg_frontier < 16.0: + target = min(target, max(2, min(n_rows, 16))) + if num_levels > max(1024, n_rows // 2): + target = max(1, target // 2) + if avg_nnz_per_row > 2048: + target = max(1, target // 2) + return _snap_cw_worker_count( + target, n_rows, - int(matrix_stats.get("max_frontier", n_rows)), - float(matrix_stats.get("avg_nnz_per_row", 0.0)), - rhs_bucket, ) +@triton.jit +def _spsv_csr_cw_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + diag_ptr, + b_ptr, + x_ptr, + ready_ptr, + row_counter_ptr, + n_rows, + n_rhs, + stride_b0, + stride_x0, + BLOCK_RHS: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.atomic_add(row_counter_ptr, 1) + while row < n_rows: + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + diag = tl.load(diag_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + + for rhs_base in range(0, n_rhs, BLOCK_RHS): + rhs_offsets = rhs_base + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + nnz_offsets = idx + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) + if LOWER: + dep_mask = nnz_mask & (col < row) + else: + dep_mask = nnz_mask & (col > row) -def _spsv_level_stats(levels, n_rows): - if not levels: - return { - "num_levels": 0, - "max_frontier": 0, - "avg_frontier": 0.0, - "frontier_ratio": 0.0, - } - num_levels = len(levels) - max_frontier = max(int(rows.numel()) for rows in levels) - avg_frontier = float(n_rows) / float(max(num_levels, 1)) - frontier_ratio = float(max_frontier) / float(max(n_rows, 1)) - return { - "num_levels": num_levels, - "max_frontier": max_frontier, - "avg_frontier": avg_frontier, - "frontier_ratio": frontier_ratio, - } + for k in range(BLOCK_NNZ): + if dep_mask[k]: + dep_col = col[k] + while tl.load(ready_ptr + dep_col) == 0: + pass + x_ptrs = x_ptr + dep_col * stride_x0 + rhs_offsets + x_vals = tl.load(x_ptrs, mask=rhs_mask, other=0.0) + acc += a[k] * x_vals + + rhs_ptrs = b_ptr + row * stride_b0 + rhs_offsets + rhs = tl.load(rhs_ptrs, mask=rhs_mask, other=0.0) + x_row = (rhs - acc) / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + out_ptrs = x_ptr + row * stride_x0 + rhs_offsets + tl.store(out_ptrs, x_row, mask=rhs_mask) + tl.debug_barrier() + tl.store(ready_ptr + row, 1) + row = tl.atomic_add(row_counter_ptr, 1) -def _build_spsv_matrix_stats(indptr64, levels, n_rows): - stats = _spsv_level_stats(levels, n_rows) +def _build_spsv_cw_matrix_stats(indptr64, n_rows): if indptr64.numel() <= 1: avg_nnz_per_row = 0.0 max_nnz_per_row = 0 @@ -471,288 +846,564 @@ def _build_spsv_matrix_stats(indptr64, levels, n_rows): row_lengths = indptr64[1:] - indptr64[:-1] avg_nnz_per_row = float(row_lengths.to(torch.float32).mean().item()) max_nnz_per_row = int(row_lengths.max().item()) - stats["avg_nnz_per_row"] = avg_nnz_per_row - stats["max_nnz_per_row"] = max_nnz_per_row - stats["n_rows"] = int(n_rows) - return stats + return { + "num_levels": 0, + "max_frontier": int(n_rows), + "avg_frontier": float(n_rows), + "frontier_ratio": 1.0 if n_rows > 0 else 0.0, + "avg_nnz_per_row": avg_nnz_per_row, + "max_nnz_per_row": max_nnz_per_row, + "n_rows": int(n_rows), + } -def _choose_spsv_block_rhs(n_rhs, matrix_stats, complex_mode=False): - avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) - max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) - frontier_ratio = float(matrix_stats.get("frontier_ratio", 0.0)) - if n_rhs <= 1: - return 1 - if complex_mode: - if avg_nnz > 512 or max_nnz > 4096: - return 4 if n_rhs >= 4 else n_rhs - return 8 if n_rhs >= 8 else n_rhs - if avg_nnz > 2048 or max_nnz > 16384: - return 4 if n_rhs >= 4 else n_rhs - if avg_nnz > 512 or frontier_ratio < 0.02: - return 8 if n_rhs >= 8 else n_rhs - if avg_nnz > 128: - return 16 if n_rhs >= 16 else n_rhs - if n_rhs <= 8: - return n_rhs - if n_rhs <= 16: - return 16 - if n_rhs <= 32: - return 32 - return 64 - - -def _score_nontrans_levels(matrix_stats, n_rhs, complex_mode): - score = 0.0 - score += float(matrix_stats["frontier_ratio"]) * 10.0 - score += min(float(matrix_stats["avg_frontier"]) / 64.0, 6.0) - score += min(float(matrix_stats["avg_nnz_per_row"]) / 256.0, 4.0) - if n_rhs >= 8: - score += 2.5 - if complex_mode: - score += 2.0 - return score - - -def _score_nontrans_cw(matrix_stats, n_rhs, complex_mode): - score = 0.0 - if matrix_stats["num_levels"] > 1024: - score += 2.0 - if matrix_stats["num_levels"] > 4096: - score += 2.0 - if matrix_stats["frontier_ratio"] < 0.03: - score += 4.0 - elif matrix_stats["frontier_ratio"] < 0.06: - score += 2.0 - if matrix_stats["avg_frontier"] < 16.0: - score += 3.0 - elif matrix_stats["avg_frontier"] < 32.0: - score += 1.5 - if matrix_stats["avg_nnz_per_row"] <= 128.0: - score += 1.5 - elif matrix_stats["avg_nnz_per_row"] <= 160.0: - score += 0.5 - if matrix_stats["max_nnz_per_row"] > 4096: - score -= 2.5 - elif matrix_stats["max_nnz_per_row"] > 1536: - score -= 1.5 - if n_rhs >= 8: - score -= 3.0 - elif n_rhs >= 4: - score -= 1.5 - if complex_mode: - score -= 2.0 - return score - - -def _score_transpose_push(matrix_stats, n_rhs, complex_mode): - score = 0.0 - score += float(matrix_stats["frontier_ratio"]) * 12.0 - score += min(float(matrix_stats["avg_frontier"]) / 48.0, 6.0) - score += min(float(matrix_stats["avg_nnz_per_row"]) / 256.0, 4.0) - if n_rhs >= 4: - score += 2.0 - if complex_mode: - score += 2.5 - return score - - -def _score_transpose_cw(matrix_stats, n_rhs, complex_mode): - score = 0.0 - if matrix_stats["num_levels"] > 2048: - score += 2.0 - if matrix_stats["num_levels"] > 8192: - score += 2.0 - if matrix_stats["frontier_ratio"] < 0.02: - score += 4.0 - elif matrix_stats["frontier_ratio"] < 0.05: - score += 2.0 - if matrix_stats["avg_frontier"] < 12.0: - score += 3.0 - elif matrix_stats["avg_frontier"] < 24.0: - score += 1.5 - if matrix_stats["avg_nnz_per_row"] <= 96.0: - score += 1.0 - if matrix_stats["max_nnz_per_row"] > 2048: - score -= 2.5 - if n_rhs >= 4: - score -= 2.5 - elif n_rhs >= 2: - score -= 1.0 - if complex_mode: - score -= 2.5 - return score - - -def _nontrans_cw_eligible(matrix_stats, n_rhs, complex_mode): - if complex_mode: - return False - if n_rhs != 1: - return False - if int(matrix_stats.get("num_levels", 0)) < 1024: - return False - if float(matrix_stats.get("avg_frontier", 0.0)) > 16.0: - return False - if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.05: +def _spsv_nontrans_prefers_nnz_balance(n_rows, matrix_stats, *, lower, unit_diagonal, value_dtype): + if not _supports_spsv_advanced_nontrans_routes("N", lower, unit_diagonal, value_dtype): return False - avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) - max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) - if avg_nnz <= 0.0 or avg_nnz > 160.0: + if int(n_rows) < 256: return False - if max_nnz > 1536: - return False - return True + avg_nnz_per_row = float(matrix_stats.get("avg_nnz_per_row", 0.0)) + max_nnz_per_row = int(matrix_stats.get("max_nnz_per_row", 0)) + return max_nnz_per_row >= 512 or avg_nnz_per_row >= 96.0 -def _transpose_cw_eligible(matrix_stats, n_rhs, complex_mode, trans_mode): - if trans_mode != "T": - return False - if complex_mode: - return False - if n_rhs != 1: - return False - if int(matrix_stats.get("num_levels", 0)) < 512: - return False - if float(matrix_stats.get("avg_frontier", 0.0)) > 16.0: - return False - if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.04: - return False - avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) - max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) - if avg_nnz <= 0.0 or avg_nnz > 160.0: - return False - if max_nnz > 1024: - return False - return True +def _supports_spsv_advanced_nontrans_routes(trans_mode, lower, unit_diagonal, value_dtype): + return ( + trans_mode == "N" + and bool(lower) + and (not bool(unit_diagonal)) + and value_dtype in (torch.float32, torch.float64) + ) -def _select_nontrans_route(matrix_stats, n_rhs, complex_mode): - if not _nontrans_cw_eligible(matrix_stats, n_rhs, complex_mode): - return "csr_levels" - levels_score = _score_nontrans_levels(matrix_stats, n_rhs, complex_mode) - cw_score = _score_nontrans_cw(matrix_stats, n_rhs, complex_mode) - return "csr_cw" if cw_score >= (levels_score + 2.0) else "csr_levels" +@triton.jit +def _spsv_levelschd_analysis_kernel( + indices_ptr, + indptr_ptr, + levels_ptr, + ready_ptr, + indegree_ptr, + n_rows, + BLOCK_ROWS: tl.constexpr, + UNIT_DIAGONAL: tl.constexpr, +): + first_row = tl.program_id(0) * BLOCK_ROWS + local_rows = tl.arange(0, BLOCK_ROWS) + local_levels = tl.zeros((BLOCK_ROWS,), dtype=tl.int32) + for local_row in range(BLOCK_ROWS): + row = first_row + local_row + if row < n_rows: + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + ptr = start + max_level = tl.zeros((), dtype=tl.int32) + degree = tl.zeros((), dtype=tl.int32) + row_done = 0 + while row_done == 0: + if ptr >= end: + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + if col < first_row: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + while dep_ready == 0: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + dep_level = tl.atomic_add(levels_ptr + col, 0) + max_level = tl.maximum(max_level, dep_level) + degree += 1 + ptr += 1 + elif col < row: + local_idx = col - first_row + dep_level = tl.sum( + tl.where(local_rows == local_idx, local_levels, 0), + axis=0, + ) + max_level = tl.maximum(max_level, dep_level) + degree += 1 + ptr += 1 + else: + if (not UNIT_DIAGONAL) and (col == row): + degree += 1 + row_done = 1 + row_level = max_level + 1 + _publish_i32_once(levels_ptr, row, row_level) + tl.store(indegree_ptr + row, degree) + local_levels = tl.where(local_rows == local_row, row_level, local_levels) + _publish_ready_flag_i32(ready_ptr, row) + + +def _build_spsv_level_schedule_metadata_lower_gpu(indices64, indptr64, n_rows, *, unit_diagonal): + n_rows = int(n_rows) + device = indices64.device + base_stats = _build_spsv_cw_matrix_stats(indptr64, n_rows) + empty_meta = { + "row_map32": torch.empty(0, dtype=torch.int32, device=device), + "level_ptr32": torch.zeros(1, dtype=torch.int32, device=device), + "indegree_init32": torch.empty(0, dtype=torch.int32, device=device), + "csr_row_idx32": torch.empty(0, dtype=torch.int32, device=device), + "matrix_stats": { + **base_stats, + "num_levels": 0, + "max_frontier": 0, + "avg_frontier": 0.0, + "frontier_ratio": 0.0, + }, + } + if n_rows == 0: + return empty_meta + + indices32 = indices64.to(torch.int32).contiguous() + levels32 = torch.zeros(n_rows, dtype=torch.int32, device=device) + ready32 = torch.zeros(n_rows, dtype=torch.int32, device=device) + indegree32 = torch.empty(n_rows, dtype=torch.int32, device=device) + _spsv_levelschd_analysis_kernel[(triton.cdiv(n_rows, 8),)]( + indices32, + indptr64, + levels32, + ready32, + indegree32, + n_rows, + BLOCK_ROWS=8, + UNIT_DIAGONAL=bool(unit_diagonal), + num_warps=1, + ) + # Stable GPU sort reproduces the row_map stage after roc-style level analysis. -def _select_transpose_route(matrix_stats, n_rhs, complex_mode, trans_mode): - if not _transpose_cw_eligible(matrix_stats, n_rhs, complex_mode, trans_mode): - return "transpose_push" - push_score = _score_transpose_push(matrix_stats, n_rhs, complex_mode) - cw_score = _score_transpose_cw(matrix_stats, n_rhs, complex_mode) - return "transpose_cw" if cw_score >= (push_score + 1.0) else "transpose_push" + try: + row_map64 = torch.argsort(levels32.to(torch.int64), stable=True) + except TypeError: + row_map64 = torch.argsort(levels32.to(torch.int64)) + row_map32 = row_map64.to(torch.int32).contiguous() + sorted_levels32 = levels32.index_select(0, row_map64) + if sorted_levels32.numel() > 0: + _, frontier_counts64 = torch.unique_consecutive(sorted_levels32, return_counts=True) + level_ptr32 = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(frontier_counts64.to(torch.int32), dim=0), + ] + ) + num_levels = int(frontier_counts64.numel()) + max_frontier = int(frontier_counts64.max().item()) + avg_frontier = float(frontier_counts64.to(torch.float32).mean().item()) + else: + level_ptr32 = torch.zeros(1, dtype=torch.int32, device=device) + num_levels = 0 + max_frontier = 0 + avg_frontier = 0.0 + + row_lengths64 = indptr64[1:] - indptr64[:-1] + csr_row_idx32 = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int32), + row_lengths64.to(torch.int64), + ).contiguous() + matrix_stats = { + **base_stats, + "num_levels": int(num_levels), + "max_frontier": int(max_frontier), + "avg_frontier": float(avg_frontier), + "frontier_ratio": (float(max_frontier) / float(n_rows)) if n_rows > 0 else 0.0, + } + return { + "row_map32": row_map32, + "level_ptr32": level_ptr32, + "indegree_init32": indegree32, + "csr_row_idx32": csr_row_idx32, + "matrix_stats": matrix_stats, + } -def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode, unit_diagonal): - if trans_mode == "N": - levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) - matrix_stats = _build_spsv_matrix_stats(indptr64, levels, n_rows) - default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr64) - levels_plan = { - "solve_kind": "csr_levels", - "kernel_data": data, - "kernel_indices32": indices64.to(torch.int32), - "kernel_indptr64": indptr64, - "lower_eff": lower, - "launch_groups": ( - _build_spsv_frontiers(indptr64, indices64, levels, lower=lower) - if SPSV_ENABLE_LEVEL_FRONTIERS - else levels - ), - "default_block_nnz": default_block_nnz, - "default_max_segments": default_max_segments, - "transpose_conjugate": False, - "cw_worker_count": None, - "matrix_stats": matrix_stats, - "route_name": "csr_levels", - "alt_plan": None, +@triton.jit +def _spsv_nnz_balance_preprocess_kernel( + indices_ptr, + indptr_ptr, + indegree_ptr, + row_idx_ptr, + n_rows, + WARP_SIZE: tl.constexpr, + UNIT_DIAGONAL: tl.constexpr, +): + row = tl.program_id(0) + if row >= n_rows: + return + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + lane = tl.arange(0, WARP_SIZE) + ptr = start + lane + degree = tl.zeros((WARP_SIZE,), dtype=tl.int32) + active = ptr < end + while tl.sum(active.to(tl.int32), axis=0) > 0: + cols = tl.load(indices_ptr + ptr, mask=active, other=row + 1) + if UNIT_DIAGONAL: + valid = active & (cols < row) + else: + valid = active & (cols <= row) + tl.store(row_idx_ptr + ptr, row, mask=valid) + degree += valid.to(tl.int32) + ptr = ptr + WARP_SIZE + active = valid & (ptr < end) + tl.store(indegree_ptr + row, tl.sum(degree, axis=0)) + + +def _build_spsv_nnz_balance_metadata(indices64, indptr64, n_rows, *, lower, unit_diagonal): + n_rows = int(n_rows) + device = indices64.device + base_stats = _build_spsv_cw_matrix_stats(indptr64, n_rows) + empty_meta = { + "indegree_init32": torch.empty(0, dtype=torch.int32, device=device), + "csr_row_idx32": torch.empty(0, dtype=torch.int32, device=device), + "matrix_stats": base_stats, + } + if n_rows == 0 or not lower: + return empty_meta + if indices64.is_cuda: + indices32 = indices64.to(torch.int32).contiguous() + indegree32 = torch.zeros(n_rows, dtype=torch.int32, device=device) + row_idx32 = torch.zeros(indices32.numel(), dtype=torch.int32, device=device) + _spsv_nnz_balance_preprocess_kernel[(n_rows,)]( + indices32, + indptr64, + indegree32, + row_idx32, + n_rows, + WARP_SIZE=32, + UNIT_DIAGONAL=bool(unit_diagonal), + num_warps=1, + ) + return { + "indegree_init32": indegree32, + "csr_row_idx32": row_idx32, + "matrix_stats": base_stats, } - _attach_spsv_complex_plan_views(levels_plan) - if not SPSV_ENABLE_CSR_CW: - return levels_plan - data_sorted, indices_sorted64, indptr_sorted64 = _sort_csr_rows( + + indptr_cpu = indptr64.to("cpu", non_blocking=False).tolist() + indices_cpu = indices64.to("cpu", non_blocking=False).tolist() + indegree_init = [0] * n_rows + row_idx = [0] * int(indices64.numel()) + for row in range(n_rows): + start = int(indptr_cpu[row]) + end = int(indptr_cpu[row + 1]) + degree = 0 + for ptr in range(start, end): + col = int(indices_cpu[ptr]) + if col < row: + row_idx[ptr] = row + degree += 1 + continue + if (not unit_diagonal) and col == row: + row_idx[ptr] = row + degree += 1 + break + indegree_init[row] = degree + return { + "indegree_init32": torch.tensor(indegree_init, dtype=torch.int32, device=device), + "csr_row_idx32": torch.tensor(row_idx, dtype=torch.int32, device=device), + "matrix_stats": base_stats, + } + + +def _build_spsv_level_schedule_metadata(indices64, indptr64, n_rows, *, lower, unit_diagonal): + n_rows = int(n_rows) + device = indices64.device + base_stats = _build_spsv_cw_matrix_stats(indptr64, n_rows) + empty_meta = { + "row_map32": torch.empty(0, dtype=torch.int32, device=device), + "level_ptr32": torch.zeros(1, dtype=torch.int32, device=device), + "indegree_init32": torch.empty(0, dtype=torch.int32, device=device), + "csr_row_idx32": torch.empty(0, dtype=torch.int32, device=device), + "matrix_stats": { + **base_stats, + "num_levels": 0, + "max_frontier": 0, + "avg_frontier": 0.0, + "frontier_ratio": 0.0, + }, + } + if n_rows == 0: + return empty_meta + + if not lower: + return empty_meta + + if indices64.is_cuda: + return _build_spsv_level_schedule_metadata_lower_gpu( + indices64, + indptr64, + n_rows, + unit_diagonal=unit_diagonal, + ) + + indptr_cpu = indptr64.to("cpu", non_blocking=False).tolist() + indices_cpu = indices64.to("cpu", non_blocking=False).tolist() + levels = [0] * n_rows + indegree_init = [0] * n_rows + level_buckets = {} + + for row in range(n_rows): + start = int(indptr_cpu[row]) + end = int(indptr_cpu[row + 1]) + deps = [] + degree = 0 + for ptr in range(start, end): + col = int(indices_cpu[ptr]) + if unit_diagonal: + if col < row: + deps.append(col) + degree += 1 + else: + break + else: + if col < row: + deps.append(col) + degree += 1 + continue + if col == row: + degree += 1 + break + indegree_init[row] = degree + row_level = 1 + if deps: + row_level = max(levels[col] for col in deps) + 1 + levels[row] = row_level + level_buckets.setdefault(row_level, []).append(row) + + row_map = [] + level_ptr = [0] + frontier_sizes = [] + for level_id in sorted(level_buckets): + rows = level_buckets[level_id] + frontier_sizes.append(len(rows)) + row_map.extend(rows) + level_ptr.append(len(row_map)) + + row_lengths64 = indptr64[1:] - indptr64[:-1] + csr_row_idx32 = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int32), + row_lengths64.to(torch.int64), + ).contiguous() + num_levels = len(frontier_sizes) + max_frontier = max(frontier_sizes) if frontier_sizes else 0 + avg_frontier = (float(sum(frontier_sizes)) / float(num_levels)) if frontier_sizes else 0.0 + matrix_stats = { + **base_stats, + "num_levels": int(num_levels), + "max_frontier": int(max_frontier), + "avg_frontier": float(avg_frontier), + "frontier_ratio": (float(max_frontier) / float(n_rows)) if n_rows > 0 else 0.0, + } + return { + "row_map32": torch.tensor(row_map, dtype=torch.int32, device=device), + "level_ptr32": torch.tensor(level_ptr, dtype=torch.int32, device=device), + "indegree_init32": torch.tensor(indegree_init, dtype=torch.int32, device=device), + "csr_row_idx32": csr_row_idx32, + "matrix_stats": matrix_stats, + } + + +def _choose_spsv_nontrans_auto_route(n_rows, matrix_stats, *, lower, unit_diagonal, value_dtype): + if not _supports_spsv_advanced_nontrans_routes("N", lower, unit_diagonal, value_dtype): + return "csr_cw" + if int(n_rows) < 256: + return "csr_cw" + max_frontier = int(matrix_stats.get("max_frontier", 0)) + frontier_ratio = float(matrix_stats.get("frontier_ratio", 0.0)) + num_levels = int(matrix_stats.get("num_levels", 0)) + if _spsv_nontrans_prefers_nnz_balance( + n_rows, + matrix_stats, + lower=lower, + unit_diagonal=unit_diagonal, + value_dtype=value_dtype, + ): + return "csr_nnz_balance" + if max_frontier >= 32 and frontier_ratio >= 0.05 and num_levels > 0: + return "csr_cw_levelschd" + return "csr_cw" + + +def _prepare_spsv_csr_system( + data, + indices64, + indptr64, + n_rows, + n_cols, + lower, + trans_mode, + unit_diagonal, + requested_solve_kind=None, + storage_view="csr_as_csc", +): + if trans_mode == "N": + data, indices64, indptr64 = _sort_csr_rows( data, indices64, indptr64, n_rows, n_cols, lower=lower ) - default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr_sorted64) + requested_route = _normalize_requested_spsv_route(requested_solve_kind, trans_mode) + base_stats = _build_spsv_cw_matrix_stats(indptr64, n_rows) + default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr64) + if lower: + nontrans_variant = "csr_u_lo_cw" if unit_diagonal else "csr_n_lo_cw" + else: + nontrans_variant = "csr_u_up_cw" if unit_diagonal else "csr_n_up_cw" + level_meta = None + nnz_meta = None + if requested_route == "csr_cw": + default_solve_kind = "csr_cw" + matrix_stats = base_stats + supported_solve_kinds = ("csr_cw",) + elif requested_route == "csr_roc": + if not bool(lower): + raise ValueError("solve_kind='csr_roc' currently supports lower-triangular CSR/COO only") + if bool(unit_diagonal): + raise ValueError("solve_kind='csr_roc' currently supports non-unit diagonal only") + level_meta = _build_spsv_level_schedule_metadata( + indices64, + indptr64, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + ) + matrix_stats = level_meta["matrix_stats"] + default_solve_kind = "csr_roc" + supported_solve_kinds = ("csr_roc",) + elif requested_route == "csr_cw_levelschd": + level_meta = _build_spsv_level_schedule_metadata( + indices64, + indptr64, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + ) + matrix_stats = level_meta["matrix_stats"] + default_solve_kind = "csr_cw_levelschd" + supported_solve_kinds = ("csr_cw_levelschd",) + elif requested_route == "csr_nnz_balance": + nnz_meta = _build_spsv_nnz_balance_metadata( + indices64, + indptr64, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + ) + matrix_stats = nnz_meta["matrix_stats"] + default_solve_kind = "csr_nnz_balance" + supported_solve_kinds = ("csr_nnz_balance",) + else: + if _spsv_nontrans_prefers_nnz_balance( + n_rows, + base_stats, + lower=lower, + unit_diagonal=unit_diagonal, + value_dtype=data.dtype, + ): + nnz_meta = _build_spsv_nnz_balance_metadata( + indices64, + indptr64, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + ) + matrix_stats = nnz_meta["matrix_stats"] + default_solve_kind = "csr_nnz_balance" + supported_solve_kinds = ("csr_nnz_balance",) + else: + level_meta = _build_spsv_level_schedule_metadata( + indices64, + indptr64, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + ) + matrix_stats = level_meta["matrix_stats"] + default_solve_kind = _choose_spsv_nontrans_auto_route( + n_rows, + matrix_stats, + lower=lower, + unit_diagonal=unit_diagonal, + value_dtype=data.dtype, + ) + if default_solve_kind == "csr_cw_levelschd": + supported_solve_kinds = ("csr_cw_levelschd",) + else: + supported_solve_kinds = ("csr_cw",) + route_name = nontrans_variant + if default_solve_kind == "csr_roc": + route_name = "csr_n_lo_roc" + elif default_solve_kind == "csr_cw_levelschd": + route_name = "csr_n_lo_cw_levelschd" + elif default_solve_kind == "csr_nnz_balance": + route_name = "csr_n_lo_nnz_balance" cw_plan = { - "solve_kind": "csr_cw", - "kernel_data": data_sorted, - "kernel_indices32": indices_sorted64.to(torch.int32), - "kernel_indptr64": indptr_sorted64, + "solve_kind": default_solve_kind, + "default_solve_kind": default_solve_kind, + "supported_solve_kinds": tuple(supported_solve_kinds), + "nontrans_variant": nontrans_variant, + "kernel_data": data, + "kernel_indices32": indices64.to(torch.int32), + "kernel_indptr64": indptr64, "lower_eff": lower, - "launch_groups": None, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, - "transpose_conjugate": False, - "cw_diag": _prepare_spsv_nontrans_cw_metadata( - data_sorted, indices_sorted64, indptr_sorted64, n_rows, unit_diagonal=unit_diagonal - ), + "storage_view": "csr", "cw_worker_count": _cw_worker_count( n_rows, matrix_stats["max_frontier"], matrix_stats["avg_nnz_per_row"], 1 ), "matrix_stats": matrix_stats, - "route_name": "csr_cw", - "alt_plan": None, + "route_name": route_name, + "level_row_map32": ( + level_meta["row_map32"] + if level_meta is not None + else torch.empty(0, dtype=torch.int32, device=data.device) + ), + "level_ptr32": ( + level_meta["level_ptr32"] + if level_meta is not None + else torch.zeros(1, dtype=torch.int32, device=data.device) + ), + "nnz_balance_indegree32": ( + nnz_meta["indegree_init32"] + if nnz_meta is not None + else torch.empty(0, dtype=torch.int32, device=data.device) + ), + "nnz_balance_row_idx32": ( + nnz_meta["csr_row_idx32"] + if nnz_meta is not None + else torch.empty(0, dtype=torch.int32, device=data.device) + ), } _attach_spsv_complex_plan_views(cw_plan) - levels_plan["alt_plan"] = cw_plan - return levels_plan - - levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) - matrix_stats = _build_spsv_matrix_stats(indptr64, levels, n_rows) - default_block_nnz, default_max_segments = _choose_transpose_family_launch_config( - indptr64 - ) - push_plan = { - "solve_kind": "transpose_push", - "kernel_data": data, - "kernel_indices32": indices64.to(torch.int32), - "kernel_indptr64": indptr64, - "lower_eff": lower, - "launch_groups": ( - _build_spsv_reverse_frontiers(indptr64, indices64, levels, lower=lower) - if SPSV_ENABLE_REVERSE_FRONTIERS - else list(reversed(levels)) - ), - "default_block_nnz": default_block_nnz, - "default_max_segments": default_max_segments, - "transpose_conjugate": trans_mode == "C", - "cw_worker_count": None, - "matrix_stats": matrix_stats, - "route_name": "transpose_push", - "alt_plan": None, - } - _attach_spsv_complex_plan_views(push_plan) - if not SPSV_ENABLE_TRANSPOSE_CW: - return push_plan - - diag, indegree_init = _prepare_spsv_transpose_cw_metadata( - data, indices64, indptr64, n_rows, lower, unit_diagonal=unit_diagonal - ) + return cw_plan + + lower_eff = not lower + storage_view = _normalize_spsv_storage_view(storage_view) + if storage_view != "csr_as_csc": + raise ValueError("TRANS/CONJ SpSV only supports storage_view='csr_as_csc'") + data_eff = data + indices_eff64 = indices64 + indptr_eff64 = indptr64 + matrix_stats = _build_spsv_cw_matrix_stats(indptr_eff64, n_rows) default_block_nnz, default_max_segments = _choose_transpose_family_launch_config( - indptr64 + indptr_eff64 ) cw_plan = { "solve_kind": "transpose_cw", - "kernel_data": data, - "kernel_indices32": indices64.to(torch.int32), - "kernel_indptr64": indptr64, - "lower_eff": lower, - "launch_groups": None, + "default_solve_kind": "transpose_cw", + "supported_solve_kinds": ("transpose_cw",), + "kernel_data": data_eff, + "kernel_indices32": indices_eff64.to(torch.int32), + "kernel_indptr64": indptr_eff64, + "lower_eff": lower_eff, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, - "transpose_conjugate": trans_mode == "C", - "transpose_diag": diag, - "transpose_indegree_init": indegree_init, "cw_worker_count": _cw_worker_count( n_rows, matrix_stats["max_frontier"], matrix_stats["avg_nnz_per_row"], 1 ), "matrix_stats": matrix_stats, + "storage_view": storage_view, "route_name": "transpose_cw", - "alt_plan": None, } _attach_spsv_complex_plan_views(cw_plan) - push_plan["alt_plan"] = cw_plan - return push_plan + return cw_plan def _resolve_spsv_csr_runtime( @@ -764,6 +1415,8 @@ def _resolve_spsv_csr_runtime( lower, transpose, unit_diagonal=False, + requested_solve_kind=None, + storage_view="csr_as_csc", ): input_data = data input_indices = indices @@ -781,7 +1434,15 @@ def _resolve_spsv_csr_runtime( _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") preprocess_key = _csr_preprocess_cache_key( - input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode, unit_diagonal + input_data, + input_indices, + input_indptr, + (n_rows, n_cols), + lower, + trans_mode, + unit_diagonal, + _normalize_requested_spsv_route(requested_solve_kind, trans_mode), + _normalize_spsv_storage_view(storage_view), ) cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) if cached is None: @@ -794,6 +1455,8 @@ def _resolve_spsv_csr_runtime( lower, trans_mode, unit_diagonal, + requested_solve_kind=requested_solve_kind, + storage_view=storage_view, ) _spsv_cache_put( _SPSV_CSR_PREPROCESS_CACHE, @@ -812,172 +1475,100 @@ def _resolve_spsv_csr_runtime( ) -def _select_spsv_runtime_plan(solve_plan, rhs_cols, compute_dtype, trans_mode): - matrix_stats = solve_plan.get("matrix_stats", {}) - route_name = solve_plan.get("route_name", solve_plan["solve_kind"]) - alt_plan = solve_plan.get("alt_plan") - complex_mode = compute_dtype in (torch.complex64, torch.complex128) - if trans_mode == "N": - desired = _select_nontrans_route(matrix_stats, rhs_cols, complex_mode) - else: - desired = _select_transpose_route( - matrix_stats, rhs_cols, complex_mode, trans_mode +def _select_spsv_runtime_plan(solve_plan, trans_mode, requested_solve_kind=None): + requested_route = _normalize_requested_spsv_route(requested_solve_kind, trans_mode) + routed = _clone_spsv_plan(solve_plan) + if requested_route is None: + requested_route = str( + solve_plan.get("default_solve_kind", solve_plan.get("solve_kind", "csr_cw")) + ) + supported = tuple(solve_plan.get("supported_solve_kinds", (solve_plan.get("solve_kind"),))) + if requested_route not in supported: + raise ValueError( + f"solve_kind={requested_route!r} is not available for this SpSV problem; " + f"supported routes: {', '.join(str(route) for route in supported if route)}" ) - if desired == route_name or alt_plan is None: - return solve_plan - if alt_plan.get("route_name", alt_plan["solve_kind"]) == desired: - return alt_plan - return solve_plan + routed["solve_kind"] = requested_route + if requested_route == "csr_cw": + routed["route_name"] = str(solve_plan.get("nontrans_variant", requested_route)) + elif requested_route == "csr_roc": + routed["route_name"] = "csr_n_lo_roc" + elif requested_route == "csr_cw_levelschd": + routed["route_name"] = "csr_n_lo_cw_levelschd" + elif requested_route == "csr_nnz_balance": + routed["route_name"] = "csr_n_lo_nnz_balance" + else: + routed["route_name"] = requested_route + return routed @triton.jit -def _spsv_csr_level_kernel( - data_ptr, - indices_ptr, - indptr_ptr, - b_ptr, - x_ptr, - rows_ptr, - n_level_rows, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, - DIAG_EPS: tl.constexpr, -): - pid = tl.program_id(0) - if pid >= n_level_rows: - return - row = tl.load(rows_ptr + pid) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) - acc = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 - diag = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 - if UNIT_DIAG: - diag = diag + 1.0 +def _publish_ready_flag_i32(flag_ptr, idx): + """Publish a ready flag through an atomic write-like operation.""" - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - a = tl.load(data_ptr + offsets, mask=mask, other=0.0) - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - x_vals = tl.load(x_ptr + col, mask=mask, other=0.0) + tl.atomic_add(flag_ptr + idx, 1) - if LOWER: - solved = col < row - else: - solved = col > row - is_diag = col == row - acc = acc + tl.sum(tl.where(mask & solved, a * x_vals, 0.0)) - if not UNIT_DIAG: - diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0)) +@triton.jit +def _load_ready_flag_i32(flag_ptr, idx): + """Mirror the original volatile/atomic polling pattern more closely.""" - rhs = tl.load(b_ptr + row) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - x_row = (rhs - acc) / diag_safe - # Prevent NaN propagation in ill-conditioned rows. - x_row = tl.where(x_row == x_row, x_row, 0.0) - tl.store(x_ptr + row, x_row) + return tl.atomic_add(flag_ptr + idx, 0) @triton.jit -def _spsv_csr_level_kernel_complex( - data_ri_ptr, - indices_ptr, - indptr_ptr, - b_ri_ptr, - x_ri_ptr, - rows_ptr, - n_level_rows, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, - USE_FP64_ACC: tl.constexpr, - DIAG_EPS: tl.constexpr, -): - pid = tl.program_id(0) - if pid >= n_level_rows: - return - row = tl.load(rows_ptr + pid) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) +def _publish_i32_once(slot_ptr, idx, value): + """Publish a single int32 payload via an atomic write-like update.""" - if USE_FP64_ACC: - acc_re = tl.zeros((1,), dtype=tl.float64) - acc_im = tl.zeros((1,), dtype=tl.float64) - diag_re = tl.zeros((1,), dtype=tl.float64) - diag_im = tl.zeros((1,), dtype=tl.float64) - else: - acc_re = tl.zeros((1,), dtype=tl.float32) - acc_im = tl.zeros((1,), dtype=tl.float32) - diag_re = tl.zeros((1,), dtype=tl.float32) - diag_im = tl.zeros((1,), dtype=tl.float32) + tl.atomic_add(slot_ptr + idx, value) - if UNIT_DIAG: - diag_re = diag_re + 1.0 - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end +@triton.jit +def _load_scalar_fp32(ptr, idx): + return tl.atomic_add(ptr + idx, 0.0) - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) - a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) - x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) - x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) - if USE_FP64_ACC: - a_re = a_re.to(tl.float64) - a_im = a_im.to(tl.float64) - x_re = x_re.to(tl.float64) - x_im = x_im.to(tl.float64) - else: - a_re = a_re.to(tl.float32) - a_im = a_im.to(tl.float32) - x_re = x_re.to(tl.float32) - x_im = x_im.to(tl.float32) +@triton.jit +def _load_scalar_fp64(ptr, idx): + return tl.atomic_add(ptr + idx, 0.0) - if LOWER: - solved = col < row - else: - solved = col > row - is_diag = col == row - prod_re = a_re * x_re - a_im * x_im - prod_im = a_re * x_im + a_im * x_re - acc_re = acc_re + tl.sum(tl.where(mask & solved, prod_re, 0.0)) - acc_im = acc_im + tl.sum(tl.where(mask & solved, prod_im, 0.0)) +@triton.jit +def _complex_atomic_add_interleaved(ptr_ri, idx, delta_re, delta_im, mask): + """Complex atomicAdd equivalent for interleaved real/imag buffers.""" - if not UNIT_DIAG: - diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) - diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + tl.atomic_add(ptr_ri + idx * 2, delta_re, mask=mask) + tl.atomic_add(ptr_ri + idx * 2 + 1, delta_im, mask=mask) - rhs_re = tl.load(b_ri_ptr + row * 2) - rhs_im = tl.load(b_ri_ptr + row * 2 + 1) - if USE_FP64_ACC: - rhs_re = rhs_re.to(tl.float64) - rhs_im = rhs_im.to(tl.float64) - else: - rhs_re = rhs_re.to(tl.float32) - rhs_im = rhs_im.to(tl.float32) - num_re = rhs_re - acc_re - num_im = rhs_im - acc_im - den = diag_re * diag_re + diag_im * diag_im - den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) +@triton.jit +def _propagate_real(residual_ptr, idx, delta, mask): + """Publish a real contribution into shared residual state.""" + + tl.atomic_add(residual_ptr + idx, delta, mask=mask) + + +@triton.jit +def _propagate_then_release_real(residual_ptr, indegree_ptr, idx, delta, mask): + """Approximate 'write contribution then decrement dependency count'.""" + + _propagate_real(residual_ptr, idx, delta, mask) + tl.atomic_add(indegree_ptr + idx, -1, mask=mask) + + +@triton.jit +def _propagate_complex(residual_ri_ptr, idx, delta_re, delta_im, mask): + """Publish a complex contribution into shared residual state.""" - x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe - x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe - x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) - x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + _complex_atomic_add_interleaved(residual_ri_ptr, idx, delta_re, delta_im, mask) - offs1 = tl.arange(0, 1) - tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) - tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + +@triton.jit +def _propagate_then_release_complex(residual_ri_ptr, indegree_ptr, idx, delta_re, delta_im, mask): + """Complex propagation + dependency release for transpose-style solve.""" + + _propagate_complex(residual_ri_ptr, idx, delta_re, delta_im, mask) + tl.atomic_add(indegree_ptr + idx, -1, mask=mask) @triton.jit @@ -985,62 +1576,88 @@ def _spsv_csr_cw_kernel( data_ptr, indices_ptr, indptr_ptr, - diag_ptr, b_ptr, x_ptr, ready_ptr, row_counter_ptr, n_rows, - n_rhs, - stride_b0, - stride_x0, - BLOCK_RHS: tl.constexpr, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, + REVERSE_ORDER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, DIAG_EPS: tl.constexpr, ): - row = tl.atomic_add(row_counter_ptr, 1) - while row < n_rows: + logical_row = tl.atomic_add(row_counter_ptr, 1) + while logical_row < n_rows: + row = tl.where(REVERSE_ORDER, n_rows - 1 - logical_row, logical_row) start = tl.load(indptr_ptr + row) end = tl.load(indptr_ptr + row + 1) - diag = tl.load(diag_ptr + row) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - - for rhs_base in range(0, n_rhs, BLOCK_RHS): - rhs_offsets = rhs_base + tl.arange(0, BLOCK_RHS) - rhs_mask = rhs_offsets < n_rhs - acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - nnz_offsets = idx + tl.arange(0, BLOCK_NNZ) - nnz_mask = nnz_offsets < end - a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) - col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) - if LOWER: - dep_mask = nnz_mask & (col < row) + ptr = start + if USE_FP64_ACC: + rhs = tl.load(b_ptr + row).to(tl.float64) + tmp_sum = tl.zeros((), dtype=tl.float64) + else: + rhs = tl.load(b_ptr + row).to(tl.float32) + tmp_sum = tl.zeros((), dtype=tl.float32) + row_done = 0 + while row_done == 0: + if UNIT_DIAG: + if ptr >= end: + x_row = rhs - tmp_sum + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + row_done = 1 else: - dep_mask = nnz_mask & (col > row) - - for k in range(BLOCK_NNZ): - if dep_mask[k]: - dep_col = col[k] - while tl.load(ready_ptr + dep_col) == 0: - pass - x_ptrs = x_ptr + dep_col * stride_x0 + rhs_offsets - x_vals = tl.load(x_ptrs, mask=rhs_mask, other=0.0) - acc += a[k] * x_vals - - rhs_ptrs = b_ptr + row * stride_b0 + rhs_offsets - rhs = tl.load(rhs_ptrs, mask=rhs_mask, other=0.0) - x_row = (rhs - acc) / diag_safe - x_row = tl.where(x_row == x_row, x_row, 0.0) - out_ptrs = x_ptr + row * stride_x0 + rhs_offsets - tl.store(out_ptrs, x_row, mask=rhs_mask) - - tl.debug_barrier() - tl.store(ready_ptr + row, 1) - row = tl.atomic_add(row_counter_ptr, 1) + col = tl.load(indices_ptr + ptr) + stop_at_diag = (col >= row) if LOWER else (col <= row) + if stop_at_diag: + x_row = rhs - tmp_sum + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + row_done = 1 + else: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + while dep_ready != 1: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + if USE_FP64_ACC: + a = tl.load(data_ptr + ptr).to(tl.float64) + y_dep = tl.load(x_ptr + col).to(tl.float64) + else: + a = tl.load(data_ptr + ptr).to(tl.float32) + y_dep = tl.load(x_ptr + col).to(tl.float32) + tmp_sum += a * y_dep + ptr += 1 + else: + if ptr >= end: + x_row = rhs * 0 + tl.store(x_ptr + row, x_row) + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + if col == row: + if USE_FP64_ACC: + diag = tl.load(data_ptr + ptr).to(tl.float64) + else: + diag = tl.load(data_ptr + ptr).to(tl.float32) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = (rhs - tmp_sum) / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + row_done = 1 + else: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + while dep_ready != 1: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + if USE_FP64_ACC: + a = tl.load(data_ptr + ptr).to(tl.float64) + y_dep = tl.load(x_ptr + col).to(tl.float64) + else: + a = tl.load(data_ptr + ptr).to(tl.float32) + y_dep = tl.load(x_ptr + col).to(tl.float32) + tmp_sum += a * y_dep + ptr += 1 + _publish_ready_flag_i32(ready_ptr, row) + logical_row = tl.atomic_add(row_counter_ptr, 1) @triton.jit @@ -1048,249 +1665,129 @@ def _spsv_csr_cw_kernel_complex( data_ri_ptr, indices_ptr, indptr_ptr, - diag_ri_ptr, b_ri_ptr, x_ri_ptr, ready_ptr, row_counter_ptr, n_rows, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, + REVERSE_ORDER: tl.constexpr, UNIT_DIAG: tl.constexpr, USE_FP64_ACC: tl.constexpr, DIAG_EPS: tl.constexpr, ): - row = tl.atomic_add(row_counter_ptr, 1) - while row < n_rows: + logical_row = tl.atomic_add(row_counter_ptr, 1) + lane2 = tl.arange(0, 2) + while logical_row < n_rows: + row = tl.where(REVERSE_ORDER, n_rows - 1 - logical_row, logical_row) start = tl.load(indptr_ptr + row) end = tl.load(indptr_ptr + row + 1) - if UNIT_DIAG: - diag_re = 1.0 - diag_im = 0.0 - else: - diag_re = tl.load(diag_ri_ptr + row * 2) - diag_im = tl.load(diag_ri_ptr + row * 2 + 1) + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) if USE_FP64_ACC: - diag_re = diag_re.to(tl.float64) - diag_im = diag_im.to(tl.float64) - acc_re = tl.zeros((1,), dtype=tl.float64) - acc_im = tl.zeros((1,), dtype=tl.float64) + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + tmp_re = tl.zeros((), dtype=tl.float64) + tmp_im = tl.zeros((), dtype=tl.float64) else: - diag_re = diag_re.to(tl.float32) - diag_im = diag_im.to(tl.float32) - acc_re = tl.zeros((1,), dtype=tl.float32) - acc_im = tl.zeros((1,), dtype=tl.float32) - - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) - a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) - if USE_FP64_ACC: - a_re = a_re.to(tl.float64) - a_im = a_im.to(tl.float64) - else: - a_re = a_re.to(tl.float32) - a_im = a_im.to(tl.float32) - if LOWER: - dep_mask = mask & (col < row) - else: - dep_mask = mask & (col > row) - - for k in range(BLOCK_NNZ): - if dep_mask[k]: - dep_col = col[k] - while tl.load(ready_ptr + dep_col) == 0: - pass - x_re = tl.load(x_ri_ptr + dep_col * 2) - x_im = tl.load(x_ri_ptr + dep_col * 2 + 1) - if USE_FP64_ACC: - x_re = x_re.to(tl.float64) - x_im = x_im.to(tl.float64) - else: - x_re = x_re.to(tl.float32) - x_im = x_im.to(tl.float32) - acc_re += a_re[k] * x_re - a_im[k] * x_im - acc_im += a_re[k] * x_im + a_im[k] * x_re - - rhs_re = tl.load(b_ri_ptr + row * 2) - rhs_im = tl.load(b_ri_ptr + row * 2 + 1) - if USE_FP64_ACC: - rhs_re = rhs_re.to(tl.float64) - rhs_im = rhs_im.to(tl.float64) - else: - rhs_re = rhs_re.to(tl.float32) - rhs_im = rhs_im.to(tl.float32) - - num_re = rhs_re - acc_re - num_im = rhs_im - acc_im - den = diag_re * diag_re + diag_im * diag_im - den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) - x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe - x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe - x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) - x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) - - tl.store(x_ri_ptr + row * 2, x_re_out) - tl.store(x_ri_ptr + row * 2 + 1, x_im_out) - tl.debug_barrier() - tl.store(ready_ptr + row, 1) - row = tl.atomic_add(row_counter_ptr, 1) - - -@triton.jit -def _spsv_csr_transpose_push_kernel( - data_ptr, - indices_ptr, - indptr_ptr, - residual_ptr, - x_ptr, - rows_ptr, - n_level_rows, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, - DIAG_EPS: tl.constexpr, -): - pid = tl.program_id(0) - if pid >= n_level_rows: - return - row = tl.load(rows_ptr + pid) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) - - diag = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 - if UNIT_DIAG: - diag = diag + 1.0 - else: - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - a = tl.load(data_ptr + offsets, mask=mask, other=0.0) - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - is_diag = col == row - diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0)) - - rhs = tl.load(residual_ptr + row) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - x_row = rhs / diag_safe - x_row = tl.where(x_row == x_row, x_row, 0.0) - tl.store(x_ptr + row, x_row) - - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - a = tl.load(data_ptr + offsets, mask=mask, other=0.0) - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - if LOWER: - target_mask = mask & (col < row) - else: - target_mask = mask & (col > row) - tl.atomic_add(residual_ptr + col, -a * x_row, mask=target_mask) - - -@triton.jit -def _spsv_csr_transpose_push_kernel_complex( - data_ri_ptr, - indices_ptr, - indptr_ptr, - residual_ri_ptr, - x_ri_ptr, - rows_ptr, - n_level_rows, - BLOCK_NNZ: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, - CONJ_TRANS: tl.constexpr, - USE_FP64_ACC: tl.constexpr, - DIAG_EPS: tl.constexpr, -): - pid = tl.program_id(0) - if pid >= n_level_rows: - return - row = tl.load(rows_ptr + pid) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) - - if USE_FP64_ACC: - diag_re = tl.zeros((1,), dtype=tl.float64) - diag_im = tl.zeros((1,), dtype=tl.float64) - else: - diag_re = tl.zeros((1,), dtype=tl.float32) - diag_im = tl.zeros((1,), dtype=tl.float32) - - if UNIT_DIAG: - diag_re = diag_re + 1.0 - else: - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) - a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) - if CONJ_TRANS: - a_im = -a_im - if USE_FP64_ACC: - a_re = a_re.to(tl.float64) - a_im = a_im.to(tl.float64) - else: - a_re = a_re.to(tl.float32) - a_im = a_im.to(tl.float32) - is_diag = col == row - diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) - diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) - - rhs_re = tl.load(residual_ri_ptr + row * 2) - rhs_im = tl.load(residual_ri_ptr + row * 2 + 1) - if USE_FP64_ACC: - rhs_re = rhs_re.to(tl.float64) - rhs_im = rhs_im.to(tl.float64) - else: - rhs_re = rhs_re.to(tl.float32) - rhs_im = rhs_im.to(tl.float32) - - den = diag_re * diag_re + diag_im * diag_im - den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) - x_re_out = (rhs_re * diag_re + rhs_im * diag_im) / den_safe - x_im_out = (rhs_im * diag_re - rhs_re * diag_im) / den_safe - x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) - x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) - - offs1 = tl.arange(0, 1) - tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) - tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) - - for seg in range(MAX_SEGMENTS): - idx = start + seg * BLOCK_NNZ - offsets = idx + tl.arange(0, BLOCK_NNZ) - mask = offsets < end - col = tl.load(indices_ptr + offsets, mask=mask, other=0) - a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) - a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) - if CONJ_TRANS: - a_im = -a_im - if USE_FP64_ACC: - a_re = a_re.to(tl.float64) - a_im = a_im.to(tl.float64) - else: - a_re = a_re.to(tl.float32) - a_im = a_im.to(tl.float32) - if LOWER: - target_mask = mask & (col < row) - else: - target_mask = mask & (col > row) - prod_re = a_re * x_re_out - a_im * x_im_out - prod_im = a_re * x_im_out + a_im * x_re_out - tl.atomic_add(residual_ri_ptr + col * 2, -prod_re, mask=target_mask) - tl.atomic_add(residual_ri_ptr + col * 2 + 1, -prod_im, mask=target_mask) + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + tmp_re = tl.zeros((), dtype=tl.float32) + tmp_im = tl.zeros((), dtype=tl.float32) + ptr = start + row_done = 0 + while row_done == 0: + if UNIT_DIAG: + if ptr >= end: + x_re_out = rhs_re - tmp_re + x_im_out = rhs_im - tmp_im + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + out_vals = tl.where(lane2 == 0, x_re_out, x_im_out) + tl.store(x_ri_ptr + row * 2 + lane2, out_vals) + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + stop_at_diag = (col >= row) if LOWER else (col <= row) + if stop_at_diag: + x_re_out = rhs_re - tmp_re + x_im_out = rhs_im - tmp_im + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + out_vals = tl.where(lane2 == 0, x_re_out, x_im_out) + tl.store(x_ri_ptr + row * 2 + lane2, out_vals) + row_done = 1 + else: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + while dep_ready != 1: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + x_re = tl.load(x_ri_ptr + col * 2) + x_im = tl.load(x_ri_ptr + col * 2 + 1) + a_re = tl.load(data_ri_ptr + ptr * 2) + a_im = tl.load(data_ri_ptr + ptr * 2 + 1) + if USE_FP64_ACC: + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + tmp_re += a_re * x_re - a_im * x_im + tmp_im += a_re * x_im + a_im * x_re + ptr += 1 + else: + if ptr >= end: + out_vals = tl.where(lane2 == 0, rhs_re * 0, rhs_im * 0) + tl.store(x_ri_ptr + row * 2 + lane2, out_vals) + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + if col == row: + diag_re = tl.load(data_ri_ptr + ptr * 2) + diag_im = tl.load(data_ri_ptr + ptr * 2 + 1) + if USE_FP64_ACC: + diag_re = diag_re.to(tl.float64) + diag_im = diag_im.to(tl.float64) + else: + diag_re = diag_re.to(tl.float32) + diag_im = diag_im.to(tl.float32) + num_re = rhs_re - tmp_re + num_im = rhs_im - tmp_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + out_vals = tl.where(lane2 == 0, x_re_out, x_im_out) + tl.store(x_ri_ptr + row * 2 + lane2, out_vals) + row_done = 1 + else: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + while dep_ready != 1: + dep_ready = tl.atomic_add(ready_ptr + col, 0) + x_re = tl.load(x_ri_ptr + col * 2) + x_im = tl.load(x_ri_ptr + col * 2 + 1) + a_re = tl.load(data_ri_ptr + ptr * 2) + a_im = tl.load(data_ri_ptr + ptr * 2 + 1) + if USE_FP64_ACC: + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + tmp_re += a_re * x_re - a_im * x_im + tmp_im += a_re * x_im + a_im * x_re + ptr += 1 + _publish_ready_flag_i32(ready_ptr, row) + logical_row = tl.atomic_add(row_counter_ptr, 1) @triton.jit @@ -1298,7 +1795,6 @@ def _spsv_csr_transpose_cw_kernel( data_ptr, indices_ptr, indptr_ptr, - diag_ptr, indegree_ptr, residual_ptr, x_ptr, @@ -1307,27 +1803,38 @@ def _spsv_csr_transpose_cw_kernel( BLOCK_NNZ: tl.constexpr, MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, + REVERSE_ORDER: tl.constexpr, UNIT_DIAG: tl.constexpr, DIAG_EPS: tl.constexpr, ): - row = tl.atomic_add(row_counter_ptr, 1) - while row < n_rows: + logical_row = tl.atomic_add(row_counter_ptr, 1) + while logical_row < n_rows: + row = tl.where(REVERSE_ORDER, n_rows - 1 - logical_row, logical_row) ready_value = 0 if UNIT_DIAG else 1 - while tl.load(indegree_ptr + row) != ready_value: - pass + dep_ready = tl.atomic_add(indegree_ptr + row, 0) + while dep_ready != ready_value: + dep_ready = tl.atomic_add(indegree_ptr + row, 0) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) rhs = tl.load(residual_ptr + row) if UNIT_DIAG: diag = rhs * 0 + 1.0 else: - diag = tl.load(diag_ptr + row) + diag = rhs * 0 + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + dep_row = tl.load(indices_ptr + offsets, mask=mask, other=0) + is_diag = dep_row == row + diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0), axis=0) diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) x_row = rhs / diag_safe x_row = tl.where(x_row == x_row, x_row, 0.0) tl.store(x_ptr + row, x_row) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) for seg in range(MAX_SEGMENTS): idx = start + seg * BLOCK_NNZ offsets = idx + tl.arange(0, BLOCK_NNZ) @@ -1335,12 +1842,13 @@ def _spsv_csr_transpose_cw_kernel( a = tl.load(data_ptr + offsets, mask=mask, other=0.0) col = tl.load(indices_ptr + offsets, mask=mask, other=0) if LOWER: - target_mask = mask & (col < row) - else: target_mask = mask & (col > row) - tl.atomic_add(residual_ptr + col, -a * x_row, mask=target_mask) - tl.atomic_add(indegree_ptr + col, -1, mask=target_mask) - row = tl.atomic_add(row_counter_ptr, 1) + else: + target_mask = mask & (col < row) + _propagate_then_release_real( + residual_ptr, indegree_ptr, col, -a * x_row, target_mask + ) + logical_row = tl.atomic_add(row_counter_ptr, 1) @triton.jit @@ -1348,7 +1856,6 @@ def _spsv_csr_transpose_cw_kernel_complex( data_ri_ptr, indices_ptr, indptr_ptr, - diag_ri_ptr, indegree_ptr, residual_ri_ptr, x_ri_ptr, @@ -1357,16 +1864,22 @@ def _spsv_csr_transpose_cw_kernel_complex( BLOCK_NNZ: tl.constexpr, MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, + REVERSE_ORDER: tl.constexpr, UNIT_DIAG: tl.constexpr, CONJ_TRANS: tl.constexpr, USE_FP64_ACC: tl.constexpr, DIAG_EPS: tl.constexpr, ): - row = tl.atomic_add(row_counter_ptr, 1) - while row < n_rows: + logical_row = tl.atomic_add(row_counter_ptr, 1) + lane2 = tl.arange(0, 2) + while logical_row < n_rows: + row = tl.where(REVERSE_ORDER, n_rows - 1 - logical_row, logical_row) ready_value = 0 if UNIT_DIAG else 1 - while tl.load(indegree_ptr + row) != ready_value: - pass + dep_ready = tl.atomic_add(indegree_ptr + row, 0) + while dep_ready != ready_value: + dep_ready = tl.atomic_add(indegree_ptr + row, 0) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) rhs_re = tl.load(residual_ri_ptr + row * 2) rhs_im = tl.load(residual_ri_ptr + row * 2 + 1) @@ -1381,16 +1894,30 @@ def _spsv_csr_transpose_cw_kernel_complex( diag_re = rhs_re * 0 + 1.0 diag_im = rhs_im * 0 else: - diag_re = tl.load(diag_ri_ptr + row * 2) - diag_im = tl.load(diag_ri_ptr + row * 2 + 1) - if CONJ_TRANS: - diag_im = -diag_im if USE_FP64_ACC: - diag_re = diag_re.to(tl.float64) - diag_im = diag_im.to(tl.float64) + diag_re = tl.zeros((), dtype=tl.float64) + diag_im = tl.zeros((), dtype=tl.float64) else: - diag_re = diag_re.to(tl.float32) - diag_im = diag_im.to(tl.float32) + diag_re = tl.zeros((), dtype=tl.float32) + diag_im = tl.zeros((), dtype=tl.float32) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + dep_row = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + if CONJ_TRANS: + a_im = -a_im + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + is_diag = dep_row == row + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0), axis=0) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0), axis=0) den = diag_re * diag_re + diag_im * diag_im den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) @@ -1399,12 +1926,9 @@ def _spsv_csr_transpose_cw_kernel_complex( x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) - offs1 = tl.arange(0, 1) - tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) - tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + out_vals = tl.where(lane2 == 0, x_re_out, x_im_out) + tl.store(x_ri_ptr + row * 2 + lane2, out_vals) - start = tl.load(indptr_ptr + row) - end = tl.load(indptr_ptr + row + 1) for seg in range(MAX_SEGMENTS): idx = start + seg * BLOCK_NNZ offsets = idx + tl.arange(0, BLOCK_NNZ) @@ -1421,56 +1945,478 @@ def _spsv_csr_transpose_cw_kernel_complex( a_re = a_re.to(tl.float32) a_im = a_im.to(tl.float32) if LOWER: - target_mask = mask & (col < row) - else: target_mask = mask & (col > row) + else: + target_mask = mask & (col < row) prod_re = a_re * x_re_out - a_im * x_im_out prod_im = a_re * x_im_out + a_im * x_re_out - tl.atomic_add(residual_ri_ptr + col * 2, -prod_re, mask=target_mask) - tl.atomic_add(residual_ri_ptr + col * 2 + 1, -prod_im, mask=target_mask) - tl.atomic_add(indegree_ptr + col, -1, mask=target_mask) - row = tl.atomic_add(row_counter_ptr, 1) + _propagate_then_release_complex( + residual_ri_ptr, + indegree_ptr, + col, + -prod_re, + -prod_im, + target_mask, + ) + logical_row = tl.atomic_add(row_counter_ptr, 1) -def _build_spsv_levels(indptr, indices, n_rows, lower=True): - """Build dependency levels for triangular solve so each level can run in parallel.""" - if n_rows == 0: - return [] - indptr_h = indptr.to(torch.int64).cpu() - indices_h = indices.to(torch.int64).cpu() - levels = [0] * n_rows - if lower: - for i in range(n_rows): - s = int(indptr_h[i].item()) - e = int(indptr_h[i + 1].item()) - lvl = 0 - for p in range(s, e): - c = int(indices_h[p].item()) - if c < i: - lvl = max(lvl, levels[c] + 1) - levels[i] = lvl +@triton.jit +def _spsv_csr_roc_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + row_map_ptr, + b_ptr, + x_ptr, + ready_ptr, + n_rows, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, + WARP_SIZE: tl.constexpr, +): + logical_row = tl.program_id(0) + if logical_row >= n_rows: + return + row = tl.load(row_map_ptr + logical_row) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + lanes = tl.arange(0, WARP_SIZE) + ptr = start + lanes + if USE_FP64_ACC: + rhs = tl.load(b_ptr + row).to(tl.float64) + local_sum = tl.where(lanes == 0, rhs, 0.0).to(tl.float64) + zero_vec = tl.zeros((WARP_SIZE,), dtype=tl.float64) + else: + rhs = tl.load(b_ptr + row).to(tl.float32) + local_sum = tl.where(lanes == 0, rhs, 0.0).to(tl.float32) + zero_vec = tl.zeros((WARP_SIZE,), dtype=tl.float32) + + loop_done = 0 + while loop_done == 0: + active = ptr < end + col = tl.load(indices_ptr + ptr, mask=active, other=row) + dep_mask = active & (col < row) + if tl.sum(dep_mask.to(tl.int32), axis=0) == 0: + loop_done = 1 + else: + dep_ready = tl.atomic_add( + ready_ptr + col, + tl.zeros((WARP_SIZE,), dtype=tl.int32), + mask=dep_mask, + ) + advance_mask = dep_mask & (dep_ready != 0) + a = tl.load(data_ptr + ptr, mask=advance_mask, other=0.0) + if USE_FP64_ACC: + a = a.to(tl.float64) + y_dep = tl.atomic_add(x_ptr + col, zero_vec, mask=advance_mask).to(tl.float64) + else: + a = a.to(tl.float32) + y_dep = tl.atomic_add(x_ptr + col, zero_vec, mask=advance_mask).to(tl.float32) + local_sum += tl.where(advance_mask, -a * y_dep, 0.0) + ptr = ptr + tl.where(advance_mask, WARP_SIZE, 0) + + active = ptr < end + col = tl.load(indices_ptr + ptr, mask=active, other=row + 1) + diag_mask = active & (col == row) + diag = tl.load(data_ptr + ptr, mask=diag_mask, other=0.0) + if USE_FP64_ACC: + diag = diag.to(tl.float64) + else: + diag = diag.to(tl.float32) + diag_val = tl.sum(diag, axis=0) + diag_safe = tl.where(tl.abs(diag_val) < DIAG_EPS, 1.0, diag_val) + out = tl.sum(local_sum, axis=0) / diag_safe + out = tl.where(out == out, out, 0.0) + if USE_FP64_ACC: + tl.atomic_add(x_ptr + row, out.to(tl.float64)) + else: + tl.atomic_add(x_ptr + row, out.to(tl.float32)) + _publish_ready_flag_i32(ready_ptr, row) + + +@triton.jit +def _spsv_csr_roc_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + row_map_ptr, + b_ri_ptr, + x_ri_ptr, + ready_ptr, + n_rows, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, + WARP_SIZE: tl.constexpr, +): + logical_row = tl.program_id(0) + if logical_row >= n_rows: + return + row = tl.load(row_map_ptr + logical_row) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + lanes = tl.arange(0, WARP_SIZE) + ptr = start + lanes + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + local_sum_re = tl.where(lanes == 0, rhs_re, 0.0).to(tl.float64) + local_sum_im = tl.where(lanes == 0, rhs_im, 0.0).to(tl.float64) + zero_vec = tl.zeros((WARP_SIZE,), dtype=tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + local_sum_re = tl.where(lanes == 0, rhs_re, 0.0).to(tl.float32) + local_sum_im = tl.where(lanes == 0, rhs_im, 0.0).to(tl.float32) + zero_vec = tl.zeros((WARP_SIZE,), dtype=tl.float32) + + loop_done = 0 + while loop_done == 0: + active = ptr < end + col = tl.load(indices_ptr + ptr, mask=active, other=row) + dep_mask = active & (col < row) + if tl.sum(dep_mask.to(tl.int32), axis=0) == 0: + loop_done = 1 + else: + dep_ready = tl.atomic_add( + ready_ptr + col, + tl.zeros((WARP_SIZE,), dtype=tl.int32), + mask=dep_mask, + ) + advance_mask = dep_mask & (dep_ready != 0) + a_re = tl.load(data_ri_ptr + ptr * 2, mask=advance_mask, other=0.0) + a_im = tl.load(data_ri_ptr + ptr * 2 + 1, mask=advance_mask, other=0.0) + x_re = tl.atomic_add(x_ri_ptr + col * 2, zero_vec, mask=advance_mask) + x_im = tl.atomic_add(x_ri_ptr + col * 2 + 1, zero_vec, mask=advance_mask) + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + prod_re = a_re * x_re - a_im * x_im + prod_im = a_re * x_im + a_im * x_re + local_sum_re += tl.where(advance_mask, -prod_re, 0.0) + local_sum_im += tl.where(advance_mask, -prod_im, 0.0) + ptr = ptr + tl.where(advance_mask, WARP_SIZE, 0) + + active = ptr < end + col = tl.load(indices_ptr + ptr, mask=active, other=row + 1) + diag_mask = active & (col == row) + diag_re = tl.load(data_ri_ptr + ptr * 2, mask=diag_mask, other=0.0) + diag_im = tl.load(data_ri_ptr + ptr * 2 + 1, mask=diag_mask, other=0.0) + if USE_FP64_ACC: + diag_re = diag_re.to(tl.float64) + diag_im = diag_im.to(tl.float64) + else: + diag_re = diag_re.to(tl.float32) + diag_im = diag_im.to(tl.float32) + diag_re = tl.sum(diag_re, axis=0) + diag_im = tl.sum(diag_im, axis=0) + sum_re = tl.sum(local_sum_re, axis=0) + sum_im = tl.sum(local_sum_im, axis=0) + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + out_re = (sum_re * diag_re + sum_im * diag_im) / den_safe + out_im = (sum_im * diag_re - sum_re * diag_im) / den_safe + out_re = tl.where(out_re == out_re, out_re, 0.0) + out_im = tl.where(out_im == out_im, out_im, 0.0) + tl.atomic_add(x_ri_ptr + row * 2, out_re) + tl.atomic_add(x_ri_ptr + row * 2 + 1, out_im) + _publish_ready_flag_i32(ready_ptr, row) + + +@triton.jit +def _spsv_csr_cw_levelschd_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + row_map_ptr, + b_ptr, + x_ptr, + ready_ptr, + n_rows, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + logical_row = tl.program_id(0) + if logical_row >= n_rows: + return + row = tl.load(row_map_ptr + logical_row) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + ptr = start + if USE_FP64_ACC: + rhs = tl.load(b_ptr + row).to(tl.float64) + tmp_sum = tl.zeros((), dtype=tl.float64) else: - for i in range(n_rows - 1, -1, -1): - s = int(indptr_h[i].item()) - e = int(indptr_h[i + 1].item()) - lvl = 0 - for p in range(s, e): - c = int(indices_h[p].item()) - if c > i: - lvl = max(lvl, levels[c] + 1) - levels[i] = lvl - - max_level = max(levels) - buckets = [[] for _ in range(max_level + 1)] - for r, lv in enumerate(levels): - buckets[lv].append(r) - - device = indptr.device - return [ - torch.tensor(rows, dtype=torch.int32, device=device) - for rows in buckets - if rows - ] + rhs = tl.load(b_ptr + row).to(tl.float32) + tmp_sum = tl.zeros((), dtype=tl.float32) + row_done = 0 + while row_done == 0: + if ptr >= end: + x_row = rhs * 0 + if USE_FP64_ACC: + tl.atomic_add(x_ptr + row, x_row.to(tl.float64)) + else: + tl.atomic_add(x_ptr + row, x_row.to(tl.float32)) + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + if col == row: + if USE_FP64_ACC: + diag = tl.load(data_ptr + ptr).to(tl.float64) + else: + diag = tl.load(data_ptr + ptr).to(tl.float32) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = (rhs - tmp_sum) / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + if USE_FP64_ACC: + tl.atomic_add(x_ptr + row, x_row.to(tl.float64)) + else: + tl.atomic_add(x_ptr + row, x_row.to(tl.float32)) + row_done = 1 + else: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + while dep_ready != 1: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + if USE_FP64_ACC: + a = tl.load(data_ptr + ptr).to(tl.float64) + y_dep = _load_scalar_fp64(x_ptr, col).to(tl.float64) + else: + a = tl.load(data_ptr + ptr).to(tl.float32) + y_dep = _load_scalar_fp32(x_ptr, col).to(tl.float32) + tmp_sum += a * y_dep + ptr += 1 + _publish_ready_flag_i32(ready_ptr, row) + + +@triton.jit +def _spsv_csr_cw_levelschd_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + row_map_ptr, + b_ri_ptr, + x_ri_ptr, + ready_ptr, + n_rows, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + logical_row = tl.program_id(0) + if logical_row >= n_rows: + return + row = tl.load(row_map_ptr + logical_row) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + ptr = start + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + tmp_sum_re = tl.zeros((), dtype=tl.float64) + tmp_sum_im = tl.zeros((), dtype=tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + tmp_sum_re = tl.zeros((), dtype=tl.float32) + tmp_sum_im = tl.zeros((), dtype=tl.float32) + row_done = 0 + while row_done == 0: + if ptr >= end: + zero = rhs_re * 0 + tl.atomic_add(x_ri_ptr + row * 2, zero) + tl.atomic_add(x_ri_ptr + row * 2 + 1, zero) + row_done = 1 + else: + col = tl.load(indices_ptr + ptr) + if col == row: + diag_re = tl.load(data_ri_ptr + ptr * 2) + diag_im = tl.load(data_ri_ptr + ptr * 2 + 1) + if USE_FP64_ACC: + diag_re = diag_re.to(tl.float64) + diag_im = diag_im.to(tl.float64) + else: + diag_re = diag_re.to(tl.float32) + diag_im = diag_im.to(tl.float32) + sum_re = rhs_re - tmp_sum_re + sum_im = rhs_im - tmp_sum_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + out_re = (sum_re * diag_re + sum_im * diag_im) / den_safe + out_im = (sum_im * diag_re - sum_re * diag_im) / den_safe + out_re = tl.where(out_re == out_re, out_re, 0.0) + out_im = tl.where(out_im == out_im, out_im, 0.0) + tl.atomic_add(x_ri_ptr + row * 2, out_re) + tl.atomic_add(x_ri_ptr + row * 2 + 1, out_im) + row_done = 1 + else: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + while dep_ready != 1: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + a_re = tl.load(data_ri_ptr + ptr * 2) + a_im = tl.load(data_ri_ptr + ptr * 2 + 1) + x_re = tl.atomic_add(x_ri_ptr + col * 2, 0.0) + x_im = tl.atomic_add(x_ri_ptr + col * 2 + 1, 0.0) + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + tmp_sum_re += a_re * x_re - a_im * x_im + tmp_sum_im += a_re * x_im + a_im * x_re + ptr += 1 + _publish_ready_flag_i32(ready_ptr, row) + + +@triton.jit +def _spsv_csr_nnz_balance_kernel( + row_idx_ptr, + col_idx_ptr, + val_ptr, + b_ptr, + x_ptr, + tmp_sum_ptr, + ready_ptr, + indegree_ptr, + nnz, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + val_id = tl.program_id(0) + if val_id >= nnz: + return + row = tl.load(row_idx_ptr + val_id) + col = tl.load(col_idx_ptr + val_id) + if row < col: + return + if USE_FP64_ACC: + a = tl.load(val_ptr + val_id).to(tl.float64) + else: + a = tl.load(val_ptr + val_id).to(tl.float32) + done = 0 + while done == 0: + if row != col: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + if dep_ready == 1: + if USE_FP64_ACC: + dep_x = _load_scalar_fp64(x_ptr, col).to(tl.float64) + else: + dep_x = _load_scalar_fp32(x_ptr, col).to(tl.float32) + tl.atomic_add(tmp_sum_ptr + row, dep_x * a) + tl.atomic_add(indegree_ptr + row, -1) + done = 1 + else: + diag_degree = tl.atomic_add(indegree_ptr + row, 0) + if diag_degree == 1: + if USE_FP64_ACC: + rhs = tl.load(b_ptr + row).to(tl.float64) + sum_val = tl.atomic_add(tmp_sum_ptr + row, 0.0).to(tl.float64) + else: + rhs = tl.load(b_ptr + row).to(tl.float32) + sum_val = tl.atomic_add(tmp_sum_ptr + row, 0.0).to(tl.float32) + diag_safe = tl.where(tl.abs(a) < DIAG_EPS, 1.0, a) + out = (rhs - sum_val) / diag_safe + out = tl.where(out == out, out, 0.0) + if USE_FP64_ACC: + tl.atomic_add(x_ptr + row, out.to(tl.float64)) + else: + tl.atomic_add(x_ptr + row, out.to(tl.float32)) + _publish_ready_flag_i32(ready_ptr, row) + done = 1 + + +@triton.jit +def _spsv_csr_nnz_balance_kernel_complex( + row_idx_ptr, + col_idx_ptr, + val_ri_ptr, + b_ri_ptr, + x_ri_ptr, + tmp_sum_ri_ptr, + ready_ptr, + indegree_ptr, + nnz, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + val_id = tl.program_id(0) + if val_id >= nnz: + return + row = tl.load(row_idx_ptr + val_id) + col = tl.load(col_idx_ptr + val_id) + if row < col: + return + val_re = tl.load(val_ri_ptr + val_id * 2) + val_im = tl.load(val_ri_ptr + val_id * 2 + 1) + if USE_FP64_ACC: + val_re = val_re.to(tl.float64) + val_im = val_im.to(tl.float64) + else: + val_re = val_re.to(tl.float32) + val_im = val_im.to(tl.float32) + done = 0 + while done == 0: + if row != col: + dep_ready = _load_ready_flag_i32(ready_ptr, col) + if dep_ready == 1: + dep_x_re = tl.atomic_add(x_ri_ptr + col * 2, 0.0) + dep_x_im = tl.atomic_add(x_ri_ptr + col * 2 + 1, 0.0) + if USE_FP64_ACC: + dep_x_re = dep_x_re.to(tl.float64) + dep_x_im = dep_x_im.to(tl.float64) + else: + dep_x_re = dep_x_re.to(tl.float32) + dep_x_im = dep_x_im.to(tl.float32) + prod_re = dep_x_re * val_re - dep_x_im * val_im + prod_im = dep_x_re * val_im + dep_x_im * val_re + tl.atomic_add(tmp_sum_ri_ptr + row * 2, prod_re) + tl.atomic_add(tmp_sum_ri_ptr + row * 2 + 1, prod_im) + tl.atomic_add(indegree_ptr + row, -1) + done = 1 + if row == col: + diag_degree = tl.atomic_add(indegree_ptr + row, 0) + if diag_degree == 1: + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + sum_re = tl.atomic_add(tmp_sum_ri_ptr + row * 2, 0.0) + sum_im = tl.atomic_add(tmp_sum_ri_ptr + row * 2 + 1, 0.0) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + sum_re = sum_re.to(tl.float64) + sum_im = sum_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + sum_re = sum_re.to(tl.float32) + sum_im = sum_im.to(tl.float32) + num_re = rhs_re - sum_re + num_im = rhs_im - sum_im + den = val_re * val_re + val_im * val_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + out_re = (num_re * val_re + num_im * val_im) / den_safe + out_im = (num_im * val_re - num_re * val_im) / den_safe + out_re = tl.where(out_re == out_re, out_re, 0.0) + out_im = tl.where(out_im == out_im, out_im, 0.0) + tl.atomic_add(x_ri_ptr + row * 2, out_re) + tl.atomic_add(x_ri_ptr + row * 2 + 1, out_im) + _publish_ready_flag_i32(ready_ptr, row) + done = 1 def _auto_spsv_launch_config(indptr, block_nnz=None, max_segments=None): @@ -1522,371 +2468,394 @@ def _auto_spsv_launch_config(indptr, block_nnz=None, max_segments=None): return block_nnz_use, max_segments_use -def _triton_spsv_csr_vector( +def _triton_spsv_csr_cw_vector( data, indices, indptr, b_vec, n_rows, + *, lower=True, unit_diagonal=False, - block_nnz=None, - max_segments=None, diag_eps=1e-12, - levels=None, - block_nnz_use=None, - max_segments_use=None, + worker_count=None, + matrix_stats=None, + ready_in=None, + row_counter_in=None, ): x = torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + row_counter = ( + row_counter_in + if row_counter_in is not None + else torch.zeros(1, dtype=torch.int32, device=b_vec.device) + ) + ready.zero_() + row_counter.zero_() if n_rows == 0: return x - if levels is None: - levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - - for rows_lv in levels: - n_lv = rows_lv.numel() - if n_lv == 0: - continue - grid = (n_lv,) - _spsv_csr_level_kernel[grid]( - data, - indices, - indptr, - b_vec, - x, - rows_lv, - n_level_rows=n_lv, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - UNIT_DIAG=unit_diagonal, - DIAG_EPS=diag_eps, - ) + if worker_count is None: + matrix_stats = matrix_stats or {} + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) + use_fp64_acc = data.dtype == torch.float64 + grid = (worker_count,) + _spsv_csr_cw_kernel[grid]( + data, + indices, + indptr, + b_vec, + x, + ready, + row_counter, + n_rows, + LOWER=lower, + REVERSE_ORDER=not lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64_acc, + DIAG_EPS=diag_eps, + ) return x -def _triton_spsv_csr_vector_complex( +def _triton_spsv_csr_cw_vector_complex( data, indices, indptr, b_vec, n_rows, + *, lower=True, unit_diagonal=False, - block_nnz=None, - max_segments=None, diag_eps=1e-12, - levels=None, - block_nnz_use=None, - max_segments_use=None, + worker_count=None, + matrix_stats=None, data_ri_in=None, + ready_in=None, + row_counter_in=None, ): x = torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + row_counter = ( + row_counter_in + if row_counter_in is not None + else torch.zeros(1, dtype=torch.int32, device=b_vec.device) + ) + ready.zero_() + row_counter.zero_() if n_rows == 0: return x - if levels is None: - levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - # Some PyTorch builds return CSR values with a non-strided layout wrapper. - # Materialize a plain 1D strided buffer before splitting into real/imag parts. data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 - if component_dtype == torch.float16: - x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) - x_ri = x_ri_work.reshape(-1).contiguous() - else: - x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() - for rows_lv in levels: - n_lv = rows_lv.numel() - if n_lv == 0: - continue - grid = (n_lv,) - _spsv_csr_level_kernel_complex[grid]( - data_ri, - indices, - indptr, - b_ri, - x_ri, - rows_lv, - n_level_rows=n_lv, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - UNIT_DIAG=unit_diagonal, - USE_FP64_ACC=use_fp64, - DIAG_EPS=diag_eps, - ) - if component_dtype == torch.float16: - return torch.view_as_complex(x_ri_work.contiguous()) + if worker_count is None: + matrix_stats = matrix_stats or {} + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) + grid = (worker_count,) + _spsv_csr_cw_kernel_complex[grid]( + data_ri, + indices, + indptr, + b_ri, + x_ri, + ready, + row_counter, + n_rows, + LOWER=lower, + REVERSE_ORDER=not lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + return x + + +def _triton_spsv_csr_u_lo_cw_vector(*args, **kwargs): + return _triton_spsv_csr_cw_vector(*args, lower=True, unit_diagonal=True, **kwargs) + + +def _triton_spsv_csr_n_lo_cw_vector(*args, **kwargs): + return _triton_spsv_csr_cw_vector(*args, lower=True, unit_diagonal=False, **kwargs) + + +def _triton_spsv_csr_u_up_cw_vector(*args, **kwargs): + return _triton_spsv_csr_cw_vector(*args, lower=False, unit_diagonal=True, **kwargs) + + +def _triton_spsv_csr_n_up_cw_vector(*args, **kwargs): + return _triton_spsv_csr_cw_vector(*args, lower=False, unit_diagonal=False, **kwargs) + + +def _triton_spsv_csr_u_lo_cw_vector_complex(*args, **kwargs): + return _triton_spsv_csr_cw_vector_complex(*args, lower=True, unit_diagonal=True, **kwargs) + + +def _triton_spsv_csr_n_lo_cw_vector_complex(*args, **kwargs): + return _triton_spsv_csr_cw_vector_complex(*args, lower=True, unit_diagonal=False, **kwargs) + + +def _triton_spsv_csr_u_up_cw_vector_complex(*args, **kwargs): + return _triton_spsv_csr_cw_vector_complex(*args, lower=False, unit_diagonal=True, **kwargs) + + +def _triton_spsv_csr_n_up_cw_vector_complex(*args, **kwargs): + return _triton_spsv_csr_cw_vector_complex(*args, lower=False, unit_diagonal=False, **kwargs) + + +def _triton_spsv_csr_n_lo_roc_vector( + data, + indices, + indptr, + row_map, + b_vec, + n_rows, + *, + diag_eps=1e-12, + ready_in=None, +): + x = torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros( + n_rows, dtype=torch.int32, device=b_vec.device + ) + ready.zero_() + if n_rows == 0: + return x + use_fp64_acc = data.dtype == torch.float64 + _spsv_csr_roc_kernel[(n_rows,)]( + data, + indices, + indptr, + row_map, + b_vec, + x, + ready, + n_rows, + USE_FP64_ACC=use_fp64_acc, + DIAG_EPS=diag_eps, + WARP_SIZE=32, + num_warps=1, + ) + return x + + +def _triton_spsv_csr_n_lo_roc_vector_complex( + data, + indices, + indptr, + row_map, + b_vec, + n_rows, + *, + diag_eps=1e-12, + data_ri_in=None, + ready_in=None, +): + x = torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros( + n_rows, dtype=torch.int32, device=b_vec.device + ) + ready.zero_() + if n_rows == 0: + return x + data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + _spsv_csr_roc_kernel_complex[(n_rows,)]( + data_ri, + indices, + indptr, + row_map, + b_ri, + x_ri, + ready, + n_rows, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + WARP_SIZE=32, + num_warps=1, + ) return x -def _triton_spsv_csr_cw_vector( +def _triton_spsv_csr_n_lo_cw_levelschd_vector( data, indices, indptr, - diag, + row_map, b_vec, n_rows, - lower=True, - block_nnz=None, - max_segments=None, + *, diag_eps=1e-12, - block_nnz_use=None, - max_segments_use=None, - worker_count=None, - matrix_stats=None, + ready_in=None, ): - if b_vec.ndim == 1: - b_mat = b_vec.unsqueeze(1).contiguous() - else: - b_mat = b_vec.contiguous() - x = torch.zeros_like(b_mat) - ready = torch.zeros(n_rows, dtype=torch.int32, device=b_mat.device) - row_counter = torch.zeros(1, dtype=torch.int32, device=b_mat.device) - n_rhs = int(b_mat.shape[1]) - if n_rows == 0 or n_rhs == 0: - return x.squeeze(1) if b_vec.ndim == 1 else x - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - matrix_stats = matrix_stats or {} - block_rhs = _choose_spsv_block_rhs(n_rhs, matrix_stats, complex_mode=False) - if worker_count is None: - worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, n_rhs) - grid = (worker_count,) - _spsv_csr_cw_kernel[grid]( + x = torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + ready.zero_() + if n_rows == 0: + return x + use_fp64_acc = data.dtype == torch.float64 + grid = (n_rows,) + _spsv_csr_cw_levelschd_kernel[grid]( data, indices, indptr, - diag, - b_mat, + row_map, + b_vec, x, ready, - row_counter, n_rows, - n_rhs, - b_mat.stride(0), - x.stride(0), - BLOCK_RHS=block_rhs, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, + USE_FP64_ACC=use_fp64_acc, DIAG_EPS=diag_eps, + num_warps=1, ) - return x.squeeze(1) if b_vec.ndim == 1 else x + return x -def _triton_spsv_csr_cw_vector_complex( +def _triton_spsv_csr_n_lo_cw_levelschd_vector_complex( data, indices, indptr, - diag, + row_map, b_vec, n_rows, - lower=True, - unit_diagonal=False, - block_nnz=None, - max_segments=None, + *, diag_eps=1e-12, - block_nnz_use=None, - max_segments_use=None, - worker_count=None, - matrix_stats=None, data_ri_in=None, - diag_ri_in=None, + ready_in=None, ): - if b_vec.ndim != 1: - shared_b = b_vec if b_vec.is_contiguous() else b_vec.contiguous() - cols = [] - for bj in torch.unbind(shared_b, dim=1): - cols.append( - _triton_spsv_csr_cw_vector_complex( - data, - indices, - indptr, - diag, - bj, - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count, - matrix_stats=matrix_stats, - data_ri_in=data_ri_in, - diag_ri_in=diag_ri_in, - ) - ) - return torch.stack(cols, dim=1) - x = torch.zeros_like(b_vec) - ready = torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) - row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + ready = ( + ready_in + if ready_in is not None + else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + ) + ready.zero_() if n_rows == 0: return x - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) - diag_ri = diag_ri_in if diag_ri_in is not None else _complex_interleaved_view(diag) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 - x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() - - if worker_count is None: - matrix_stats = matrix_stats or {} - worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) - grid = (worker_count,) - _spsv_csr_cw_kernel_complex[grid]( + grid = (n_rows,) + _spsv_csr_cw_levelschd_kernel_complex[grid]( data_ri, indices, indptr, - diag_ri, + row_map, b_ri, x_ri, ready, - row_counter, n_rows, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - UNIT_DIAG=unit_diagonal, USE_FP64_ACC=use_fp64, DIAG_EPS=diag_eps, + num_warps=1, ) return x -def _triton_spsv_csr_transpose_push_vector( +def _triton_spsv_csr_n_lo_nnz_balance_vector( data, indices, - indptr, + row_idx, + indegree_init, b_vec, n_rows, - lower=True, - unit_diagonal=False, - block_nnz=None, - max_segments=None, + *, diag_eps=1e-12, - launch_groups=None, - block_nnz_use=None, - max_segments_use=None, + tmp_sum_in=None, + ready_in=None, + indegree_in=None, ): x = torch.zeros_like(b_vec) if n_rows == 0: return x - residual = b_vec.clone() - if launch_groups is None: - levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) - launch_groups = list(reversed(levels)) - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - - for rows_lv in launch_groups: - n_lv = rows_lv.numel() - if n_lv == 0: - continue - grid = (n_lv,) - _spsv_csr_transpose_push_kernel[grid]( - data, - indices, - indptr, - residual, - x, - rows_lv, - n_level_rows=n_lv, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - UNIT_DIAG=unit_diagonal, - DIAG_EPS=diag_eps, - ) + tmp_sum = tmp_sum_in if tmp_sum_in is not None else torch.zeros_like(b_vec) + ready = ready_in if ready_in is not None else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + indegree = ( + indegree_in + if indegree_in is not None + else torch.empty(n_rows, dtype=torch.int32, device=b_vec.device) + ) + tmp_sum.zero_() + ready.zero_() + indegree.copy_(indegree_init) + use_fp64_acc = data.dtype == torch.float64 + grid = (int(data.numel()),) + _spsv_csr_nnz_balance_kernel[grid]( + row_idx, + indices, + data, + b_vec, + x, + tmp_sum, + ready, + indegree, + int(data.numel()), + USE_FP64_ACC=use_fp64_acc, + DIAG_EPS=diag_eps, + num_warps=1, + ) return x -def _triton_spsv_csr_transpose_push_vector_complex( +def _triton_spsv_csr_n_lo_nnz_balance_vector_complex( data, indices, - indptr, + row_idx, + indegree_init, b_vec, n_rows, - lower=True, - unit_diagonal=False, - conjugate=False, - block_nnz=None, - max_segments=None, + *, diag_eps=1e-12, - launch_groups=None, - block_nnz_use=None, - max_segments_use=None, data_ri_in=None, + tmp_sum_in=None, + ready_in=None, + indegree_in=None, ): x = torch.zeros_like(b_vec) if n_rows == 0: return x - if launch_groups is None: - levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) - launch_groups = list(reversed(levels)) - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( - indptr, block_nnz=block_nnz, max_segments=max_segments - ) - - residual_work = b_vec.contiguous().clone() + tmp_sum = tmp_sum_in if tmp_sum_in is not None else torch.zeros_like(b_vec) + ready = ( + ready_in + if ready_in is not None + else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + ) + indegree = ( + indegree_in + if indegree_in is not None + else torch.empty(n_rows, dtype=torch.int32, device=b_vec.device) + ) + tmp_sum.zero_() + ready.zero_() + indegree.copy_(indegree_init) data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) - residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + tmp_sum_ri = torch.view_as_real(tmp_sum.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 - if component_dtype == torch.float16: - x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) - x_ri = x_ri_work.reshape(-1).contiguous() - else: - x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() - - for rows_lv in launch_groups: - n_lv = rows_lv.numel() - if n_lv == 0: - continue - grid = (n_lv,) - _spsv_csr_transpose_push_kernel_complex[grid]( - data_ri, - indices, - indptr, - residual_ri, - x_ri, - rows_lv, - n_level_rows=n_lv, - BLOCK_NNZ=block_nnz_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - UNIT_DIAG=unit_diagonal, - CONJ_TRANS=conjugate, - USE_FP64_ACC=use_fp64, - DIAG_EPS=diag_eps, - ) - if component_dtype == torch.float16: - return torch.view_as_complex(x_ri_work.contiguous()) + grid = (int(data.numel()),) + _spsv_csr_nnz_balance_kernel_complex[grid]( + row_idx, + indices, + data_ri, + b_ri, + x_ri, + tmp_sum_ri, + ready, + indegree, + int(data.numel()), + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + num_warps=1, + ) return x @@ -1894,8 +2863,6 @@ def _triton_spsv_csr_transpose_cw_vector( data, indices, indptr, - diag, - indegree_init, b_vec, n_rows, lower=True, @@ -1907,17 +2874,42 @@ def _triton_spsv_csr_transpose_cw_vector( max_segments_use=None, worker_count=None, matrix_stats=None, + residual_in=None, + indegree_in=None, + row_counter_in=None, + preprocessed=False, ): x = torch.zeros_like(b_vec) if n_rows == 0: return x - residual = b_vec.clone() - indegree = indegree_init.clone() - row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + residual = residual_in if residual_in is not None else b_vec.clone() + indegree = ( + indegree_in + if indegree_in is not None + else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + ) + row_counter = ( + row_counter_in + if row_counter_in is not None + else torch.zeros(1, dtype=torch.int32, device=b_vec.device) + ) + residual.copy_(b_vec) + row_counter.zero_() if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( indptr, block_nnz=block_nnz, max_segments=max_segments ) + if not preprocessed: + _run_spsv_csc_preprocess( + indices, + indptr, + indegree, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) if worker_count is None: matrix_stats = matrix_stats or {} worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) @@ -1926,7 +2918,6 @@ def _triton_spsv_csr_transpose_cw_vector( data, indices, indptr, - diag, indegree, residual, x, @@ -1935,6 +2926,7 @@ def _triton_spsv_csr_transpose_cw_vector( BLOCK_NNZ=block_nnz_use, MAX_SEGMENTS=max_segments_use, LOWER=lower, + REVERSE_ORDER=not lower, UNIT_DIAG=unit_diagonal, DIAG_EPS=diag_eps, ) @@ -1945,8 +2937,6 @@ def _triton_spsv_csr_transpose_cw_vector_complex( data, indices, indptr, - diag, - indegree_init, b_vec, n_rows, lower=True, @@ -1960,7 +2950,10 @@ def _triton_spsv_csr_transpose_cw_vector_complex( worker_count=None, matrix_stats=None, data_ri_in=None, - diag_ri_in=None, + residual_in=None, + indegree_in=None, + row_counter_in=None, + preprocessed=False, ): x = torch.zeros_like(b_vec) if n_rows == 0: @@ -1970,11 +2963,31 @@ def _triton_spsv_csr_transpose_cw_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) - residual_work = b_vec.contiguous().clone() - indegree = indegree_init.clone() - row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + residual_work = residual_in if residual_in is not None else b_vec.contiguous().clone() + indegree = ( + indegree_in + if indegree_in is not None + else torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + ) + row_counter = ( + row_counter_in + if row_counter_in is not None + else torch.zeros(1, dtype=torch.int32, device=b_vec.device) + ) + residual_work.copy_(b_vec.contiguous()) + row_counter.zero_() + if not preprocessed: + _run_spsv_csc_preprocess( + indices, + indptr, + indegree, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) - diag_ri = diag_ri_in if diag_ri_in is not None else _complex_interleaved_view(diag) residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -1992,7 +3005,6 @@ def _triton_spsv_csr_transpose_cw_vector_complex( data_ri, indices, indptr, - diag_ri, indegree, residual_ri, x_ri, @@ -2001,6 +3013,7 @@ def _triton_spsv_csr_transpose_cw_vector_complex( BLOCK_NNZ=block_nnz_use, MAX_SEGMENTS=max_segments_use, LOWER=lower, + REVERSE_ORDER=not lower, UNIT_DIAG=unit_diagonal, CONJ_TRANS=conjugate, USE_FP64_ACC=use_fp64, @@ -2027,6 +3040,34 @@ def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments= return cand, req +def _run_spsv_csc_preprocess( + indices, + indptr, + indegree, + n_rows, + *, + lower, + unit_diagonal, + block_nnz_use, + max_segments_use, +): + indegree.zero_() + if n_rows == 0: + return indegree + grid = (n_rows,) + _spsv_csc_preprocess_kernel[grid]( + indices, + indptr, + indegree, + n_rows, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + ) + return indegree + + def _prepare_spsv_coo_inputs(data, row, col, b, shape): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -2036,14 +3077,12 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape): raise ValueError("data, row, col must be 1D") if row.numel() != data.numel() or col.numel() != data.numel(): raise ValueError("data, row, col must have the same length") - if b.ndim not in (1, 2): - raise ValueError("b must be 1D or 2D (vector or multiple RHS)") + if b.ndim != 1: + raise ValueError("b must be a 1D dense vector (DnVec)") n_rows, n_cols = int(shape[0]), int(shape[1]) - if b.ndim == 1 and b.numel() != n_rows: + if b.numel() != n_rows: raise ValueError(f"b length must equal n_rows={n_rows}") - if b.ndim == 2 and b.shape[0] != n_rows: - raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") if data.dtype not in ( torch.float32, @@ -2094,26 +3133,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape): ) -def _csr_transpose(data, indices64, indptr64, n_rows, n_cols, conjugate=False): - if data.numel() == 0: - out_data = data - out_indices = torch.empty(0, dtype=torch.int64, device=data.device) - out_indptr = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) - return out_data, out_indices, out_indptr - - row_ids = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr64[1:] - indptr64[:-1], - ) - new_row = indices64 - new_col = row_ids - data_eff = data.conj() if conjugate and torch.is_complex(data) else data - data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( - data_eff, new_row, new_col, n_cols, n_rows - ) - return data_t, indices_t, indptr_t - - def _build_coo_row_ptr(row_sorted, n_rows): row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=row_sorted.device) if row_sorted.numel() > 0: @@ -2122,499 +3141,500 @@ def _build_coo_row_ptr(row_sorted, n_rows): return row_ptr -def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): +def _coo_order_for_spsv(data, row64, col64): + if data.numel() == 0: + return data, row64, col64 + key = row64 + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + return data[order], row64[order], col64[order] + + +def _coo2csr_for_spsv(data, row64, col64, n_rows, assume_ordered=False): nnz = data.numel() if nnz == 0: indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) indices = torch.empty(0, dtype=torch.int64, device=data.device) return data, indices, indptr - key = row64 * max(1, n_cols) + col64 - try: - order = torch.argsort(key, stable=True) - except TypeError: - order = torch.argsort(key) - key_s = key[order] - data_s = data[order] + if not assume_ordered: + data, row64, col64 = _coo_order_for_spsv(data, row64, col64) - unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) - out_nnz = unique_key.numel() + indptr = _build_coo_row_ptr(row64, n_rows) + indices = col64.to(torch.int64).contiguous() + return data.contiguous(), indices, indptr - data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) - data_u.scatter_add_(0, inverse, data_s) - row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") - col_u = unique_key - row_u * max(1, n_cols) - indptr = _build_coo_row_ptr(row_u, n_rows) - indices = col_u.to(torch.int64) - return data_u, indices, indptr +def _analyze_spsv_csr_descriptor( + data, + indices, + indptr, + shape, + *, + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=None, + compute_dtype=None, + handle=None, + workspace=None, + storage_view="csr_as_csc", + format_name="csr", + clear_cache=False, +): + if clear_cache: + _clear_spsv_csr_preprocess_cache() + n_rows = int(shape[0]) + dummy_b = torch.empty(n_rows, dtype=data.dtype, device=data.device) + ( + matrix_data, + _dummy_b, + _original_output_dtype, + trans_mode, + n_rows, + n_cols, + solve_plan, + ) = _resolve_spsv_csr_runtime( + data, + indices, + indptr, + dummy_b, + shape, + lower, + transpose, + unit_diagonal, + requested_solve_kind=solve_kind, + storage_view=storage_view, + ) + input_index_dtype = indices.dtype + solve_plan = _select_spsv_runtime_plan( + solve_plan, trans_mode, requested_solve_kind=solve_kind + ) + compute_dtype = _spsv_effective_compute_dtype( + matrix_data.dtype, trans_mode, compute_dtype=compute_dtype + ) + layout = _build_spsv_workspace_layout( + n_rows, solve_plan["solve_kind"], value_dtype=compute_dtype + ) + if workspace is not None: + _resolve_spsv_workspace(workspace, layout, matrix_data.device) + return FlagSparseSpSVDescr( + format=_normalize_spsv_format(format_name), + canonical_format="csr", + shape=(int(shape[0]), int(shape[1])), + lower=bool(lower), + unit_diagonal=bool(unit_diagonal), + fill_mode="lower" if lower else "upper", + diag_type="unit" if unit_diagonal else "non_unit", + matrix_type="triangular", + index_base=0, + transpose_mode=trans_mode, + value_dtype=matrix_data.dtype, + compute_dtype=compute_dtype, + index_dtype=input_index_dtype, + solve_kind=solve_plan["solve_kind"], + route_name=str(solve_plan.get("route_name", solve_plan["solve_kind"])), + storage_view=str(solve_plan.get("storage_view", "csr")), + buffer_size=_workspace_size_bytes(layout), + workspace_layout=layout, + data=matrix_data, + indices=indices.contiguous(), + indptr=indptr.contiguous(), + solve_plan=_clone_spsv_plan(solve_plan), + ) -def flagsparse_spsv_csr( +def flagsparse_spsv_analysis_csr( data, indices, indptr, + shape, + *, + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=None, + compute_dtype=None, + handle=None, + workspace=None, + storage_view="csr_as_csc", + clear_cache=False, +): + """Analyze a CSR SpSV problem and return a reusable Triton descriptor.""" + + return _analyze_spsv_csr_descriptor( + data, + indices, + indptr, + shape, + lower=lower, + unit_diagonal=unit_diagonal, + transpose=transpose, + solve_kind=solve_kind, + compute_dtype=compute_dtype, + handle=handle, + workspace=workspace, + storage_view=storage_view, + format_name="csr", + clear_cache=clear_cache, + ) + + +def flagsparse_spsv_analysis_coo( + data, + row, + col, + shape, + *, + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=None, + compute_dtype=None, + handle=None, + workspace=None, + storage_view="csr_as_csc", +): + """Analyze a COO SpSV problem by canonicalizing COO into CSR first.""" + + dummy_b = torch.empty(int(shape[0]), dtype=data.dtype, device=data.device) + data, _input_index_dtype, row64, col64, _b, n_rows, n_cols = _prepare_spsv_coo_inputs( + data, row, col, dummy_b, shape + ) + trans_mode = _normalize_spsv_transpose_mode(transpose) + if trans_mode == "N": + _validate_spsv_non_trans_combo(data.dtype, row.dtype, "COO") + else: + _validate_spsv_trans_combo(data.dtype, row.dtype, "COO") + data_csr, indices_csr, indptr_csr = _coo2csr_for_spsv( + data, row64, col64, n_rows, assume_ordered=False + ) + return _analyze_spsv_csr_descriptor( + data_csr, + indices_csr, + indptr_csr, + shape, + lower=lower, + unit_diagonal=unit_diagonal, + transpose=transpose, + solve_kind=solve_kind, + compute_dtype=compute_dtype, + handle=handle, + workspace=workspace, + storage_view=storage_view, + format_name="coo", + clear_cache=False, + ) + + +def _execute_spsv_csr_plan( + data, b, - shape, - lower=True, + solve_plan, + trans_mode, + n_rows, + *, + alpha=1, unit_diagonal=False, - transpose=False, block_nnz=None, max_segments=None, out=None, return_time=False, + workspace=None, + original_output_dtype=None, + compute_dtype=None, + handle=None, + stream=None, ): - """Sparse triangular solve using Triton CSR kernels. - - Primary support matrix: - - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices - - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices - """ - ( - data, - b, - original_output_dtype, - trans_mode, - n_rows, - n_cols, - solve_plan, - ) = _resolve_spsv_csr_runtime( - data, - indices, - indptr, - b, - shape, - lower, - transpose, - unit_diagonal, - ) - - rhs_cols = 1 if b.ndim == 1 else int(b.shape[1]) - solve_plan = _select_spsv_runtime_plan( - solve_plan, rhs_cols, data.dtype, trans_mode - ) + solve_plan = _clone_spsv_plan(solve_plan) solve_kind = solve_plan["solve_kind"] kernel_data = solve_plan["kernel_data"] kernel_indices32 = solve_plan["kernel_indices32"] kernel_indptr64 = solve_plan["kernel_indptr64"] - lower_eff = solve_plan["lower_eff"] - launch_groups = solve_plan["launch_groups"] - transpose_conjugate = solve_plan["transpose_conjugate"] default_block_nnz = solve_plan["default_block_nnz"] default_max_segments = solve_plan["default_max_segments"] - transpose_diag = solve_plan.get("transpose_diag") - transpose_indegree_init = solve_plan.get("transpose_indegree_init") - cw_diag = solve_plan.get("cw_diag") cw_worker_count = solve_plan.get("cw_worker_count") + nontrans_variant = solve_plan.get("nontrans_variant", "csr_n_lo_cw") + lower_eff = solve_plan["lower_eff"] matrix_stats = solve_plan.get("matrix_stats", {}) + level_row_map32 = solve_plan.get("level_row_map32") + nnz_balance_row_idx32 = solve_plan.get("nnz_balance_row_idx32") + nnz_balance_indegree32 = solve_plan.get("nnz_balance_indegree32") kernel_indices = kernel_indices32 kernel_indptr = kernel_indptr64 - compute_dtype = data.dtype + compute_dtype = _spsv_effective_compute_dtype( + data.dtype, trans_mode, compute_dtype=compute_dtype + ) data_in = kernel_data b_in = b - if ( - data.dtype == torch.complex64 - and trans_mode in ("T", "C") - and SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 - ): - compute_dtype = torch.complex128 - data_in = kernel_data.to(torch.complex128) - b_in = b.to(torch.complex128) - elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: - compute_dtype = torch.float64 - data_in = kernel_data.to(torch.float64) - b_in = b.to(torch.float64) - elif ( - data.dtype == torch.float32 - and trans_mode in ("T", "C") - and SPSV_PROMOTE_TRANSPOSE_FP32_TO_FP64 - ): - compute_dtype = torch.float64 - data_in = kernel_data.to(torch.float64) - b_in = b.to(torch.float64) - - if solve_kind in ("transpose_push", "transpose_cw"): + if compute_dtype != data.dtype: + data_in = kernel_data.to(compute_dtype) + b_in = b.to(compute_dtype) + alpha_in = _coerce_spsv_alpha(alpha, compute_dtype, b.device) + b_in = b_in * alpha_in + solve_stream = _resolve_spsv_stream(handle, stream, b.device) + + if solve_kind == "transpose_cw": if block_nnz is None and max_segments is None: block_nnz_use, max_segments_use = default_block_nnz, default_max_segments else: block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) - if solve_kind == "transpose_cw": - vec_real = _triton_spsv_csr_transpose_cw_vector - vec_complex = _triton_spsv_csr_transpose_cw_vector_complex - else: - vec_real = _triton_spsv_csr_transpose_push_vector - vec_complex = _triton_spsv_csr_transpose_push_vector_complex + vec_real = _triton_spsv_csr_transpose_cw_vector + vec_complex = _triton_spsv_csr_transpose_cw_vector_complex elif solve_kind == "csr_cw": - if block_nnz is None and max_segments is None: - block_nnz_use, max_segments_use = default_block_nnz, default_max_segments - else: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) - vec_real = _triton_spsv_csr_cw_vector - vec_complex = _triton_spsv_csr_cw_vector_complex + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + nontrans_real_wrappers = { + "csr_u_lo_cw": _triton_spsv_csr_u_lo_cw_vector, + "csr_n_lo_cw": _triton_spsv_csr_n_lo_cw_vector, + "csr_u_up_cw": _triton_spsv_csr_u_up_cw_vector, + "csr_n_up_cw": _triton_spsv_csr_n_up_cw_vector, + } + nontrans_complex_wrappers = { + "csr_u_lo_cw": _triton_spsv_csr_u_lo_cw_vector_complex, + "csr_n_lo_cw": _triton_spsv_csr_n_lo_cw_vector_complex, + "csr_u_up_cw": _triton_spsv_csr_u_up_cw_vector_complex, + "csr_n_up_cw": _triton_spsv_csr_n_up_cw_vector_complex, + } + vec_real = nontrans_real_wrappers[nontrans_variant] + vec_complex = nontrans_complex_wrappers[nontrans_variant] + elif solve_kind == "csr_roc": + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + vec_real = _triton_spsv_csr_n_lo_roc_vector + vec_complex = _triton_spsv_csr_n_lo_roc_vector_complex + elif solve_kind == "csr_cw_levelschd": + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + vec_real = _triton_spsv_csr_n_lo_cw_levelschd_vector + vec_complex = _triton_spsv_csr_n_lo_cw_levelschd_vector_complex + elif solve_kind == "csr_nnz_balance": + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + vec_real = _triton_spsv_csr_n_lo_nnz_balance_vector + vec_complex = _triton_spsv_csr_n_lo_nnz_balance_vector_complex else: - if block_nnz is None and max_segments is None: - block_nnz_use, max_segments_use = default_block_nnz, default_max_segments - else: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) - vec_real = _triton_spsv_csr_vector - vec_complex = _triton_spsv_csr_vector_complex + raise RuntimeError(f"unexpected SpSV solve kind: {solve_kind}") diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) if return_time: torch.cuda.synchronize() t0 = time.perf_counter() - transpose_diag_in = transpose_diag - if transpose_diag is not None and compute_dtype != data.dtype: - transpose_diag_in = transpose_diag.to(compute_dtype) - cw_diag_in = cw_diag - if cw_diag is not None and compute_dtype != data.dtype: - cw_diag_in = cw_diag.to(compute_dtype) worker_count_use = cw_worker_count - rhs_cols = 1 if b_in.ndim == 1 else int(b_in.shape[1]) matrix_stats_use = dict(matrix_stats) if solve_kind in ("csr_cw", "transpose_cw"): worker_count_use = _resolve_cw_worker_count( n_rows, matrix_stats_use, - rhs_cols, + 1, cached_worker_count=cw_worker_count, ) complex_kernel_data_ri = None - complex_transpose_diag_ri = None - complex_cw_diag_ri = None if torch.is_complex(data_in): if compute_dtype == solve_plan["kernel_data"].dtype: complex_kernel_data_ri = solve_plan.get("kernel_data_ri") - complex_transpose_diag_ri = solve_plan.get("transpose_diag_ri") - complex_cw_diag_ri = solve_plan.get("cw_diag_ri") if complex_kernel_data_ri is None: complex_kernel_data_ri = _complex_interleaved_view(data_in) - if transpose_diag_in is not None and complex_transpose_diag_ri is None: - complex_transpose_diag_ri = _complex_interleaved_view(transpose_diag_in) - if cw_diag_in is not None and complex_cw_diag_ri is None: - complex_cw_diag_ri = _complex_interleaved_view(cw_diag_in) - if b_in.ndim == 1: + workspace_buffers = _resolve_spsv_workspace( + workspace, + _build_spsv_workspace_layout(n_rows, solve_kind, value_dtype=compute_dtype), + b.device, + ) + ready_buf = workspace_buffers.get("ready") + tmp_sum_buf = workspace_buffers.get("tmp_sum") + residual_buf = workspace_buffers.get("residual") + indegree_buf = workspace_buffers.get("indegree") + row_counter_buf = workspace_buffers.get("row_counter") + transpose_preprocessed = False + if solve_kind == "csr_nnz_balance": + if tmp_sum_buf is None or ready_buf is None or indegree_buf is None: + raise RuntimeError("csr_nnz_balance workspace is missing required buffers") + tmp_sum_buf.zero_() + ready_buf.zero_() + indegree_buf.copy_(nnz_balance_indegree32) + if solve_kind == "transpose_cw": + transpose_sig = _transpose_cw_preprocess_signature( + solve_plan, + n_rows, + unit_diagonal, + block_nnz_use, + max_segments_use, + ) + if isinstance(workspace, FlagSparseSpSVWorkspace): + transpose_preprocessed = ( + workspace.prepared_solve_kind == "transpose_cw" + and workspace.prepared_signature == transpose_sig + ) + if not transpose_preprocessed: + _run_spsv_csc_preprocess( + kernel_indices, + kernel_indptr, + indegree_buf, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + transpose_preprocessed = True + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "transpose_cw" + workspace.prepared_signature = transpose_sig + stream_ctx = ( + torch.cuda.stream(solve_stream) + if solve_stream is not None + else nullcontext() + ) + with stream_ctx: if torch.is_complex(data_in): - if solve_kind == "transpose_push": + if vec_complex is None: + raise ValueError(f"solve_kind={solve_kind!r} currently supports real dtypes only") + if solve_kind == "transpose_cw": x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + conjugate=(trans_mode == "C"), + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + data_ri_in=complex_kernel_data_ri, + residual_in=residual_buf, + indegree_in=indegree_buf, + row_counter_in=row_counter_buf, + preprocessed=transpose_preprocessed, + ) + else: + if solve_kind == "csr_roc": + x = vec_complex( data_in, kernel_indices, kernel_indptr, + level_row_map32, b_in, n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - conjugate=transpose_conjugate, - block_nnz=block_nnz, - max_segments=max_segments, diag_eps=diag_eps, - launch_groups=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, data_ri_in=complex_kernel_data_ri, - ) - elif solve_kind == "transpose_cw": - x = vec_complex( + ready_in=ready_buf, + ) + elif solve_kind == "csr_cw_levelschd": + x = vec_complex( data_in, kernel_indices, kernel_indptr, - transpose_diag_in, - transpose_indegree_init, + level_row_map32, b_in, n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - conjugate=transpose_conjugate, - block_nnz=block_nnz, - max_segments=max_segments, diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, data_ri_in=complex_kernel_data_ri, - diag_ri_in=complex_transpose_diag_ri, - ) - elif solve_kind == "csr_cw": - x = vec_complex( + ready_in=ready_buf, + ) + elif solve_kind == "csr_nnz_balance": + x = vec_complex( data_in, kernel_indices, - kernel_indptr, - cw_diag_in, + nnz_balance_row_idx32, + nnz_balance_indegree32, b_in, n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, data_ri_in=complex_kernel_data_ri, - diag_ri_in=complex_cw_diag_ri, - ) - else: - x = vec_complex( + tmp_sum_in=tmp_sum_buf, + ready_in=ready_buf, + indegree_in=indegree_buf, + ) + else: + x = vec_complex( data_in, kernel_indices, kernel_indptr, b_in, n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, diag_eps=diag_eps, - levels=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, data_ri_in=complex_kernel_data_ri, - ) + ready_in=ready_buf, + row_counter_in=row_counter_buf, + ) else: - if solve_kind == "transpose_push": + if solve_kind == "transpose_cw": x = vec_real( - data_in, - kernel_indices, - kernel_indptr, - b_in, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - launch_groups=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + residual_in=residual_buf, + indegree_in=indegree_buf, + row_counter_in=row_counter_buf, + preprocessed=transpose_preprocessed, ) - elif solve_kind == "transpose_cw": + elif solve_kind == "csr_roc": x = vec_real( - data_in, - kernel_indices, - kernel_indptr, - transpose_diag_in, - transpose_indegree_init, - b_in, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, + data_in, + kernel_indices, + kernel_indptr, + level_row_map32, + b_in, + n_rows, + diag_eps=diag_eps, + ready_in=ready_buf, ) - elif solve_kind == "csr_cw": + elif solve_kind == "csr_cw_levelschd": x = vec_real( - data_in, - kernel_indices, - kernel_indptr, - cw_diag_in, - b_in, - n_rows, - lower=lower_eff, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, + data_in, + kernel_indices, + kernel_indptr, + level_row_map32, + b_in, + n_rows, + diag_eps=diag_eps, + ready_in=ready_buf, ) - else: + elif solve_kind == "csr_nnz_balance": x = vec_real( - data_in, - kernel_indices, - kernel_indptr, - b_in, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + data_in, + kernel_indices, + nnz_balance_row_idx32, + nnz_balance_indegree32, + b_in, + n_rows, + diag_eps=diag_eps, + tmp_sum_in=tmp_sum_buf, + ready_in=ready_buf, + indegree_in=indegree_buf, ) - else: - b_cols = b_in if b_in.is_contiguous() else b_in.contiguous() - cols = [] - for bj in torch.unbind(b_cols, dim=1): - if torch.is_complex(data_in): - if solve_kind == "transpose_push": - cols.append( - vec_complex( - data_in, - kernel_indices, - kernel_indptr, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - conjugate=transpose_conjugate, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - launch_groups=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - data_ri_in=complex_kernel_data_ri, - ) - ) - elif solve_kind == "transpose_cw": - cols.append( - vec_complex( - data_in, - kernel_indices, - kernel_indptr, - transpose_diag_in, - transpose_indegree_init, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - conjugate=transpose_conjugate, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, - data_ri_in=complex_kernel_data_ri, - diag_ri_in=complex_transpose_diag_ri, - ) - ) - elif solve_kind == "csr_cw": - cols.append( - vec_complex( - data_in, - kernel_indices, - kernel_indptr, - cw_diag_in, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, - data_ri_in=complex_kernel_data_ri, - diag_ri_in=complex_cw_diag_ri, - ) - ) - else: - cols.append( - vec_complex( - data_in, - kernel_indices, - kernel_indptr, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - data_ri_in=complex_kernel_data_ri, - ) - ) else: - if solve_kind == "transpose_push": - cols.append( - vec_real( - data_in, - kernel_indices, - kernel_indptr, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - launch_groups=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - ) - ) - elif solve_kind == "transpose_cw": - cols.append( - vec_real( - data_in, - kernel_indices, - kernel_indptr, - transpose_diag_in, - transpose_indegree_init, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, - ) - ) - elif solve_kind == "csr_cw": - cols.append( - vec_real( - data_in, - kernel_indices, - kernel_indptr, - cw_diag_in, - bj, - n_rows, - lower=lower_eff, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - worker_count=worker_count_use, - matrix_stats=matrix_stats_use, - ) - ) - else: - cols.append( - vec_real( - data_in, - kernel_indices, - kernel_indptr, - bj, - n_rows, - lower=lower_eff, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=launch_groups, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - ) - ) - x = torch.stack(cols, dim=1) + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + diag_eps=diag_eps, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ready_in=ready_buf, + row_counter_in=row_counter_buf, + ) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: x = x.to(target_dtype) @@ -2632,6 +3652,439 @@ def flagsparse_spsv_csr( return x +def flagsparse_spsv_solve_csr( + descr, + b, + *, + alpha=1, + compute_dtype=None, + block_nnz=None, + max_segments=None, + out=None, + return_time=False, + workspace=None, + handle=None, + stream=None, +): + """Solve a previously analyzed CSR SpSV problem.""" + + if not isinstance(descr, FlagSparseSpSVDescr): + raise TypeError("descr must be a FlagSparseSpSVDescr") + if descr.canonical_format != "csr": + raise ValueError("descr must reference a CSR-canonicalized SpSV analysis") + if not torch.is_tensor(b): + raise TypeError("b must be a torch.Tensor") + if not b.is_cuda: + raise ValueError("b must be a CUDA tensor") + if b.ndim != 1: + raise ValueError("b must be a 1D dense vector (DnVec)") + if int(b.numel()) != int(descr.shape[0]): + raise ValueError(f"b length must equal n_rows={descr.shape[0]}") + if b.dtype != descr.value_dtype: + raise TypeError("b dtype must match the analyzed matrix dtype") + return _execute_spsv_csr_plan( + descr.data, + b.contiguous(), + descr.solve_plan, + descr.transpose_mode, + int(descr.shape[0]), + alpha=alpha, + unit_diagonal=descr.unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + workspace=workspace, + original_output_dtype=descr.value_dtype, + compute_dtype=compute_dtype if compute_dtype is not None else descr.compute_dtype, + handle=handle, + stream=stream, + ) + + +def flagsparse_spsv_solve_coo( + descr, + b, + *, + alpha=1, + compute_dtype=None, + block_nnz=None, + max_segments=None, + out=None, + return_time=False, + workspace=None, + handle=None, + stream=None, +): + """Solve a previously analyzed COO SpSV problem via its CSR canonical form.""" + + return flagsparse_spsv_solve_csr( + descr, + b, + alpha=alpha, + compute_dtype=compute_dtype, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + workspace=workspace, + handle=handle, + stream=stream, + ) + + +def _materialize_spsv_workspace_state(descr, workspace=None): + if not isinstance(descr, FlagSparseSpSVDescr): + raise TypeError("descr must be a FlagSparseSpSVDescr") + buffers = _resolve_spsv_workspace( + workspace, descr.workspace_layout, descr.data.device + ) + solve_kind = descr.solve_kind + preprocess_sig = None + if solve_kind == "csr_cw": + ready = buffers.get("ready") + row_counter = buffers.get("row_counter") + if ready is not None: + ready.zero_() + if row_counter is not None: + row_counter.zero_() + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "" + workspace.prepared_signature = None + elif solve_kind == "csr_roc": + ready = buffers.get("ready") + if ready is not None: + ready.zero_() + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "" + workspace.prepared_signature = None + elif solve_kind == "csr_cw_levelschd": + ready = buffers.get("ready") + if ready is not None: + ready.zero_() + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "" + workspace.prepared_signature = None + elif solve_kind == "csr_nnz_balance": + tmp_sum = buffers.get("tmp_sum") + ready = buffers.get("ready") + indegree = buffers.get("indegree") + if tmp_sum is not None: + tmp_sum.zero_() + if ready is not None: + ready.zero_() + if indegree is not None: + indegree.copy_(descr.solve_plan["nnz_balance_indegree32"]) + preprocess_sig = ( + "csr_nnz_balance", + _tensor_cache_token(descr.solve_plan["nnz_balance_indegree32"]), + ) + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "csr_nnz_balance" + workspace.prepared_signature = preprocess_sig + elif solve_kind == "transpose_cw": + residual = buffers.get("residual") + indegree = buffers.get("indegree") + row_counter = buffers.get("row_counter") + block_nnz_use = int(descr.solve_plan["default_block_nnz"]) + max_segments_use = int(descr.solve_plan["default_max_segments"]) + preprocess_sig = _transpose_cw_preprocess_signature( + descr.solve_plan, + int(descr.shape[0]), + bool(descr.unit_diagonal), + block_nnz_use, + max_segments_use, + ) + if residual is not None: + residual.zero_() + if indegree is not None: + _run_spsv_csc_preprocess( + descr.solve_plan["kernel_indices32"], + descr.solve_plan["kernel_indptr64"], + indegree, + int(descr.shape[0]), + lower=bool(descr.solve_plan["lower_eff"]), + unit_diagonal=bool(descr.unit_diagonal), + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + if row_counter is not None: + row_counter.zero_() + if isinstance(workspace, FlagSparseSpSVWorkspace): + workspace.prepared_solve_kind = "transpose_cw" + workspace.prepared_signature = preprocess_sig + else: + raise RuntimeError(f"unexpected SpSV solve kind: {solve_kind}") + if workspace is None: + return FlagSparseSpSVWorkspace( + buffer_size=int(descr.buffer_size), + layout=tuple(descr.workspace_layout), + device=descr.data.device, + buffers=buffers, + prepared_solve_kind=( + "transpose_cw" + if solve_kind == "transpose_cw" + else ("csr_nnz_balance" if solve_kind == "csr_nnz_balance" else "") + ), + prepared_signature=preprocess_sig, + ) + return workspace + + +def flagsparse_spsv_preprocess_csr(descr, *, workspace=None): + """Materialize caller-managed workspace for a CSR SpSV descriptor.""" + + return _materialize_spsv_workspace_state(descr, workspace=workspace) + + +def flagsparse_spsv_preprocess_coo(descr, *, workspace=None): + """Materialize caller-managed workspace for a COO SpSV descriptor.""" + + return _materialize_spsv_workspace_state(descr, workspace=workspace) + + +def flagsparse_spsv_buffer_size_ex( + handle, + opA, + alpha, + matA, + vecX, + vecY=None, + *, + compute_dtype=None, + solve_kind=None, + storage_view="csr_as_csc", +): + if not isinstance(matA, FlagSparseSpMatDescr): + raise TypeError("matA must be a FlagSparseSpMatDescr") + if not isinstance(vecX, FlagSparseDnVecDescr): + raise TypeError("vecX must be a FlagSparseDnVecDescr") + return flagsparse_spsv_buffer_size( + matA.shape, + matA.values.dtype, + format=matA.format, + transpose=opA, + solve_kind=solve_kind, + compute_dtype=compute_dtype, + alpha=alpha, + handle=handle, + vecX=vecX, + vecY=vecY, + storage_view=storage_view, + ) + + +def flagsparse_spsv_analysis_ex( + handle, + opA, + alpha, + matA, + vecX, + vecY=None, + *, + compute_dtype=None, + solve_kind=None, + workspace=None, + storage_view="csr_as_csc", + clear_cache=False, +): + if not isinstance(matA, FlagSparseSpMatDescr): + raise TypeError("matA must be a FlagSparseSpMatDescr") + if not isinstance(vecX, FlagSparseDnVecDescr): + raise TypeError("vecX must be a FlagSparseDnVecDescr") + if matA.format == "csr": + return flagsparse_spsv_analysis_csr( + matA.values, + matA.indices, + matA.indptr_or_col, + matA.shape, + lower=matA.lower, + unit_diagonal=matA.unit_diagonal, + transpose=opA, + solve_kind=solve_kind, + compute_dtype=compute_dtype, + handle=handle, + workspace=workspace, + storage_view=storage_view, + clear_cache=clear_cache, + ) + if matA.format == "coo": + return flagsparse_spsv_analysis_coo( + matA.values, + matA.indices, + matA.indptr_or_col, + matA.shape, + lower=matA.lower, + unit_diagonal=matA.unit_diagonal, + transpose=opA, + solve_kind=solve_kind, + compute_dtype=compute_dtype, + handle=handle, + workspace=workspace, + storage_view=storage_view, + ) + raise ValueError("matA.format must be 'csr' or 'coo'") + + +def flagsparse_spsv_solve_ex( + handle, + opA, + alpha, + matA, + vecX, + vecY=None, + descr=None, + *, + compute_dtype=None, + solve_kind=None, + workspace=None, + stream=None, + storage_view="csr_as_csc", + block_nnz=None, + max_segments=None, + return_time=False, +): + if not isinstance(matA, FlagSparseSpMatDescr): + raise TypeError("matA must be a FlagSparseSpMatDescr") + if not isinstance(vecX, FlagSparseDnVecDescr): + raise TypeError("vecX must be a FlagSparseDnVecDescr") + out_tensor = None if vecY is None else vecY.values + if matA.format == "csr": + return flagsparse_spsv_csr( + matA.values, + matA.indices, + matA.indptr_or_col, + vecX.values, + matA.shape, + lower=matA.lower, + unit_diagonal=matA.unit_diagonal, + transpose=opA, + alpha=alpha, + compute_dtype=compute_dtype, + block_nnz=block_nnz, + max_segments=max_segments, + out=out_tensor, + return_time=return_time, + descr=descr, + workspace=workspace, + solve_kind=solve_kind, + handle=handle, + stream=stream, + storage_view=storage_view, + ) + if matA.format == "coo": + return flagsparse_spsv_coo( + matA.values, + matA.indices, + matA.indptr_or_col, + vecX.values, + matA.shape, + lower=matA.lower, + unit_diagonal=matA.unit_diagonal, + transpose=opA, + alpha=alpha, + compute_dtype=compute_dtype, + block_nnz=block_nnz, + max_segments=max_segments, + out=out_tensor, + return_time=return_time, + descr=descr, + workspace=workspace, + solve_kind=solve_kind, + handle=handle, + stream=stream, + storage_view=storage_view, + ) + raise ValueError("matA.format must be 'csr' or 'coo'") + + +def flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=True, + unit_diagonal=False, + transpose=False, + alpha=1, + compute_dtype=None, + block_nnz=None, + max_segments=None, + out=None, + return_time=False, + descr=None, + workspace=None, + solve_kind=None, + handle=None, + stream=None, + storage_view="csr_as_csc", +): + """Sparse triangular solve using Triton CSR CW kernels. + + Current support matrix: + - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices + - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices + """ + if descr is not None: + if not isinstance(descr, FlagSparseSpSVDescr): + raise TypeError("descr must be a FlagSparseSpSVDescr or None") + return flagsparse_spsv_solve_csr( + descr, + b, + alpha=alpha, + compute_dtype=compute_dtype, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + workspace=workspace, + handle=handle, + stream=stream, + ) + ( + data, + b, + original_output_dtype, + trans_mode, + n_rows, + _n_cols, + solve_plan, + ) = _resolve_spsv_csr_runtime( + data, + indices, + indptr, + b, + shape, + lower, + transpose, + unit_diagonal, + requested_solve_kind=solve_kind, + storage_view=storage_view, + ) + solve_plan = _select_spsv_runtime_plan( + solve_plan, trans_mode, requested_solve_kind=solve_kind + ) + return _execute_spsv_csr_plan( + data, + b, + solve_plan, + trans_mode, + n_rows, + alpha=alpha, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + workspace=workspace, + original_output_dtype=original_output_dtype, + compute_dtype=compute_dtype, + handle=handle, + stream=stream, + ) + + def _analyze_spsv_csr( data, indices, @@ -2641,6 +4094,7 @@ def _analyze_spsv_csr( lower=True, unit_diagonal=False, transpose=False, + solve_kind=None, clear_cache=False, return_time=False, ): @@ -2649,7 +4103,15 @@ def _analyze_spsv_csr( if return_time: torch.cuda.synchronize() t0 = time.perf_counter() - _resolve_spsv_csr_runtime( + ( + _data, + _b, + _original_output_dtype, + trans_mode, + _n_rows, + _n_cols, + solve_plan, + ) = _resolve_spsv_csr_runtime( data, indices, indptr, @@ -2658,6 +4120,10 @@ def _analyze_spsv_csr( lower, transpose, unit_diagonal, + requested_solve_kind=solve_kind, + ) + _select_spsv_runtime_plan( + solve_plan, trans_mode, requested_solve_kind=solve_kind ) if return_time: torch.cuda.synchronize() @@ -2673,12 +4139,36 @@ def flagsparse_spsv_coo( lower=True, unit_diagonal=False, transpose=False, + alpha=1, + compute_dtype=None, block_nnz=None, max_segments=None, out=None, return_time=False, + descr=None, + workspace=None, + solve_kind=None, + handle=None, + stream=None, + storage_view="csr_as_csc", ): """COO SpSV by canonicalizing COO into CSR, then reusing CSR SpSV.""" + if descr is not None: + if not isinstance(descr, FlagSparseSpSVDescr): + raise TypeError("descr must be a FlagSparseSpSVDescr or None") + return flagsparse_spsv_solve_coo( + descr, + b, + alpha=alpha, + compute_dtype=compute_dtype, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + workspace=workspace, + handle=handle, + stream=stream, + ) data, input_index_dtype, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape ) @@ -2690,8 +4180,8 @@ def flagsparse_spsv_coo( _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "COO") else: _validate_spsv_trans_combo(data.dtype, input_index_dtype, "COO") - data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( - data, row64, col64, n_rows, n_cols + data_csr, indices_csr, indptr_csr = _coo2csr_for_spsv( + data, row64, col64, n_rows, assume_ordered=False ) return flagsparse_spsv_csr( data_csr, @@ -2702,8 +4192,15 @@ def flagsparse_spsv_coo( lower=lower, unit_diagonal=unit_diagonal, transpose=transpose, + alpha=alpha, + compute_dtype=compute_dtype, block_nnz=block_nnz, max_segments=max_segments, out=out, return_time=return_time, + workspace=workspace, + solve_kind=solve_kind, + handle=handle, + stream=stream, + storage_view=storage_view, ) diff --git a/tests/pytest/test_spsv_coo_accuracy.py b/tests/pytest/test_spsv_coo_accuracy.py index 7ba683d..f8888ff 100644 --- a/tests/pytest/test_spsv_coo_accuracy.py +++ b/tests/pytest/test_spsv_coo_accuracy.py @@ -1,19 +1,28 @@ import pytest import torch -from flagsparse import flagsparse_spsv_coo +from flagsparse import ( + FlagSparseSpSVDescr, + flagsparse_spsv_analysis_coo, + flagsparse_spsv_coo, + flagsparse_spsv_create_workspace, + flagsparse_spsv_preprocess_coo, + flagsparse_spsv_solve_coo, +) import flagsparse.sparse_operations.spsv as fs_spsv_impl from tests.pytest.param_shapes import SPSV_N from tests.pytest.test_spsv_csr_accuracy import ( _apply_ref_op, _build_triangular, + _dense_ref_spsv, _dtype_id, _effective_upper, _rand_like, _tol, _transpose_arg, NON_TRANS_DTYPES, + SUPPORTED_COMPLEX_DTYPES, TRANS_CONJ_MODES, ) @@ -150,3 +159,444 @@ def _wrapped_flagsparse_spsv_csr(*args, **kwargs): rtol, atol = _tol(dtype) assert called["hit"] assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_to_csr_keeps_duplicates(): + device = torch.device("cuda") + data = torch.tensor([1.0, 2.0, 3.0], device=device) + row = torch.tensor([0, 0, 1], dtype=torch.int64, device=device) + col = torch.tensor([1, 1, 0], dtype=torch.int64, device=device) + + data_csr, indices_csr, indptr_csr = fs_spsv_impl._coo2csr_for_spsv( + data, row, col, 2, assume_ordered=False + ) + + assert data_csr.numel() == 3 + assert indices_csr.tolist() == [1, 1, 0] + assert indptr_csr.tolist() == [0, 2, 3] + + +@pytest.mark.spsv +def test_spsv_coo_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.complex64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + descr = flagsparse_spsv_analysis_coo( + data, + row.to(torch.int64), + col.to(torch.int64), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + assert isinstance(descr, FlagSparseSpSVDescr) + assert descr.format == "coo" + assert descr.canonical_format == "csr" + + workspace = flagsparse_spsv_create_workspace(descr) + x_via_descr = flagsparse_spsv_solve_coo(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_coo( + data, + row.to(torch.int64), + col.to(torch.int64), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_explicit_roc_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_roc", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_roc", "csr_cw_levelschd", "alg2"]) +def test_spsv_coo_explicit_complex_level_routes_match_dense(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_nnz_balance", "alg3"]) +def test_spsv_coo_explicit_complex_nnz_balance_routes_match_dense(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_explicit_levelschd_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_cw_levelschd", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_explicit_nnz_balance_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 96 + A = torch.tril(torch.randn(n, n, dtype=dtype, device=device) * 0.02) + A = A + torch.eye(n, dtype=dtype, device=device) * 3.0 + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_nnz_balance", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_roc_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + descr = flagsparse_spsv_analysis_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_roc", + ) + assert descr.solve_kind == "csr_roc" + assert descr.canonical_format == "csr" + workspace = flagsparse_spsv_preprocess_coo( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_coo(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_roc", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_levelschd_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + descr = flagsparse_spsv_analysis_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_cw_levelschd", + ) + assert descr.solve_kind == "csr_cw_levelschd" + assert descr.canonical_format == "csr" + workspace = flagsparse_spsv_preprocess_coo( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_coo(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_cw_levelschd", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_coo_nnz_balance_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 96 + A = torch.tril(torch.randn(n, n, dtype=dtype, device=device) * 0.02) + A = A + torch.eye(n, dtype=dtype, device=device) * 3.0 + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + descr = flagsparse_spsv_analysis_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_nnz_balance", + ) + assert descr.solve_kind == "csr_nnz_balance" + assert descr.canonical_format == "csr" + workspace = flagsparse_spsv_preprocess_coo( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_coo(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_nnz_balance", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_nnz_balance", "alg3"]) +def test_spsv_coo_complex_nnz_balance_analysis_workspace_matches_direct(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + descr = flagsparse_spsv_analysis_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + assert descr.solve_kind == "csr_nnz_balance" + assert descr.canonical_format == "csr" + workspace = flagsparse_spsv_preprocess_coo( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_coo(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("lower", [True, False], ids=["lower", "upper"]) +def test_spsv_coo_non_trans_unit_supported_combos(n, dtype, index_dtype, lower): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=lower) + b = _rand_like(dtype, (n,), device) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=lower, unit_diagonal=True) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(index_dtype), + col.to(index_dtype), + b, + (n, n), + lower=lower, + unit_diagonal=True, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("lower", [True, False], ids=["lower", "upper"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_coo_unit_transpose_family_supported_combos( + n, dtype, index_dtype, lower, op_mode +): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=lower) + b = _rand_like(dtype, (n,), device) + x_ref = _dense_ref_spsv( + A.to(dtype), + b.to(dtype), + lower=lower, + op_mode=op_mode, + unit_diagonal=True, + ) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(index_dtype), + col.to(index_dtype), + b, + (n, n), + lower=lower, + unit_diagonal=True, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index 4503825..a450e35 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -1,7 +1,26 @@ import pytest import torch -from flagsparse import flagsparse_spsv_csr +from flagsparse import ( + FlagSparseDnVecDescr, + FlagSparseSpMatDescr, + FlagSparseSpSVDescr, + FlagSparseSpSVHandle, + FlagSparseSpSVWorkspace, + flagsparse_create_dnvec, + flagsparse_create_spmat_csr, + flagsparse_create_spsv_handle, + flagsparse_spsv_analysis_csr, + flagsparse_spsv_analysis_ex, + flagsparse_spsv_buffer_size, + flagsparse_spsv_buffer_size_ex, + flagsparse_spsv_create_workspace, + flagsparse_spsv_csr, + flagsparse_spsv_preprocess_csr, + flagsparse_spsv_solve_ex, + flagsparse_spsv_solve_csr, +) +import flagsparse.sparse_operations.spsv as fs_spsv_impl from tests.pytest.param_shapes import SPSV_N @@ -64,6 +83,17 @@ def _transpose_arg(op_mode): return op_mode +def _dense_ref_spsv(A, b, *, lower, op_mode="NON", unit_diagonal=False): + A_eff = _apply_ref_op(A, op_mode) + x = torch.linalg.solve_triangular( + A_eff, + b.unsqueeze(-1), + upper=_effective_upper(lower, op_mode), + unitriangular=unit_diagonal, + ) + return x.squeeze(-1) + + def _cupy_apply_op(A_cp, op_mode): if op_mode == "TRANS": return A_cp.transpose().tocsr() @@ -168,6 +198,877 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) +@pytest.mark.spsv +def test_spsv_csr_rejects_matrix_rhs(): + device = torch.device("cuda") + dtype = torch.float32 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + b = torch.randn(n, 2, dtype=dtype, device=device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + with pytest.raises(ValueError, match="DnVec"): + flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("lower", [True, False], ids=["lower", "upper"]) +def test_spsv_csr_non_trans_unit_supported_combos(n, dtype, index_dtype, lower): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=lower) + b = _rand_like(dtype, (n,), device) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=lower, unit_diagonal=True) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=lower, + unit_diagonal=True, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("lower", [True, False], ids=["lower", "upper"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_unit_transpose_family_supported_combos( + n, dtype, index_dtype, lower, op_mode +): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=lower) + b = _rand_like(dtype, (n,), device) + x_ref = _dense_ref_spsv( + A.to(dtype), + b.to(dtype), + lower=lower, + op_mode=op_mode, + unit_diagonal=True, + ) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=lower, + unit_diagonal=True, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_complex_non_trans_defaults_to_cw_route(): + device = torch.device("cuda") + n = SPSV_N[0] + dtype = torch.complex64 + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + b = _rand_like(dtype, (n,), device) + + _, _, _, trans_mode, _, _, solve_plan = fs_spsv_impl._resolve_spsv_csr_runtime( + data, indices, indptr, b, (n, n), True, False, False + ) + selected = fs_spsv_impl._select_spsv_runtime_plan(solve_plan, trans_mode) + assert selected["solve_kind"] == "csr_cw" + + +@pytest.mark.spsv +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_family_defaults_to_cw_route(op_mode): + device = torch.device("cuda") + n = SPSV_N[0] + dtype = torch.complex128 + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + b = _rand_like(dtype, (n,), device) + + _, _, _, trans_mode, _, _, solve_plan = fs_spsv_impl._resolve_spsv_csr_runtime( + data, indices, indptr, b, (n, n), True, _transpose_arg(op_mode), False + ) + selected = fs_spsv_impl._select_spsv_runtime_plan(solve_plan, trans_mode) + assert selected["solve_kind"] == "transpose_cw" + + +@pytest.mark.spsv +def test_spsv_internal_cw_worker_count_limits_narrow_frontier(): + matrix_stats = { + "max_frontier": 3, + "avg_frontier": 2.5, + "frontier_ratio": 0.006, + "num_levels": 4096, + "avg_nnz_per_row": 4096.0, + } + worker_count = fs_spsv_impl._resolve_cw_worker_count(4096, matrix_stats, 1) + assert worker_count <= 4 + + +@pytest.mark.spsv +def test_spsv_auto_route_promotes_dense_real_lower_to_nnz_balance(): + matrix_stats = { + "num_levels": 256, + "max_frontier": 1, + "avg_frontier": 1.0, + "frontier_ratio": 1.0 / 256.0, + "avg_nnz_per_row": 128.0, + "max_nnz_per_row": 256, + } + route = fs_spsv_impl._choose_spsv_nontrans_auto_route( + 256, + matrix_stats, + lower=True, + unit_diagonal=False, + value_dtype=torch.float64, + ) + assert route == "csr_nnz_balance" + + +@pytest.mark.spsv +def test_spsv_levelschd_analysis_builds_sorted_row_map(): + device = torch.device("cuda") + data = torch.tensor([4.0, 1.0, 5.0, 2.0, 6.0, 3.0, 4.0, 7.0], dtype=torch.float64, device=device) + indices = torch.tensor([0, 0, 1, 0, 2, 1, 2, 3], dtype=torch.int64, device=device) + indptr = torch.tensor([0, 1, 3, 5, 8], dtype=torch.int64, device=device) + meta = fs_spsv_impl._build_spsv_level_schedule_metadata( + indices, + indptr, + 4, + lower=True, + unit_diagonal=False, + ) + assert meta["row_map32"].tolist() == [0, 1, 2, 3] + assert meta["level_ptr32"].tolist() == [0, 1, 3, 4] + assert meta["indegree_init32"].tolist() == [1, 2, 2, 3] + assert meta["matrix_stats"]["num_levels"] == 3 + assert meta["matrix_stats"]["max_frontier"] == 2 + + +@pytest.mark.spsv +def test_spsv_nnz_balance_analysis_builds_row_idx_and_indegree(): + device = torch.device("cuda") + indices = torch.tensor([0, 0, 1, 0, 2, 1, 2, 3], dtype=torch.int64, device=device) + indptr = torch.tensor([0, 1, 3, 5, 8], dtype=torch.int64, device=device) + meta = fs_spsv_impl._build_spsv_nnz_balance_metadata( + indices, + indptr, + 4, + lower=True, + unit_diagonal=False, + ) + assert meta["indegree_init32"].tolist() == [1, 2, 2, 3] + assert meta["csr_row_idx32"].tolist() == [0, 1, 1, 2, 2, 3, 3, 3] + + +@pytest.mark.spsv +def test_spsv_auto_route_promotes_wide_frontier_real_lower_to_levelschd(): + matrix_stats = { + "num_levels": 48, + "max_frontier": 64, + "avg_frontier": 12.0, + "frontier_ratio": 0.125, + "avg_nnz_per_row": 12.0, + "max_nnz_per_row": 48, + } + route = fs_spsv_impl._choose_spsv_nontrans_auto_route( + 512, + matrix_stats, + lower=True, + unit_diagonal=False, + value_dtype=torch.float32, + ) + assert route == "csr_cw_levelschd" + + +@pytest.mark.spsv +def test_spsv_csr_transpose_descriptor_keeps_preprocess_metadata(): + device = torch.device("cuda") + dtype = torch.float64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + descr = flagsparse_spsv_analysis_csr( + data, + indices, + indptr, + (n, n), + lower=True, + unit_diagonal=False, + transpose="TRANS", + ) + assert descr.solve_kind == "transpose_cw" + assert descr.storage_view == "csr_as_csc" + assert descr.solve_plan.get("transpose_indegree_init") is None + assert descr.solve_plan.get("transpose_diag") is None + assert descr.solve_plan.get("transpose_dep_start") is None + assert descr.solve_plan.get("transpose_dep_end") is None + + +@pytest.mark.spsv +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_public_solve_uses_transpose_kernel(monkeypatch, op_mode): + device = torch.device("cuda") + dtype = torch.complex128 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + b = _rand_like(dtype, (n,), device) + + called = {"transpose_complex": False} + real_impl = fs_spsv_impl._triton_spsv_csr_transpose_cw_vector_complex + + def _wrapped(*args, **kwargs): + called["transpose_complex"] = True + return real_impl(*args, **kwargs) + + monkeypatch.setattr( + fs_spsv_impl, + "_triton_spsv_csr_transpose_cw_vector_complex", + _wrapped, + ) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_ref = _dense_ref_spsv( + A.to(dtype), + b.to(dtype), + lower=True, + op_mode=op_mode, + unit_diagonal=False, + ) + rtol, atol = _tol(dtype) + assert called["transpose_complex"] + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_explicit_roc_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + x = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_roc", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_roc", "csr_cw_levelschd", "alg2"]) +def test_spsv_csr_explicit_complex_level_routes_match_dense(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + x = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind=solve_kind, + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_nnz_balance", "alg3"]) +def test_spsv_csr_explicit_complex_nnz_balance_routes_match_dense(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + x = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind=solve_kind, + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_explicit_roc_analysis_builds_only_level_metadata(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_roc", + ) + assert descr.solve_kind == "csr_roc" + assert int(descr.solve_plan["level_row_map32"].numel()) == n + assert int(descr.solve_plan["nnz_balance_row_idx32"].numel()) == 0 + assert int(descr.solve_plan["nnz_balance_indegree32"].numel()) == 0 + + +@pytest.mark.spsv +def test_spsv_csr_explicit_levelschd_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + x = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_cw_levelschd", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_explicit_levelschd_analysis_builds_only_level_metadata(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_cw_levelschd", + ) + assert descr.solve_kind == "csr_cw_levelschd" + assert int(descr.solve_plan["level_row_map32"].numel()) == n + assert int(descr.solve_plan["nnz_balance_row_idx32"].numel()) == 0 + assert int(descr.solve_plan["nnz_balance_indegree32"].numel()) == 0 + + +@pytest.mark.spsv +def test_spsv_csr_explicit_nnz_balance_route_matches_dense(): + device = torch.device("cuda") + dtype = torch.float64 + n = 96 + A = torch.tril(torch.randn(n, n, dtype=dtype, device=device) * 0.02) + A = A + torch.eye(n, dtype=dtype, device=device) * 3.0 + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + x = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_nnz_balance", + ) + x_ref = _dense_ref_spsv(A.to(dtype), b.to(dtype), lower=True, unit_diagonal=False) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_explicit_nnz_balance_analysis_builds_only_nnz_metadata(): + device = torch.device("cuda") + dtype = torch.float64 + n = 96 + A = torch.tril(torch.randn(n, n, dtype=dtype, device=device) * 0.02) + A = A + torch.eye(n, dtype=dtype, device=device) * 3.0 + Asp = A.to_sparse_csr() + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_nnz_balance", + ) + assert descr.solve_kind == "csr_nnz_balance" + assert int(descr.solve_plan["level_row_map32"].numel()) == 0 + assert int(descr.solve_plan["nnz_balance_row_idx32"].numel()) == int(Asp.values().numel()) + assert int(descr.solve_plan["nnz_balance_indegree32"].numel()) == n + + +@pytest.mark.spsv +def test_spsv_csr_roc_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_roc", + ) + assert descr.solve_kind == "csr_roc" + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_roc", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_roc", "csr_cw_levelschd", "alg2"]) +def test_spsv_csr_complex_level_route_analysis_workspace_matches_direct(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind=solve_kind, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("dtype", SUPPORTED_COMPLEX_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("solve_kind", ["csr_nnz_balance", "alg3"]) +def test_spsv_csr_complex_nnz_balance_analysis_workspace_matches_direct(dtype, solve_kind): + device = torch.device("cuda") + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind=solve_kind, + ) + assert descr.solve_kind == "csr_nnz_balance" + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind=solve_kind, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_levelschd_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 64 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_cw_levelschd", + ) + assert descr.solve_kind == "csr_cw_levelschd" + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_cw_levelschd", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_nnz_balance_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = 96 + A = torch.tril(torch.randn(n, n, dtype=dtype, device=device) * 0.02) + A = A + torch.eye(n, dtype=dtype, device=device) * 3.0 + b = _rand_like(dtype, (n,), device) + Asp = A.to_sparse_csr() + + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + solve_kind="csr_nnz_balance", + ) + assert descr.solve_kind == "csr_nnz_balance" + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + solve_kind="csr_nnz_balance", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_analysis_workspace_solve_matches_direct(): + device = torch.device("cuda") + dtype = torch.float64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + b = _rand_like(dtype, (n,), device) + + descr = flagsparse_spsv_analysis_csr( + data, + indices, + indptr, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + assert isinstance(descr, FlagSparseSpSVDescr) + assert descr.solve_kind == "csr_cw" + assert descr.buffer_size == flagsparse_spsv_buffer_size((n, n), dtype, format="csr") + + workspace = flagsparse_spsv_create_workspace(descr) + assert isinstance(workspace, FlagSparseSpSVWorkspace) + assert workspace.buffer_size == descr.buffer_size + + x_direct = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False + ) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_analysis_workspace_route(op_mode): + device = torch.device("cuda") + dtype = torch.complex128 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + b = _rand_like(dtype, (n,), device) + + descr = flagsparse_spsv_analysis_csr( + data, + indices, + indptr, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + assert descr.solve_kind == "transpose_cw" + assert descr.route_name == "transpose_cw" + assert descr.solve_plan.get("transpose_lower") is True + layout_names = [entry["name"] for entry in descr.workspace_layout] + assert layout_names == ["residual", "indegree", "row_counter"] + + workspace = flagsparse_spsv_create_workspace(descr) + x_via_descr = flagsparse_spsv_solve_csr(descr, b, workspace=workspace) + x_direct = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x_via_descr, x_direct, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +def test_spsv_csr_descriptor_exposes_cuda_style_fields(): + device = torch.device("cuda") + dtype = torch.float64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + assert descr.fill_mode == "lower" + assert descr.diag_type == "non_unit" + assert descr.matrix_type == "triangular" + assert descr.index_base == 0 + assert descr.storage_view == "csr" + + +@pytest.mark.spsv +def test_spsv_csr_preprocess_initializes_workspace(): + device = torch.device("cuda") + dtype = torch.float64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + descr = flagsparse_spsv_analysis_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + transpose="TRANS", + solve_kind="transpose_cw", + ) + workspace = flagsparse_spsv_create_workspace(descr) + workspace = flagsparse_spsv_preprocess_csr(descr, workspace=workspace) + assert isinstance(workspace, FlagSparseSpSVWorkspace) + indegree_expected = torch.zeros(n, dtype=torch.int32, device=device) + row_ids = torch.repeat_interleave( + torch.arange(n, device=device, dtype=torch.int64), + Asp.crow_indices().to(torch.int64)[1:] - Asp.crow_indices().to(torch.int64)[:-1], + ) + mask = Asp.col_indices().to(torch.int64) <= row_ids + if bool(torch.any(mask).item()): + counts = torch.bincount( + Asp.col_indices().to(torch.int64)[mask], minlength=n + ).to(torch.int32) + indegree_expected.copy_(counts) + assert torch.equal(workspace.buffers["indegree"], indegree_expected) + + +@pytest.mark.spsv +def test_spsv_ex_interfaces_match_direct_route(): + device = torch.device("cuda") + dtype = torch.float64 + n = SPSV_N[0] + A = _build_triangular(n, dtype, device, lower=True) + Asp = A.to_sparse_csr() + b = _rand_like(dtype, (n,), device) + + handle = flagsparse_create_spsv_handle(device=device) + mat = flagsparse_create_spmat_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + (n, n), + lower=True, + unit_diagonal=False, + ) + vec = flagsparse_create_dnvec(b) + assert isinstance(handle, FlagSparseSpSVHandle) + assert isinstance(mat, FlagSparseSpMatDescr) + assert isinstance(vec, FlagSparseDnVecDescr) + + descr = flagsparse_spsv_analysis_ex( + handle, + False, + 1, + mat, + vec, + compute_dtype=torch.float64, + ) + workspace = flagsparse_spsv_preprocess_csr( + descr, workspace=flagsparse_spsv_create_workspace(descr) + ) + x_ex = flagsparse_spsv_solve_ex(handle, False, 1, mat, vec, descr=descr, workspace=workspace) + x_direct = flagsparse_spsv_csr( + Asp.values(), + Asp.col_indices().to(torch.int32), + Asp.crow_indices().to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + ) + assert flagsparse_spsv_buffer_size_ex(handle, False, 1, mat, vec) == descr.buffer_size + rtol, atol = _tol(dtype) + assert torch.allclose(x_ex, x_direct, rtol=rtol, atol=atol) + + @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) @pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) diff --git a/tests/test_spsm.py b/tests/test_spsm.py index 78652f7..8792d1f 100644 --- a/tests/test_spsm.py +++ b/tests/test_spsm.py @@ -19,12 +19,12 @@ try: import cupy as cp + import cupyx.cusparse as cpx_cusparse import cupyx.scipy.sparse as cpx_sparse - from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular except Exception: cp = None + cpx_cusparse = None cpx_sparse = None - cpx_spsolve_triangular = None FORMATS = ("csr", "coo") @@ -32,8 +32,8 @@ INDEX_DTYPES = [torch.int32] CSV_VALUE_DTYPES = [torch.float32, torch.float64] CSV_INDEX_DTYPES = [torch.int32] -WARMUP = 5 -ITERS = 20 +WARMUP = 1 +ITERS = 1 SPSM_OP_MODES = ["NON", "NON_TRANS"] @@ -72,6 +72,11 @@ def _sum_ms(*values): return sum(values) +def _spsm_benchmark_schedule(nnz, n_rhs, value_dtype, fmt="csr"): + del nnz, n_rhs, value_dtype, fmt + return int(WARMUP), int(ITERS) + + def _csv_export_row_spsm(row): return { "matrix": row.get("matrix"), @@ -180,8 +185,8 @@ def _benchmark_pytorch_reference(data, indices, indptr, shape, B): return None, None, "unavailable", f"PyTorch sparse solve unavailable ({exc})" -def _benchmark_cusparse_reference(data, row, col, indptr, B, shape, fmt): - if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: +def _benchmark_cusparse_reference(data, row, col, indptr, B, shape, fmt, warmup, iters): + if cp is None or cpx_sparse is None or cpx_cusparse is None: return None, None, "cusparse unavailable" try: data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) @@ -194,17 +199,18 @@ def _benchmark_cusparse_reference(data, row, col, indptr, B, shape, fmt): idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.contiguous())) ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.contiguous())) A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) - for _ in range(WARMUP): - _ = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + A_cp.sum_duplicates() + for _ in range(warmup): + _ = cpx_cusparse.spsm(A_cp, B_cp, lower=True, unit_diag=False, transa=False) cp.cuda.runtime.deviceSynchronize() c0 = cp.cuda.Event() c1 = cp.cuda.Event() c0.record() - for _ in range(ITERS): - X_cp = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + for _ in range(iters): + X_cp = cpx_cusparse.spsm(A_cp, B_cp, lower=True, unit_diag=False, transa=False) c1.record() c1.synchronize() - ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + ms = cp.cuda.get_elapsed_time(c0, c1) / iters X_t = torch.utils.dlpack.from_dlpack(X_cp.toDlpack()).to(B.dtype) return X_t, ms, None except Exception as exc: @@ -227,22 +233,25 @@ def _solution_residual_metrics(data, indices, indptr, shape, X, B, value_dtype): return err, ok -def _benchmark_flagsparse(call): +def _benchmark_flagsparse(call, warmup, iters): X = None - for _ in range(WARMUP): + for _ in range(warmup): X = call() torch.cuda.synchronize() e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) e0.record() - for _ in range(ITERS): + for _ in range(iters): X = call() e1.record() torch.cuda.synchronize() - return X, e0.elapsed_time(e1) / ITERS + return X, e0.elapsed_time(e1) / iters def _benchmark_flagsparse_spsm_csr_split(data, indices, indptr, B, shape): + warmup, iters = _spsm_benchmark_schedule( + data.numel(), B.shape[1], data.dtype, fmt="csr" + ) analysis_ms = fs_spsm_impl._analyze_spsm_csr( data, indices, @@ -266,12 +275,17 @@ def _benchmark_flagsparse_spsm_csr_split(data, indices, indptr, B, shape): opA="NON_TRANS", opB="NON_TRANS", major="row", - ) + ), + warmup, + iters, ) return X, analysis_ms, solve_ms def _benchmark_flagsparse_spsm_coo_split(data, row, col, B, shape): + warmup, iters = _spsm_benchmark_schedule( + data.numel(), B.shape[1], data.dtype, fmt="coo" + ) analysis_ms = fs_spsm_impl._analyze_spsm_coo( data, row, @@ -295,7 +309,9 @@ def _benchmark_flagsparse_spsm_coo_split(data, row, col, B, shape): opA="NON_TRANS", opB="NON_TRANS", major="row", - ) + ), + warmup, + iters, ) return X, analysis_ms, solve_ms @@ -390,6 +406,9 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n B = torch.randn((n_rows, n_rhs), dtype=value_dtype, device=data.device).contiguous() atol, rtol = _tol(value_dtype) row, col = _csr_to_coo(indices, indptr, n_rows) + warmup, iters = _spsm_benchmark_schedule( + data.numel(), n_rhs, value_dtype, fmt=fmt + ) if fmt == "csr": X_fs, analysis_ms, solve_ms = _benchmark_flagsparse_spsm_csr_split( @@ -408,7 +427,7 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n shape, ) X_cu, cusparse_ms, _cusparse_reason = _benchmark_cusparse_reference( - data, row, col, indptr, B, shape, fmt + data, row, col, indptr, B, shape, fmt, warmup, iters ) X_pt, pytorch_ms, _pt_backend, pytorch_reason = _benchmark_pytorch_reference( data, indices, indptr, shape, B @@ -544,6 +563,10 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "cuSPARSE solves the full matrix RHS in one interface call)" ) print("=" * 176) + print( + f"Benchmark schedule: warmup={WARMUP}, iter={ITERS} " + "(Library-main style defaults; override with --warmup/--iters)" + ) print( "PT.total is the aggregated time of one torch.sparse.spsolve call per RHS column; " "CU.total is one official sparse triangular solve call on the full matrix RHS. " @@ -676,6 +699,7 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): def main(): + global WARMUP, ITERS parser = argparse.ArgumentParser( description="SpSM test: synthetic triangular systems and optional .mtx batch CSV." ) @@ -695,7 +719,25 @@ def main(): default="NON", help="comma-separated op(A) modes; currently only NON/NON_TRANS is supported", ) + parser.add_argument( + "--warmup", + type=int, + default=WARMUP, + help="Benchmark warmup iterations (Library-main style default: 1)", + ) + parser.add_argument( + "--iters", + type=int, + default=ITERS, + help="Benchmark timed iterations (Library-main style default: 1)", + ) args = parser.parse_args() + WARMUP = max(0, int(args.warmup)) + ITERS = max(1, int(args.iters)) + + ops = _parse_ops_filter(args.ops) + if any(op != "NON" for op in ops): + raise ValueError("SpSM test currently supports only --ops NON/NON_TRANS") ops = _parse_ops_filter(args.ops) if any(op != "NON" for op in ops): diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 1c84369..b7be7a8 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -30,8 +30,8 @@ VALUE_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] INDEX_DTYPES = [torch.int32, torch.int64] TEST_SIZES = [256, 512, 1024, 2048] -WARMUP = 5 -ITERS = 20 +WARMUP = 1 +ITERS = 1 SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) @@ -59,6 +59,12 @@ def _dtype_name(dtype): INDEX_DTYPE_NAME_MAP = { _dtype_name(dtype): dtype for dtype in CSR_FULL_INDEX_DTYPES } +SPSV_ALG_NUM_TO_SOLVE_KIND = { + 1: "csr_cw", + 2: "csr_cw_levelschd", + 3: "csr_nnz_balance", + 8: "csr_nnz_balance", +} def _parse_csv_tokens(raw): @@ -89,6 +95,41 @@ def _parse_op_modes_filter(raw): return tokens +def _parse_alg_num(raw): + value = int(raw) + if value not in SPSV_ALG_NUM_TO_SOLVE_KIND: + raise ValueError( + "unsupported alg_num: " + f"{value}. Supported values: {sorted(SPSV_ALG_NUM_TO_SOLVE_KIND)}" + ) + return value + + +def _solve_kind_from_alg_num(alg_num): + if alg_num is None: + return None + return SPSV_ALG_NUM_TO_SOLVE_KIND[int(alg_num)] + + +def _alg_label(alg_num): + return "AUTO" if alg_num is None else f"ALG{int(alg_num)}" + + +def _alg_num_supports_case(alg_num, fmt, op_mode, lower, value_dtype): + if alg_num is None: + return True + alg_num = int(alg_num) + if alg_num == 1: + return True + if alg_num in (2, 3, 8): + return ( + fmt in ("CSR", "COO") + and op_mode == "NON" + and bool(lower) + ) + return False + + def _fmt_ms(v): return "N/A" if v is None else f"{v:.4f}" @@ -119,6 +160,11 @@ def _sum_ms(*values): return sum(values) +def _spsv_benchmark_schedule(nnz, op_mode, value_dtype, fmt="CSR"): + del nnz, op_mode, value_dtype, fmt + return int(WARMUP), int(ITERS) + + def _status_str(ok_flag, has_value): if ok_flag: return "PASS" @@ -527,19 +573,19 @@ def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, return err_res, ok_res -def _benchmark_flagsparse(call): +def _benchmark_flagsparse(call, *, warmup=WARMUP, iters=ITERS): x = None - for _ in range(WARMUP): + for _ in range(warmup): x = call() torch.cuda.synchronize() e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) e0.record() - for _ in range(ITERS): + for _ in range(iters): x = call() e1.record() torch.cuda.synchronize() - return x, e0.elapsed_time(e1) / ITERS + return x, e0.elapsed_time(e1) / iters def _benchmark_flagsparse_spsv_csr( @@ -551,6 +597,9 @@ def _benchmark_flagsparse_spsv_csr( *, lower=True, transpose=False, + solve_kind=None, + warmup=WARMUP, + iters=ITERS, ): return _benchmark_flagsparse( lambda: fs.flagsparse_spsv_csr( @@ -561,7 +610,10 @@ def _benchmark_flagsparse_spsv_csr( shape, lower=lower, transpose=transpose, - ) + solve_kind=solve_kind, + ), + warmup=warmup, + iters=iters, ) @@ -574,7 +626,15 @@ def _benchmark_flagsparse_spsv_csr_split( *, lower=True, transpose=False, + solve_kind=None, ): + op_mode = fs_spsv_impl._normalize_spsv_transpose_mode(transpose) + warmup, iters = _spsv_benchmark_schedule( + int(data.numel()), + "NON" if op_mode == "N" else ("TRANS" if op_mode == "T" else "CONJ"), + data.dtype, + fmt="CSR", + ) analysis_ms = fs_spsv_impl._analyze_spsv_csr( data, indices, @@ -583,6 +643,7 @@ def _benchmark_flagsparse_spsv_csr_split( shape, lower=lower, transpose=transpose, + solve_kind=solve_kind, clear_cache=True, return_time=True, ) @@ -594,6 +655,9 @@ def _benchmark_flagsparse_spsv_csr_split( shape, lower=lower, transpose=transpose, + solve_kind=solve_kind, + warmup=warmup, + iters=iters, ) return x, analysis_ms, solve_ms @@ -607,6 +671,7 @@ def _benchmark_flagsparse_spsv_coo_split( *, lower=True, transpose=False, + solve_kind=None, ): data, input_index_dtype, row64, col64, b, n_rows, n_cols = fs_spsv_impl._prepare_spsv_coo_inputs( data, row, col, b, shape @@ -616,8 +681,14 @@ def _benchmark_flagsparse_spsv_coo_split( fs_spsv_impl._validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "COO") else: fs_spsv_impl._validate_spsv_trans_combo(data.dtype, input_index_dtype, "COO") - data_csr, indices_csr, indptr_csr = fs_spsv_impl._coo_to_csr_sorted_unique( - data, row64, col64, n_rows, n_cols + data_csr, indices_csr, indptr_csr = fs_spsv_impl._coo2csr_for_spsv( + data, row64, col64, n_rows, assume_ordered=False + ) + warmup, iters = _spsv_benchmark_schedule( + int(data_csr.numel()), + "NON" if trans_mode == "N" else ("TRANS" if trans_mode == "T" else "CONJ"), + data.dtype, + fmt="COO", ) analysis_ms = fs_spsv_impl._analyze_spsv_csr( data_csr, @@ -627,6 +698,7 @@ def _benchmark_flagsparse_spsv_coo_split( (n_rows, n_cols), lower=lower, transpose=transpose, + solve_kind=solve_kind, clear_cache=True, return_time=True, ) @@ -638,6 +710,9 @@ def _benchmark_flagsparse_spsv_coo_split( (n_rows, n_cols), lower=lower, transpose=transpose, + solve_kind=solve_kind, + warmup=warmup, + iters=iters, ) return x, analysis_ms, solve_ms @@ -711,6 +786,7 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): ): return None, None try: + warmup, iters = _spsv_benchmark_schedule(int(data.numel()), op_mode, data.dtype, fmt="CSR") data_ref, b_ref = _cupy_ref_inputs(data, b) data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) idx_cp = cp.from_dlpack( @@ -731,7 +807,7 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): A_eff = A_cp lower_eff = lower - for _ in range(WARMUP): + for _ in range(warmup): _ = cpx_spsolve_triangular( A_eff, b_cp, lower=lower_eff, unit_diagonal=False ) @@ -739,20 +815,20 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): c0 = cp.cuda.Event() c1 = cp.cuda.Event() c0.record() - for _ in range(ITERS): + for _ in range(iters): x_cp = cpx_spsolve_triangular( A_eff, b_cp, lower=lower_eff, unit_diagonal=False ) c1.record() c1.synchronize() - ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + ms = cp.cuda.get_elapsed_time(c0, c1) / iters x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()).to(b.dtype) return ms, x_t except Exception: return None, None -def run_spsv_synthetic_all(lower=True): +def run_spsv_synthetic_all(lower=True, alg_num=None): if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -762,8 +838,12 @@ def run_spsv_synthetic_all(lower=True): print("FLAGSPARSE SpSV BENCHMARK (synthetic triangular systems, CSR + COO)") print(sep) print(f"GPU: {torch.cuda.get_device_name(0)}") - print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print( + f"Benchmark schedule: warmup={WARMUP}, iter={ITERS} " + "(Library-main style defaults; override with --warmup/--iters)" + ) print(f"Triangle: {'LOWER' if lower else 'UPPER'}") + print(f"Algorithm: {_alg_label(alg_num)}") print() hdr = ( @@ -792,6 +872,10 @@ def run_spsv_synthetic_all(lower=True): else ["NON"] ) for op_mode in op_modes: + if not _alg_num_supports_case( + alg_num, fmt, op_mode, lower, value_dtype + ): + continue data, indices, indptr, shape = _build_random_triangular_csr( n, value_dtype, index_dtype, device, lower=lower ) @@ -822,6 +906,7 @@ def run_spsv_synthetic_all(lower=True): shape, lower=lower, transpose=op_mode, + solve_kind=_solve_kind_from_alg_num(alg_num), ) else: dc, rr, cc = _csr_to_coo( @@ -835,6 +920,7 @@ def run_spsv_synthetic_all(lower=True): shape, lower=lower, transpose=op_mode, + solve_kind=_solve_kind_from_alg_num(alg_num), ) torch.cuda.synchronize() @@ -920,7 +1006,7 @@ def run_spsv_synthetic_all(lower=True): print(sep) -def _run_one_csv_row_coo(path, value_dtype, index_dtype, op_mode, device, lower=True): +def _run_one_csv_row_coo(path, value_dtype, index_dtype, op_mode, device, lower=True, alg_num=None): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower ) @@ -952,6 +1038,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, op_mode, device, lower= shape, lower=lower, transpose=op_mode, + solve_kind=_solve_kind_from_alg_num(alg_num), ) return _finalize_csv_row( path, @@ -1070,7 +1157,7 @@ def _finalize_csv_row( return row, pt_skip_reason -def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True, alg_num=None): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower ) @@ -1099,6 +1186,7 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l shape, lower=lower, transpose=op_mode, + solve_kind=_solve_kind_from_alg_num(alg_num), ) return _finalize_csv_row_csr_full( path, @@ -1224,6 +1312,7 @@ def run_all_supported_spsv_csr_csv( value_dtypes=None, index_dtypes=None, op_modes=None, + alg_num=None, ): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -1240,14 +1329,23 @@ def run_all_supported_spsv_csr_csv( if op in selected_op_modes ] for op_mode in supported_op_modes: + if not _alg_num_supports_case( + alg_num, "CSR", op_mode, lower, value_dtype + ): + continue print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" ) + print(f"Algorithm: {_alg_label(alg_num)}") print( "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, " "PyTorch(ms)=official sparse solve reference" ) + print( + f"Benchmark schedule: warmup={WARMUP}, iter={ITERS} " + "(Library-main style defaults; override with --warmup/--iters)" + ) print( "RHS is generated directly, matching Library-main's SpSV test style. " "PT.total / CU.total are single official interface call times. " @@ -1267,7 +1365,7 @@ def run_all_supported_spsv_csr_csv( for path in mtx_paths: try: row, pt_skip = _run_one_csv_row_csr_full( - path, value_dtype, index_dtype, op_mode, device, lower=lower + path, value_dtype, index_dtype, op_mode, device, lower=lower, alg_num=alg_num ) rows_out.append(row) name = os.path.basename(path)[:27] @@ -1377,6 +1475,7 @@ def run_all_dtypes_spsv_coo_csv( value_dtypes=None, index_dtypes=None, op_modes=None, + alg_num=None, ): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -1392,16 +1491,25 @@ def run_all_dtypes_spsv_coo_csv( if op in (op_modes or SPSV_OP_MODES) ] for op_mode in supported_op_modes: + if not _alg_num_supports_case( + alg_num, "COO", op_mode, lower, value_dtype + ): + continue print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | COO" f" triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" ) + print(f"Algorithm: {_alg_label(alg_num)}") print( "Formats: FlagSparse=COO input routed through CSR SpSV, cuSPARSE=CSR ref, " "PyTorch(ms)=official sparse solve reference. " "RHS is generated directly, matching Library-main's SpSV test style." ) + print( + f"Benchmark schedule: warmup={WARMUP}, iter={ITERS} " + "(Library-main style defaults; override with --warmup/--iters)" + ) print( "PT.total / CU.total are single official interface call times. " "PT.spdS and CU.spdS compare against FS.solve; PT.spdT and CU.spdT compare against FS.total." @@ -1425,7 +1533,7 @@ def run_all_dtypes_spsv_coo_csv( for path in mtx_paths: try: row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, op_mode, device, lower=lower + path, value_dtype, index_dtype, op_mode, device, lower=lower, alg_num=alg_num ) rows_out.append(row) name = os.path.basename(path)[:27] @@ -1713,6 +1821,7 @@ def run_csr_transpose_check( def main(): + global WARMUP, ITERS parser = argparse.ArgumentParser( description="SpSV test: synthetic triangular systems and optional .mtx (CSR/COO), same baselines as CSR." ) @@ -1754,6 +1863,18 @@ def main(): default=None, help="Comma-separated opA filter for CSR/COO CSV, e.g. NON,TRANS,CONJ", ) + parser.add_argument( + "--alg-num", + "--alg_num", + dest="alg_num", + type=_parse_alg_num, + default=None, + help=( + "Algorithm selection compatible with allinone style. " + "Supported: 1=ALG1(csr_cw), 2=ALG2(csr_cw_levelschd), 3=ALG3(csr_nnz_balance), 8=ALG8(csr_nnz_balance). " + "Omit to use AUTO routing." + ), + ) parser.add_argument( "--value-dtypes", type=str, @@ -1766,11 +1887,40 @@ def main(): default=None, help="Comma-separated index dtype filter for CSR CSV, e.g. int32,int64", ) + parser.add_argument( + "--warmup", + type=int, + default=WARMUP, + help="Benchmark warmup iterations (Library-main style default: 1)", + ) + parser.add_argument( + "--iters", + type=int, + default=ITERS, + help="Benchmark timed iterations (Library-main style default: 1)", + ) args = parser.parse_args() + WARMUP = max(0, int(args.warmup)) + ITERS = max(1, int(args.iters)) lower = not args.upper + if args.alg_num in (2, 3, 8): + if args.check_transpose: + raise ValueError( + f"ALG{args.alg_num} matches allinone's NON-only path; --check-transpose is not supported" + ) + if args.upper: + raise ValueError( + f"ALG{args.alg_num} matches allinone's lower-triangular path; --upper is not supported" + ) + if args.ops: + op_modes_cli = _parse_op_modes_filter(args.ops) + if any(op != "NON" for op in op_modes_cli): + raise ValueError( + f"ALG{args.alg_num} matches allinone's NON-only path; use --ops NON" + ) if args.synthetic: - run_spsv_synthetic_all(lower=lower) + run_spsv_synthetic_all(lower=lower, alg_num=args.alg_num) return paths = [] @@ -1836,6 +1986,7 @@ def main(): value_dtypes=value_dtypes, index_dtypes=index_dtypes, op_modes=op_modes, + alg_num=args.alg_num, ) return if args.csv_coo: @@ -1866,6 +2017,7 @@ def main(): value_dtypes=value_dtypes, index_dtypes=index_dtypes, op_modes=op_modes, + alg_num=args.alg_num, ) return