diff --git a/.gitignore b/.gitignore index 4ca5df5..e44113e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ .pytest_cache/ .coverage htmlcov/ +.DS_Store diff --git a/benchmark/benchmark_spsv.py b/benchmark/benchmark_spsv.py new file mode 100644 index 0000000..3fa258f --- /dev/null +++ b/benchmark/benchmark_spsv.py @@ -0,0 +1,14 @@ +"""Run SpSV benchmark. From project root: python benchmark/benchmark_spsv.py [--synthetic | --csv-csr out.csv].""" + +import sys +from pathlib import Path + +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) + +from tests.test_spsv import main + + +if __name__ == "__main__": + main() diff --git a/src/flagsparse/sparse_operations/_common.py b/src/flagsparse/sparse_operations/_common.py index 7e89526..6f7725d 100644 --- a/src/flagsparse/sparse_operations/_common.py +++ b/src/flagsparse/sparse_operations/_common.py @@ -18,19 +18,14 @@ cp = None cpx_sparse = None -_TORCH_COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if _TORCH_COMPLEX32_DTYPE is None: - _TORCH_COMPLEX32_DTYPE = getattr(torch, "chalf", None) - _SUPPORTED_VALUE_DTYPES = [ torch.float16, torch.bfloat16, torch.float32, torch.float64, + torch.complex64, + torch.complex128, ] -if _TORCH_COMPLEX32_DTYPE is not None: - _SUPPORTED_VALUE_DTYPES.append(_TORCH_COMPLEX32_DTYPE) -_SUPPORTED_VALUE_DTYPES.extend([torch.complex64, torch.complex128]) SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES) SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) _INDEX_LIMIT_INT32 = 2**31 - 1 @@ -40,8 +35,6 @@ "SUPPORTED_VALUE_DTYPES", "SUPPORTED_INDEX_DTYPES", "_INDEX_LIMIT_INT32", - "_torch_complex32_dtype", - "_is_complex32_dtype", "_is_complex_dtype", "_resolve_scatter_value_dtype", "_component_dtype_for_complex", @@ -68,17 +61,8 @@ "tl", ) - -def _torch_complex32_dtype(): - return _TORCH_COMPLEX32_DTYPE - - -def _is_complex32_dtype(value_dtype): - return _TORCH_COMPLEX32_DTYPE is not None and value_dtype == _TORCH_COMPLEX32_DTYPE - - def _is_complex_dtype(value_dtype): - return _is_complex32_dtype(value_dtype) or value_dtype in (torch.complex64, torch.complex128) + return value_dtype in (torch.complex64, torch.complex128) def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): @@ -95,24 +79,13 @@ def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): "complex64": torch.complex64, "complex128": torch.complex128, } - if token == "complex32": - if _TORCH_COMPLEX32_DTYPE is not None: - return _TORCH_COMPLEX32_DTYPE, False, None - if dtype_policy == "strict": - raise TypeError("complex32 is unavailable in this torch build") - return torch.complex64, True, "complex32 is unavailable; fallback to complex64" if token not in mapping: raise TypeError(f"Unsupported dtype token: {value_dtype}") value_dtype = mapping[token] - if _is_complex32_dtype(value_dtype): - # If complex32 exists in torch dtype table, keep native path. - return value_dtype, False, None return value_dtype, False, None def _component_dtype_for_complex(value_dtype): - if _is_complex32_dtype(value_dtype): - return torch.float16 if value_dtype == torch.complex64: return torch.float32 if value_dtype == torch.complex128: @@ -123,8 +96,6 @@ def _component_dtype_for_complex(value_dtype): def _tolerance_for_dtype(value_dtype): if value_dtype == torch.float16: return 2e-3, 2e-3 - if _is_complex32_dtype(value_dtype): - return 5e-3, 5e-3 if value_dtype == torch.bfloat16: return 1e-1, 1e-1 if value_dtype in (torch.float32, torch.complex64): @@ -155,9 +126,6 @@ def _cupy_dtype_from_torch(torch_dtype): torch.int32: cp.int32, torch.int64: cp.int64, } - if _TORCH_COMPLEX32_DTYPE is not None: - # CuPy has no native complex32 sparse path; use complex64 for baseline parity. - mapping[_TORCH_COMPLEX32_DTYPE] = cp.complex64 if torch_dtype not in mapping: raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}") return mapping[torch_dtype] @@ -193,8 +161,6 @@ def _to_backend_like(torch_tensor, ref_obj): def _cusparse_baseline_skip_reason(value_dtype): if value_dtype == torch.bfloat16: return "bfloat16 is not supported by the cuSPARSE baseline path; skipped" - if _is_complex32_dtype(value_dtype): - return "complex32 is not supported by the cuSPARSE baseline path; skipped" if cp is None and value_dtype == torch.float16: return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped" return None @@ -203,9 +169,6 @@ def _cusparse_baseline_skip_reason(value_dtype): def _build_random_dense(dense_size, value_dtype, device): if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): return torch.randn(dense_size, dtype=value_dtype, device=device) - if _is_complex32_dtype(value_dtype): - stacked = torch.randn((dense_size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) if _is_complex_dtype(value_dtype): component_dtype = _component_dtype_for_complex(value_dtype) real = torch.randn(dense_size, dtype=component_dtype, device=device) diff --git a/src/flagsparse/sparse_operations/gather_scatter.py b/src/flagsparse/sparse_operations/gather_scatter.py index 6a02d8e..67066b0 100644 --- a/src/flagsparse/sparse_operations/gather_scatter.py +++ b/src/flagsparse/sparse_operations/gather_scatter.py @@ -226,6 +226,10 @@ def _launch_triton_scatter_kernel( ) +def _validate_gather_value_dtype(dense_vector, op_name): + return None + + def _cusparse_spmv(selector_matrix, dense_vector): if cp is not None and cpx_sparse is not None and isinstance(selector_matrix, cpx_sparse.spmatrix): if torch.is_tensor(dense_vector): @@ -330,6 +334,7 @@ def flagsparse_gather(a, indices, out=None, mode="raise", block_size=1024, retur dense_vector, dense_backend = _to_torch_tensor(a, "a") indices_tensor, _ = _to_torch_tensor(indices, "indices") dense_vector, indices_tensor, kernel_indices = _prepare_inputs(dense_vector, indices_tensor) + _validate_gather_value_dtype(dense_vector, "flagsparse_gather") torch.cuda.synchronize() start_time = time.perf_counter() @@ -479,6 +484,7 @@ def pytorch_index_scatter( def cusparse_spmv_gather(dense_vector, indices, selector_matrix=None): """Equivalent gather baseline via cuSPARSE-backed COO SpMV.""" dense_vector, indices, _ = _prepare_inputs(dense_vector, indices) + _validate_gather_value_dtype(dense_vector, "cusparse_spmv_gather") skip_reason = _cusparse_baseline_skip_reason(dense_vector.dtype) if skip_reason: raise RuntimeError(skip_reason) @@ -557,21 +563,14 @@ def _cupy_gather_detect_layout(dense_vector): return "scalar16" if dense_vector.ndim == 1 and dense_vector.dtype == torch.complex64: return "complex64" - if ( - dense_vector.ndim == 2 - and dense_vector.shape[1] == 2 - and dense_vector.dtype == torch.float16 - ): - # complex32 alignment with half2 storage. - return "complex16_pair" raise TypeError( "Unsupported gather input format. Expected one of: " - "1D float16/bfloat16, 1D complex64, or 2D (N,2) float16." + "1D float16/bfloat16 or 1D complex64." ) def _cupy_gather_dense_size(dense_vector, layout): - if layout in ("scalar16", "complex64", "complex16_pair"): + if layout in ("scalar16", "complex64"): return int(dense_vector.shape[0]) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -601,7 +600,7 @@ def _cupy_gather_validate_inputs(dense_vector, indices): def _cupy_gather_validate_combo(dense_vector, indices, layout): # Keep only the required extra gather combos: - # Half+Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if layout == "scalar16": if dense_vector.dtype == torch.float16: if indices.dtype != torch.int64: @@ -611,11 +610,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): return raise TypeError("scalar16 gather_cupy supports only float16/bfloat16") - if layout == "complex16_pair": - if dense_vector.dtype != torch.float16: - raise TypeError("complex16_pair gather_cupy supports only float16 pairs") - return - if layout == "complex64": return @@ -625,8 +619,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): def _cupy_gather_layout_raw_kind(layout): if layout == "scalar16": return 16 - if layout == "complex16_pair": - return 32 if layout == "complex64": return 64 raise RuntimeError(f"Unknown gather layout: {layout}") @@ -675,9 +667,6 @@ def _cupy_gather_dense_to_raw_torch(dense_t, layout): return dense_t.reshape(-1).view(torch.uint16) if layout == "complex64": return dense_t.reshape(-1).view(torch.uint64) - if layout == "complex16_pair": - lanes_u16 = dense_t.reshape(-1).view(torch.uint16) - return lanes_u16.reshape(-1, 2).view(torch.uint32).reshape(-1) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -686,22 +675,17 @@ def _cupy_gather_raw_to_dense_torch(out_raw_t, layout, dense_t_dtype): return out_raw_t.view(dense_t_dtype).reshape(-1) if layout == "complex64": return out_raw_t.view(torch.complex64).reshape(-1) - if layout == "complex16_pair": - lanes_u16 = out_raw_t.view(torch.uint16).reshape(-1, 2) - return lanes_u16.view(dense_t_dtype).reshape(-1, 2) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_empty(layout, dense_dtype, device): if layout in ("scalar16", "complex64"): return torch.empty(0, dtype=dense_dtype, device=device) - if layout == "complex16_pair": - return torch.empty((0, 2), dtype=dense_dtype, device=device) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_selector_dtype(layout, dense_dtype): - if layout in ("scalar16", "complex16_pair"): + if layout == "scalar16": return dense_dtype if layout == "complex64": return torch.complex64 @@ -709,21 +693,12 @@ def _cupy_gather_selector_dtype(layout, dense_dtype): def _cupy_gather_prepare_dense(dense_vector, indices): - runtime_dense = dense_vector - restore_mode = None - native_complex32 = _torch_complex32_dtype() - if native_complex32 is not None and dense_vector.ndim == 1 and dense_vector.dtype == native_complex32: - runtime_dense = torch.view_as_real(dense_vector).contiguous() - restore_mode = "native_complex32" - - layout, dense_size = _cupy_gather_validate_inputs(runtime_dense, indices) - _cupy_gather_validate_combo(runtime_dense, indices, layout) - return runtime_dense, layout, dense_size, restore_mode + layout, dense_size = _cupy_gather_validate_inputs(dense_vector, indices) + _cupy_gather_validate_combo(dense_vector, indices, layout) + return dense_vector, layout, dense_size, None def _cupy_gather_restore_output(gathered_t, restore_mode): - if restore_mode == "native_complex32": - return torch.view_as_complex(gathered_t.contiguous()) return gathered_t @@ -880,10 +855,6 @@ def cusparse_spmv_gather_cupy(dense_vector, indices, selector_matrix=None): start_time = time.perf_counter() if layout in ("scalar16", "complex64"): gathered_t = _cusparse_spmv(selector_matrix, runtime_dense_t) - elif layout == "complex16_pair": - gathered_real = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 0]) - gathered_imag = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 1]) - gathered_t = torch.stack([gathered_real, gathered_imag], dim=1) else: raise RuntimeError(f"Unknown gather layout: {layout}") torch.cuda.synchronize() diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index cdfbf50..ffdb6ea 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -2,24 +2,44 @@ from ._common import * +from collections import OrderedDict import os import time import triton import triton.language as tl SUPPORTED_SPSV_VALUE_DTYPES = ( - torch.bfloat16, torch.float32, torch.float64, + torch.complex64, + torch.complex128, ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) -SPSV_NON_TRANS_PRIMARY_COMBOS = ( +SPSV_NON_TRANS_SUPPORTED_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), + (torch.complex64, torch.int32), + (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), +) +SPSV_TRANS_SUPPORTED_COMBOS = ( + (torch.float32, torch.int32), + (torch.float64, torch.int32), + (torch.complex64, torch.int32), + (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") ).lower() in ("1", "true", "yes", "on") +_SPSV_CSR_PREPROCESS_CACHE = OrderedDict() +_SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 def _csr_to_dense(data, indices, indptr, shape): """Convert CSR (torch CUDA tensors) to dense matrix on the same device.""" @@ -44,17 +64,42 @@ def _csr_to_dense(data, indices, indptr, shape): def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" - if (data_dtype, index_dtype) in SPSV_NON_TRANS_PRIMARY_COMBOS: + if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: return - if data_dtype == torch.bfloat16 and index_dtype == torch.int32: + raise TypeError( + f"{fmt_name} SpSV currently supports NON_TRANS combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" + ) + + +def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): + if (data_dtype, index_dtype) in SPSV_TRANS_SUPPORTED_COMBOS: return raise TypeError( - f"{fmt_name} SpSV currently supports NON_TRANS combinations with int32 kernel " - "indices: (float32, int32), (float64, int32), (bfloat16, int32)" + f"{fmt_name} SpSV currently supports TRANS/CONJ combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" ) -def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): +def _normalize_spsv_transpose_mode(transpose): + if isinstance(transpose, bool): + return "T" if transpose else "N" + token = str(transpose).strip().upper() + if token in ("N", "NON", "NON_TRANS"): + return "N" + if token in ("T", "TRANS"): + return "T" + if token in ("C", "H", "CONJ", "CONJ_TRANS", "CONJUGATE_TRANSPOSE"): + return "C" + raise ValueError( + "transpose must be bool or one of: " + "N/NON/NON_TRANS, T/TRANS, C/H/CONJ/CONJ_TRANS/CONJUGATE_TRANSPOSE" + ) + + +def _prepare_spsv_inputs(data, indices, indptr, b, shape): """Validate and normalize inputs for sparse solve A x = b with CSR A.""" if not all(torch.is_tensor(t) for t in (data, indices, indptr, b)): raise TypeError("data, indices, indptr, b must all be torch.Tensor") @@ -77,7 +122,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( - "data dtype must be one of: bfloat16, float32, float64" + "data dtype must be one of: float32, float64, complex64, complex128" ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -85,8 +130,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise TypeError("indptr dtype must be torch.int32 or torch.int64") if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") indices64 = indices.to(torch.int64).contiguous() indptr64 = indptr.to(torch.int64).contiguous() @@ -94,8 +137,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise ValueError( f"int64 index value {int(indices64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "CSR") - if indptr64.numel() > 0: if int(indptr64[0].item()) != 0: raise ValueError("indptr[0] must be 0") @@ -112,6 +153,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): return ( data.contiguous(), + indices.dtype, indices64, indptr64, b.contiguous(), @@ -120,6 +162,94 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): ) +def _prepare_spsv_working_inputs(data, b): + return data, b, None + + +def _restore_spsv_output(x, target_dtype): + return x.to(target_dtype) + + +def _spsv_diag_eps_for_dtype(value_dtype): + return 1e-12 if value_dtype in (torch.float64, torch.complex128) else 1e-6 + + +def _tensor_cache_token(tensor): + try: + storage_ptr = int(tensor.untyped_storage().data_ptr()) + except Exception: + storage_ptr = 0 + return ( + str(tensor.device), + str(tensor.dtype), + tuple(int(v) for v in tensor.shape), + int(tensor.numel()), + storage_ptr, + int(getattr(tensor, "_version", 0)), + ) + + +def _spsv_cache_get(cache, key): + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + return value + + +def _spsv_cache_put(cache, key, value, max_entries): + cache[key] = value + cache.move_to_end(key) + while len(cache) > max_entries: + cache.popitem(last=False) + + +def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): + return ( + "csr_preprocess", + trans_mode, + bool(lower), + int(shape[0]), + int(shape[1]), + _tensor_cache_token(data), + _tensor_cache_token(indices), + _tensor_cache_token(indptr), + ) + + +def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): + """Prepare an equivalent CSR triangular system before launching the solve kernel. + + Keep TRANS/CONJ handling outside the Triton solve kernels so the kernels only + execute one fixed CSR triangular solve semantics. + """ + if trans_mode == "N": + kernel_data = data + kernel_indices64 = indices64 + kernel_indptr64 = indptr64 + lower_eff = lower + else: + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, + indices64, + indptr64, + n_rows, + n_cols, + conjugate=(trans_mode == "C"), + ) + lower_eff = not lower + + levels = _build_spsv_levels( + kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff + ) + return ( + kernel_data, + kernel_indices64, + kernel_indptr64, + lower_eff, + levels, + ) + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -172,6 +302,105 @@ def _spsv_csr_level_kernel( tl.store(x_ptr + row, x_row) +@triton.jit +def _spsv_csr_level_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + b_ri_ptr, + x_ri_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + acc_re = tl.zeros((1,), dtype=tl.float64) + acc_im = tl.zeros((1,), dtype=tl.float64) + diag_re = tl.zeros((1,), dtype=tl.float64) + diag_im = tl.zeros((1,), dtype=tl.float64) + else: + acc_re = tl.zeros((1,), dtype=tl.float32) + acc_im = tl.zeros((1,), dtype=tl.float32) + diag_re = tl.zeros((1,), dtype=tl.float32) + diag_im = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag_re = diag_re + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) + x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) + + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + prod_re = a_re * x_re - a_im * x_im + prod_im = a_re * x_im + a_im * x_re + acc_re = acc_re + tl.sum(tl.where(mask & solved, prod_re, 0.0)) + acc_im = acc_im + tl.sum(tl.where(mask & solved, prod_im, 0.0)) + + if not UNIT_DIAG: + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -360,6 +589,90 @@ def _triton_spsv_csr_vector( return x +def _triton_spsv_csr_vector_complex( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + # Some PyTorch builds return CSR values with a non-strided layout wrapper. + # Materialize a plain 1D strided buffer before splitting into real/imag parts. + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_level_kernel_complex[grid]( + data_ri, + indices, + indptr, + b_ri, + x_ri, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + +def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments=None): + if block_nnz is not None or max_segments is not None: + return _auto_spsv_launch_config(indptr, block_nnz=block_nnz, max_segments=max_segments) + + if indptr.numel() <= 1: + return 32, 1 + max_nnz_per_row = int((indptr[1:] - indptr[:-1]).max().item()) + for cand in (32, 64, 128, 256, 512, 1024): + req = max((max_nnz_per_row + cand - 1) // cand, 1) + if req <= 2048: + return cand, req + cand = 2048 + req = max((max_nnz_per_row + cand - 1) // cand, 1) + return cand, req + + def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -378,24 +691,27 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") - if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: - raise TypeError("data dtype must be one of: bfloat16, float32, float64") + if data.dtype not in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + raise TypeError( + "data dtype must be one of: float32, float64, complex64, complex128" + ) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("row dtype must be torch.int32 or torch.int64") if col.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("col dtype must be torch.int32 or torch.int64") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") - row64 = row.to(torch.int64).contiguous() col64 = col.to(torch.int64).contiguous() if col64.numel() > 0 and int(col64.max().item()) > _INDEX_LIMIT_INT32: raise ValueError( f"int64 index value {int(col64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") if row64.numel() > 0: if bool(torch.any(row64 < 0).item()): raise IndexError("row indices must be non-negative") @@ -408,6 +724,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if max_col >= n_cols: raise IndexError(f"col indices out of range for n_cols={n_cols}") + _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") return ( data.contiguous(), row64, @@ -418,6 +735,26 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): ) +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols, conjugate=False): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = indices64 + new_col = row_ids + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( + data_eff, new_row, new_col, n_cols, n_rows + ) + return data_t, indices_t, indptr_t + + def _coo_is_sorted_unique(row64, col64, n_cols): nnz = row64.numel() if nnz <= 1: @@ -454,13 +791,8 @@ def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) out_nnz = unique_key.numel() - if data_s.dtype == torch.bfloat16: - reduced_f32 = torch.zeros(out_nnz, dtype=torch.float32, device=data.device) - reduced_f32.scatter_add_(0, inverse, data_s.to(torch.float32)) - data_u = reduced_f32.to(torch.bfloat16) - else: - data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) - data_u.scatter_add_(0, inverse, data_s) + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") col_u = unique_key - row_u * max(1, n_cols) @@ -532,75 +864,162 @@ def flagsparse_spsv_csr( ): """Sparse triangular solve using Triton level-scheduling kernels. - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Primary support matrix: + - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices + - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices """ - data, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( - data, indices, indptr, b, shape, transpose=transpose + input_data = data + input_indices = indices + input_indptr = indptr + trans_mode = _normalize_spsv_transpose_mode(transpose) + data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( + data, indices, indptr, b, shape ) + original_output_dtype = None + data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") - kernel_indices = indices.to(torch.int32) if indices.dtype != torch.int32 else indices - kernel_indptr = indptr + if trans_mode == "N": + _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + + preprocess_key = _csr_preprocess_cache_key( + input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode + ) + cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) + if cached is None: + cached = _prepare_spsv_csr_system( + data, + indices, + indptr, + n_rows, + n_cols, + lower, + trans_mode, + ) + _spsv_cache_put( + _SPSV_CSR_PREPROCESS_CACHE, + preprocess_key, + cached, + _SPSV_CSR_PREPROCESS_CACHE_SIZE, + ) + kernel_data, kernel_indices64, kernel_indptr64, lower_eff, levels = cached + + kernel_indices = ( + kernel_indices64.to(torch.int32) + if kernel_indices64.dtype != torch.int32 + else kernel_indices64 + ) + kernel_indptr = kernel_indptr64 compute_dtype = data.dtype - data_in = data + data_in = kernel_data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = data.to(torch.float32) - b_in = b.to(torch.float32) + if data.dtype == torch.complex64 and trans_mode in ("T", "C"): + compute_dtype = torch.complex128 + data_in = kernel_data.to(torch.complex128) + b_in = b.to(torch.complex128) elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: - # Optional high-precision mode; disabled by default for throughput. compute_dtype = torch.float64 - data_in = data.to(torch.float64) + data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - levels = _build_spsv_levels(kernel_indptr, kernel_indices, n_rows, lower=lower) - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + elif data.dtype == torch.float32 and trans_mode in ("T", "C"): + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) + + is_transpose_family_op = trans_mode != "N" + if is_transpose_family_op: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) + vec_real = _triton_spsv_csr_vector + vec_complex = _triton_spsv_csr_vector_complex + torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: - x = _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in, - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - ) + if torch.is_complex(data_in): + x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) else: cols = [] for j in range(b_in.shape[1]): - cols.append( - _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in[:, j].contiguous(), - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + bj = b_in[:, j].contiguous() + if torch.is_complex(data_in): + cols.append( + vec_complex( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + cols.append( + vec_real( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) ) - ) x = torch.stack(cols, dim=1) - if compute_dtype != data.dtype: - x = x.to(data.dtype) + target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype + if x.dtype != target_dtype: + x = _restore_spsv_output(x, target_dtype) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -615,6 +1034,7 @@ def flagsparse_spsv_csr( def flagsparse_spsv_coo( + data, row, col, @@ -632,11 +1052,11 @@ def flagsparse_spsv_coo( """COO SpSV with dual mode: - direct: use COO level kernel directly (requires sorted+unique COO) - csr: convert COO -> CSR (sorted+deduplicated) then call flagsparse_spsv_csr - - auto: pick direct when sorted+unique, otherwise csr + - auto: pick direct when sorted+unique and supported, otherwise csr - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Notes: + - direct mode currently supports only non-transposed real-valued inputs + - complex dtypes and TRANS/CONJ always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape, transpose=transpose @@ -649,7 +1069,14 @@ def flagsparse_spsv_coo( raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) - use_direct = mode == "direct" or (mode == "auto" and sorted_unique) + trans_mode = _normalize_spsv_transpose_mode(transpose) + direct_supported = (trans_mode == "N") and (not torch.is_complex(data)) + use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) + if mode == "direct" and not direct_supported: + raise ValueError( + "coo_mode='direct' supports only non-transposed real-valued inputs; " + "use coo_mode='csr' or 'auto' for TRANS/CONJ or complex dtypes" + ) if mode == "direct" and not sorted_unique: raise ValueError( "coo_mode='direct' requires COO sorted by (row, col) with no duplicate coordinates; " @@ -681,11 +1108,7 @@ def flagsparse_spsv_coo( compute_dtype = data.dtype data_in = data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = data.to(torch.float32) - b_in = b.to(torch.float32) - elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + if data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: compute_dtype = torch.float64 data_in = data.to(torch.float64) b_in = b.to(torch.float64) @@ -693,7 +1116,7 @@ def flagsparse_spsv_coo( block_nnz_use, max_segments_use = _auto_spsv_launch_config( row_ptr, block_nnz=block_nnz, max_segments=max_segments ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) torch.cuda.synchronize() t0 = time.perf_counter() @@ -747,4 +1170,4 @@ def flagsparse_spsv_coo( if return_time: return x, elapsed_ms - return x \ No newline at end of file + return x diff --git a/tests/pytest/test_gather_scatter_accuracy.py b/tests/pytest/test_gather_scatter_accuracy.py index 3563ef1..e67c648 100644 --- a/tests/pytest/test_gather_scatter_accuracy.py +++ b/tests/pytest/test_gather_scatter_accuracy.py @@ -19,17 +19,8 @@ RESET_OUTPUT_CASES = [True, False] RESET_OUTPUT_IDS = ["reset", "inplace"] - -def _complex32_dtype(): - dtype = getattr(torch, "complex32", None) - if dtype is None: - dtype = getattr(torch, "chalf", None) - return dtype - - def _scatter_dtype_cases(): cases = [(str(dtype).replace("torch.", ""), dtype) for dtype in FLOAT_DTYPES] - cases.append(("complex32", _complex32_dtype())) cases.append(("complex64", torch.complex64)) cases.append(("complex128", torch.complex128)) return cases @@ -59,9 +50,6 @@ def _build_random_values(size, dtype, device): real = torch.randn(size, dtype=torch.float64, device=device) imag = torch.randn(size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if _complex32_dtype() is not None and dtype == _complex32_dtype(): - stacked = torch.randn((size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) raise TypeError(f"Unsupported dtype in test: {dtype}") @@ -113,8 +101,6 @@ def _extra_gather_tolerance(value_dtype): ("scalar16", torch.float16, torch.int64), ("scalar16", torch.bfloat16, torch.int32), ("scalar16", torch.bfloat16, torch.int64), - ("complex16_pair", torch.float16, torch.int32), - ("complex16_pair", torch.float16, torch.int64), ("complex64", torch.complex64, torch.int32), ("complex64", torch.complex64, torch.int64), ] @@ -122,8 +108,6 @@ def _extra_gather_tolerance(value_dtype): "half_i64", "bf16_i32", "bf16_i64", - "c16f_i32", - "c16f_i64", "c64_i32", "c64_i64", ] @@ -144,6 +128,21 @@ def test_gather_matches_indexing(dense_size, nnz, dtype, index_dtype): assert torch.equal(ref, got) +@pytest.mark.gather +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +def test_gather_complex128_matches_indexing(index_dtype): + device = torch.device("cuda") + dense_size = 4096 + nnz = 1024 + real = torch.randn(dense_size, dtype=torch.float64, device=device) + imag = torch.randn(dense_size, dtype=torch.float64, device=device) + dense = torch.complex(real, imag) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + ref = dense.index_select(0, indices.to(torch.int64)) + got = flagsparse_gather(dense, indices) + assert torch.allclose(got, ref, atol=1e-10, rtol=1e-8) + + @pytest.mark.scatter @pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) @pytest.mark.parametrize("dtype_name,dtype", SCATTER_DTYPE_CASES, ids=SCATTER_DTYPE_IDS) @@ -323,94 +322,6 @@ def test_gather_cupy_same_backend_out_float16_i64(backend): assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -@pytest.mark.parametrize("backend", ["torch", "cupy"]) -def test_gather_cupy_same_backend_out_pair_complex32(backend): - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_t = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - indices_t = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_t.index_select(0, indices_t) - - dense_in = _to_backend_tensor(dense_t, backend) - indices_in = _to_backend_tensor(indices_t, backend) - out = _to_backend_tensor(torch.empty_like(reference), backend) - result = flagsparse_gather_cupy(dense_in, indices_in, out=out) - - assert result is out - assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_out_matches_reference(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_native.index_select(0, indices) - - out = torch.empty_like(reference) - result = flagsparse_gather_cupy(dense_native, indices, out=out) - - assert result is out - assert out.dtype == native_dtype - assert torch.allclose( - torch.view_as_real(out).contiguous(), - torch.view_as_real(reference).contiguous(), - atol=5e-3, - rtol=5e-3, - ) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_matches_reference_and_pair_layout(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - - reference = dense_native.index_select(0, indices) - reference_pair = torch.view_as_real(reference).contiguous() - pair_got = flagsparse_gather_cupy(dense_pair, indices) - native_got = flagsparse_gather_cupy(dense_native, indices) - cusparse_values, _, _ = cusparse_spmv_gather_cupy(dense_native, indices) - - atol, rtol = 5e-3, 5e-3 - assert pair_got.shape == reference_pair.shape - assert pair_got.dtype == torch.float16 - assert native_got.shape == reference.shape - assert native_got.dtype == native_dtype - assert cusparse_values.shape == reference.shape - assert cusparse_values.dtype == native_dtype - assert torch.allclose(pair_got, reference_pair, atol=atol, rtol=rtol) - assert torch.allclose( - torch.view_as_real(native_got).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - assert torch.allclose( - torch.view_as_real(cusparse_values).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - - @pytest.mark.gather @pytest.mark.skipif(cp is None, reason="CuPy required") def test_gather_cupy_int64_auto_fallback_to_int32(monkeypatch): diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index f2cf944..fc9c891 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -1,12 +1,118 @@ import pytest import torch -from flagsparse import flagsparse_spsv_csr +from flagsparse import flagsparse_spsv_coo, flagsparse_spsv_csr from tests.pytest.param_shapes import SPSV_N +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_MODES = ["TRANS", "CONJ"] + + +def _dtype_id(dtype): + return str(dtype).replace("torch.", "") + + +def _tol(dtype): + if dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-3 + return 1e-10, 1e-8 + + +def _rand_like(dtype, shape, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(shape, dtype=dtype, device=device) + base = torch.float32 if dtype == torch.complex64 else torch.float64 + r = torch.randn(shape, dtype=base, device=device) + i = torch.randn(shape, dtype=base, device=device) + return torch.complex(r, i) + + +def _ref_dtype(dtype): + return dtype + + +def _safe_cast_tensor(tensor, dtype): + return tensor.to(dtype) + + +def _cmp_view(tensor, dtype): + return tensor + + +def _apply_ref_op(A, op_mode): + if op_mode == "TRANS": + return A.transpose(-2, -1) + if op_mode == "CONJ": + return A.transpose(-2, -1).conj() if torch.is_complex(A) else A.transpose(-2, -1) + return A + + +def _effective_upper(lower, op_mode): + return lower if op_mode in ("TRANS", "CONJ") else not lower + + +def _effective_lower_for_op(lower, op_mode): + return (not lower) if op_mode in ("TRANS", "CONJ") else lower + + +def _transpose_arg(op_mode): + if op_mode == "NON": + return False + return op_mode + + +def _cupy_apply_op(A_cp, op_mode): + if op_mode == "TRANS": + return A_cp.transpose().tocsr() + if op_mode == "CONJ": + return A_cp.transpose().conj().tocsr() + return A_cp + + +def _build_triangular(n, dtype, device, lower=True): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) if lower else torch.triu(off) + if torch.is_complex(A): + diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) + A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) + else: + diag = torch.rand(n, device=device, dtype=A.dtype) + 2.0 + A = A + torch.diag(diag) + return A + + +def _cupy_csr_from_torch(data, indices, indptr, shape): + if cp is None or cpx_sparse is None: + return None + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) + return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + + +def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): + if cp is None or cpx_spsolve_triangular is None: + return None + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) + return x_t.to(b_t.dtype) @pytest.mark.spsv @@ -17,12 +123,12 @@ ids=["float32", "float64"], ) def test_spsv_csr_lower_matches_dense(n, dtype): + # Keep the original baseline test case untouched in semantics. device = torch.device("cuda") base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) eye = torch.eye(n, dtype=dtype, device=device) A = base + eye * (float(n) * 0.5 + 2.0) b = torch.randn(n, dtype=dtype, device=device) - # PyTorch 2.x requires B with rank >= 2 for solve_triangular. x_ref = torch.linalg.solve_triangular( A, b.unsqueeze(-1), upper=False ).squeeze(-1) @@ -42,3 +148,300 @@ def test_spsv_csr_lower_matches_dense(n, dtype): rtol = 1e-4 if dtype == torch.float32 else 1e-10 atol = 1e-5 if dtype == torch.float32 else 1e-10 assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(True, op_mode) + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_transpose_family(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(True, op_mode), + unit_diagonal=False, + ) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(False, op_mode) + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_upper_transpose_family(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(False, op_mode), + unit_diagonal=False, + ) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_coo_transpose_family_complex128_routes_through_csr(n, op_mode): + device = torch.device("cuda") + dtype = torch.complex128 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A, op_mode), b.unsqueeze(-1), upper=_effective_upper(True, op_mode) + ).squeeze(-1) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + coo_mode="auto", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) diff --git a/tests/test_gather.py b/tests/test_gather.py index 481138a..b00a6c4 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -16,7 +16,7 @@ (524_288, 16_384), (1_048_576, 65_536), ] -DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex32,complex64,complex128" +DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex64,complex128" DEFAULT_INDEX_DTYPES = "int32,int64" WARMUP = 20 ITERS = 200 @@ -48,7 +48,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } @@ -164,7 +163,7 @@ def _collect_samples(case_id, expected, flagsparse_out, limit): def _dtype_mode(value_dtype_req): - if value_dtype_req in ("float16", "bfloat16", "complex32", "complex64"): + if value_dtype_req in ("float16", "bfloat16", "complex64"): return "gather_cupy" return "gather_triton" @@ -193,8 +192,6 @@ def _build_dense(value_dtype_req, dense_size, device): real = torch.randn(dense_size, dtype=torch.float64, device=device) imag = torch.randn(dense_size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if value_dtype_req == "complex32": - return torch.randn(dense_size, 2, dtype=torch.float16, device=device) raise ValueError(f"Unsupported value dtype request: {value_dtype_req}") @@ -206,13 +203,12 @@ def _effective_dtype_name(value_dtype_req): "float64": "float64", "complex64": "complex64", "complex128": "complex128", - "complex32": "complex16_pair_f16", } return mapping[value_dtype_req] def _tolerance(value_dtype_req): - if value_dtype_req in ("float16", "complex32"): + if value_dtype_req == "float16": return 5e-3, 5e-3 if value_dtype_req in ("bfloat16",): return 1e-2, 1e-2 @@ -230,10 +226,10 @@ def _check_dtype_supported(value_dtype_req): def _is_supported_extra_gather_combo(value_dtype_req, index_dtype): # Required extra gather combos only: - # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if value_dtype_req == "float16": return index_dtype in (torch.int32, torch.int64) - if value_dtype_req in ("bfloat16", "complex32", "complex64"): + if value_dtype_req in ("bfloat16", "complex64"): return index_dtype in (torch.int32, torch.int64) # Original gather path dtypes keep original behavior. return True diff --git a/tests/test_scatter.py b/tests/test_scatter.py index e92b41e..604adeb 100644 --- a/tests/test_scatter.py +++ b/tests/test_scatter.py @@ -54,7 +54,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 3123774..cb28cb4 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -7,11 +7,20 @@ import argparse import csv import glob +import hashlib import os +import sys +from pathlib import Path import torch +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + import flagsparse as fs +import flagsparse.sparse_operations.spsv as fs_spsv_impl try: import cupy as cp @@ -22,19 +31,69 @@ cpx_sparse = None cpx_spsolve_triangular = None -VALUE_DTYPES = [torch.float32, torch.float64] -INDEX_DTYPES = [torch.int32] +VALUE_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +INDEX_DTYPES = [torch.int32, torch.int64] TEST_SIZES = [256, 512, 1024, 2048] WARMUP = 5 ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +SPSV_OP_MODES = ["NON", "TRANS", "CONJ"] def _dtype_name(dtype): return str(dtype).replace("torch.", "") +VALUE_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_VALUE_DTYPES +} +VALUE_DTYPE_NAME_MAP.update({ + "float": torch.float32, + "double": torch.float64, +}) +INDEX_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_INDEX_DTYPES +} + + +def _parse_csv_tokens(raw): + return [tok.strip() for tok in str(raw).split(",") if tok.strip()] + + +def _parse_value_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in VALUE_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return [VALUE_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_index_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in INDEX_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [INDEX_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_op_modes_filter(raw): + tokens = [tok.upper() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in SPSV_OP_MODES] + if invalid: + raise ValueError(f"unsupported ops: {invalid}") + return tokens + + def _fmt_ms(v): return "N/A" if v is None else f"{v:.4f}" @@ -50,11 +109,94 @@ def _fmt_err(v): def _tol_for_dtype(dtype): - if dtype == torch.float32: + if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 +def _stable_case_seed(*parts): + raw = "|".join(str(part) for part in parts).encode("utf-8") + return int.from_bytes(hashlib.sha256(raw).digest()[:8], "little") % (2**63) + + +def _generator_for_seed(seed): + if seed is None: + return None + gen = torch.Generator() + gen.manual_seed(int(seed)) + return gen + + +def _randn_by_dtype(n, dtype, device, generator=None): + if dtype in (torch.float32, torch.float64): + return torch.randn(n, dtype=dtype, device=device, generator=generator) + base = torch.float32 if dtype == torch.complex64 else torch.float64 + real = torch.randn(n, dtype=base, device=device, generator=generator) + imag = torch.randn(n, dtype=base, device=device, generator=generator) + return torch.complex(real, imag) + + +def _dense_ref_dtype(dtype): + return dtype + + +def _tensor_from_scalar_values(values, dtype, device): + return torch.tensor(values, dtype=dtype, device=device) + + +def _safe_cast_tensor(tensor, dtype): + return tensor.to(dtype) + + +def _cast_real_tensor_to_value_dtype(values, value_dtype): + return values.to(value_dtype) + + +def _matrix_market_value(parts, mm_field): + if mm_field == "complex": + if len(parts) < 4: + raise ValueError("MatrixMarket complex entry requires real and imag parts") + return complex(float(parts[2]), float(parts[3])) + if len(parts) >= 3: + return float(parts[2]) + if mm_field == "pattern": + return 1.0 + raise ValueError("MatrixMarket entry is missing a numeric value") + + +def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): + if op_mode == "TRANS": + A_eff = A.transpose(0, 1) + upper = lower + elif op_mode == "CONJ": + A_eff = A.transpose(0, 1).conj() if torch.is_complex(A) else A.transpose(0, 1) + upper = lower + else: + A_eff = A + upper = not lower + return torch.linalg.solve_triangular( + A_eff, b.unsqueeze(1), upper=upper + ).squeeze(1) + + +def _cupy_ref_inputs(data, b): + return data, b + + +def _compare_view(tensor, value_dtype): + return tensor + + +def _supported_csr_full_ops(value_dtype, index_dtype): + if value_dtype not in CSR_FULL_VALUE_DTYPES: + return [] + if index_dtype == torch.int32: + return ["NON", "TRANS", "CONJ"] + if index_dtype == torch.int64: + return ["NON", "TRANS", "CONJ"] + return [] + + def _allow_dense_pytorch_ref(shape, dtype): n_rows, n_cols = int(shape[0]), int(shape[1]) elem_bytes = torch.empty((), dtype=dtype).element_size() @@ -63,14 +205,21 @@ def _allow_dense_pytorch_ref(shape, dtype): def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): - """Build a well-conditioned triangular CSR (float32/float64).""" + """Build a well-conditioned triangular CSR for real and complex dtypes.""" max_bandwidth = max(4, min(n, 16)) rows_host = [] cols_host = [] vals_host = [] - base_real_dtype = ( - torch.float32 if value_dtype == torch.float32 else torch.float64 - ) + row_off_abs = [0.0] * n + col_off_abs = [0.0] * n + if value_dtype == torch.float32: + base_real_dtype = torch.float32 + elif value_dtype == torch.float64: + base_real_dtype = torch.float64 + elif value_dtype == torch.complex64: + base_real_dtype = torch.float32 + else: + base_real_dtype = torch.float64 for i in range(n): if lower: @@ -87,20 +236,39 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True off_cols = [off_cand[j] for j in perm] else: off_cols = [] - off_vals_real = torch.randn(len(off_cols), dtype=base_real_dtype).mul_(0.01) - sum_abs = float(torch.sum(torch.abs(off_vals_real)).item()) if off_vals_real.numel() else 0.0 - diag_val = sum_abs + 1.0 - rows_host.append(i) - cols_host.append(diag_col) - vals_host.append(diag_val) - for c, v in zip(off_cols, off_vals_real.tolist()): + if value_dtype in (torch.complex64, torch.complex128): + off_vals = torch.complex( + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + ) + off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] + else: + off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) + off_vals_host = off_vals.cpu().tolist() + for c, v in zip(off_cols, off_vals_host): rows_host.append(i) cols_host.append(int(c)) vals_host.append(v) + mag = abs(v) + row_off_abs[i] += mag + col_off_abs[int(c)] += mag + + for i in range(n): + diag_mag = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[i], col_off_abs[i]) + 1.0 + ) + diag_val = ( + complex(diag_mag, 0.0) + if value_dtype in (torch.complex64, torch.complex128) + else diag_mag + ) + rows_host.append(i) + cols_host.append(i) + vals_host.append(diag_val) rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) - vals_t = torch.tensor(vals_host, dtype=base_real_dtype, device=device).to(value_dtype) + vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -127,17 +295,45 @@ def _csr_to_dense(data, indices, indptr, shape): return coo.to_dense() -def _csr_to_coo(data, indices, indptr, shape): +def _csr_to_coo(data, indices, indptr, shape, index_dtype=torch.int64): n_rows = int(shape[0]) row = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), + torch.arange(n_rows, device=data.device, dtype=index_dtype), indptr[1:] - indptr[:-1], ) - col = indices.to(torch.int64) + col = indices.to(index_dtype) return data, row, col -def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): +def _csr_transpose(data, indices, indptr, shape, conjugate=False): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device), + ) + + row, col = _csr_to_coo(data, indices, indptr, shape)[1:] + row_t = col + col_t = row + key = row_t * max(1, n_rows) + col_t + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + + row_t = row_t[order] + col_t = col_t[order] + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t = data_eff[order] + nnz_per_row = torch.bincount(row_t, minlength=n_cols) + indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_t, col_t.to(torch.int64), indptr_t + + +def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(file_path, "r", encoding="utf-8") as f: @@ -182,28 +378,36 @@ def _accum(r, c, v): continue r = int(parts[0]) - 1 c = int(parts[1]) - 1 - if len(parts) >= 3: - v = float(parts[2]) - elif mm_field == "pattern": - v = 1.0 - else: - continue + v = _matrix_market_value(parts, mm_field) _accum(r, c, v) - if mm_symmetry in ("symmetric", "hermitian") and r != c: + if mm_symmetry == "symmetric" and r != c: _accum(c, r, v) + elif mm_symmetry == "hermitian" and r != c: + _accum(c, r, v.conjugate() if isinstance(v, complex) else v) elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) + tri_rows = [dict() for _ in range(n_rows)] + row_off_abs = [0.0] * n_rows + col_off_abs = [0.0] * n_cols for r in range(n_rows): - row = row_maps[r] - lower_row = {} - off_abs_sum = 0.0 + for c, v in row_maps[r].items(): + keep = c < r if lower else c > r + if keep: + tri_rows[r][c] = tri_rows[r].get(c, 0.0) + v + + for r, row in enumerate(tri_rows): for c, v in row.items(): - if c < r: - lower_row[c] = lower_row.get(c, 0.0) + v - off_abs_sum += abs(v) - lower_row[r] = off_abs_sum + 1.0 - row_maps[r] = lower_row + mag = abs(v) + row_off_abs[r] += mag + col_off_abs[c] += mag + + for r in range(n_rows): + # Make the generated triangular system stable for both A and op(A). + tri_rows[r][r] = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[r], col_off_abs[r]) + 1.0 + ) + row_maps = tri_rows cols_s = [] vals_s = [] @@ -214,15 +418,17 @@ def _accum(r, c, v): cols_s.append(c) vals_s.append(row[c]) indptr_list.append(len(cols_s)) - data = torch.tensor(vals_s, dtype=dtype, device=device) + data = _tensor_from_scalar_values(vals_s, dtype, device) indices = torch.tensor(cols_s, dtype=torch.int64, device=device) indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) return data, indices, indptr, (n_rows, n_cols) -def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): +def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torch.int64): """Sorted COO from CSR; optional shuffle/duplicate for csr|auto (与原先 CSV 行为一致).""" - data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) + data_c, row_c, col_c = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) if coo_mode in ("csr", "auto"): if data_c.numel() == 0: return data_c, row_c, col_c @@ -239,6 +445,55 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): return data_c, row_c, col_c +def _random_rhs_for_spsv(shape, value_dtype, device, op_mode="NON", seed=None): + n_rows, n_cols = int(shape[0]), int(shape[1]) + rhs_size = n_rows if op_mode == "NON" else n_cols + if seed is None: + return _randn_by_dtype(rhs_size, value_dtype, device) + rhs = _randn_by_dtype( + rhs_size, + value_dtype, + torch.device("cpu"), + generator=_generator_for_seed(seed), + ) + return rhs.to(device) + + +def _apply_csr_op(data, indices, indptr, x, shape, op_mode): + n_rows, n_cols = int(shape[0]), int(shape[1]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + col = indices.to(torch.int64) + if op_mode == "NON": + b = torch.zeros(n_rows, dtype=data.dtype, device=data.device) + b.scatter_add_(0, row, data * x[col]) + return b + if op_mode == "TRANS": + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + b.scatter_add_(0, col, data * x[row]) + return b + if op_mode == "CONJ": + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + data_eff = data.conj() if torch.is_complex(data) else data + b.scatter_add_(0, col, data_eff * x[row]) + return b + raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") + + +def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, op_mode): + atol, rtol = _tol_for_dtype(value_dtype) + b_recon = _apply_csr_op(data, indices, indptr, x, shape, op_mode) + err_res = ( + float(torch.max(torch.abs(b_recon - b)).item()) + if b.numel() > 0 + else 0.0 + ) + ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) + return err_res, ok_res + + def _cupy_spsolve_lower_csr_or_coo( fmt, data, @@ -248,6 +503,7 @@ def _cupy_spsolve_lower_csr_or_coo( b, warmup, iters, + lower, ): """Triangular solve via CuPy: CSR or COO storage. Returns (ms, x_torch) or (None, None).""" if ( @@ -279,7 +535,7 @@ def _cupy_spsolve_lower_csr_or_coo( A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) for _ in range(warmup): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() t0 = cp.cuda.Event() @@ -287,18 +543,68 @@ def _cupy_spsolve_lower_csr_or_coo( t0.record() for _ in range(iters): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) t1.record() t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(b.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cu_t = x_cu_t.to(b.dtype) return cupy_ms, x_cu_t except Exception: return None, None -def run_spsv_synthetic_all(): +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): + if ( + cp is None + or cpx_sparse is None + or cpx_spsolve_triangular is None + ): + return None, None + try: + data_ref, b_ref = _cupy_ref_inputs(data, b) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous()) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous()) + ) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + if op_mode == "TRANS": + A_eff = A_cp.transpose().tocsr() + lower_eff = not lower + elif op_mode == "CONJ": + A_eff = A_cp.transpose().conj().tocsr() + lower_eff = not lower + else: + A_eff = A_cp + lower_eff = lower + + for _ in range(WARMUP): + _ = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + x_cp = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()).to(b.dtype) + return ms, x_t + except Exception: + return None, None + + +def run_spsv_synthetic_all(lower=True): if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -309,10 +615,11 @@ def run_spsv_synthetic_all(): print(sep) print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print(f"Triangle: {'LOWER' if lower else 'UPPER'}") print() hdr = ( - f"{'Fmt':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " + f"{'Fmt':>5} {'opA':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " f"{'FS/PT':>8} {'FS/CU':>8} {'Status':>8} {'Err(PT)':>12} {'Err(CU)':>12}" ) @@ -330,101 +637,134 @@ def run_spsv_synthetic_all(): print("-" * 110) for n in TEST_SIZES: for fmt in ("CSR", "COO"): - data, indices, indptr, shape = _build_random_triangular_csr( - n, value_dtype, index_dtype, device, lower=True + op_modes = ( + _supported_csr_full_ops(value_dtype, index_dtype) + if fmt == "CSR" + else ["NON"] ) - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) - x_true = torch.randn(n, dtype=value_dtype, device=device) - b = A_dense @ x_true - - torch.cuda.synchronize() - if fmt == "CSR": - x, t_ms = fs.flagsparse_spsv_csr( - data, - indices, - indptr, - b, - shape, - lower=True, - return_time=True, + for op_mode in op_modes: + data, indices, indptr, shape = _build_random_triangular_csr( + n, value_dtype, index_dtype, device, lower=lower ) - else: - dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) - x, t_ms = fs.flagsparse_spsv_coo( - dc, - rr, - cc, - b, - shape, - lower=True, - coo_mode="auto", - return_time=True, + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape ) - torch.cuda.synchronize() - - A_ref = A_dense - b_ref = b - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_pt = torch.linalg.solve( - A_ref, b_ref.unsqueeze(1) - ).squeeze(1) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 - - cupy_ms = None - err_cu = None - x_cu_t = None - if value_dtype in (torch.float32, torch.float64): - cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( - fmt, - data, - indices, - indptr, + rhs_op = op_mode if fmt == "CSR" else "NON" + b = _random_rhs_for_spsv( shape, - b, - WARMUP, - ITERS, + value_dtype, + device, + op_mode=rhs_op, + seed=_stable_case_seed( + "synthetic", + "LOWER" if lower else "UPPER", + fmt, + op_mode, + n, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), ) + + torch.cuda.synchronize() + if fmt == "CSR": + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + return_time=True, + ) + else: + dc, rr, cc = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) + x, t_ms = fs.flagsparse_spsv_coo( + dc, + rr, + cc, + b, + shape, + lower=lower, + coo_mode="auto", + return_time=True, + ) + torch.cuda.synchronize() + + A_ref = A_dense + b_ref = b + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_pt = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode=op_mode + ) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 + + cupy_ms = None + err_cu = None + x_cu_t = None + if fmt == "CSR": + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + elif value_dtype in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( + fmt, + data, + indices, + indptr, + shape, + b, + WARMUP, + ITERS, + lower, + ) if x_cu_t is not None and n > 0: err_cu = float( torch.max(torch.abs(x - x_cu_t)).item() ) - atol, rtol = _tol_for_dtype(value_dtype) - ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) - ok_cu = ( - True - if x_cu_t is None - else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) - ) - ok = ok_pt or ok_cu - status = "PASS" if ok else "FAIL" - if not ok: - failed += 1 - total += 1 - - fs_vs_pt = ( - (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None - ) - fs_vs_cu = ( - (cupy_ms / t_ms) - if (cupy_ms is not None and t_ms and t_ms > 0) - else None - ) - print( - f"{fmt:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " - f"{_fmt_ms(cupy_ms):>10} " - f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " - f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " - f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" - ) + atol, rtol = _tol_for_dtype(value_dtype) + ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) + ok_cu = ( + True + if x_cu_t is None + else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ) + ok = ok_pt or ok_cu + status = "PASS" if ok else "FAIL" + if not ok: + failed += 1 + total += 1 + + fs_vs_pt = ( + (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None + ) + fs_vs_cu = ( + (cupy_ms / t_ms) + if (cupy_ms is not None and t_ms and t_ms > 0) + else None + ) + print( + f"{fmt:>5} {op_mode:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " + f"{_fmt_ms(cupy_ms):>10} " + f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " + f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " + f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" + ) print("-" * 110) print() @@ -433,37 +773,28 @@ def run_spsv_synthetic_all(): print(sep) -def _run_one_csv_row_csr(path, value_dtype, index_dtype, device): +def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, dtype=value_dtype, device=device, lower=lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True - ) - x, t_ms = fs.flagsparse_spsv_csr( - data, indices, indptr, b, shape, lower=True, return_time=True - ) - return _finalize_csv_row( - path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, n_rows, n_cols, - ) - - -def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): - data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device - ) - indices = indices.to(index_dtype) - n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "csv-coo", + os.path.basename(path), + "LOWER" if lower else "UPPER", + coo_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), ) d_in, r_in, c_in = _coo_inputs_for_csv( - data, indices, indptr, shape, coo_mode + data, indices, indptr, shape, coo_mode, index_dtype=index_dtype ) x, t_ms = fs.flagsparse_spsv_coo( d_in, @@ -471,7 +802,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): c_in, b, shape, - lower=True, + lower=lower, coo_mode=coo_mode, return_time=True, ) @@ -488,6 +819,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): b, n_rows, n_cols, + lower=lower, nnz_display=int(d_in.numel()), cupy_coo_data=d_in, cupy_coo_row=r_in, @@ -509,12 +841,16 @@ def _finalize_csv_row( n_rows, n_cols, *, + lower=True, nnz_display=None, cupy_coo_data=None, cupy_coo_row=None, cupy_coo_col=None, ): atol, rtol = _tol_for_dtype(value_dtype) + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, "NON" + ) pytorch_ms = None err_pt = None ok_pt = False @@ -524,22 +860,27 @@ def _finalize_csv_row( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr, shape ) - A_ref = A_dense - b_ref = b + ref_dtype = _dense_ref_dtype(value_dtype) + A_ref = A_dense.to(ref_dtype) + b_ref = b.to(ref_dtype) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() - x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) + x_ref = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode="NON" + ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() torch.cuda.synchronize() pytorch_ms = e0.elapsed_time(e1) err_pt = ( - float(torch.max(torch.abs(x - x_ref)).item()) + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_pt = torch.allclose(x, x_ref, atol=atol, rtol=rtol) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) except RuntimeError as e: if "out of memory" in str(e).lower(): pt_skip_reason = "PyTorch dense ref OOM; skipped" @@ -558,17 +899,18 @@ def _finalize_csv_row( cp is not None and cpx_sparse is not None and cpx_spsolve_triangular is not None - and value_dtype in (torch.float32, torch.float64) ): try: - b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b.contiguous())) + data_ref, b_ref = _cupy_ref_inputs(data, b) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) if ( cupy_coo_data is not None and cupy_coo_row is not None and cupy_coo_col is not None ): + coo_data_ref, _ = _cupy_ref_inputs(cupy_coo_data, b) data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(cupy_coo_data.contiguous()) + torch.utils.dlpack.to_dlpack(coo_data_ref.contiguous()) ) row_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -585,7 +927,7 @@ def _finalize_csv_row( ) else: data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(data.contiguous()) + torch.utils.dlpack.to_dlpack(data_ref.contiguous()) ) idx_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -600,7 +942,7 @@ def _finalize_csv_row( ) for _ in range(WARMUP): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() c0 = cp.cuda.Event() @@ -608,18 +950,20 @@ def _finalize_csv_row( c0.record() for _ in range(ITERS): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) c1.record() c1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(x.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) err_cu = ( - float(torch.max(torch.abs(x - x_cu_t)).item()) + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_cu = torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) except Exception: cupy_ms = None err_cu = None @@ -627,6 +971,8 @@ def _finalize_csv_row( status = "PASS" if (ok_pt or ok_cu) else "FAIL" if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None nnz_out = ( int(data.numel()) if nnz_display is None else int(nnz_display) @@ -643,59 +989,336 @@ def _finalize_csv_row( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, + "err_ref": err_ref, + "err_res": err_res, + "err_pt": err_pt, + "err_cu": err_cu, + } + return row, pt_skip_reason + + +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "csv-csr", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + return_time=True, + ) + return _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, + lower=lower, + ) + + +def _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, + lower=True, +): + atol, rtol = _tol_for_dtype(value_dtype) + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, op_mode + ) + + pytorch_ms = None + err_pt = None + ok_pt = False + pt_skip_reason = None + if _allow_dense_pytorch_ref(shape, value_dtype): + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise + else: + pt_skip_reason = ( + f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + ) + + cupy_ms = None + err_cu = None + ok_cu = False + x_cu_t = None + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + if x_cu_t is not None: + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) + err_cu = ( + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) + + status = "PASS" if (ok_pt or ok_cu) else "FAIL" + if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): + status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None + + row = { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": int(data.numel()), + "triton_ms": t_ms, + "pytorch_ms": pytorch_ms, + "cusparse_ms": cupy_ms, + "csc_ms": None, + "status": status, + "err_ref": err_ref, + "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, } return row, pt_skip_reason -def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): +def run_all_supported_spsv_csr_csv( + mtx_paths, + csv_path, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): if not torch.cuda.is_available(): print("CUDA is not available.") return device = torch.device("cuda") rows_out = [] - label = "COO" if use_coo else "CSR" - cu_col = "COO(ms)" if use_coo else "CSR(ms)" - fs_cu_hdr = "FS/COO" if use_coo else "FS/CSR" - for value_dtype in VALUE_DTYPES: - for index_dtype in INDEX_DTYPES: - atol, rtol = _tol_for_dtype(value_dtype) - print("=" * 150) - print( - f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | {label}" - + (f" coo_mode={coo_mode}" if use_coo else "") - ) - if use_coo: + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = op_modes or SPSV_OP_MODES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + supported_op_modes = [ + op for op in _supported_csr_full_ops(value_dtype, index_dtype) + if op in selected_op_modes + ] + for op_mode in supported_op_modes: + print("=" * 150) print( - "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " - "b 由 CSR SpMV 构造,与 CSR 测试一致。" + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" ) - else: print( "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) + print( + "RHS is generated directly, matching Library-main's SpSV test style. " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|op(A)*x-b|, " + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device, lower=lower + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] + err_ref, err_res = row["err_ref"], row["err_res"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " + f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "csc_ms": None, + "status": status, + "err_ref": None, + "err_res": None, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " + f"{'N/A':>7} {'N/A':>7} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "opA", + "n_rows", + "n_cols", + "nnz", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "csc_ms", + "status", + "err_ref", + "err_res", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def run_all_dtypes_spsv_coo_csv( + mtx_paths, + csv_path, + coo_mode="auto", + lower=True, + value_dtypes=None, + index_dtypes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + selected_value_dtypes = value_dtypes or VALUE_DTYPES + selected_index_dtypes = index_dtypes or INDEX_DTYPES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | COO" + f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" + ) print( + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " + "RHS is generated directly, matching Library-main's SpSV test style." + ) + print( + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FlagSparse(ms)':>10} {cu_col:>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FlagSparse(ms)':>10} {'COO(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: try: - if use_coo: - row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, device, coo_mode - ) - else: - row, pt_skip = _run_one_csv_row_csr( - path, value_dtype, index_dtype, device - ) + row, pt_skip = _run_one_csv_row_coo( + path, value_dtype, index_dtype, device, coo_mode, lower=lower + ) rows_out.append(row) name = os.path.basename(path)[:27] if len(os.path.basename(path)) > 27: @@ -705,13 +1328,14 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] + err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -731,6 +1355,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") "cusparse_ms": None, "csc_ms": None, "status": status, + "err_ref": None, + "err_res": None, "err_pt": None, "err_cu": None, } @@ -741,7 +1367,7 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") print( f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " - f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") print("-" * 150) @@ -757,6 +1383,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") "cusparse_ms", "csc_ms", "status", + "err_ref", + "err_res", "err_pt", "err_cu", ] @@ -768,6 +1396,188 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") print(f"Wrote {len(rows_out)} rows to {csv_path}") +def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + trans_data, trans_indices64, trans_indptr64 = fs_spsv_impl._csr_transpose( + data, + indices.to(torch.int64), + indptr.to(torch.int64), + n_rows, + n_cols, + conjugate=(op_mode == "CONJ"), + ) + trans_shape = (n_cols, n_rows) + trans_indices = trans_indices64.to(index_dtype) + trans_indptr = trans_indptr64.to(index_dtype) + + probe = _random_rhs_for_spsv( + trans_shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "check-transpose-action", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + action_ref = _apply_csr_op(data, indices, indptr, probe, shape, op_mode) + action_trans = _apply_csr_op( + trans_data, trans_indices, trans_indptr, probe, trans_shape, "NON" + ) + action_err = ( + float(torch.max(torch.abs(action_trans - action_ref)).item()) + if action_ref.numel() > 0 + else 0.0 + ) + atol, rtol = _tol_for_dtype(value_dtype) + action_ok = torch.allclose(action_trans, action_ref, atol=atol, rtol=rtol) + + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "check-transpose-solve", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x_op = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + x_mat = fs.flagsparse_spsv_csr( + trans_data, + trans_indices, + trans_indptr, + b, + trans_shape, + lower=not lower, + transpose="NON", + ) + solve_err = ( + float(torch.max(torch.abs(x_op - x_mat)).item()) if x_op.numel() > 0 else 0.0 + ) + solve_ok = torch.allclose(x_op, x_mat, atol=atol, rtol=rtol) + + ref_err = None + ref_ok = None + if _allow_dense_pytorch_ref(shape, value_dtype): + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + ref_err = ( + float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 + ) + ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) + + status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" + return { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "nnz": int(data.numel()), + "action_err": action_err, + "solve_err": solve_err, + "ref_err": ref_err, + "status": status, + } + + +def run_csr_transpose_check( + mtx_paths, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = [op for op in (op_modes or ("TRANS", "CONJ")) if op in ("TRANS", "CONJ")] + if not selected_op_modes: + print("--check-transpose only checks TRANS/CONJ; no matching op selected.") + return + + print("=" * 150) + print( + "CSR TRANS/CONJ preprocessing check: " + "ActionErr compares materialized op(A) against direct CSR scatter; " + "SolveErr compares transpose path against materialized NON path." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'dtype':>10} {'index':>7} {'opA':>5} " + f"{'N':>7} {'NNZ':>10} {'Status':>6} {'ActionErr':>10} {'SolveErr':>10} {'RefErr':>10}" + ) + print("-" * 150) + total = 0 + failed = 0 + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + for op_mode in selected_op_modes: + for path in mtx_paths: + try: + row = _check_one_csr_transpose_case( + path, + value_dtype, + index_dtype, + op_mode, + device, + lower=lower, + ) + total += 1 + failed += int(row["status"] != "PASS") + name = row["matrix"][:27] + if len(row["matrix"]) > 27: + name += "..." + print( + f"{name:<28} {row['value_dtype']:>10} {row['index_dtype']:>7} {row['opA']:>5} " + f"{row['n_rows']:>7} {row['nnz']:>10} {row['status']:>6} " + f"{_fmt_err(row['action_err']):>10} {_fmt_err(row['solve_err']):>10} {_fmt_err(row['ref_err']):>10}" + ) + except Exception as e: + total += 1 + failed += 1 + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name += "..." + print( + f"{name:<28} {_dtype_name(value_dtype):>10} {_dtype_name(index_dtype):>7} {op_mode:>5} " + f"{'ERR':>7} {'ERR':>10} {'ERROR':>6} " + f"{_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" ERROR: {e}") + print("-" * 150) + print(f"Total cases: {total} Failed: {failed}") + + def main(): parser = argparse.ArgumentParser( description="SpSV test: synthetic triangular systems and optional .mtx (CSR/COO), same baselines as CSR." @@ -785,7 +1595,7 @@ def main(): type=str, default=None, metavar="FILE", - help="Run all dtypes/index dtypes on .mtx (CSR SpSV) and export CSV", + help="Run full supported CSR SpSV combinations (dtype/index/opA) on .mtx and export CSV", ) parser.add_argument( "--csv-coo", @@ -794,6 +1604,11 @@ def main(): metavar="FILE", help="Run all dtypes on .mtx (COO SpSV), same CSV columns as --csv-csr", ) + parser.add_argument( + "--check-transpose", + action="store_true", + help="Check CSR TRANS/CONJ preprocessing against direct CSR scatter and materialized NON solve", + ) parser.add_argument( "--coo-mode", type=str, @@ -801,10 +1616,34 @@ def main(): choices=["auto", "direct", "csr"], help="COO mode for --csv-coo (default: auto)", ) + parser.add_argument( + "--upper", + action="store_true", + help="Use upper-triangular inputs instead of the default lower-triangular inputs", + ) + parser.add_argument( + "--ops", + type=str, + default=None, + help="Comma-separated opA filter for CSR CSV, e.g. TRANS,CONJ", + ) + parser.add_argument( + "--value-dtypes", + type=str, + default=None, + help="Comma-separated value dtype filter for CSR CSV, e.g. float,double,complex64,complex128", + ) + parser.add_argument( + "--index-dtypes", + type=str, + default=None, + help="Comma-separated index dtype filter for CSR CSV, e.g. int32,int64", + ) args = parser.parse_args() + lower = not args.upper if args.synthetic: - run_spsv_synthetic_all() + run_spsv_synthetic_all(lower=lower) return paths = [] @@ -813,22 +1652,90 @@ def main(): paths.append(p) elif os.path.isdir(p): paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if args.check_transpose: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --check-transpose") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_csr_transpose_check( + paths, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) + return if args.csv_csr: if not paths: paths = sorted(glob.glob("*.mtx")) if not paths: print("No .mtx files found for --csv-csr") return - run_all_dtypes_spsv_csv(paths, args.csv_csr, use_coo=False) + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_all_supported_spsv_csr_csv( + paths, + args.csv_csr, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) return if args.csv_coo: + if args.ops: + parser.error("--ops is only supported with --csv-csr; COO tests only run opA=NON") if not paths: paths = sorted(glob.glob("*.mtx")) if not paths: print("No .mtx files found for --csv-coo") return - run_all_dtypes_spsv_csv( - paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + run_all_dtypes_spsv_coo_csv( + paths, + args.csv_coo, + coo_mode=args.coo_mode, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, ) return @@ -836,4 +1743,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()