From 3f70f442fcd846829cddef3fd0a6df724f6c4028 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 13 Apr 2026 15:25:56 +0800 Subject: [PATCH 1/5] test --- src/flagsparse/sparse_operations/spsv.py | 428 ++++++++++++++++--- tests/pytest/test_spsv_csr_accuracy.py | 192 ++++++++- tests/test_spsv.py | 501 +++++++++++++++++++++-- 3 files changed, 1041 insertions(+), 80 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index cdfbf50..2976ad1 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -11,11 +11,27 @@ torch.bfloat16, torch.float32, torch.float64, + *((_torch_complex32_dtype(),) if _torch_complex32_dtype() is not None else ()), + torch.complex64, ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), + *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int32), +) +SPSV_NON_TRANS_EXTENDED_COMBOS = ( + (torch.float32, torch.int64), + (torch.float64, torch.int64), + *(((_torch_complex32_dtype(), torch.int64),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int64), +) +SPSV_TRANS_PRIMARY_COMBOS = ( + (torch.float32, torch.int32), + (torch.float64, torch.int32), + *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int32), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") @@ -46,15 +62,40 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_PRIMARY_COMBOS: return + if (data_dtype, index_dtype) in SPSV_NON_TRANS_EXTENDED_COMBOS: + return if data_dtype == torch.bfloat16 and index_dtype == torch.int32: return raise TypeError( - f"{fmt_name} SpSV currently supports NON_TRANS combinations with int32 kernel " - "indices: (float32, int32), (float64, int32), (bfloat16, int32)" + f"{fmt_name} SpSV currently supports NON_TRANS combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" ) -def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): +def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): + if (data_dtype, index_dtype) in SPSV_TRANS_PRIMARY_COMBOS: + return + raise TypeError( + f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " + "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" + ) + + +def _normalize_spsv_transpose_mode(transpose): + if isinstance(transpose, bool): + return "T" if transpose else "N" + token = str(transpose).strip().upper() + if token in ("N", "NON", "NON_TRANS"): + return "N" + if token in ("T", "TRANS"): + return "T" + raise ValueError( + "transpose must be bool or one of: N/NON/NON_TRANS, T/TRANS" + ) + + +def _prepare_spsv_inputs(data, indices, indptr, b, shape): """Validate and normalize inputs for sparse solve A x = b with CSR A.""" if not all(torch.is_tensor(t) for t in (data, indices, indptr, b)): raise TypeError("data, indices, indptr, b must all be torch.Tensor") @@ -77,7 +118,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( - "data dtype must be one of: bfloat16, float32, float64" + "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -85,8 +126,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise TypeError("indptr dtype must be torch.int32 or torch.int64") if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") indices64 = indices.to(torch.int64).contiguous() indptr64 = indptr.to(torch.int64).contiguous() @@ -94,8 +133,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise ValueError( f"int64 index value {int(indices64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "CSR") - if indptr64.numel() > 0: if int(indptr64[0].item()) != 0: raise ValueError("indptr[0] must be 0") @@ -112,6 +149,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): return ( data.contiguous(), + indices.dtype, indices64, indptr64, b.contiguous(), @@ -120,6 +158,21 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): ) +def _promote_complex32_spsv_inputs(data, b): + if _is_complex32_dtype(data.dtype): + return data.to(torch.complex64), b.to(torch.complex64), data.dtype + return data, b, None + + +def _restore_complex32_spsv_output(x, target_dtype): + if _is_complex32_dtype(target_dtype): + limit = 65504.0 + real = torch.clamp(x.real, min=-limit, max=limit).to(torch.float16) + imag = torch.clamp(x.imag, min=-limit, max=limit).to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return x.to(target_dtype) + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -172,6 +225,104 @@ def _spsv_csr_level_kernel( tl.store(x_ptr + row, x_row) +@triton.jit +def _spsv_csr_level_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + b_ri_ptr, + x_ri_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + acc_re = tl.zeros((1,), dtype=tl.float64) + acc_im = tl.zeros((1,), dtype=tl.float64) + diag_re = tl.zeros((1,), dtype=tl.float64) + diag_im = tl.zeros((1,), dtype=tl.float64) + else: + acc_re = tl.zeros((1,), dtype=tl.float32) + acc_im = tl.zeros((1,), dtype=tl.float32) + diag_re = tl.zeros((1,), dtype=tl.float32) + diag_im = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag_re = diag_re + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) + x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) + + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + prod_re = a_re * x_re - a_im * x_im + prod_im = a_re * x_im + a_im * x_re + acc_re = acc_re + tl.sum(tl.where(mask & solved, prod_re, 0.0)) + acc_im = acc_im + tl.sum(tl.where(mask & solved, prod_im, 0.0)) + + if not UNIT_DIAG: + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -360,6 +511,66 @@ def _triton_spsv_csr_vector( return x +def _triton_spsv_csr_vector_complex( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_level_kernel_complex[grid]( + data_ri, + indices, + indptr, + b_ri, + x_ri, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -378,7 +589,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") - if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: + if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): raise TypeError("data dtype must be one of: bfloat16, float32, float64") if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") @@ -395,7 +606,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise ValueError( f"int64 index value {int(col64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") if row64.numel() > 0: if bool(torch.any(row64 < 0).item()): raise IndexError("row indices must be non-negative") @@ -408,6 +618,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if max_col >= n_cols: raise IndexError(f"col indices out of range for n_cols={n_cols}") + _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") return ( data.contiguous(), row64, @@ -418,6 +629,44 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): ) +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = indices64 + new_col = row_ids + data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( + data, new_row, new_col, n_cols, n_rows + ) + return data_t, indices_t, indptr_t + + +def _csr_reverse_rows_cols(data, indices64, indptr64, n_rows): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = (n_rows - 1) - row_ids + new_col = (n_rows - 1) - indices64 + data_r, indices_r, indptr_r = _coo_to_csr_sorted_unique( + data, new_row, new_col, n_rows, n_rows + ) + return data_r, indices_r, indptr_r + + def _coo_is_sorted_unique(row64, col64, n_cols): nnz = row64.numel() if nnz <= 1: @@ -532,30 +781,61 @@ def flagsparse_spsv_csr( ): """Sparse triangular solve using Triton level-scheduling kernels. - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Primary support matrix: + - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices + - TRANS: float32/float64/complex32/complex64 with int32 indices + - bfloat16 remains NON_TRANS + int32 """ - data, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( - data, indices, indptr, b, shape, transpose=transpose + trans_mode = _normalize_spsv_transpose_mode(transpose) + data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( + data, indices, indptr, b, shape ) + original_output_dtype = None + data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") - kernel_indices = indices.to(torch.int32) if indices.dtype != torch.int32 else indices - kernel_indptr = indptr + if trans_mode == "N": + _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") + lower_eff = lower + kernel_data = data + kernel_indices64 = indices + kernel_indptr64 = indptr + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + lower_eff = not lower + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, indices, indptr, n_rows, n_cols + ) + + kernel_indices = ( + kernel_indices64.to(torch.int32) + if kernel_indices64.dtype != torch.int32 + else kernel_indices64 + ) + kernel_indptr = kernel_indptr64 compute_dtype = data.dtype - data_in = data + data_in = kernel_data b_in = b if data.dtype == torch.bfloat16: compute_dtype = torch.float32 - data_in = data.to(torch.float32) + data_in = kernel_data.to(torch.float32) b_in = b.to(torch.float32) + elif data.dtype == torch.complex64 and trans_mode == "T": + compute_dtype = torch.complex128 + data_in = kernel_data.to(torch.complex128) + b_in = b.to(torch.complex128) elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: # Optional high-precision mode; disabled by default for throughput. compute_dtype = torch.float64 - data_in = data.to(torch.float64) + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) + elif data.dtype == torch.float32 and trans_mode == "T": + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - levels = _build_spsv_levels(kernel_indptr, kernel_indices, n_rows, lower=lower) + levels = _build_spsv_levels( + kernel_indptr, kernel_indices, n_rows, lower=lower_eff + ) block_nnz_use, max_segments_use = _auto_spsv_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) @@ -563,44 +843,82 @@ def flagsparse_spsv_csr( torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: - x = _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in, - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - ) + if torch.is_complex(data_in): + x = _triton_spsv_csr_vector_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + x = _triton_spsv_csr_vector( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) else: cols = [] for j in range(b_in.shape[1]): - cols.append( - _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in[:, j].contiguous(), - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + bj = b_in[:, j].contiguous() + if torch.is_complex(data_in): + cols.append( + _triton_spsv_csr_vector_complex( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + cols.append( + _triton_spsv_csr_vector( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) ) - ) x = torch.stack(cols, dim=1) - if compute_dtype != data.dtype: - x = x.to(data.dtype) + target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype + if x.dtype != target_dtype: + x = _restore_complex32_spsv_output(x, target_dtype) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -747,4 +1065,4 @@ def flagsparse_spsv_coo( if return_time: return x, elapsed_ms - return x \ No newline at end of file + return x diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index f2cf944..94788f2 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -5,8 +5,105 @@ from tests.pytest.param_shapes import SPSV_N +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +COMPLEX32_DTYPE = getattr(torch, "complex32", None) +if COMPLEX32_DTYPE is None: + COMPLEX32_DTYPE = getattr(torch, "chalf", None) + +SUPPORTED_COMPLEX_DTYPES = [] +if COMPLEX32_DTYPE is not None: + SUPPORTED_COMPLEX_DTYPES.append(COMPLEX32_DTYPE) +SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] + + +def _dtype_id(dtype): + return str(dtype).replace("torch.", "") + + +def _tol(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return 5e-3, 5e-3 + if dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-3 + return 1e-10, 1e-8 + + +def _rand_like(dtype, shape, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(shape, dtype=dtype, device=device) + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 + return torch.view_as_complex(pair) + base = torch.float32 + r = torch.randn(shape, dtype=base, device=device) + i = torch.randn(shape, dtype=base, device=device) + return torch.complex(r, i) + + +def _ref_dtype(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return torch.complex64 + return dtype + + +def _safe_cast_tensor(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = tensor.real.to(torch.float16) + imag = tensor.imag.to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return tensor.to(dtype) + + +def _cmp_view(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return tensor.to(torch.complex64) + return tensor + + +def _build_lower_triangular(n, dtype, device): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) + if torch.is_complex(A): + diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) + A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) + else: + diag = torch.rand(n, device=device, dtype=A.dtype) + 2.0 + A = A + torch.diag(diag) + return A + + +def _cupy_csr_from_torch(data, indices, indptr, shape): + if cp is None or cpx_sparse is None: + return None + data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) + return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + + +def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): + if cp is None or cpx_spsolve_triangular is None: + return None + b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) + if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: + return x_t.to(torch.complex64) + return x_t.to(b_t.dtype) @pytest.mark.spsv @@ -17,12 +114,12 @@ ids=["float32", "float64"], ) def test_spsv_csr_lower_matches_dense(n, dtype): + # Keep the original baseline test case untouched in semantics. device = torch.device("cuda") base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) eye = torch.eye(n, dtype=dtype, device=device) A = base + eye * (float(n) * 0.5 + 2.0) b = torch.randn(n, dtype=dtype, device=device) - # PyTorch 2.x requires B with rank >= 2 for solve_triangular. x_ref = torch.linalg.solve_triangular( A, b.unsqueeze(-1), upper=False ).squeeze(-1) @@ -42,3 +139,96 @@ def test_spsv_csr_lower_matches_dense(n, dtype): rtol = 1e-4 if dtype == torch.float32 else 1e-10 atol = 1e-5 if dtype == torch.float32 else 1e-10 assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=True, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) + + x_trans = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True + ) + x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 3123774..162dfc8 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -29,6 +29,20 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +FLOAT16_LIMIT = 65504.0 +COMPLEX32_DTYPE = getattr(torch, "complex32", None) +if COMPLEX32_DTYPE is None: + COMPLEX32_DTYPE = getattr(torch, "chalf", None) + +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, +] +if COMPLEX32_DTYPE is not None: + CSR_FULL_VALUE_DTYPES.append(COMPLEX32_DTYPE) +CSR_FULL_VALUE_DTYPES.append(torch.complex64) +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] def _dtype_name(dtype): @@ -50,11 +64,81 @@ def _fmt_err(v): def _tol_for_dtype(dtype): - if dtype == torch.float32: + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return 5e-3, 5e-3 + if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 +def _randn_by_dtype(n, dtype, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(n, dtype=dtype, device=device) + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 + return torch.view_as_complex(pair) + base = torch.float32 + real = torch.randn(n, dtype=base, device=device) + imag = torch.randn(n, dtype=base, device=device) + return torch.complex(real, imag) + + +def _dense_ref_dtype(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return torch.complex64 + return dtype + + +def _tensor_from_scalar_values(values, dtype, device): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = torch.clamp( + torch.tensor(values, dtype=torch.float32, device=device), + min=-FLOAT16_LIMIT, + max=FLOAT16_LIMIT, + ).to(torch.float16) + imag = torch.zeros_like(real) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return torch.tensor(values, dtype=dtype, device=device) + + +def _safe_cast_tensor(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return tensor.to(dtype) + + +def _cast_real_tensor_to_value_dtype(values, value_dtype): + if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + imag = torch.zeros_like(real) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return values.to(value_dtype) + + +def _cupy_ref_inputs(data, b): + if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: + return data.to(torch.complex64), b.to(torch.complex64) + return data, b + + +def _compare_view(tensor, value_dtype): + if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + return tensor.to(torch.complex64) + return tensor + + +def _supported_csr_full_ops(value_dtype, index_dtype): + if value_dtype not in CSR_FULL_VALUE_DTYPES: + return [] + if index_dtype == torch.int32: + return ["NON", "TRANS"] + if index_dtype == torch.int64: + return ["NON"] + return [] + + def _allow_dense_pytorch_ref(shape, dtype): n_rows, n_cols = int(shape[0]), int(shape[1]) elem_bytes = torch.empty((), dtype=dtype).element_size() @@ -68,9 +152,14 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_host = [] cols_host = [] vals_host = [] - base_real_dtype = ( - torch.float32 if value_dtype == torch.float32 else torch.float64 - ) + if value_dtype == torch.float32: + base_real_dtype = torch.float32 + elif value_dtype == torch.float64: + base_real_dtype = torch.float64 + elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + base_real_dtype = torch.float16 + else: + base_real_dtype = torch.float32 for i in range(n): if lower: @@ -100,7 +189,10 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) - vals_t = torch.tensor(vals_host, dtype=base_real_dtype, device=device).to(value_dtype) + vals_t = _cast_real_tensor_to_value_dtype( + torch.tensor(vals_host, dtype=base_real_dtype, device=device), + value_dtype, + ) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -114,15 +206,16 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True def _csr_to_dense(data, indices, indptr, shape): n_rows, n_cols = shape + coo_data = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), + torch.arange(n_rows, device=coo_data.device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) coo = torch.sparse_coo_tensor( torch.stack([row_ind, indices.to(torch.int64)]), - data, + coo_data, (n_rows, n_cols), - device=data.device, + device=coo_data.device, ).coalesce() return coo.to_dense() @@ -137,6 +230,33 @@ def _csr_to_coo(data, indices, indptr, shape): return data, row, col +def _csr_transpose(data, indices, indptr, shape): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device), + ) + + row, col = _csr_to_coo(data, indices, indptr, shape)[1:] + row_t = col + col_t = row + key = row_t * max(1, n_rows) + col_t + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + + row_t = row_t[order] + col_t = col_t[order] + data_t = data[order] + nnz_per_row = torch.bincount(row_t, minlength=n_cols) + indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_t, col_t.to(torch.int64), indptr_t + + def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -214,7 +334,7 @@ def _accum(r, c, v): cols_s.append(c) vals_s.append(row[c]) indptr_list.append(len(cols_s)) - data = torch.tensor(vals_s, dtype=dtype, device=device) + data = _tensor_from_scalar_values(vals_s, dtype, device) indices = torch.tensor(cols_s, dtype=torch.int64, device=device) indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) return data, indices, indptr, (n_rows, n_cols) @@ -239,6 +359,46 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): return data_c, row_c, col_c +def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): + if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: + data_ref = data.to(torch.complex64) + x_ref = x_true.to(torch.complex64) + if op_mode == "NON": + b_ref, _ = fs.flagsparse_spmv_csr( + data_ref, indices, indptr, x_ref, shape, return_time=True + ) + return _safe_cast_tensor(b_ref, x_true.dtype) + if op_mode == "TRANS": + data_t, indices_t, indptr_t = _csr_transpose(data_ref, indices, indptr, shape) + b_ref, _ = fs.flagsparse_spmv_csr( + data_t, + indices_t.to(indices.dtype), + indptr_t.to(indptr.dtype), + x_ref, + (shape[1], shape[0]), + return_time=True, + ) + return _safe_cast_tensor(b_ref, x_true.dtype) + raise ValueError("op_mode must be 'NON' or 'TRANS'") + if op_mode == "NON": + b, _ = fs.flagsparse_spmv_csr( + data, indices, indptr, x_true, shape, return_time=True + ) + return b + if op_mode == "TRANS": + data_t, indices_t, indptr_t = _csr_transpose(data, indices, indptr, shape) + b, _ = fs.flagsparse_spmv_csr( + data_t, + indices_t.to(indices.dtype), + indptr_t.to(indptr.dtype), + x_true, + (shape[1], shape[0]), + return_time=True, + ) + return b + raise ValueError("op_mode must be 'NON' or 'TRANS'") + + def _cupy_spsolve_lower_csr_or_coo( fmt, data, @@ -292,12 +452,62 @@ def _cupy_spsolve_lower_csr_or_coo( t1.record() t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(b.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: + x_cu_t = x_cu_t.to(torch.complex64) + else: + x_cu_t = x_cu_t.to(b.dtype) return cupy_ms, x_cu_t except Exception: return None, None +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): + if ( + cp is None + or cpx_sparse is None + or cpx_spsolve_triangular is None + ): + return None, None + try: + data_ref, b_ref = _cupy_ref_inputs(data, b) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous()) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous()) + ) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + if op_mode == "TRANS": + A_eff = A_cp.transpose().tocsr() + lower_eff = False + else: + A_eff = A_cp + lower_eff = True + + for _ in range(WARMUP): + _ = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + x_cp = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()).to(b.dtype) + return ms, x_t + except Exception: + return None, None + + def run_spsv_synthetic_all(): if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") @@ -524,22 +734,25 @@ def _finalize_csv_row( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr, shape ) - A_ref = A_dense - b_ref = b + ref_dtype = _dense_ref_dtype(value_dtype) + A_ref = A_dense.to(ref_dtype) + b_ref = b.to(ref_dtype) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() torch.cuda.synchronize() pytorch_ms = e0.elapsed_time(e1) err_pt = ( - float(torch.max(torch.abs(x - x_ref)).item()) + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_pt = torch.allclose(x, x_ref, atol=atol, rtol=rtol) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) except RuntimeError as e: if "out of memory" in str(e).lower(): pt_skip_reason = "PyTorch dense ref OOM; skipped" @@ -558,17 +771,18 @@ def _finalize_csv_row( cp is not None and cpx_sparse is not None and cpx_spsolve_triangular is not None - and value_dtype in (torch.float32, torch.float64) ): try: - b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b.contiguous())) + data_ref, b_ref = _cupy_ref_inputs(data, b) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) if ( cupy_coo_data is not None and cupy_coo_row is not None and cupy_coo_col is not None ): + coo_data_ref, _ = _cupy_ref_inputs(cupy_coo_data, b) data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(cupy_coo_data.contiguous()) + torch.utils.dlpack.to_dlpack(coo_data_ref.contiguous()) ) row_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -585,7 +799,7 @@ def _finalize_csv_row( ) else: data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(data.contiguous()) + torch.utils.dlpack.to_dlpack(data_ref.contiguous()) ) idx_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -613,13 +827,15 @@ def _finalize_csv_row( c1.record() c1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(x.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) err_cu = ( - float(torch.max(torch.abs(x - x_cu_t)).item()) + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_cu = torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) except Exception: cupy_ms = None err_cu = None @@ -649,6 +865,243 @@ def _finalize_csv_row( return row, pt_skip_reason +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + x_true = _randn_by_dtype(n_rows, value_dtype, device) + b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=True, + transpose=(op_mode == "TRANS"), + return_time=True, + ) + return _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, + ) + + +def _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, +): + atol, rtol = _tol_for_dtype(value_dtype) + + pytorch_ms = None + err_pt = None + ok_pt = False + pt_skip_reason = None + if _allow_dense_pytorch_ref(shape, value_dtype): + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise + else: + pt_skip_reason = ( + f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + ) + + cupy_ms = None + err_cu = None + ok_cu = False + x_cu_t = None + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode + ) + if x_cu_t is not None: + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) + err_cu = ( + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) + + status = "PASS" if (ok_pt or ok_cu) else "FAIL" + if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): + status = "REF_FAIL" + + row = { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": int(data.numel()), + "triton_ms": t_ms, + "pytorch_ms": pytorch_ms, + "cusparse_ms": cupy_ms, + "csc_ms": None, + "status": status, + "err_pt": err_pt, + "err_cu": err_cu, + } + return row, pt_skip_reason + + +def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + for value_dtype in CSR_FULL_VALUE_DTYPES: + for index_dtype in CSR_FULL_INDEX_DTYPES: + op_modes = _supported_csr_full_ops(value_dtype, index_dtype) + for op_mode in op_modes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | opA={op_mode}" + ) + print( + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + ) + print( + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if either error within tolerance." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " + f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " + f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "csc_ms": None, + "status": status, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " + f"{'N/A':>7} {'N/A':>7} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "opA", + "n_rows", + "n_cols", + "nnz", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "csc_ms", + "status", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -785,7 +1238,7 @@ def main(): type=str, default=None, metavar="FILE", - help="Run all dtypes/index dtypes on .mtx (CSR SpSV) and export CSV", + help="Run full supported CSR SpSV combinations (dtype/index/opA) on .mtx and export CSV", ) parser.add_argument( "--csv-coo", @@ -819,7 +1272,7 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return - run_all_dtypes_spsv_csv(paths, args.csv_csr, use_coo=False) + run_all_supported_spsv_csr_csv(paths, args.csv_csr) return if args.csv_coo: if not paths: @@ -836,4 +1289,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 0e76bd8d1e404cd974487fd1519a86aff0dbc349 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Thu, 16 Apr 2026 18:34:21 +0800 Subject: [PATCH 2/5] complex128 --- .gitignore | 1 + benchmark/benchmark_spsv.py | 14 + src/flagsparse/sparse_operations/_common.py | 43 +- .../sparse_operations/gather_scatter.py | 55 +- src/flagsparse/sparse_operations/spsv.py | 199 ++++++- tests/pytest/test_gather_scatter_accuracy.py | 119 +--- tests/pytest/test_spsv_csr_accuracy.py | 245 +++++++- tests/test_gather.py | 14 +- tests/test_spsv.py | 545 ++++++++++++++---- 9 files changed, 901 insertions(+), 334 deletions(-) create mode 100644 benchmark/benchmark_spsv.py diff --git a/.gitignore b/.gitignore index 4ca5df5..e44113e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ .pytest_cache/ .coverage htmlcov/ +.DS_Store diff --git a/benchmark/benchmark_spsv.py b/benchmark/benchmark_spsv.py new file mode 100644 index 0000000..3fa258f --- /dev/null +++ b/benchmark/benchmark_spsv.py @@ -0,0 +1,14 @@ +"""Run SpSV benchmark. From project root: python benchmark/benchmark_spsv.py [--synthetic | --csv-csr out.csv].""" + +import sys +from pathlib import Path + +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) + +from tests.test_spsv import main + + +if __name__ == "__main__": + main() diff --git a/src/flagsparse/sparse_operations/_common.py b/src/flagsparse/sparse_operations/_common.py index 7e89526..6f7725d 100644 --- a/src/flagsparse/sparse_operations/_common.py +++ b/src/flagsparse/sparse_operations/_common.py @@ -18,19 +18,14 @@ cp = None cpx_sparse = None -_TORCH_COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if _TORCH_COMPLEX32_DTYPE is None: - _TORCH_COMPLEX32_DTYPE = getattr(torch, "chalf", None) - _SUPPORTED_VALUE_DTYPES = [ torch.float16, torch.bfloat16, torch.float32, torch.float64, + torch.complex64, + torch.complex128, ] -if _TORCH_COMPLEX32_DTYPE is not None: - _SUPPORTED_VALUE_DTYPES.append(_TORCH_COMPLEX32_DTYPE) -_SUPPORTED_VALUE_DTYPES.extend([torch.complex64, torch.complex128]) SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES) SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) _INDEX_LIMIT_INT32 = 2**31 - 1 @@ -40,8 +35,6 @@ "SUPPORTED_VALUE_DTYPES", "SUPPORTED_INDEX_DTYPES", "_INDEX_LIMIT_INT32", - "_torch_complex32_dtype", - "_is_complex32_dtype", "_is_complex_dtype", "_resolve_scatter_value_dtype", "_component_dtype_for_complex", @@ -68,17 +61,8 @@ "tl", ) - -def _torch_complex32_dtype(): - return _TORCH_COMPLEX32_DTYPE - - -def _is_complex32_dtype(value_dtype): - return _TORCH_COMPLEX32_DTYPE is not None and value_dtype == _TORCH_COMPLEX32_DTYPE - - def _is_complex_dtype(value_dtype): - return _is_complex32_dtype(value_dtype) or value_dtype in (torch.complex64, torch.complex128) + return value_dtype in (torch.complex64, torch.complex128) def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): @@ -95,24 +79,13 @@ def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): "complex64": torch.complex64, "complex128": torch.complex128, } - if token == "complex32": - if _TORCH_COMPLEX32_DTYPE is not None: - return _TORCH_COMPLEX32_DTYPE, False, None - if dtype_policy == "strict": - raise TypeError("complex32 is unavailable in this torch build") - return torch.complex64, True, "complex32 is unavailable; fallback to complex64" if token not in mapping: raise TypeError(f"Unsupported dtype token: {value_dtype}") value_dtype = mapping[token] - if _is_complex32_dtype(value_dtype): - # If complex32 exists in torch dtype table, keep native path. - return value_dtype, False, None return value_dtype, False, None def _component_dtype_for_complex(value_dtype): - if _is_complex32_dtype(value_dtype): - return torch.float16 if value_dtype == torch.complex64: return torch.float32 if value_dtype == torch.complex128: @@ -123,8 +96,6 @@ def _component_dtype_for_complex(value_dtype): def _tolerance_for_dtype(value_dtype): if value_dtype == torch.float16: return 2e-3, 2e-3 - if _is_complex32_dtype(value_dtype): - return 5e-3, 5e-3 if value_dtype == torch.bfloat16: return 1e-1, 1e-1 if value_dtype in (torch.float32, torch.complex64): @@ -155,9 +126,6 @@ def _cupy_dtype_from_torch(torch_dtype): torch.int32: cp.int32, torch.int64: cp.int64, } - if _TORCH_COMPLEX32_DTYPE is not None: - # CuPy has no native complex32 sparse path; use complex64 for baseline parity. - mapping[_TORCH_COMPLEX32_DTYPE] = cp.complex64 if torch_dtype not in mapping: raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}") return mapping[torch_dtype] @@ -193,8 +161,6 @@ def _to_backend_like(torch_tensor, ref_obj): def _cusparse_baseline_skip_reason(value_dtype): if value_dtype == torch.bfloat16: return "bfloat16 is not supported by the cuSPARSE baseline path; skipped" - if _is_complex32_dtype(value_dtype): - return "complex32 is not supported by the cuSPARSE baseline path; skipped" if cp is None and value_dtype == torch.float16: return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped" return None @@ -203,9 +169,6 @@ def _cusparse_baseline_skip_reason(value_dtype): def _build_random_dense(dense_size, value_dtype, device): if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): return torch.randn(dense_size, dtype=value_dtype, device=device) - if _is_complex32_dtype(value_dtype): - stacked = torch.randn((dense_size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) if _is_complex_dtype(value_dtype): component_dtype = _component_dtype_for_complex(value_dtype) real = torch.randn(dense_size, dtype=component_dtype, device=device) diff --git a/src/flagsparse/sparse_operations/gather_scatter.py b/src/flagsparse/sparse_operations/gather_scatter.py index 6a02d8e..67066b0 100644 --- a/src/flagsparse/sparse_operations/gather_scatter.py +++ b/src/flagsparse/sparse_operations/gather_scatter.py @@ -226,6 +226,10 @@ def _launch_triton_scatter_kernel( ) +def _validate_gather_value_dtype(dense_vector, op_name): + return None + + def _cusparse_spmv(selector_matrix, dense_vector): if cp is not None and cpx_sparse is not None and isinstance(selector_matrix, cpx_sparse.spmatrix): if torch.is_tensor(dense_vector): @@ -330,6 +334,7 @@ def flagsparse_gather(a, indices, out=None, mode="raise", block_size=1024, retur dense_vector, dense_backend = _to_torch_tensor(a, "a") indices_tensor, _ = _to_torch_tensor(indices, "indices") dense_vector, indices_tensor, kernel_indices = _prepare_inputs(dense_vector, indices_tensor) + _validate_gather_value_dtype(dense_vector, "flagsparse_gather") torch.cuda.synchronize() start_time = time.perf_counter() @@ -479,6 +484,7 @@ def pytorch_index_scatter( def cusparse_spmv_gather(dense_vector, indices, selector_matrix=None): """Equivalent gather baseline via cuSPARSE-backed COO SpMV.""" dense_vector, indices, _ = _prepare_inputs(dense_vector, indices) + _validate_gather_value_dtype(dense_vector, "cusparse_spmv_gather") skip_reason = _cusparse_baseline_skip_reason(dense_vector.dtype) if skip_reason: raise RuntimeError(skip_reason) @@ -557,21 +563,14 @@ def _cupy_gather_detect_layout(dense_vector): return "scalar16" if dense_vector.ndim == 1 and dense_vector.dtype == torch.complex64: return "complex64" - if ( - dense_vector.ndim == 2 - and dense_vector.shape[1] == 2 - and dense_vector.dtype == torch.float16 - ): - # complex32 alignment with half2 storage. - return "complex16_pair" raise TypeError( "Unsupported gather input format. Expected one of: " - "1D float16/bfloat16, 1D complex64, or 2D (N,2) float16." + "1D float16/bfloat16 or 1D complex64." ) def _cupy_gather_dense_size(dense_vector, layout): - if layout in ("scalar16", "complex64", "complex16_pair"): + if layout in ("scalar16", "complex64"): return int(dense_vector.shape[0]) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -601,7 +600,7 @@ def _cupy_gather_validate_inputs(dense_vector, indices): def _cupy_gather_validate_combo(dense_vector, indices, layout): # Keep only the required extra gather combos: - # Half+Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if layout == "scalar16": if dense_vector.dtype == torch.float16: if indices.dtype != torch.int64: @@ -611,11 +610,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): return raise TypeError("scalar16 gather_cupy supports only float16/bfloat16") - if layout == "complex16_pair": - if dense_vector.dtype != torch.float16: - raise TypeError("complex16_pair gather_cupy supports only float16 pairs") - return - if layout == "complex64": return @@ -625,8 +619,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): def _cupy_gather_layout_raw_kind(layout): if layout == "scalar16": return 16 - if layout == "complex16_pair": - return 32 if layout == "complex64": return 64 raise RuntimeError(f"Unknown gather layout: {layout}") @@ -675,9 +667,6 @@ def _cupy_gather_dense_to_raw_torch(dense_t, layout): return dense_t.reshape(-1).view(torch.uint16) if layout == "complex64": return dense_t.reshape(-1).view(torch.uint64) - if layout == "complex16_pair": - lanes_u16 = dense_t.reshape(-1).view(torch.uint16) - return lanes_u16.reshape(-1, 2).view(torch.uint32).reshape(-1) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -686,22 +675,17 @@ def _cupy_gather_raw_to_dense_torch(out_raw_t, layout, dense_t_dtype): return out_raw_t.view(dense_t_dtype).reshape(-1) if layout == "complex64": return out_raw_t.view(torch.complex64).reshape(-1) - if layout == "complex16_pair": - lanes_u16 = out_raw_t.view(torch.uint16).reshape(-1, 2) - return lanes_u16.view(dense_t_dtype).reshape(-1, 2) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_empty(layout, dense_dtype, device): if layout in ("scalar16", "complex64"): return torch.empty(0, dtype=dense_dtype, device=device) - if layout == "complex16_pair": - return torch.empty((0, 2), dtype=dense_dtype, device=device) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_selector_dtype(layout, dense_dtype): - if layout in ("scalar16", "complex16_pair"): + if layout == "scalar16": return dense_dtype if layout == "complex64": return torch.complex64 @@ -709,21 +693,12 @@ def _cupy_gather_selector_dtype(layout, dense_dtype): def _cupy_gather_prepare_dense(dense_vector, indices): - runtime_dense = dense_vector - restore_mode = None - native_complex32 = _torch_complex32_dtype() - if native_complex32 is not None and dense_vector.ndim == 1 and dense_vector.dtype == native_complex32: - runtime_dense = torch.view_as_real(dense_vector).contiguous() - restore_mode = "native_complex32" - - layout, dense_size = _cupy_gather_validate_inputs(runtime_dense, indices) - _cupy_gather_validate_combo(runtime_dense, indices, layout) - return runtime_dense, layout, dense_size, restore_mode + layout, dense_size = _cupy_gather_validate_inputs(dense_vector, indices) + _cupy_gather_validate_combo(dense_vector, indices, layout) + return dense_vector, layout, dense_size, None def _cupy_gather_restore_output(gathered_t, restore_mode): - if restore_mode == "native_complex32": - return torch.view_as_complex(gathered_t.contiguous()) return gathered_t @@ -880,10 +855,6 @@ def cusparse_spmv_gather_cupy(dense_vector, indices, selector_matrix=None): start_time = time.perf_counter() if layout in ("scalar16", "complex64"): gathered_t = _cusparse_spmv(selector_matrix, runtime_dense_t) - elif layout == "complex16_pair": - gathered_real = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 0]) - gathered_imag = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 1]) - gathered_t = torch.stack([gathered_real, gathered_imag], dim=1) else: raise RuntimeError(f"Unknown gather layout: {layout}") torch.cuda.synchronize() diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 2976ad1..dd867bc 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -2,6 +2,7 @@ from ._common import * +from collections import OrderedDict import os import time import triton @@ -11,31 +12,34 @@ torch.bfloat16, torch.float32, torch.float64, - *((_torch_complex32_dtype(),) if _torch_complex32_dtype() is not None else ()), torch.complex64, + torch.complex128, + ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), - *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int32), + (torch.complex128, torch.int32), ) SPSV_NON_TRANS_EXTENDED_COMBOS = ( (torch.float32, torch.int64), (torch.float64, torch.int64), - *(((_torch_complex32_dtype(), torch.int64),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int64), + (torch.complex128, torch.int64), ) SPSV_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), - *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int32), + (torch.complex128, torch.int32), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") ).lower() in ("1", "true", "yes", "on") +_SPSV_CSR_PREPROCESS_CACHE = OrderedDict() +_SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 def _csr_to_dense(data, indices, indptr, shape): """Convert CSR (torch CUDA tensors) to dense matrix on the same device.""" @@ -69,7 +73,11 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " +<<<<<<< HEAD "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" +======= + "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" +>>>>>>> 5a83e0f (test) ) @@ -78,7 +86,11 @@ def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): return raise TypeError( f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " +<<<<<<< HEAD "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" +======= + "(float32, int32), (float64, int32), (complex64, int32), (complex128, int32)" +>>>>>>> 5a83e0f (test) ) @@ -118,7 +130,11 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( +<<<<<<< HEAD "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" +======= + "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" +>>>>>>> 5a83e0f (test) ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -158,6 +174,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): ) +<<<<<<< HEAD def _promote_complex32_spsv_inputs(data, b): if _is_complex32_dtype(data.dtype): return data.to(torch.complex64), b.to(torch.complex64), data.dtype @@ -173,6 +190,62 @@ def _restore_complex32_spsv_output(x, target_dtype): return x.to(target_dtype) +======= +def _prepare_spsv_working_inputs(data, b): + return data, b, None + + +def _restore_spsv_output(x, target_dtype): + return x.to(target_dtype) + + +def _spsv_diag_eps_for_dtype(value_dtype): + return 1e-12 if value_dtype in (torch.float64, torch.complex128) else 1e-6 + + +def _tensor_cache_token(tensor): + try: + storage_ptr = int(tensor.untyped_storage().data_ptr()) + except Exception: + storage_ptr = 0 + return ( + str(tensor.device), + str(tensor.dtype), + tuple(int(v) for v in tensor.shape), + int(tensor.numel()), + storage_ptr, + int(getattr(tensor, "_version", 0)), + ) + + +def _spsv_cache_get(cache, key): + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + return value + + +def _spsv_cache_put(cache, key, value, max_entries): + cache[key] = value + cache.move_to_end(key) + while len(cache) > max_entries: + cache.popitem(last=False) + + +def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): + return ( + "csr_preprocess", + trans_mode, + bool(lower), + int(shape[0]), + int(shape[1]), + _tensor_cache_token(data), + _tensor_cache_token(indices), + _tensor_cache_token(indptr), + ) + + +>>>>>>> 5a83e0f (test) @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -536,7 +609,19 @@ def _triton_spsv_csr_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) +<<<<<<< HEAD data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() +======= + # Some PyTorch builds return CSR values with a non-strided layout wrapper. + # Materialize a plain 1D strided buffer before splitting into real/imag parts. + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() +>>>>>>> 5a83e0f (test) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -589,17 +674,27 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") +<<<<<<< HEAD if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): raise TypeError("data dtype must be one of: bfloat16, float32, float64") +======= + if data.dtype not in ( + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + raise TypeError( + "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + ) +>>>>>>> 5a83e0f (test) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("row dtype must be torch.int32 or torch.int64") if col.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("col dtype must be torch.int32 or torch.int64") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") - row64 = row.to(torch.int64).contiguous() col64 = col.to(torch.int64).contiguous() if col64.numel() > 0 and int(col64.max().item()) > _INDEX_LIMIT_INT32: @@ -782,20 +877,36 @@ def flagsparse_spsv_csr( """Sparse triangular solve using Triton level-scheduling kernels. Primary support matrix: +<<<<<<< HEAD - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices - TRANS: float32/float64/complex32/complex64 with int32 indices - bfloat16 remains NON_TRANS + int32 """ +======= + - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices + - TRANS: float32/float64/complex64/complex128 with int32 indices + - bfloat16 remains NON_TRANS + int32 + """ + input_data = data + input_indices = indices + input_indptr = indptr +>>>>>>> 5a83e0f (test) trans_mode = _normalize_spsv_transpose_mode(transpose) data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( data, indices, indptr, b, shape ) original_output_dtype = None +<<<<<<< HEAD data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) +======= + rev_perm = None + data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) +>>>>>>> 5a83e0f (test) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") if trans_mode == "N": _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") +<<<<<<< HEAD lower_eff = lower kernel_data = data kernel_indices64 = indices @@ -806,6 +917,46 @@ def flagsparse_spsv_csr( kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( data, indices, indptr, n_rows, n_cols ) +======= + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + + preprocess_key = _csr_preprocess_cache_key( + input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode + ) + cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) + if cached is None: + if trans_mode == "N": + lower_eff = lower + kernel_data = data + kernel_indices64 = indices + kernel_indptr64 = indptr + rev_perm = None + else: + lower_eff = not lower + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, indices, indptr, n_rows, n_cols + ) + rev_perm = None + levels = _build_spsv_levels( + kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff + ) + cached = ( + kernel_data, + kernel_indices64, + kernel_indptr64, + rev_perm, + lower_eff, + levels, + ) + _spsv_cache_put( + _SPSV_CSR_PREPROCESS_CACHE, + preprocess_key, + cached, + _SPSV_CSR_PREPROCESS_CACHE_SIZE, + ) + kernel_data, kernel_indices64, kernel_indptr64, rev_perm, lower_eff, levels = cached +>>>>>>> 5a83e0f (test) kernel_indices = ( kernel_indices64.to(torch.int32) @@ -828,6 +979,7 @@ def flagsparse_spsv_csr( # Optional high-precision mode; disabled by default for throughput. compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) +<<<<<<< HEAD b_in = b.to(torch.float64) elif data.dtype == torch.float32 and trans_mode == "T": compute_dtype = torch.float64 @@ -836,10 +988,17 @@ def flagsparse_spsv_csr( levels = _build_spsv_levels( kernel_indptr, kernel_indices, n_rows, lower=lower_eff ) +======= + b_in = b.to(torch.float64) + elif data.dtype == torch.float32 and trans_mode == "T": + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) +>>>>>>> 5a83e0f (test) block_nnz_use, max_segments_use = _auto_spsv_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: @@ -918,7 +1077,11 @@ def flagsparse_spsv_csr( x = torch.stack(cols, dim=1) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: +<<<<<<< HEAD x = _restore_complex32_spsv_output(x, target_dtype) +======= + x = _restore_spsv_output(x, target_dtype) +>>>>>>> 5a83e0f (test) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -950,11 +1113,11 @@ def flagsparse_spsv_coo( """COO SpSV with dual mode: - direct: use COO level kernel directly (requires sorted+unique COO) - csr: convert COO -> CSR (sorted+deduplicated) then call flagsparse_spsv_csr - - auto: pick direct when sorted+unique, otherwise csr + - auto: pick direct when sorted+unique and supported, otherwise csr - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Notes: + - direct mode currently supports only non-transposed real-valued inputs + - complex dtypes and transpose=True always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape, transpose=transpose @@ -967,7 +1130,13 @@ def flagsparse_spsv_coo( raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) - use_direct = mode == "direct" or (mode == "auto" and sorted_unique) + direct_supported = (not transpose) and (not torch.is_complex(data)) + use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) + if mode == "direct" and not direct_supported: + raise ValueError( + "coo_mode='direct' supports only non-transposed real-valued inputs; " + "use coo_mode='csr' or 'auto' for transpose or complex dtypes" + ) if mode == "direct" and not sorted_unique: raise ValueError( "coo_mode='direct' requires COO sorted by (row, col) with no duplicate coordinates; " @@ -978,6 +1147,8 @@ def flagsparse_spsv_coo( data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( data, row64, col64, n_rows, n_cols ) + if transpose: + indices_csr = indices_csr.to(torch.int32) return flagsparse_spsv_csr( data_csr, indices_csr, @@ -1011,7 +1182,7 @@ def flagsparse_spsv_coo( block_nnz_use, max_segments_use = _auto_spsv_launch_config( row_ptr, block_nnz=block_nnz, max_segments=max_segments ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) torch.cuda.synchronize() t0 = time.perf_counter() diff --git a/tests/pytest/test_gather_scatter_accuracy.py b/tests/pytest/test_gather_scatter_accuracy.py index 3563ef1..e67c648 100644 --- a/tests/pytest/test_gather_scatter_accuracy.py +++ b/tests/pytest/test_gather_scatter_accuracy.py @@ -19,17 +19,8 @@ RESET_OUTPUT_CASES = [True, False] RESET_OUTPUT_IDS = ["reset", "inplace"] - -def _complex32_dtype(): - dtype = getattr(torch, "complex32", None) - if dtype is None: - dtype = getattr(torch, "chalf", None) - return dtype - - def _scatter_dtype_cases(): cases = [(str(dtype).replace("torch.", ""), dtype) for dtype in FLOAT_DTYPES] - cases.append(("complex32", _complex32_dtype())) cases.append(("complex64", torch.complex64)) cases.append(("complex128", torch.complex128)) return cases @@ -59,9 +50,6 @@ def _build_random_values(size, dtype, device): real = torch.randn(size, dtype=torch.float64, device=device) imag = torch.randn(size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if _complex32_dtype() is not None and dtype == _complex32_dtype(): - stacked = torch.randn((size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) raise TypeError(f"Unsupported dtype in test: {dtype}") @@ -113,8 +101,6 @@ def _extra_gather_tolerance(value_dtype): ("scalar16", torch.float16, torch.int64), ("scalar16", torch.bfloat16, torch.int32), ("scalar16", torch.bfloat16, torch.int64), - ("complex16_pair", torch.float16, torch.int32), - ("complex16_pair", torch.float16, torch.int64), ("complex64", torch.complex64, torch.int32), ("complex64", torch.complex64, torch.int64), ] @@ -122,8 +108,6 @@ def _extra_gather_tolerance(value_dtype): "half_i64", "bf16_i32", "bf16_i64", - "c16f_i32", - "c16f_i64", "c64_i32", "c64_i64", ] @@ -144,6 +128,21 @@ def test_gather_matches_indexing(dense_size, nnz, dtype, index_dtype): assert torch.equal(ref, got) +@pytest.mark.gather +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +def test_gather_complex128_matches_indexing(index_dtype): + device = torch.device("cuda") + dense_size = 4096 + nnz = 1024 + real = torch.randn(dense_size, dtype=torch.float64, device=device) + imag = torch.randn(dense_size, dtype=torch.float64, device=device) + dense = torch.complex(real, imag) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + ref = dense.index_select(0, indices.to(torch.int64)) + got = flagsparse_gather(dense, indices) + assert torch.allclose(got, ref, atol=1e-10, rtol=1e-8) + + @pytest.mark.scatter @pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) @pytest.mark.parametrize("dtype_name,dtype", SCATTER_DTYPE_CASES, ids=SCATTER_DTYPE_IDS) @@ -323,94 +322,6 @@ def test_gather_cupy_same_backend_out_float16_i64(backend): assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -@pytest.mark.parametrize("backend", ["torch", "cupy"]) -def test_gather_cupy_same_backend_out_pair_complex32(backend): - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_t = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - indices_t = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_t.index_select(0, indices_t) - - dense_in = _to_backend_tensor(dense_t, backend) - indices_in = _to_backend_tensor(indices_t, backend) - out = _to_backend_tensor(torch.empty_like(reference), backend) - result = flagsparse_gather_cupy(dense_in, indices_in, out=out) - - assert result is out - assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_out_matches_reference(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_native.index_select(0, indices) - - out = torch.empty_like(reference) - result = flagsparse_gather_cupy(dense_native, indices, out=out) - - assert result is out - assert out.dtype == native_dtype - assert torch.allclose( - torch.view_as_real(out).contiguous(), - torch.view_as_real(reference).contiguous(), - atol=5e-3, - rtol=5e-3, - ) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_matches_reference_and_pair_layout(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - - reference = dense_native.index_select(0, indices) - reference_pair = torch.view_as_real(reference).contiguous() - pair_got = flagsparse_gather_cupy(dense_pair, indices) - native_got = flagsparse_gather_cupy(dense_native, indices) - cusparse_values, _, _ = cusparse_spmv_gather_cupy(dense_native, indices) - - atol, rtol = 5e-3, 5e-3 - assert pair_got.shape == reference_pair.shape - assert pair_got.dtype == torch.float16 - assert native_got.shape == reference.shape - assert native_got.dtype == native_dtype - assert cusparse_values.shape == reference.shape - assert cusparse_values.dtype == native_dtype - assert torch.allclose(pair_got, reference_pair, atol=atol, rtol=rtol) - assert torch.allclose( - torch.view_as_real(native_got).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - assert torch.allclose( - torch.view_as_real(cusparse_values).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - - @pytest.mark.gather @pytest.mark.skipif(cp is None, reason="CuPy required") def test_gather_cupy_int64_auto_fallback_to_int32(monkeypatch): diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index 94788f2..d44d7b5 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -1,7 +1,7 @@ import pytest import torch -from flagsparse import flagsparse_spsv_csr +from flagsparse import flagsparse_spsv_coo, flagsparse_spsv_csr from tests.pytest.param_shapes import SPSV_N @@ -16,6 +16,7 @@ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +<<<<<<< HEAD COMPLEX32_DTYPE = getattr(torch, "complex32", None) if COMPLEX32_DTYPE is None: COMPLEX32_DTYPE = getattr(torch, "chalf", None) @@ -26,6 +27,13 @@ SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +======= +SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +>>>>>>> 5a83e0f (test) def _dtype_id(dtype): @@ -33,8 +41,11 @@ def _dtype_id(dtype): def _tol(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return 5e-3, 5e-3 +======= +>>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-3 return 1e-10, 1e-8 @@ -43,30 +54,41 @@ def _tol(dtype): def _rand_like(dtype, shape, device): if dtype in (torch.float32, torch.float64): return torch.randn(shape, dtype=dtype, device=device) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 return torch.view_as_complex(pair) base = torch.float32 +======= + base = torch.float32 if dtype == torch.complex64 else torch.float64 +>>>>>>> 5a83e0f (test) r = torch.randn(shape, dtype=base, device=device) i = torch.randn(shape, dtype=base, device=device) return torch.complex(r, i) def _ref_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return torch.complex64 +======= +>>>>>>> 5a83e0f (test) return dtype def _safe_cast_tensor(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = tensor.real.to(torch.float16) imag = tensor.imag.to(torch.float16) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cmp_view(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return tensor.to(torch.complex64) return tensor @@ -75,6 +97,14 @@ def _cmp_view(tensor, dtype): def _build_lower_triangular(n, dtype, device): off = _rand_like(dtype, (n, n), device) * 0.02 A = torch.tril(off) +======= + return tensor + + +def _build_triangular(n, dtype, device, lower=True): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) if lower else torch.triu(off) +>>>>>>> 5a83e0f (test) if torch.is_complex(A): diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) @@ -87,8 +117,12 @@ def _build_lower_triangular(n, dtype, device): def _cupy_csr_from_torch(data, indices, indptr, shape): if cp is None or cpx_sparse is None: return None +<<<<<<< HEAD data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) +======= + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) +>>>>>>> 5a83e0f (test) idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) @@ -97,12 +131,18 @@ def _cupy_csr_from_torch(data, indices, indptr, shape): def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): if cp is None or cpx_spsolve_triangular is None: return None +<<<<<<< HEAD b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: return x_t.to(torch.complex64) +======= + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) +>>>>>>> 5a83e0f (test) return x_t.to(b_t.dtype) @@ -143,11 +183,19 @@ def test_spsv_csr_lower_matches_dense(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) +<<<<<<< HEAD @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) @pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False @@ -174,10 +222,17 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) +<<<<<<< HEAD @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) def test_spsv_csr_trans_int32_supported_combos(n, dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) @@ -205,12 +260,24 @@ def test_spsv_csr_trans_int32_supported_combos(n, dtype): @pytest.mark.spsv +<<<<<<< HEAD @pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") @pytest.mark.parametrize("n", SPSV_N) @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() @@ -224,11 +291,187 @@ def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): ) x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) +<<<<<<< HEAD +======= + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + +>>>>>>> 5a83e0f (test) x_trans = flagsparse_spsv_csr( data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True ) x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) rtol, atol = _tol(dtype) +<<<<<<< HEAD assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) +======= + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=True, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=True + ) + x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=True, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +def test_spsv_coo_transpose_complex128_routes_through_csr(n): + device = torch.device("cuda") + dtype = torch.complex128 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.transpose(-2, -1), b.unsqueeze(-1), upper=True + ).squeeze(-1) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=True, + coo_mode="auto", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) +>>>>>>> 5a83e0f (test) diff --git a/tests/test_gather.py b/tests/test_gather.py index 481138a..b00a6c4 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -16,7 +16,7 @@ (524_288, 16_384), (1_048_576, 65_536), ] -DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex32,complex64,complex128" +DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex64,complex128" DEFAULT_INDEX_DTYPES = "int32,int64" WARMUP = 20 ITERS = 200 @@ -48,7 +48,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } @@ -164,7 +163,7 @@ def _collect_samples(case_id, expected, flagsparse_out, limit): def _dtype_mode(value_dtype_req): - if value_dtype_req in ("float16", "bfloat16", "complex32", "complex64"): + if value_dtype_req in ("float16", "bfloat16", "complex64"): return "gather_cupy" return "gather_triton" @@ -193,8 +192,6 @@ def _build_dense(value_dtype_req, dense_size, device): real = torch.randn(dense_size, dtype=torch.float64, device=device) imag = torch.randn(dense_size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if value_dtype_req == "complex32": - return torch.randn(dense_size, 2, dtype=torch.float16, device=device) raise ValueError(f"Unsupported value dtype request: {value_dtype_req}") @@ -206,13 +203,12 @@ def _effective_dtype_name(value_dtype_req): "float64": "float64", "complex64": "complex64", "complex128": "complex128", - "complex32": "complex16_pair_f16", } return mapping[value_dtype_req] def _tolerance(value_dtype_req): - if value_dtype_req in ("float16", "complex32"): + if value_dtype_req == "float16": return 5e-3, 5e-3 if value_dtype_req in ("bfloat16",): return 1e-2, 1e-2 @@ -230,10 +226,10 @@ def _check_dtype_supported(value_dtype_req): def _is_supported_extra_gather_combo(value_dtype_req, index_dtype): # Required extra gather combos only: - # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if value_dtype_req == "float16": return index_dtype in (torch.int32, torch.int64) - if value_dtype_req in ("bfloat16", "complex32", "complex64"): + if value_dtype_req in ("bfloat16", "complex64"): return index_dtype in (torch.int32, torch.int64) # Original gather path dtypes keep original behavior. return True diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 162dfc8..cdc7a31 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -22,8 +22,8 @@ cpx_sparse = None cpx_spsolve_triangular = None -VALUE_DTYPES = [torch.float32, torch.float64] -INDEX_DTYPES = [torch.int32] +VALUE_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +INDEX_DTYPES = [torch.int32, torch.int64] TEST_SIZES = [256, 512, 1024, 2048] WARMUP = 5 ITERS = 20 @@ -44,6 +44,15 @@ CSR_FULL_VALUE_DTYPES.append(torch.complex64) CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] + def _dtype_name(dtype): return str(dtype).replace("torch.", "") @@ -64,8 +73,11 @@ def _fmt_err(v): def _tol_for_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return 5e-3, 5e-3 +======= +>>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 @@ -74,22 +86,30 @@ def _tol_for_dtype(dtype): def _randn_by_dtype(n, dtype, device): if dtype in (torch.float32, torch.float64): return torch.randn(n, dtype=dtype, device=device) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 return torch.view_as_complex(pair) base = torch.float32 +======= + base = torch.float32 if dtype == torch.complex64 else torch.float64 +>>>>>>> 5a83e0f (test) real = torch.randn(n, dtype=base, device=device) imag = torch.randn(n, dtype=base, device=device) return torch.complex(real, imag) def _dense_ref_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return torch.complex64 +======= +>>>>>>> 5a83e0f (test) return dtype def _tensor_from_scalar_values(values, dtype, device): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = torch.clamp( torch.tensor(values, dtype=torch.float32, device=device), @@ -98,18 +118,24 @@ def _tensor_from_scalar_values(values, dtype, device): ).to(torch.float16) imag = torch.zeros_like(real) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return torch.tensor(values, dtype=dtype, device=device) def _safe_cast_tensor(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cast_real_tensor_to_value_dtype(values, value_dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) imag = torch.zeros_like(real) @@ -120,12 +146,45 @@ def _cast_real_tensor_to_value_dtype(values, value_dtype): def _cupy_ref_inputs(data, b): if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: return data.to(torch.complex64), b.to(torch.complex64) +======= + return values.to(value_dtype) + + +def _matrix_market_value(parts, mm_field): + if mm_field == "complex": + if len(parts) < 4: + raise ValueError("MatrixMarket complex entry requires real and imag parts") + return complex(float(parts[2]), float(parts[3])) + if len(parts) >= 3: + return float(parts[2]) + if mm_field == "pattern": + return 1.0 + raise ValueError("MatrixMarket entry is missing a numeric value") + + +def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): + if op_mode == "TRANS": + A_eff = A.transpose(0, 1) + upper = lower + else: + A_eff = A + upper = not lower + return torch.linalg.solve_triangular( + A_eff, b.unsqueeze(1), upper=upper + ).squeeze(1) + + +def _cupy_ref_inputs(data, b): +>>>>>>> 5a83e0f (test) return data, b def _compare_view(tensor, value_dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: return tensor.to(torch.complex64) +======= +>>>>>>> 5a83e0f (test) return tensor @@ -147,7 +206,7 @@ def _allow_dense_pytorch_ref(shape, dtype): def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): - """Build a well-conditioned triangular CSR (float32/float64).""" + """Build a well-conditioned triangular CSR for real and complex dtypes.""" max_bandwidth = max(4, min(n, 16)) rows_host = [] cols_host = [] @@ -156,10 +215,17 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True base_real_dtype = torch.float32 elif value_dtype == torch.float64: base_real_dtype = torch.float64 +<<<<<<< HEAD elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: base_real_dtype = torch.float16 else: base_real_dtype = torch.float32 +======= + elif value_dtype == torch.complex64: + base_real_dtype = torch.float32 + else: + base_real_dtype = torch.float64 +>>>>>>> 5a83e0f (test) for i in range(n): if lower: @@ -176,23 +242,44 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True off_cols = [off_cand[j] for j in perm] else: off_cols = [] - off_vals_real = torch.randn(len(off_cols), dtype=base_real_dtype).mul_(0.01) - sum_abs = float(torch.sum(torch.abs(off_vals_real)).item()) if off_vals_real.numel() else 0.0 - diag_val = sum_abs + 1.0 + if value_dtype in (torch.complex64, torch.complex128): + off_vals = torch.complex( + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + ) + sum_abs = ( + float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 + ) + diag_imag = float( + torch.randn((), dtype=base_real_dtype, device=device).mul_(0.05).item() + ) + diag_val = complex(sum_abs + 1.0, diag_imag) + off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] + else: + off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) + sum_abs = ( + float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 + ) + diag_val = sum_abs + 1.0 + off_vals_host = off_vals.cpu().tolist() rows_host.append(i) cols_host.append(diag_col) vals_host.append(diag_val) - for c, v in zip(off_cols, off_vals_real.tolist()): + for c, v in zip(off_cols, off_vals_host): rows_host.append(i) cols_host.append(int(c)) vals_host.append(v) rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) +<<<<<<< HEAD vals_t = _cast_real_tensor_to_value_dtype( torch.tensor(vals_host, dtype=base_real_dtype, device=device), value_dtype, ) +======= + vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) +>>>>>>> 5a83e0f (test) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -257,7 +344,11 @@ def _csr_transpose(data, indices, indptr, shape): return data_t, col_t.to(torch.int64), indptr_t +<<<<<<< HEAD def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): +======= +def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): +>>>>>>> 5a83e0f (test) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(file_path, "r", encoding="utf-8") as f: @@ -302,28 +393,26 @@ def _accum(r, c, v): continue r = int(parts[0]) - 1 c = int(parts[1]) - 1 - if len(parts) >= 3: - v = float(parts[2]) - elif mm_field == "pattern": - v = 1.0 - else: - continue + v = _matrix_market_value(parts, mm_field) _accum(r, c, v) - if mm_symmetry in ("symmetric", "hermitian") and r != c: + if mm_symmetry == "symmetric" and r != c: _accum(c, r, v) + elif mm_symmetry == "hermitian" and r != c: + _accum(c, r, v.conjugate() if isinstance(v, complex) else v) elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) for r in range(n_rows): row = row_maps[r] - lower_row = {} + tri_row = {} off_abs_sum = 0.0 for c, v in row.items(): - if c < r: - lower_row[c] = lower_row.get(c, 0.0) + v + keep = c < r if lower else c > r + if keep: + tri_row[c] = tri_row.get(c, 0.0) + v off_abs_sum += abs(v) - lower_row[r] = off_abs_sum + 1.0 - row_maps[r] = lower_row + tri_row[r] = off_abs_sum + 1.0 + row_maps[r] = tri_row cols_s = [] vals_s = [] @@ -360,6 +449,7 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: data_ref = data.to(torch.complex64) x_ref = x_true.to(torch.complex64) @@ -380,6 +470,8 @@ def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): ) return _safe_cast_tensor(b_ref, x_true.dtype) raise ValueError("op_mode must be 'NON' or 'TRANS'") +======= +>>>>>>> 5a83e0f (test) if op_mode == "NON": b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True @@ -408,6 +500,7 @@ def _cupy_spsolve_lower_csr_or_coo( b, warmup, iters, + lower, ): """Triangular solve via CuPy: CSR or COO storage. Returns (ms, x_torch) or (None, None).""" if ( @@ -439,7 +532,7 @@ def _cupy_spsolve_lower_csr_or_coo( A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) for _ in range(warmup): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() t0 = cp.cuda.Event() @@ -447,22 +540,30 @@ def _cupy_spsolve_lower_csr_or_coo( t0.record() for _ in range(iters): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) t1.record() t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: x_cu_t = x_cu_t.to(torch.complex64) else: x_cu_t = x_cu_t.to(b.dtype) +======= + x_cu_t = x_cu_t.to(b.dtype) +>>>>>>> 5a83e0f (test) return cupy_ms, x_cu_t except Exception: return None, None +<<<<<<< HEAD def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): +======= +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): +>>>>>>> 5a83e0f (test) if ( cp is None or cpx_sparse is None @@ -482,10 +583,17 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) if op_mode == "TRANS": A_eff = A_cp.transpose().tocsr() +<<<<<<< HEAD lower_eff = False else: A_eff = A_cp lower_eff = True +======= + lower_eff = not lower + else: + A_eff = A_cp + lower_eff = lower +>>>>>>> 5a83e0f (test) for _ in range(WARMUP): _ = cpx_spsolve_triangular( @@ -508,7 +616,11 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): return None, None +<<<<<<< HEAD def run_spsv_synthetic_all(): +======= +def run_spsv_synthetic_all(lower=True): +>>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -519,10 +631,11 @@ def run_spsv_synthetic_all(): print(sep) print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print(f"Triangle: {'LOWER' if lower else 'UPPER'}") print() hdr = ( - f"{'Fmt':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " + f"{'Fmt':>5} {'opA':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " f"{'FS/PT':>8} {'FS/CU':>8} {'Status':>8} {'Err(PT)':>12} {'Err(CU)':>12}" ) @@ -540,101 +653,121 @@ def run_spsv_synthetic_all(): print("-" * 110) for n in TEST_SIZES: for fmt in ("CSR", "COO"): - data, indices, indptr, shape = _build_random_triangular_csr( - n, value_dtype, index_dtype, device, lower=True + op_modes = ( + _supported_csr_full_ops(value_dtype, index_dtype) + if fmt == "CSR" + else ["NON"] ) - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) - x_true = torch.randn(n, dtype=value_dtype, device=device) - b = A_dense @ x_true - - torch.cuda.synchronize() - if fmt == "CSR": - x, t_ms = fs.flagsparse_spsv_csr( - data, - indices, - indptr, - b, - shape, - lower=True, - return_time=True, + for op_mode in op_modes: + data, indices, indptr, shape = _build_random_triangular_csr( + n, value_dtype, index_dtype, device, lower=lower ) - else: - dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) - x, t_ms = fs.flagsparse_spsv_coo( - dc, - rr, - cc, - b, - shape, - lower=True, - coo_mode="auto", - return_time=True, + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape ) - torch.cuda.synchronize() - - A_ref = A_dense - b_ref = b - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_pt = torch.linalg.solve( - A_ref, b_ref.unsqueeze(1) - ).squeeze(1) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 - - cupy_ms = None - err_cu = None - x_cu_t = None - if value_dtype in (torch.float32, torch.float64): - cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( - fmt, - data, - indices, - indptr, - shape, - b, - WARMUP, - ITERS, + x_true = _randn_by_dtype(n, value_dtype, device) + if fmt == "CSR": + b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + else: + b = A_dense @ x_true + + torch.cuda.synchronize() + if fmt == "CSR": + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=(op_mode == "TRANS"), + return_time=True, + ) + else: + dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) + x, t_ms = fs.flagsparse_spsv_coo( + dc, + rr, + cc, + b, + shape, + lower=lower, + coo_mode="auto", + return_time=True, + ) + torch.cuda.synchronize() + + A_ref = A_dense + b_ref = b + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_pt = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode=op_mode ) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 + + cupy_ms = None + err_cu = None + x_cu_t = None + if fmt == "CSR": + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + elif value_dtype in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( + fmt, + data, + indices, + indptr, + shape, + b, + WARMUP, + ITERS, + lower, + ) if x_cu_t is not None and n > 0: err_cu = float( torch.max(torch.abs(x - x_cu_t)).item() ) - atol, rtol = _tol_for_dtype(value_dtype) - ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) - ok_cu = ( - True - if x_cu_t is None - else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) - ) - ok = ok_pt or ok_cu - status = "PASS" if ok else "FAIL" - if not ok: - failed += 1 - total += 1 - - fs_vs_pt = ( - (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None - ) - fs_vs_cu = ( - (cupy_ms / t_ms) - if (cupy_ms is not None and t_ms and t_ms > 0) - else None - ) - print( - f"{fmt:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " - f"{_fmt_ms(cupy_ms):>10} " - f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " - f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " - f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" - ) + atol, rtol = _tol_for_dtype(value_dtype) + ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) + ok_cu = ( + True + if x_cu_t is None + else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ) + ok = ok_pt or ok_cu + status = "PASS" if ok else "FAIL" + if not ok: + failed += 1 + total += 1 + + fs_vs_pt = ( + (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None + ) + fs_vs_cu = ( + (cupy_ms / t_ms) + if (cupy_ms is not None and t_ms and t_ms > 0) + else None + ) + print( + f"{fmt:>5} {op_mode:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " + f"{_fmt_ms(cupy_ms):>10} " + f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " + f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " + f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" + ) print("-" * 110) print() @@ -643,32 +776,32 @@ def run_spsv_synthetic_all(): print(sep) -def _run_one_csv_row_csr(path, value_dtype, index_dtype, device): +def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, dtype=value_dtype, device=device, lower=lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) + x_true = _randn_by_dtype(n_rows, value_dtype, device) b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True ) x, t_ms = fs.flagsparse_spsv_csr( - data, indices, indptr, b, shape, lower=True, return_time=True + data, indices, indptr, b, shape, lower=lower, return_time=True ) return _finalize_csv_row( path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, n_rows, n_cols, + x, t_ms, b, n_rows, n_cols, lower=lower, ) -def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): +def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, dtype=value_dtype, device=device, lower=lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) + x_true = _randn_by_dtype(n_rows, value_dtype, device) b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True ) @@ -681,7 +814,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): c_in, b, shape, - lower=True, + lower=lower, coo_mode=coo_mode, return_time=True, ) @@ -698,6 +831,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): b, n_rows, n_cols, + lower=lower, nnz_display=int(d_in.numel()), cupy_coo_data=d_in, cupy_coo_row=r_in, @@ -719,6 +853,7 @@ def _finalize_csv_row( n_rows, n_cols, *, + lower=True, nnz_display=None, cupy_coo_data=None, cupy_coo_row=None, @@ -741,7 +876,13 @@ def _finalize_csv_row( e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() +<<<<<<< HEAD x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) +======= + x_ref = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode="NON" + ) +>>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -814,7 +955,7 @@ def _finalize_csv_row( ) for _ in range(WARMUP): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() c0 = cp.cuda.Event() @@ -822,7 +963,7 @@ def _finalize_csv_row( c0.record() for _ in range(ITERS): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) c1.record() c1.synchronize() @@ -865,9 +1006,15 @@ def _finalize_csv_row( return row, pt_skip_reason +<<<<<<< HEAD def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device +======= +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower +>>>>>>> 5a83e0f (test) ) indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) @@ -880,7 +1027,11 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): indptr, b, shape, +<<<<<<< HEAD lower=True, +======= + lower=lower, +>>>>>>> 5a83e0f (test) transpose=(op_mode == "TRANS"), return_time=True, ) @@ -898,6 +1049,10 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): b, n_rows, n_cols, +<<<<<<< HEAD +======= + lower=lower, +>>>>>>> 5a83e0f (test) ) @@ -915,6 +1070,10 @@ def _finalize_csv_row_csr_full( b, n_rows, n_cols, +<<<<<<< HEAD +======= + lower=True, +>>>>>>> 5a83e0f (test) ): atol, rtol = _tol_for_dtype(value_dtype) @@ -927,12 +1086,21 @@ def _finalize_csv_row_csr_full( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(value_dtype)) +<<<<<<< HEAD A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense +======= +>>>>>>> 5a83e0f (test) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() +<<<<<<< HEAD x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) +======= + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) +>>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -959,7 +1127,11 @@ def _finalize_csv_row_csr_full( ok_cu = False x_cu_t = None cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( +<<<<<<< HEAD data, indices, indptr, shape, b, op_mode +======= + data, indices, indptr, shape, b, op_mode, lower +>>>>>>> 5a83e0f (test) ) if x_cu_t is not None: x_cmp = _compare_view(x, value_dtype) @@ -994,6 +1166,7 @@ def _finalize_csv_row_csr_full( return row, pt_skip_reason +<<<<<<< HEAD def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -1103,11 +1276,125 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): +======= +def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): +>>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available.") return device = torch.device("cuda") rows_out = [] + for value_dtype in CSR_FULL_VALUE_DTYPES: + for index_dtype in CSR_FULL_INDEX_DTYPES: + op_modes = _supported_csr_full_ops(value_dtype, index_dtype) + for op_mode in op_modes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" + ) + print( + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + ) + print( + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if either error within tolerance." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device, lower=lower + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " + f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " + f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "csc_ms": None, + "status": status, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " + f"{'N/A':>7} {'N/A':>7} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "opA", + "n_rows", + "n_cols", + "nnz", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "csc_ms", + "status", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", lower=True): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + if not use_coo: + run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=lower) + return + device = torch.device("cuda") + rows_out = [] label = "COO" if use_coo else "CSR" cu_col = "COO(ms)" if use_coo else "CSR(ms)" fs_cu_hdr = "FS/COO" if use_coo else "FS/CSR" @@ -1117,7 +1404,7 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | {label}" - + (f" coo_mode={coo_mode}" if use_coo else "") + + (f" triA={'LOWER' if lower else 'UPPER'}" if not use_coo else f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}") ) if use_coo: print( @@ -1143,11 +1430,11 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") try: if use_coo: row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, device, coo_mode + path, value_dtype, index_dtype, device, coo_mode, lower=lower ) else: row, pt_skip = _run_one_csv_row_csr( - path, value_dtype, index_dtype, device + path, value_dtype, index_dtype, device, lower=lower ) rows_out.append(row) name = os.path.basename(path)[:27] @@ -1254,10 +1541,16 @@ def main(): choices=["auto", "direct", "csr"], help="COO mode for --csv-coo (default: auto)", ) + parser.add_argument( + "--upper", + action="store_true", + help="Use upper-triangular inputs instead of the default lower-triangular inputs", + ) args = parser.parse_args() + lower = not args.upper if args.synthetic: - run_spsv_synthetic_all() + run_spsv_synthetic_all(lower=lower) return paths = [] @@ -1272,7 +1565,11 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return +<<<<<<< HEAD run_all_supported_spsv_csr_csv(paths, args.csv_csr) +======= + run_all_supported_spsv_csr_csv(paths, args.csv_csr, lower=lower) +>>>>>>> 5a83e0f (test) return if args.csv_coo: if not paths: @@ -1281,7 +1578,7 @@ def main(): print("No .mtx files found for --csv-coo") return run_all_dtypes_spsv_csv( - paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode + paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode, lower=lower ) return From f2b0b5bf2385295fd73b6629bb9e3d5963b52c11 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 20 Apr 2026 20:56:03 +0800 Subject: [PATCH 3/5] Support CONJ transpose mode in spsv --- src/flagsparse/sparse_operations/spsv.py | 253 +++++------- tests/pytest/test_spsv_csr_accuracy.py | 210 +++++----- tests/test_scatter.py | 1 - tests/test_spsv.py | 491 +++++++++-------------- 4 files changed, 379 insertions(+), 576 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index dd867bc..5f6174c 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -17,23 +17,25 @@ ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) -SPSV_NON_TRANS_PRIMARY_COMBOS = ( +SPSV_NON_TRANS_SUPPORTED_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), (torch.complex64, torch.int32), (torch.complex128, torch.int32), -) -SPSV_NON_TRANS_EXTENDED_COMBOS = ( (torch.float32, torch.int64), (torch.float64, torch.int64), (torch.complex64, torch.int64), (torch.complex128, torch.int64), ) -SPSV_TRANS_PRIMARY_COMBOS = ( +SPSV_TRANS_SUPPORTED_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), (torch.complex64, torch.int32), (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") @@ -64,33 +66,24 @@ def _csr_to_dense(data, indices, indptr, shape): def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" - if (data_dtype, index_dtype) in SPSV_NON_TRANS_PRIMARY_COMBOS: - return - if (data_dtype, index_dtype) in SPSV_NON_TRANS_EXTENDED_COMBOS: + if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: return if data_dtype == torch.bfloat16 and index_dtype == torch.int32: return raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " -<<<<<<< HEAD - "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" -======= "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" ->>>>>>> 5a83e0f (test) ) def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): - if (data_dtype, index_dtype) in SPSV_TRANS_PRIMARY_COMBOS: + if (data_dtype, index_dtype) in SPSV_TRANS_SUPPORTED_COMBOS: return raise TypeError( - f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " -<<<<<<< HEAD - "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" -======= - "(float32, int32), (float64, int32), (complex64, int32), (complex128, int32)" ->>>>>>> 5a83e0f (test) + f"{fmt_name} SpSV currently supports TRANS/CONJ combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" ) @@ -102,8 +95,11 @@ def _normalize_spsv_transpose_mode(transpose): return "N" if token in ("T", "TRANS"): return "T" + if token in ("C", "H", "CONJ", "CONJ_TRANS", "CONJUGATE_TRANSPOSE"): + return "C" raise ValueError( - "transpose must be bool or one of: N/NON/NON_TRANS, T/TRANS" + "transpose must be bool or one of: " + "N/NON/NON_TRANS, T/TRANS, C/H/CONJ/CONJ_TRANS/CONJUGATE_TRANSPOSE" ) @@ -130,11 +126,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( -<<<<<<< HEAD - "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" -======= "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" ->>>>>>> 5a83e0f (test) ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -174,23 +166,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): ) -<<<<<<< HEAD -def _promote_complex32_spsv_inputs(data, b): - if _is_complex32_dtype(data.dtype): - return data.to(torch.complex64), b.to(torch.complex64), data.dtype - return data, b, None - - -def _restore_complex32_spsv_output(x, target_dtype): - if _is_complex32_dtype(target_dtype): - limit = 65504.0 - real = torch.clamp(x.real, min=-limit, max=limit).to(torch.float16) - imag = torch.clamp(x.imag, min=-limit, max=limit).to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) - return x.to(target_dtype) - - -======= def _prepare_spsv_working_inputs(data, b): return data, b, None @@ -245,7 +220,40 @@ def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): ) ->>>>>>> 5a83e0f (test) +def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): + """Prepare an equivalent CSR triangular system before launching the solve kernel. + + Keep TRANS/CONJ handling outside the Triton solve kernels so the kernels only + execute one fixed CSR triangular solve semantics. + """ + if trans_mode == "N": + kernel_data = data + kernel_indices64 = indices64 + kernel_indptr64 = indptr64 + lower_eff = lower + else: + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, + indices64, + indptr64, + n_rows, + n_cols, + conjugate=(trans_mode == "C"), + ) + lower_eff = not lower + + levels = _build_spsv_levels( + kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff + ) + return ( + kernel_data, + kernel_indices64, + kernel_indptr64, + lower_eff, + levels, + ) + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -396,6 +404,7 @@ def _spsv_csr_level_kernel_complex( tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -609,9 +618,6 @@ def _triton_spsv_csr_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) -<<<<<<< HEAD - data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() -======= # Some PyTorch builds return CSR values with a non-strided layout wrapper. # Materialize a plain 1D strided buffer before splitting into real/imag parts. if data.layout != torch.strided: @@ -621,7 +627,6 @@ def _triton_spsv_csr_vector_complex( data_strided = data.contiguous() data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() ->>>>>>> 5a83e0f (test) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -656,6 +661,22 @@ def _triton_spsv_csr_vector_complex( return x +def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments=None): + if block_nnz is not None or max_segments is not None: + return _auto_spsv_launch_config(indptr, block_nnz=block_nnz, max_segments=max_segments) + + if indptr.numel() <= 1: + return 32, 1 + max_nnz_per_row = int((indptr[1:] - indptr[:-1]).max().item()) + for cand in (32, 64, 128, 256, 512, 1024): + req = max((max_nnz_per_row + cand - 1) // cand, 1) + if req <= 2048: + return cand, req + cand = 2048 + req = max((max_nnz_per_row + cand - 1) // cand, 1) + return cand, req + + def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -674,10 +695,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") -<<<<<<< HEAD - if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): - raise TypeError("data dtype must be one of: bfloat16, float32, float64") -======= if data.dtype not in ( torch.bfloat16, torch.float32, @@ -688,7 +705,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise TypeError( "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" ) ->>>>>>> 5a83e0f (test) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: @@ -724,7 +740,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): ) -def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols, conjugate=False): if data.numel() == 0: out_data = data out_indices = torch.empty(0, dtype=torch.int64, device=data.device) @@ -737,31 +753,13 @@ def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): ) new_row = indices64 new_col = row_ids + data_eff = data.conj() if conjugate and torch.is_complex(data) else data data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( - data, new_row, new_col, n_cols, n_rows + data_eff, new_row, new_col, n_cols, n_rows ) return data_t, indices_t, indptr_t -def _csr_reverse_rows_cols(data, indices64, indptr64, n_rows): - if data.numel() == 0: - out_data = data - out_indices = torch.empty(0, dtype=torch.int64, device=data.device) - out_indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) - return out_data, out_indices, out_indptr - - row_ids = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr64[1:] - indptr64[:-1], - ) - new_row = (n_rows - 1) - row_ids - new_col = (n_rows - 1) - indices64 - data_r, indices_r, indptr_r = _coo_to_csr_sorted_unique( - data, new_row, new_col, n_rows, n_rows - ) - return data_r, indices_r, indptr_r - - def _coo_is_sorted_unique(row64, col64, n_cols): nnz = row64.numel() if nnz <= 1: @@ -877,47 +875,23 @@ def flagsparse_spsv_csr( """Sparse triangular solve using Triton level-scheduling kernels. Primary support matrix: -<<<<<<< HEAD - - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices - - TRANS: float32/float64/complex32/complex64 with int32 indices - - bfloat16 remains NON_TRANS + int32 - """ -======= - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices - - TRANS: float32/float64/complex64/complex128 with int32 indices + - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices - bfloat16 remains NON_TRANS + int32 """ input_data = data input_indices = indices input_indptr = indptr ->>>>>>> 5a83e0f (test) trans_mode = _normalize_spsv_transpose_mode(transpose) data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( data, indices, indptr, b, shape ) original_output_dtype = None -<<<<<<< HEAD - data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) -======= - rev_perm = None data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) ->>>>>>> 5a83e0f (test) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") if trans_mode == "N": _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") -<<<<<<< HEAD - lower_eff = lower - kernel_data = data - kernel_indices64 = indices - kernel_indptr64 = indptr - else: - _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") - lower_eff = not lower - kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( - data, indices, indptr, n_rows, n_cols - ) -======= else: _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") @@ -926,28 +900,14 @@ def flagsparse_spsv_csr( ) cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) if cached is None: - if trans_mode == "N": - lower_eff = lower - kernel_data = data - kernel_indices64 = indices - kernel_indptr64 = indptr - rev_perm = None - else: - lower_eff = not lower - kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( - data, indices, indptr, n_rows, n_cols - ) - rev_perm = None - levels = _build_spsv_levels( - kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff - ) - cached = ( - kernel_data, - kernel_indices64, - kernel_indptr64, - rev_perm, - lower_eff, - levels, + cached = _prepare_spsv_csr_system( + data, + indices, + indptr, + n_rows, + n_cols, + lower, + trans_mode, ) _spsv_cache_put( _SPSV_CSR_PREPROCESS_CACHE, @@ -955,8 +915,7 @@ def flagsparse_spsv_csr( cached, _SPSV_CSR_PREPROCESS_CACHE_SIZE, ) - kernel_data, kernel_indices64, kernel_indptr64, rev_perm, lower_eff, levels = cached ->>>>>>> 5a83e0f (test) + kernel_data, kernel_indices64, kernel_indptr64, lower_eff, levels = cached kernel_indices = ( kernel_indices64.to(torch.int32) @@ -971,39 +930,37 @@ def flagsparse_spsv_csr( compute_dtype = torch.float32 data_in = kernel_data.to(torch.float32) b_in = b.to(torch.float32) - elif data.dtype == torch.complex64 and trans_mode == "T": + elif data.dtype == torch.complex64 and trans_mode in ("T", "C"): compute_dtype = torch.complex128 data_in = kernel_data.to(torch.complex128) b_in = b.to(torch.complex128) elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: - # Optional high-precision mode; disabled by default for throughput. - compute_dtype = torch.float64 - data_in = kernel_data.to(torch.float64) -<<<<<<< HEAD - b_in = b.to(torch.float64) - elif data.dtype == torch.float32 and trans_mode == "T": compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - levels = _build_spsv_levels( - kernel_indptr, kernel_indices, n_rows, lower=lower_eff - ) -======= - b_in = b.to(torch.float64) - elif data.dtype == torch.float32 and trans_mode == "T": + elif data.dtype == torch.float32 and trans_mode in ("T", "C"): compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) ->>>>>>> 5a83e0f (test) - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) + + is_transpose_family_op = trans_mode != "N" + if is_transpose_family_op: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) + vec_real = _triton_spsv_csr_vector + vec_complex = _triton_spsv_csr_vector_complex + torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: if torch.is_complex(data_in): - x = _triton_spsv_csr_vector_complex( + x = vec_complex( data_in, kernel_indices, kernel_indptr, @@ -1019,7 +976,7 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, ) else: - x = _triton_spsv_csr_vector( + x = vec_real( data_in, kernel_indices, kernel_indptr, @@ -1040,7 +997,7 @@ def flagsparse_spsv_csr( bj = b_in[:, j].contiguous() if torch.is_complex(data_in): cols.append( - _triton_spsv_csr_vector_complex( + vec_complex( data_in, kernel_indices, kernel_indptr, @@ -1058,7 +1015,7 @@ def flagsparse_spsv_csr( ) else: cols.append( - _triton_spsv_csr_vector( + vec_real( data_in, kernel_indices, kernel_indptr, @@ -1077,11 +1034,7 @@ def flagsparse_spsv_csr( x = torch.stack(cols, dim=1) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: -<<<<<<< HEAD - x = _restore_complex32_spsv_output(x, target_dtype) -======= x = _restore_spsv_output(x, target_dtype) ->>>>>>> 5a83e0f (test) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -1096,6 +1049,7 @@ def flagsparse_spsv_csr( def flagsparse_spsv_coo( + data, row, col, @@ -1117,7 +1071,7 @@ def flagsparse_spsv_coo( Notes: - direct mode currently supports only non-transposed real-valued inputs - - complex dtypes and transpose=True always route through the CSR implementation + - complex dtypes and TRANS/CONJ always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape, transpose=transpose @@ -1130,12 +1084,13 @@ def flagsparse_spsv_coo( raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) - direct_supported = (not transpose) and (not torch.is_complex(data)) + trans_mode = _normalize_spsv_transpose_mode(transpose) + direct_supported = (trans_mode == "N") and (not torch.is_complex(data)) use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) if mode == "direct" and not direct_supported: raise ValueError( "coo_mode='direct' supports only non-transposed real-valued inputs; " - "use coo_mode='csr' or 'auto' for transpose or complex dtypes" + "use coo_mode='csr' or 'auto' for TRANS/CONJ or complex dtypes" ) if mode == "direct" and not sorted_unique: raise ValueError( @@ -1147,8 +1102,6 @@ def flagsparse_spsv_coo( data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( data, row64, col64, n_rows, n_cols ) - if transpose: - indices_csr = indices_csr.to(torch.int32) return flagsparse_spsv_csr( data_csr, indices_csr, diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index d44d7b5..fc9c891 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -16,24 +16,12 @@ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -<<<<<<< HEAD -COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if COMPLEX32_DTYPE is None: - COMPLEX32_DTYPE = getattr(torch, "chalf", None) - -SUPPORTED_COMPLEX_DTYPES = [] -if COMPLEX32_DTYPE is not None: - SUPPORTED_COMPLEX_DTYPES.append(COMPLEX32_DTYPE) -SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) - -SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] -======= SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] -TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] ->>>>>>> 5a83e0f (test) +TRANS_CONJ_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_MODES = ["TRANS", "CONJ"] def _dtype_id(dtype): @@ -41,11 +29,6 @@ def _dtype_id(dtype): def _tol(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return 5e-3, 5e-3 -======= ->>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-3 return 1e-10, 1e-8 @@ -54,57 +37,57 @@ def _tol(dtype): def _rand_like(dtype, shape, device): if dtype in (torch.float32, torch.float64): return torch.randn(shape, dtype=dtype, device=device) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 - return torch.view_as_complex(pair) - base = torch.float32 -======= base = torch.float32 if dtype == torch.complex64 else torch.float64 ->>>>>>> 5a83e0f (test) r = torch.randn(shape, dtype=base, device=device) i = torch.randn(shape, dtype=base, device=device) return torch.complex(r, i) def _ref_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return torch.complex64 -======= ->>>>>>> 5a83e0f (test) return dtype def _safe_cast_tensor(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = tensor.real.to(torch.float16) - imag = tensor.imag.to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cmp_view(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return tensor.to(torch.complex64) return tensor -def _build_lower_triangular(n, dtype, device): - off = _rand_like(dtype, (n, n), device) * 0.02 - A = torch.tril(off) -======= - return tensor +def _apply_ref_op(A, op_mode): + if op_mode == "TRANS": + return A.transpose(-2, -1) + if op_mode == "CONJ": + return A.transpose(-2, -1).conj() if torch.is_complex(A) else A.transpose(-2, -1) + return A + + +def _effective_upper(lower, op_mode): + return lower if op_mode in ("TRANS", "CONJ") else not lower + + +def _effective_lower_for_op(lower, op_mode): + return (not lower) if op_mode in ("TRANS", "CONJ") else lower + + +def _transpose_arg(op_mode): + if op_mode == "NON": + return False + return op_mode + + +def _cupy_apply_op(A_cp, op_mode): + if op_mode == "TRANS": + return A_cp.transpose().tocsr() + if op_mode == "CONJ": + return A_cp.transpose().conj().tocsr() + return A_cp def _build_triangular(n, dtype, device, lower=True): off = _rand_like(dtype, (n, n), device) * 0.02 A = torch.tril(off) if lower else torch.triu(off) ->>>>>>> 5a83e0f (test) if torch.is_complex(A): diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) @@ -117,12 +100,7 @@ def _build_triangular(n, dtype, device, lower=True): def _cupy_csr_from_torch(data, indices, indptr, shape): if cp is None or cpx_sparse is None: return None -<<<<<<< HEAD - data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data - data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) -======= data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) ->>>>>>> 5a83e0f (test) idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) @@ -131,18 +109,9 @@ def _cupy_csr_from_torch(data, indices, indptr, shape): def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): if cp is None or cpx_spsolve_triangular is None: return None -<<<<<<< HEAD - b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t - b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) - x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) - x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) - if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: - return x_t.to(torch.complex64) -======= b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) ->>>>>>> 5a83e0f (test) return x_t.to(b_t.dtype) @@ -183,19 +152,11 @@ def test_spsv_csr_lower_matches_dense(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -<<<<<<< HEAD -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) -def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= @pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) @pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False @@ -222,28 +183,23 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -<<<<<<< HEAD -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_int32_supported_combos(n, dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_int32_supported_combos(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) x_ref = torch.linalg.solve_triangular( - A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=True + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(True, op_mode) ).squeeze(-1) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) x = flagsparse_spsv_csr( data, @@ -253,21 +209,13 @@ def test_spsv_csr_trans_int32_supported_combos(n, dtype): (n, n), lower=True, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) @pytest.mark.spsv -<<<<<<< HEAD -@pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") -@pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= @pytest.mark.skipif( cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required", @@ -277,7 +225,6 @@ def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): def test_spsv_csr_matches_cusparse_non_trans(n, dtype): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() @@ -291,8 +238,6 @@ def test_spsv_csr_matches_cusparse_non_trans(n, dtype): ) x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) -<<<<<<< HEAD -======= rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) @@ -303,29 +248,38 @@ def test_spsv_csr_matches_cusparse_non_trans(n, dtype): reason="CuPy/cuSPARSE required", ) @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_trans(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_transpose_family(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) ->>>>>>> 5a83e0f (test) x_trans = flagsparse_spsv_csr( - data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(True, op_mode), + unit_diagonal=False, ) - x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) rtol, atol = _tol(dtype) -<<<<<<< HEAD - assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) - assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) -======= assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) @@ -362,21 +316,23 @@ def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) x_ref = torch.linalg.solve_triangular( - A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=False + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(False, op_mode) ).squeeze(-1) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) x = flagsparse_spsv_csr( data, @@ -386,7 +342,7 @@ def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): (n, n), lower=False, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) @@ -425,22 +381,36 @@ def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): reason="CuPy/cuSPARSE required", ) @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_upper_transpose_family(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) x_trans = flagsparse_spsv_csr( - data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=True + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(False, op_mode), + unit_diagonal=False, ) - x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=True, unit_diagonal=False) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) @@ -448,13 +418,14 @@ def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -def test_spsv_coo_transpose_complex128_routes_through_csr(n): +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_coo_transpose_family_complex128_routes_through_csr(n, op_mode): device = torch.device("cuda") dtype = torch.complex128 A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( - A.transpose(-2, -1), b.unsqueeze(-1), upper=True + _apply_ref_op(A, op_mode), b.unsqueeze(-1), upper=_effective_upper(True, op_mode) ).squeeze(-1) A_coo = A.to_sparse_coo().coalesce() @@ -469,9 +440,8 @@ def test_spsv_coo_transpose_complex128_routes_through_csr(n): (n, n), lower=True, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), coo_mode="auto", ) rtol, atol = _tol(dtype) assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) ->>>>>>> 5a83e0f (test) diff --git a/tests/test_scatter.py b/tests/test_scatter.py index e92b41e..604adeb 100644 --- a/tests/test_scatter.py +++ b/tests/test_scatter.py @@ -54,7 +54,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } diff --git a/tests/test_spsv.py b/tests/test_spsv.py index cdc7a31..2737db2 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -29,21 +29,6 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB -FLOAT16_LIMIT = 65504.0 -COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if COMPLEX32_DTYPE is None: - COMPLEX32_DTYPE = getattr(torch, "chalf", None) - -# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) -CSR_FULL_VALUE_DTYPES = [ - torch.float32, - torch.float64, -] -if COMPLEX32_DTYPE is not None: - CSR_FULL_VALUE_DTYPES.append(COMPLEX32_DTYPE) -CSR_FULL_VALUE_DTYPES.append(torch.complex64) -CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] - # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ torch.float32, @@ -52,12 +37,53 @@ torch.complex128, ] CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +SPSV_OP_MODES = ["NON", "TRANS", "CONJ"] def _dtype_name(dtype): return str(dtype).replace("torch.", "") +VALUE_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_VALUE_DTYPES +} +VALUE_DTYPE_NAME_MAP.update({ + "float": torch.float32, + "double": torch.float64, +}) +INDEX_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_INDEX_DTYPES +} + + +def _parse_csv_tokens(raw): + return [tok.strip() for tok in str(raw).split(",") if tok.strip()] + + +def _parse_value_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in VALUE_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return [VALUE_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_index_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in INDEX_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [INDEX_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_op_modes_filter(raw): + tokens = [tok.upper() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in SPSV_OP_MODES] + if invalid: + raise ValueError(f"unsupported ops: {invalid}") + return tokens + + def _fmt_ms(v): return "N/A" if v is None else f"{v:.4f}" @@ -73,11 +99,6 @@ def _fmt_err(v): def _tol_for_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return 5e-3, 5e-3 -======= ->>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 @@ -86,67 +107,25 @@ def _tol_for_dtype(dtype): def _randn_by_dtype(n, dtype, device): if dtype in (torch.float32, torch.float64): return torch.randn(n, dtype=dtype, device=device) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 - return torch.view_as_complex(pair) - base = torch.float32 -======= base = torch.float32 if dtype == torch.complex64 else torch.float64 ->>>>>>> 5a83e0f (test) real = torch.randn(n, dtype=base, device=device) imag = torch.randn(n, dtype=base, device=device) return torch.complex(real, imag) def _dense_ref_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return torch.complex64 -======= ->>>>>>> 5a83e0f (test) return dtype def _tensor_from_scalar_values(values, dtype, device): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = torch.clamp( - torch.tensor(values, dtype=torch.float32, device=device), - min=-FLOAT16_LIMIT, - max=FLOAT16_LIMIT, - ).to(torch.float16) - imag = torch.zeros_like(real) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return torch.tensor(values, dtype=dtype, device=device) def _safe_cast_tensor(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cast_real_tensor_to_value_dtype(values, value_dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - imag = torch.zeros_like(real) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) - return values.to(value_dtype) - - -def _cupy_ref_inputs(data, b): - if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: - return data.to(torch.complex64), b.to(torch.complex64) -======= return values.to(value_dtype) @@ -166,6 +145,9 @@ def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): if op_mode == "TRANS": A_eff = A.transpose(0, 1) upper = lower + elif op_mode == "CONJ": + A_eff = A.transpose(0, 1).conj() if torch.is_complex(A) else A.transpose(0, 1) + upper = lower else: A_eff = A upper = not lower @@ -175,16 +157,10 @@ def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): def _cupy_ref_inputs(data, b): ->>>>>>> 5a83e0f (test) return data, b def _compare_view(tensor, value_dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - return tensor.to(torch.complex64) -======= ->>>>>>> 5a83e0f (test) return tensor @@ -192,9 +168,9 @@ def _supported_csr_full_ops(value_dtype, index_dtype): if value_dtype not in CSR_FULL_VALUE_DTYPES: return [] if index_dtype == torch.int32: - return ["NON", "TRANS"] + return ["NON", "TRANS", "CONJ"] if index_dtype == torch.int64: - return ["NON"] + return ["NON", "TRANS", "CONJ"] return [] @@ -215,17 +191,10 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True base_real_dtype = torch.float32 elif value_dtype == torch.float64: base_real_dtype = torch.float64 -<<<<<<< HEAD - elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - base_real_dtype = torch.float16 - else: - base_real_dtype = torch.float32 -======= elif value_dtype == torch.complex64: base_real_dtype = torch.float32 else: base_real_dtype = torch.float64 ->>>>>>> 5a83e0f (test) for i in range(n): if lower: @@ -272,14 +241,7 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) -<<<<<<< HEAD - vals_t = _cast_real_tensor_to_value_dtype( - torch.tensor(vals_host, dtype=base_real_dtype, device=device), - value_dtype, - ) -======= vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) ->>>>>>> 5a83e0f (test) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -293,16 +255,15 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True def _csr_to_dense(data, indices, indptr, shape): n_rows, n_cols = shape - coo_data = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=coo_data.device, dtype=torch.int64), + torch.arange(n_rows, device=data.device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) coo = torch.sparse_coo_tensor( torch.stack([row_ind, indices.to(torch.int64)]), - coo_data, + data, (n_rows, n_cols), - device=coo_data.device, + device=data.device, ).coalesce() return coo.to_dense() @@ -317,7 +278,7 @@ def _csr_to_coo(data, indices, indptr, shape): return data, row, col -def _csr_transpose(data, indices, indptr, shape): +def _csr_transpose(data, indices, indptr, shape, conjugate=False): n_rows, n_cols = int(shape[0]), int(shape[1]) if data.numel() == 0: return ( @@ -337,18 +298,15 @@ def _csr_transpose(data, indices, indptr, shape): row_t = row_t[order] col_t = col_t[order] - data_t = data[order] + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t = data_eff[order] nnz_per_row = torch.bincount(row_t, minlength=n_cols) indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) return data_t, col_t.to(torch.int64), indptr_t -<<<<<<< HEAD -def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): -======= def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): ->>>>>>> 5a83e0f (test) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(file_path, "r", encoding="utf-8") as f: @@ -449,29 +407,6 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: - data_ref = data.to(torch.complex64) - x_ref = x_true.to(torch.complex64) - if op_mode == "NON": - b_ref, _ = fs.flagsparse_spmv_csr( - data_ref, indices, indptr, x_ref, shape, return_time=True - ) - return _safe_cast_tensor(b_ref, x_true.dtype) - if op_mode == "TRANS": - data_t, indices_t, indptr_t = _csr_transpose(data_ref, indices, indptr, shape) - b_ref, _ = fs.flagsparse_spmv_csr( - data_t, - indices_t.to(indices.dtype), - indptr_t.to(indptr.dtype), - x_ref, - (shape[1], shape[0]), - return_time=True, - ) - return _safe_cast_tensor(b_ref, x_true.dtype) - raise ValueError("op_mode must be 'NON' or 'TRANS'") -======= ->>>>>>> 5a83e0f (test) if op_mode == "NON": b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True @@ -488,7 +423,41 @@ def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): return_time=True, ) return b - raise ValueError("op_mode must be 'NON' or 'TRANS'") + if op_mode == "CONJ": + data_h, indices_h, indptr_h = _csr_transpose( + data, indices, indptr, shape, conjugate=True + ) + b, _ = fs.flagsparse_spmv_csr( + data_h, + indices_h.to(indices.dtype), + indptr_h.to(indptr.dtype), + x_true, + (shape[1], shape[0]), + return_time=True, + ) + return b + raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") + + +def _known_solution_metrics(data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode): + atol, rtol = _tol_for_dtype(value_dtype) + x_cmp = _compare_view(x, value_dtype) + x_true_cmp = _compare_view(x_true, value_dtype) + err_x = ( + float(torch.max(torch.abs(x_cmp - x_true_cmp)).item()) + if x.numel() > 0 + else 0.0 + ) + ok_x = torch.allclose(x_cmp, x_true_cmp, atol=atol, rtol=rtol) + + b_recon = _build_rhs_for_csr_op(data, indices, indptr, x, shape, op_mode) + err_res = ( + float(torch.max(torch.abs(b_recon - b)).item()) + if b.numel() > 0 + else 0.0 + ) + ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) + return err_x, ok_x, err_res, ok_res def _cupy_spsolve_lower_csr_or_coo( @@ -546,24 +515,13 @@ def _cupy_spsolve_lower_csr_or_coo( t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: - x_cu_t = x_cu_t.to(torch.complex64) - else: - x_cu_t = x_cu_t.to(b.dtype) -======= x_cu_t = x_cu_t.to(b.dtype) ->>>>>>> 5a83e0f (test) return cupy_ms, x_cu_t except Exception: return None, None -<<<<<<< HEAD -def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): -======= def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): ->>>>>>> 5a83e0f (test) if ( cp is None or cpx_sparse is None @@ -583,17 +541,13 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) if op_mode == "TRANS": A_eff = A_cp.transpose().tocsr() -<<<<<<< HEAD - lower_eff = False - else: - A_eff = A_cp - lower_eff = True -======= + lower_eff = not lower + elif op_mode == "CONJ": + A_eff = A_cp.transpose().conj().tocsr() lower_eff = not lower else: A_eff = A_cp lower_eff = lower ->>>>>>> 5a83e0f (test) for _ in range(WARMUP): _ = cpx_spsolve_triangular( @@ -616,11 +570,7 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): return None, None -<<<<<<< HEAD -def run_spsv_synthetic_all(): -======= def run_spsv_synthetic_all(lower=True): ->>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -680,7 +630,7 @@ def run_spsv_synthetic_all(lower=True): b, shape, lower=lower, - transpose=(op_mode == "TRANS"), + transpose=op_mode, return_time=True, ) else: @@ -791,7 +741,7 @@ def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): ) return _finalize_csv_row( path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, n_rows, n_cols, lower=lower, + x, t_ms, b, x_true, n_rows, n_cols, lower=lower, ) @@ -829,6 +779,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower x, t_ms, b, + x_true, n_rows, n_cols, lower=lower, @@ -850,6 +801,7 @@ def _finalize_csv_row( x, t_ms, b, + x_true, n_rows, n_cols, *, @@ -860,6 +812,9 @@ def _finalize_csv_row( cupy_coo_col=None, ): atol, rtol = _tol_for_dtype(value_dtype) + err_x, ok_x, err_res, ok_res = _known_solution_metrics( + data, indices, indptr, shape, x, x_true, b, value_dtype, "NON" + ) pytorch_ms = None err_pt = None ok_pt = False @@ -876,13 +831,9 @@ def _finalize_csv_row( e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() -<<<<<<< HEAD - x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) -======= x_ref = _triangular_solve_reference( A_ref, b_ref, lower=lower, op_mode="NON" ) ->>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -1000,21 +951,17 @@ def _finalize_csv_row( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, + "err_x": err_x, + "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, } return row, pt_skip_reason -<<<<<<< HEAD -def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): - data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device -======= def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower ->>>>>>> 5a83e0f (test) ) indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) @@ -1027,12 +974,8 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l indptr, b, shape, -<<<<<<< HEAD - lower=True, -======= lower=lower, ->>>>>>> 5a83e0f (test) - transpose=(op_mode == "TRANS"), + transpose=op_mode, return_time=True, ) return _finalize_csv_row_csr_full( @@ -1047,12 +990,10 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l x, t_ms, b, + x_true, n_rows, n_cols, -<<<<<<< HEAD -======= lower=lower, ->>>>>>> 5a83e0f (test) ) @@ -1068,14 +1009,15 @@ def _finalize_csv_row_csr_full( x, t_ms, b, + x_true, n_rows, n_cols, -<<<<<<< HEAD -======= lower=True, ->>>>>>> 5a83e0f (test) ): atol, rtol = _tol_for_dtype(value_dtype) + err_x, ok_x, err_res, ok_res = _known_solution_metrics( + data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode + ) pytorch_ms = None err_pt = None @@ -1086,21 +1028,13 @@ def _finalize_csv_row_csr_full( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(value_dtype)) -<<<<<<< HEAD - A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense -======= ->>>>>>> 5a83e0f (test) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() -<<<<<<< HEAD - x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) -======= x_ref = _triangular_solve_reference( A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode ) ->>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -1127,11 +1061,7 @@ def _finalize_csv_row_csr_full( ok_cu = False x_cu_t = None cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( -<<<<<<< HEAD - data, indices, indptr, shape, b, op_mode -======= data, indices, indptr, shape, b, op_mode, lower ->>>>>>> 5a83e0f (test) ) if x_cu_t is not None: x_cmp = _compare_view(x, value_dtype) @@ -1160,134 +1090,37 @@ def _finalize_csv_row_csr_full( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, + "err_x": err_x, + "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, } return row, pt_skip_reason -<<<<<<< HEAD -def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): - if not torch.cuda.is_available(): - print("CUDA is not available.") - return - device = torch.device("cuda") - rows_out = [] - for value_dtype in CSR_FULL_VALUE_DTYPES: - for index_dtype in CSR_FULL_INDEX_DTYPES: - op_modes = _supported_csr_full_ops(value_dtype, index_dtype) - for op_mode in op_modes: - print("=" * 150) - print( - f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | opA={op_mode}" - ) - print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." - ) - print( - "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." - ) - print("-" * 150) - print( - f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" - ) - print("-" * 150) - for path in mtx_paths: - try: - row, pt_skip = _run_one_csv_row_csr_full( - path, value_dtype, index_dtype, op_mode, device - ) - rows_out.append(row) - name = os.path.basename(path)[:27] - if len(os.path.basename(path)) > 27: - name = name + "…" - n_rows, n_cols = row["n_rows"], row["n_cols"] - nnz = row["nnz"] - t_ms = row["triton_ms"] - cupy_ms = row["cusparse_ms"] - pytorch_ms = row["pytorch_ms"] - err_pt, err_cu = row["err_pt"], row["err_cu"] - status = row["status"] - print( - f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " - f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " - f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" - ) - if pt_skip: - print(f" NOTE: {pt_skip}") - except Exception as e: - err_msg = str(e) - status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" - rows_out.append( - { - "matrix": os.path.basename(path), - "value_dtype": _dtype_name(value_dtype), - "index_dtype": _dtype_name(index_dtype), - "opA": op_mode, - "n_rows": "ERR", - "n_cols": "ERR", - "nnz": "ERR", - "triton_ms": None, - "pytorch_ms": None, - "cusparse_ms": None, - "csc_ms": None, - "status": status, - "err_pt": None, - "err_cu": None, - } - ) - name = os.path.basename(path)[:27] - if len(os.path.basename(path)) > 27: - name = name + "…" - print( - f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " - f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " - f"{'N/A':>7} {'N/A':>7} " - f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" - ) - print(f" {status}: {e}") - print("-" * 150) - fieldnames = [ - "matrix", - "value_dtype", - "index_dtype", - "opA", - "n_rows", - "n_cols", - "nnz", - "triton_ms", - "pytorch_ms", - "cusparse_ms", - "csc_ms", - "status", - "err_pt", - "err_cu", - ] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - w = csv.DictWriter(f, fieldnames=fieldnames) - w.writeheader() - for r in rows_out: - w.writerow(r) - print(f"Wrote {len(rows_out)} rows to {csv_path}") - - -def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): -======= -def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): ->>>>>>> 5a83e0f (test) +def run_all_supported_spsv_csr_csv( + mtx_paths, + csv_path, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): if not torch.cuda.is_available(): print("CUDA is not available.") return device = torch.device("cuda") rows_out = [] - for value_dtype in CSR_FULL_VALUE_DTYPES: - for index_dtype in CSR_FULL_INDEX_DTYPES: - op_modes = _supported_csr_full_ops(value_dtype, index_dtype) - for op_mode in op_modes: + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = op_modes or SPSV_OP_MODES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + supported_op_modes = [ + op for op in _supported_csr_full_ops(value_dtype, index_dtype) + if op in selected_op_modes + ] + for op_mode in supported_op_modes: print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" @@ -1296,14 +1129,15 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( + "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." + "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1320,13 +1154,14 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] + err_x, err_res = row["err_x"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1347,6 +1182,8 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "cusparse_ms": None, "csc_ms": None, "status": status, + "err_x": None, + "err_res": None, "err_pt": None, "err_cu": None, } @@ -1358,7 +1195,7 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " f"{'N/A':>7} {'N/A':>7} " - f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") print("-" * 150) @@ -1375,6 +1212,8 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "cusparse_ms", "csc_ms", "status", + "err_x", + "err_res", "err_pt", "err_cu", ] @@ -1416,14 +1255,15 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( + "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." + "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {cu_col:>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1445,13 +1285,14 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] + err_x, err_res = row["err_x"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1471,6 +1312,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "cusparse_ms": None, "csc_ms": None, "status": status, + "err_x": None, + "err_res": None, "err_pt": None, "err_cu": None, } @@ -1481,7 +1324,7 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", print( f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " - f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") print("-" * 150) @@ -1497,6 +1340,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "cusparse_ms", "csc_ms", "status", + "err_x", + "err_res", "err_pt", "err_cu", ] @@ -1546,6 +1391,24 @@ def main(): action="store_true", help="Use upper-triangular inputs instead of the default lower-triangular inputs", ) + parser.add_argument( + "--ops", + type=str, + default=None, + help="Comma-separated opA filter for CSR CSV, e.g. TRANS,CONJ", + ) + parser.add_argument( + "--value-dtypes", + type=str, + default=None, + help="Comma-separated value dtype filter for CSR CSV, e.g. float,double,complex64,complex128", + ) + parser.add_argument( + "--index-dtypes", + type=str, + default=None, + help="Comma-separated index dtype filter for CSR CSV, e.g. int32,int64", + ) args = parser.parse_args() lower = not args.upper @@ -1565,11 +1428,29 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return -<<<<<<< HEAD - run_all_supported_spsv_csr_csv(paths, args.csv_csr) -======= - run_all_supported_spsv_csr_csv(paths, args.csv_csr, lower=lower) ->>>>>>> 5a83e0f (test) + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_all_supported_spsv_csr_csv( + paths, + args.csv_csr, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) return if args.csv_coo: if not paths: From c8768642922531701cbbf7771dc50559d48b1e31 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 11:38:57 +0800 Subject: [PATCH 4/5] test --- tests/test_spsv.py | 119 +++++++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 59 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 2737db2..b64e6ff 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -8,9 +8,16 @@ import csv import glob import os +import sys +from pathlib import Path import torch +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + import flagsparse as fs try: @@ -268,13 +275,13 @@ def _csr_to_dense(data, indices, indptr, shape): return coo.to_dense() -def _csr_to_coo(data, indices, indptr, shape): +def _csr_to_coo(data, indices, indptr, shape, index_dtype=torch.int64): n_rows = int(shape[0]) row = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), + torch.arange(n_rows, device=data.device, dtype=index_dtype), indptr[1:] - indptr[:-1], ) - col = indices.to(torch.int64) + col = indices.to(index_dtype) return data, row, col @@ -387,9 +394,11 @@ def _accum(r, c, v): return data, indices, indptr, (n_rows, n_cols) -def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): +def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torch.int64): """Sorted COO from CSR; optional shuffle/duplicate for csr|auto (与原先 CSV 行为一致).""" - data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) + data_c, row_c, col_c = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) if coo_mode in ("csr", "auto"): if data_c.numel() == 0: return data_c, row_c, col_c @@ -634,7 +643,9 @@ def run_spsv_synthetic_all(lower=True): return_time=True, ) else: - dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) + dc, rr, cc = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) x, t_ms = fs.flagsparse_spsv_coo( dc, rr, @@ -726,25 +737,6 @@ def run_spsv_synthetic_all(lower=True): print(sep) -def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): - data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device, lower=lower - ) - indices = indices.to(index_dtype) - n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True - ) - x, t_ms = fs.flagsparse_spsv_csr( - data, indices, indptr, b, shape, lower=lower, return_time=True - ) - return _finalize_csv_row( - path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, x_true, n_rows, n_cols, lower=lower, - ) - - def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower @@ -756,7 +748,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower data, indices, indptr, x_true, shape, return_time=True ) d_in, r_in, c_in = _coo_inputs_for_csv( - data, indices, indptr, shape, coo_mode + data, indices, indptr, shape, coo_mode, index_dtype=index_dtype ) x, t_ms = fs.flagsparse_spsv_coo( d_in, @@ -1225,35 +1217,32 @@ def run_all_supported_spsv_csr_csv( print(f"Wrote {len(rows_out)} rows to {csv_path}") -def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", lower=True): +def run_all_dtypes_spsv_coo_csv( + mtx_paths, + csv_path, + coo_mode="auto", + lower=True, + value_dtypes=None, + index_dtypes=None, +): if not torch.cuda.is_available(): print("CUDA is not available.") return - if not use_coo: - run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=lower) - return device = torch.device("cuda") rows_out = [] - label = "COO" if use_coo else "CSR" - cu_col = "COO(ms)" if use_coo else "CSR(ms)" - fs_cu_hdr = "FS/COO" if use_coo else "FS/CSR" - for value_dtype in VALUE_DTYPES: - for index_dtype in INDEX_DTYPES: - atol, rtol = _tol_for_dtype(value_dtype) + selected_value_dtypes = value_dtypes or VALUE_DTYPES + selected_index_dtypes = index_dtypes or INDEX_DTYPES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: print("=" * 150) print( - f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | {label}" - + (f" triA={'LOWER' if lower else 'UPPER'}" if not use_coo else f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}") + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | COO" + f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" + ) + print( + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " + "b 由 CSR SpMV 构造,与 CSR 测试一致。" ) - if use_coo: - print( - "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " - "b 由 CSR SpMV 构造,与 CSR 测试一致。" - ) - else: - print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." - ) print( "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " @@ -1262,20 +1251,15 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FlagSparse(ms)':>10} {cu_col:>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FlagSparse(ms)':>10} {'COO(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: try: - if use_coo: - row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, device, coo_mode, lower=lower - ) - else: - row, pt_skip = _run_one_csv_row_csr( - path, value_dtype, index_dtype, device, lower=lower - ) + row, pt_skip = _run_one_csv_row_coo( + path, value_dtype, index_dtype, device, coo_mode, lower=lower + ) rows_out.append(row) name = os.path.basename(path)[:27] if len(os.path.basename(path)) > 27: @@ -1453,13 +1437,30 @@ def main(): ) return if args.csv_coo: + if args.ops: + parser.error("--ops is only supported with --csv-csr; COO tests only run opA=NON") if not paths: paths = sorted(glob.glob("*.mtx")) if not paths: print("No .mtx files found for --csv-coo") return - run_all_dtypes_spsv_csv( - paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode, lower=lower + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + run_all_dtypes_spsv_coo_csv( + paths, + args.csv_coo, + coo_mode=args.coo_mode, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, ) return From 2b103136a247f98ec0b4d8d7a5d0f7c6335dcc9f Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 15:35:59 +0800 Subject: [PATCH 5/5] Update spsv csr tests --- src/flagsparse/sparse_operations/spsv.py | 33 +- tests/test_spsv.py | 469 ++++++++++++++++++----- 2 files changed, 379 insertions(+), 123 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 5f6174c..ffdb6ea 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -9,12 +9,10 @@ import triton.language as tl SUPPORTED_SPSV_VALUE_DTYPES = ( - torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128, - ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_SUPPORTED_COMBOS = ( @@ -68,12 +66,10 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: return - if data_dtype == torch.bfloat16 and index_dtype == torch.int32: - return raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " - "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" + "(complex64, int32/int64), (complex128, int32/int64)" ) @@ -126,7 +122,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( - "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + "data dtype must be one of: float32, float64, complex64, complex128" ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -696,14 +692,13 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") if data.dtype not in ( - torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128, ): raise TypeError( - "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + "data dtype must be one of: float32, float64, complex64, complex128" ) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") @@ -796,13 +791,8 @@ def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) out_nnz = unique_key.numel() - if data_s.dtype == torch.bfloat16: - reduced_f32 = torch.zeros(out_nnz, dtype=torch.float32, device=data.device) - reduced_f32.scatter_add_(0, inverse, data_s.to(torch.float32)) - data_u = reduced_f32.to(torch.bfloat16) - else: - data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) - data_u.scatter_add_(0, inverse, data_s) + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") col_u = unique_key - row_u * max(1, n_cols) @@ -877,7 +867,6 @@ def flagsparse_spsv_csr( Primary support matrix: - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices - - bfloat16 remains NON_TRANS + int32 """ input_data = data input_indices = indices @@ -926,11 +915,7 @@ def flagsparse_spsv_csr( compute_dtype = data.dtype data_in = kernel_data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = kernel_data.to(torch.float32) - b_in = b.to(torch.float32) - elif data.dtype == torch.complex64 and trans_mode in ("T", "C"): + if data.dtype == torch.complex64 and trans_mode in ("T", "C"): compute_dtype = torch.complex128 data_in = kernel_data.to(torch.complex128) b_in = b.to(torch.complex128) @@ -1123,11 +1108,7 @@ def flagsparse_spsv_coo( compute_dtype = data.dtype data_in = data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = data.to(torch.float32) - b_in = b.to(torch.float32) - elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + if data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: compute_dtype = torch.float64 data_in = data.to(torch.float64) b_in = b.to(torch.float64) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index b64e6ff..cb28cb4 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -7,6 +7,7 @@ import argparse import csv import glob +import hashlib import os import sys from pathlib import Path @@ -19,6 +20,7 @@ sys.path.insert(0, str(_SRC_ROOT)) import flagsparse as fs +import flagsparse.sparse_operations.spsv as fs_spsv_impl try: import cupy as cp @@ -36,6 +38,7 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ torch.float32, @@ -111,12 +114,25 @@ def _tol_for_dtype(dtype): return 1e-12, 1e-10 -def _randn_by_dtype(n, dtype, device): +def _stable_case_seed(*parts): + raw = "|".join(str(part) for part in parts).encode("utf-8") + return int.from_bytes(hashlib.sha256(raw).digest()[:8], "little") % (2**63) + + +def _generator_for_seed(seed): + if seed is None: + return None + gen = torch.Generator() + gen.manual_seed(int(seed)) + return gen + + +def _randn_by_dtype(n, dtype, device, generator=None): if dtype in (torch.float32, torch.float64): - return torch.randn(n, dtype=dtype, device=device) + return torch.randn(n, dtype=dtype, device=device, generator=generator) base = torch.float32 if dtype == torch.complex64 else torch.float64 - real = torch.randn(n, dtype=base, device=device) - imag = torch.randn(n, dtype=base, device=device) + real = torch.randn(n, dtype=base, device=device, generator=generator) + imag = torch.randn(n, dtype=base, device=device, generator=generator) return torch.complex(real, imag) @@ -194,6 +210,8 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_host = [] cols_host = [] vals_host = [] + row_off_abs = [0.0] * n + col_off_abs = [0.0] * n if value_dtype == torch.float32: base_real_dtype = torch.float32 elif value_dtype == torch.float64: @@ -223,28 +241,30 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), ) - sum_abs = ( - float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 - ) - diag_imag = float( - torch.randn((), dtype=base_real_dtype, device=device).mul_(0.05).item() - ) - diag_val = complex(sum_abs + 1.0, diag_imag) off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] else: off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) - sum_abs = ( - float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 - ) - diag_val = sum_abs + 1.0 off_vals_host = off_vals.cpu().tolist() - rows_host.append(i) - cols_host.append(diag_col) - vals_host.append(diag_val) for c, v in zip(off_cols, off_vals_host): rows_host.append(i) cols_host.append(int(c)) vals_host.append(v) + mag = abs(v) + row_off_abs[i] += mag + col_off_abs[int(c)] += mag + + for i in range(n): + diag_mag = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[i], col_off_abs[i]) + 1.0 + ) + diag_val = ( + complex(diag_mag, 0.0) + if value_dtype in (torch.complex64, torch.complex128) + else diag_mag + ) + rows_host.append(i) + cols_host.append(i) + vals_host.append(diag_val) rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) @@ -367,17 +387,27 @@ def _accum(r, c, v): elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) + tri_rows = [dict() for _ in range(n_rows)] + row_off_abs = [0.0] * n_rows + col_off_abs = [0.0] * n_cols for r in range(n_rows): - row = row_maps[r] - tri_row = {} - off_abs_sum = 0.0 - for c, v in row.items(): + for c, v in row_maps[r].items(): keep = c < r if lower else c > r if keep: - tri_row[c] = tri_row.get(c, 0.0) + v - off_abs_sum += abs(v) - tri_row[r] = off_abs_sum + 1.0 - row_maps[r] = tri_row + tri_rows[r][c] = tri_rows[r].get(c, 0.0) + v + + for r, row in enumerate(tri_rows): + for c, v in row.items(): + mag = abs(v) + row_off_abs[r] += mag + col_off_abs[c] += mag + + for r in range(n_rows): + # Make the generated triangular system stable for both A and op(A). + tri_rows[r][r] = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[r], col_off_abs[r]) + 1.0 + ) + row_maps = tri_rows cols_s = [] vals_s = [] @@ -415,58 +445,53 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torc return data_c, row_c, col_c -def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): +def _random_rhs_for_spsv(shape, value_dtype, device, op_mode="NON", seed=None): + n_rows, n_cols = int(shape[0]), int(shape[1]) + rhs_size = n_rows if op_mode == "NON" else n_cols + if seed is None: + return _randn_by_dtype(rhs_size, value_dtype, device) + rhs = _randn_by_dtype( + rhs_size, + value_dtype, + torch.device("cpu"), + generator=_generator_for_seed(seed), + ) + return rhs.to(device) + + +def _apply_csr_op(data, indices, indptr, x, shape, op_mode): + n_rows, n_cols = int(shape[0]), int(shape[1]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + col = indices.to(torch.int64) if op_mode == "NON": - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True - ) + b = torch.zeros(n_rows, dtype=data.dtype, device=data.device) + b.scatter_add_(0, row, data * x[col]) return b if op_mode == "TRANS": - data_t, indices_t, indptr_t = _csr_transpose(data, indices, indptr, shape) - b, _ = fs.flagsparse_spmv_csr( - data_t, - indices_t.to(indices.dtype), - indptr_t.to(indptr.dtype), - x_true, - (shape[1], shape[0]), - return_time=True, - ) + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + b.scatter_add_(0, col, data * x[row]) return b if op_mode == "CONJ": - data_h, indices_h, indptr_h = _csr_transpose( - data, indices, indptr, shape, conjugate=True - ) - b, _ = fs.flagsparse_spmv_csr( - data_h, - indices_h.to(indices.dtype), - indptr_h.to(indptr.dtype), - x_true, - (shape[1], shape[0]), - return_time=True, - ) + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + data_eff = data.conj() if torch.is_complex(data) else data + b.scatter_add_(0, col, data_eff * x[row]) return b raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") -def _known_solution_metrics(data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode): +def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, op_mode): atol, rtol = _tol_for_dtype(value_dtype) - x_cmp = _compare_view(x, value_dtype) - x_true_cmp = _compare_view(x_true, value_dtype) - err_x = ( - float(torch.max(torch.abs(x_cmp - x_true_cmp)).item()) - if x.numel() > 0 - else 0.0 - ) - ok_x = torch.allclose(x_cmp, x_true_cmp, atol=atol, rtol=rtol) - - b_recon = _build_rhs_for_csr_op(data, indices, indptr, x, shape, op_mode) + b_recon = _apply_csr_op(data, indices, indptr, x, shape, op_mode) err_res = ( float(torch.max(torch.abs(b_recon - b)).item()) if b.numel() > 0 else 0.0 ) ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) - return err_x, ok_x, err_res, ok_res + return err_res, ok_res def _cupy_spsolve_lower_csr_or_coo( @@ -624,11 +649,22 @@ def run_spsv_synthetic_all(lower=True): A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr, shape ) - x_true = _randn_by_dtype(n, value_dtype, device) - if fmt == "CSR": - b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) - else: - b = A_dense @ x_true + rhs_op = op_mode if fmt == "CSR" else "NON" + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=rhs_op, + seed=_stable_case_seed( + "synthetic", + "LOWER" if lower else "UPPER", + fmt, + op_mode, + n, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) torch.cuda.synchronize() if fmt == "CSR": @@ -743,9 +779,19 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "csv-coo", + os.path.basename(path), + "LOWER" if lower else "UPPER", + coo_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), ) d_in, r_in, c_in = _coo_inputs_for_csv( data, indices, indptr, shape, coo_mode, index_dtype=index_dtype @@ -771,7 +817,6 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower x, t_ms, b, - x_true, n_rows, n_cols, lower=lower, @@ -793,7 +838,6 @@ def _finalize_csv_row( x, t_ms, b, - x_true, n_rows, n_cols, *, @@ -804,8 +848,8 @@ def _finalize_csv_row( cupy_coo_col=None, ): atol, rtol = _tol_for_dtype(value_dtype) - err_x, ok_x, err_res, ok_res = _known_solution_metrics( - data, indices, indptr, shape, x, x_true, b, value_dtype, "NON" + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, "NON" ) pytorch_ms = None err_pt = None @@ -927,6 +971,8 @@ def _finalize_csv_row( status = "PASS" if (ok_pt or ok_cu) else "FAIL" if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None nnz_out = ( int(data.numel()) if nnz_display is None else int(nnz_display) @@ -943,7 +989,7 @@ def _finalize_csv_row( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, - "err_x": err_x, + "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, @@ -958,8 +1004,20 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "csv-csr", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) x, t_ms = fs.flagsparse_spsv_csr( data, indices, @@ -982,7 +1040,6 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l x, t_ms, b, - x_true, n_rows, n_cols, lower=lower, @@ -1001,14 +1058,13 @@ def _finalize_csv_row_csr_full( x, t_ms, b, - x_true, n_rows, n_cols, lower=True, ): atol, rtol = _tol_for_dtype(value_dtype) - err_x, ok_x, err_res, ok_res = _known_solution_metrics( - data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, op_mode ) pytorch_ms = None @@ -1068,6 +1124,8 @@ def _finalize_csv_row_csr_full( status = "PASS" if (ok_pt or ok_cu) else "FAIL" if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None row = { "matrix": os.path.basename(path), @@ -1082,7 +1140,7 @@ def _finalize_csv_row_csr_full( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, - "err_x": err_x, + "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, @@ -1121,15 +1179,16 @@ def run_all_supported_spsv_csr_csv( "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( - "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " + "RHS is generated directly, matching Library-main's SpSV test style. " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|op(A)*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1146,14 +1205,14 @@ def run_all_supported_spsv_csr_csv( t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] - err_x, err_res = row["err_x"], row["err_res"] + err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1174,7 +1233,7 @@ def run_all_supported_spsv_csr_csv( "cusparse_ms": None, "csc_ms": None, "status": status, - "err_x": None, + "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, @@ -1204,7 +1263,7 @@ def run_all_supported_spsv_csr_csv( "cusparse_ms", "csc_ms", "status", - "err_x", + "err_ref", "err_res", "err_pt", "err_cu", @@ -1241,18 +1300,18 @@ def run_all_dtypes_spsv_coo_csv( ) print( "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " - "b 由 CSR SpMV 构造,与 CSR 测试一致。" + "RHS is generated directly, matching Library-main's SpSV test style." ) print( - "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'COO(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1269,14 +1328,14 @@ def run_all_dtypes_spsv_coo_csv( t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] - err_x, err_res = row["err_x"], row["err_res"] + err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1296,7 +1355,7 @@ def run_all_dtypes_spsv_coo_csv( "cusparse_ms": None, "csc_ms": None, "status": status, - "err_x": None, + "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, @@ -1324,7 +1383,7 @@ def run_all_dtypes_spsv_coo_csv( "cusparse_ms", "csc_ms", "status", - "err_x", + "err_ref", "err_res", "err_pt", "err_cu", @@ -1337,6 +1396,188 @@ def run_all_dtypes_spsv_coo_csv( print(f"Wrote {len(rows_out)} rows to {csv_path}") +def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + trans_data, trans_indices64, trans_indptr64 = fs_spsv_impl._csr_transpose( + data, + indices.to(torch.int64), + indptr.to(torch.int64), + n_rows, + n_cols, + conjugate=(op_mode == "CONJ"), + ) + trans_shape = (n_cols, n_rows) + trans_indices = trans_indices64.to(index_dtype) + trans_indptr = trans_indptr64.to(index_dtype) + + probe = _random_rhs_for_spsv( + trans_shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "check-transpose-action", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + action_ref = _apply_csr_op(data, indices, indptr, probe, shape, op_mode) + action_trans = _apply_csr_op( + trans_data, trans_indices, trans_indptr, probe, trans_shape, "NON" + ) + action_err = ( + float(torch.max(torch.abs(action_trans - action_ref)).item()) + if action_ref.numel() > 0 + else 0.0 + ) + atol, rtol = _tol_for_dtype(value_dtype) + action_ok = torch.allclose(action_trans, action_ref, atol=atol, rtol=rtol) + + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "check-transpose-solve", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x_op = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + x_mat = fs.flagsparse_spsv_csr( + trans_data, + trans_indices, + trans_indptr, + b, + trans_shape, + lower=not lower, + transpose="NON", + ) + solve_err = ( + float(torch.max(torch.abs(x_op - x_mat)).item()) if x_op.numel() > 0 else 0.0 + ) + solve_ok = torch.allclose(x_op, x_mat, atol=atol, rtol=rtol) + + ref_err = None + ref_ok = None + if _allow_dense_pytorch_ref(shape, value_dtype): + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + ref_err = ( + float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 + ) + ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) + + status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" + return { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "nnz": int(data.numel()), + "action_err": action_err, + "solve_err": solve_err, + "ref_err": ref_err, + "status": status, + } + + +def run_csr_transpose_check( + mtx_paths, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = [op for op in (op_modes or ("TRANS", "CONJ")) if op in ("TRANS", "CONJ")] + if not selected_op_modes: + print("--check-transpose only checks TRANS/CONJ; no matching op selected.") + return + + print("=" * 150) + print( + "CSR TRANS/CONJ preprocessing check: " + "ActionErr compares materialized op(A) against direct CSR scatter; " + "SolveErr compares transpose path against materialized NON path." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'dtype':>10} {'index':>7} {'opA':>5} " + f"{'N':>7} {'NNZ':>10} {'Status':>6} {'ActionErr':>10} {'SolveErr':>10} {'RefErr':>10}" + ) + print("-" * 150) + total = 0 + failed = 0 + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + for op_mode in selected_op_modes: + for path in mtx_paths: + try: + row = _check_one_csr_transpose_case( + path, + value_dtype, + index_dtype, + op_mode, + device, + lower=lower, + ) + total += 1 + failed += int(row["status"] != "PASS") + name = row["matrix"][:27] + if len(row["matrix"]) > 27: + name += "..." + print( + f"{name:<28} {row['value_dtype']:>10} {row['index_dtype']:>7} {row['opA']:>5} " + f"{row['n_rows']:>7} {row['nnz']:>10} {row['status']:>6} " + f"{_fmt_err(row['action_err']):>10} {_fmt_err(row['solve_err']):>10} {_fmt_err(row['ref_err']):>10}" + ) + except Exception as e: + total += 1 + failed += 1 + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name += "..." + print( + f"{name:<28} {_dtype_name(value_dtype):>10} {_dtype_name(index_dtype):>7} {op_mode:>5} " + f"{'ERR':>7} {'ERR':>10} {'ERROR':>6} " + f"{_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" ERROR: {e}") + print("-" * 150) + print(f"Total cases: {total} Failed: {failed}") + + def main(): parser = argparse.ArgumentParser( description="SpSV test: synthetic triangular systems and optional .mtx (CSR/COO), same baselines as CSR." @@ -1363,6 +1604,11 @@ def main(): metavar="FILE", help="Run all dtypes on .mtx (COO SpSV), same CSV columns as --csv-csr", ) + parser.add_argument( + "--check-transpose", + action="store_true", + help="Check CSR TRANS/CONJ preprocessing against direct CSR scatter and materialized NON solve", + ) parser.add_argument( "--coo-mode", type=str, @@ -1406,6 +1652,35 @@ def main(): paths.append(p) elif os.path.isdir(p): paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if args.check_transpose: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --check-transpose") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_csr_transpose_check( + paths, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) + return if args.csv_csr: if not paths: paths = sorted(glob.glob("*.mtx"))