From e22c4f2052ce46f905f8fea1f00f93236a234387 Mon Sep 17 00:00:00 2001 From: "Corey R. Randall" Date: Tue, 18 Mar 2025 12:30:03 -0600 Subject: [PATCH 1/2] Fix typos in IDA iterative solvers, add tests --- .../src/scikits_odes_sundials/ida.pyx | 8 +-- .../tests/test_iterative.py | 50 +++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/ida.pyx b/packages/scikits-odes-sundials/src/scikits_odes_sundials/ida.pyx index f63f10b..4eafc2c 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/ida.pyx +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/ida.pyx @@ -540,7 +540,7 @@ cdef int _prec_solvefn(sunrealtype tt, N_Vector yy, N_Vector yp, N_Vector r, yp_tmp = aux_data.yp_tmp residual_tmp = aux_data.residual_tmp - if aux_data.r_vec is None: + if aux_data.rvec_tmp is None: N = len(yy_tmp) aux_data.rvec_tmp = np.empty(N, DTYPE) @@ -557,7 +557,7 @@ cdef int _prec_solvefn(sunrealtype tt, N_Vector yy, N_Vector yp, N_Vector r, nv_s2ndarray(rvec, rvec_tmp) nv_s2ndarray(z, z_tmp) - user_flag = aux_data.prec_solvefn.evaluate(tt, yy_tmp, residual_tmp, + user_flag = aux_data.prec_solvefn.evaluate(tt, yy_tmp, yp_tmp, residual_tmp, rvec_tmp, z_tmp, cj, delta, aux_data.user_data) @@ -878,7 +878,7 @@ cdef class IDA(BaseSundialsSolver): Absolute tolerancy 'linsolver': Values: 'dense' (= default), 'lapackdense', 'band', - 'lapackband', 'spgmr', 'spbcg', 'sptfqmr' + 'lapackband', 'spgmr', 'spbcgs', 'sptfqmr' Description: Specifies used linear solver. Limitations: Linear solvers for dense and band matrices can @@ -898,7 +898,7 @@ cdef class IDA(BaseSundialsSolver): Values: 0 (= default), 1, 2, 3, 4, 5 Description: Dimension of the number of used Krylov subspaces - (used only by 'spgmr', 'spbcg', 'sptfqmr' linsolvers) + (used only by 'spgmr', 'spbcgs', 'sptfqmr' linsolvers) 'tstop': Values: float, 0.0 = default Description: diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py b/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py new file mode 100644 index 0000000..9c395a8 --- /dev/null +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py @@ -0,0 +1,50 @@ +import pytest +import numpy as np + +from scikits_odes_sundials.ida import IDA + + +def resfn(t, y, yp, res, user_data): + res[0] = yp[0] + 0.04*y[0] - 1e4*y[1]*y[2] + res[1] = yp[1] - 0.04*y[0] + 1e4*y[1]*y[2] + 3e7*y[1]**2 + res[2] = y[0] + y[1] + y[2] - 1 + + +def jacfn(t, y, yp, res, cj, JJ, user_data): + JJ[0,0] = 0.04 + cj + JJ[0,1] = -1e4*y[2] + JJ[0,2] = -1e4*y[1] + JJ[1,0] = -0.04 + JJ[1,1] = 1e4*y[2] + 6e7*y[1] + cj + JJ[1,2] = 1e4*y[1] + JJ[2,0] = 1 + JJ[2,1] = 1 + JJ[2,2] = 1 + + +def prec_setupfn(t, y, yp, res, cj, user_data): + P = user_data['precond'] + jacfn(t, y, yp, res, cj, P, user_data) + + +def prec_solvefn(t, y, yp, res, rvec, zvec, cj, delta, user_data): + P = user_data['precond'] + zvec[:] = np.linalg.solve(P, rvec) + + +@pytest.mark.parametrize('linsolver', ('spgmr', 'spbcgs', 'sptfqmr')) +def test_iterative_solvers(linsolver): + tspan = np.logspace(-6, 6, 50) + y0 = np.array([1, 0, 0]) + yp0 = np.zeros_like(y0) + + user_data = {'precond': np.zeros((y0.size, y0.size))} + + solver = IDA(resfn, algebraic_vars_idx=[2], compute_initcond='yp0', + atol=1e-8, linsolver=linsolver, precond_type='left', + prec_setupfn=prec_setupfn, prec_solvefn=prec_solvefn, + user_data=user_data) + + soln = solver.solve(tspan, y0, yp0) + assert soln.flag == 0 + \ No newline at end of file From 4e7045b32a3bc53bb2ee88df9c4674e3484a3232 Mon Sep 17 00:00:00 2001 From: "Corey R. Randall" Date: Tue, 18 Mar 2025 13:00:47 -0600 Subject: [PATCH 2/2] Skip iterative tests in extended precision b/c np.linalg.solve does not support longdouble --- .../src/scikits_odes_sundials/tests/test_iterative.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py b/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py index 9c395a8..17d4104 100644 --- a/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py +++ b/packages/scikits-odes-sundials/src/scikits_odes_sundials/tests/test_iterative.py @@ -2,6 +2,7 @@ import numpy as np from scikits_odes_sundials.ida import IDA +from scikits_odes_sundials.common_defs import DTYPE def resfn(t, y, yp, res, user_data): @@ -45,6 +46,10 @@ def test_iterative_solvers(linsolver): prec_setupfn=prec_setupfn, prec_solvefn=prec_solvefn, user_data=user_data) - soln = solver.solve(tspan, y0, yp0) - assert soln.flag == 0 + # np.linalg.solve does not support extended precision + if DTYPE == np.longdouble: + pass + else: + soln = solver.solve(tspan, y0, yp0) + assert soln.flag == 0 \ No newline at end of file