From eabc46fa9563ce12665e62e32c09f5e0b0e20901 Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Wed, 21 Nov 2018 12:24:46 +1100 Subject: [PATCH 1/3] ENH: Add some useful SUNDIALS user functions --- docs/solvers.rst | 14 +++++ scikits/odes/sundials/__init__.py | 88 +++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/docs/solvers.rst b/docs/solvers.rst index 55d9bbe4..3ad752c7 100644 --- a/docs/solvers.rst +++ b/docs/solvers.rst @@ -37,3 +37,17 @@ A comparison of different methods is given in following image. In this BDF, RK23 .. image:: ../ipython_examples/PerformanceTests.png You can generate above graph via the `Performance notebook `_. + +Solver Specific Options +####################### + +The high level interfaces allow the option of passing solver specific options to +the solvers. These options are covered in more detail in the `API docs `_, but some features specific to ``odes`` are mentioned below. + +SUNDIALS +======== + +There are a number of SUNDIALS specific utilities in :py:mod:`scikits.odes.sundials`. +Firstly there are :py:func:`scikits.odes.sundials.ontstop_stop`, :py:func:`scikits.odes.sundials.ontstop_continue`, :py:func:`scikits.odes.sundials.onroot_stop` and :py:func:`scikits.odes.sundials.onroot_continue`, which can be used with the `ontstop` or `onroot` options to either stop or continue evaluation when tstop or a root is encountered. +Secondly, there are functions which can be passed to the `err_handler` option to either stop all messages from SUNDIALS being printed (:py:func:`scikits.odes.sundials.drop_all_error_handler`), or to pass them to Python's logging machinery (:py:func:`scikits.odes.sundials.log_error_handler`). +Finally, the module contains the exceptions which can be caught in user code raised when using the `validate_flags` option. diff --git a/scikits/odes/sundials/__init__.py b/scikits/odes/sundials/__init__.py index fb1f36d3..fbd8744f 100644 --- a/scikits/odes/sundials/__init__.py +++ b/scikits/odes/sundials/__init__.py @@ -3,6 +3,11 @@ # import inspect +from logging import getLogger + +logger = getLogger(__name__) +DEFAULT_LOG_FORMAT = "SUNDIALS message in %s:%s: %s" + class CVODESolveException(Exception): """Base class for exceptions raised by ``CVODE.validate_flags``.""" @@ -64,3 +69,86 @@ def _get_num_args(func): return numargs else: return len(inspect.getargspec(func).args) + + +def drop_all_error_handler(error_code, module, func, msg, user_data): + """ + Drop all CVODE/IDA messages, rather than printing them. + + Examples + -------- + >>> scikits.odes.ode('cvode', rhsfuc, err_handler=drop_all_error_handler) + """ + # pylint: disable=unused-argument + pass + + +def log_error_handler(error_code, module, func, msg, user_data): + """ + Log all CVODE/IDA messages using the builtin python logging. + + Examples + -------- + >>> scikits.odes.ode('cvode', rhsfuc, err_handler=log_error_handler) + """ + # pylint: disable=unused-argument + if error_code > 0: + logger.warning(DEFAULT_LOG_FORMAT, module, func, msg) + else: + logger.error(DEFAULT_LOG_FORMAT, module, func, msg) + + +def onroot_continue(*args): + """ + Always continue after finding root. + + Examples + -------- + >>> scikits.odes.ode( + ... 'cvode', rhsfuc, rootfn=rootfn, nr_rootfns=nroots, + ... onroot=onroot_continue + ... ) + """ + # pylint: disable=unused-argument + return 0 + + +def onroot_stop(*args): + """ + Always stop after finding root. + + Examples + -------- + >>> scikits.odes.ode( + ... 'cvode', rhsfuc, rootfn=rootfn, nr_rootfns=nroots, + ... onroot=onroot_stop + ... ) + """ + # pylint: disable=unused-argument + return 1 + + +def ontstop_continue(*args): + """ + Always continue after finding tstop. + + Examples + -------- + >>> scikits.odes.ode( + ... 'cvode', rhsfuc, tstop=tstop, ontstop=ontstop_continue + ... ) + """ + # pylint: disable=unused-argument + return 0 + + +def ontstop_stop(*args): + """ + Always stop after finding tstop. + + Examples + -------- + >>> scikits.odes.ode('cvode', rhsfuc, tstop=tstop, ontstop=ontstop_stop) + """ + # pylint: disable=unused-argument + return 1 From 63fc4b3675f5efa88c2061a710c46d2d4794e905 Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Sun, 26 May 2019 12:57:39 +1000 Subject: [PATCH 2/3] TST: Use on*_(stop,continue) and logging tools in tests --- MANIFEST.in | 2 +- pytest.ini | 5 ++ scikits/odes/tests/test_dae.py | 4 +- scikits/odes/tests/test_dop.py | 4 +- scikits/odes/tests/test_get_info.py | 12 +++-- scikits/odes/tests/test_odeint.py | 5 +- scikits/odes/tests/test_on_funcs.py | 52 ++++++++---------- scikits/odes/tests/test_on_funcs_ida.py | 52 ++++++++---------- .../odes/tests/test_user_return_vals_cvode.py | 54 ++++++++++--------- .../odes/tests/test_user_return_vals_ida.py | 42 ++++++++------- tox.ini | 1 - 11 files changed, 122 insertions(+), 111 deletions(-) create mode 100644 pytest.ini diff --git a/MANIFEST.in b/MANIFEST.in index f2920b70..ff9df26b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,7 @@ include common.py setup_build.py recursive-include scikits *.pyx *.pxd *.pyf *.f include CONTRIBUTING.md README.md LICENSE.txt -include tox.ini +include tox.ini pytest.ini include pyproject.toml prune ci_support diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..3474cd54 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +log_cli = 1 +log_cli_level = INFO +log_cli_format = %(asctime)s %(levelname)s %(message)s +log_cli_date_format = %H:%M:%S diff --git a/scikits/odes/tests/test_dae.py b/scikits/odes/tests/test_dae.py index fb02ed35..2640ffaf 100644 --- a/scikits/odes/tests/test_dae.py +++ b/scikits/odes/tests/test_dae.py @@ -11,8 +11,8 @@ from numpy.testing import TestCase, run_module_suite from scipy.integrate import ode as Iode -from scikits.odes import ode,dae -from scikits.odes.sundials.common_defs import DTYPE +from .. import ode, dae +from ..sundials.common_defs import DTYPE class TestDae(TestCase): """ diff --git a/scikits/odes/tests/test_dop.py b/scikits/odes/tests/test_dop.py index 1dcb0fc8..0a0767d9 100644 --- a/scikits/odes/tests/test_dop.py +++ b/scikits/odes/tests/test_dop.py @@ -12,8 +12,8 @@ from numpy.testing import ( assert_, TestCase, run_module_suite, assert_array_almost_equal, assert_raises, assert_allclose, assert_array_equal, assert_equal) -from scikits.odes import ode -from scikits.odes.dopri5 import StatusEnumDOP +from .. import ode +from ..dopri5 import StatusEnumDOP class SimpleOscillator(): diff --git a/scikits/odes/tests/test_get_info.py b/scikits/odes/tests/test_get_info.py index 933de91d..bd926bdc 100644 --- a/scikits/odes/tests/test_get_info.py +++ b/scikits/odes/tests/test_get_info.py @@ -1,7 +1,13 @@ from __future__ import print_function import numpy as np import unittest -from scikits.odes import ode +from .. import ode +from ..sundials import log_error_handler + +COMMON_ARGS = { + "old_api": False, + "err_handler": log_error_handler +} xs = np.linspace(1, 10, 10) @@ -20,7 +26,7 @@ def rhs(x, y, ydot): class GetInfoTest(unittest.TestCase): def setUp(self): - self.ode = ode('cvode', rhs, old_api=False) + self.ode = ode('cvode', rhs, **COMMON_ARGS) self.solution = self.ode.solve(xs, np.array([1])) def test_we_integrated_correctly(self): @@ -47,7 +53,7 @@ def test_ode_exposes_num_rhs_evals(self): class GetInfoTestSpils(unittest.TestCase): def setUp(self): - self.ode = ode('cvode', rhs, linsolver="spgmr", old_api=False) + self.ode = ode('cvode', rhs, linsolver="spgmr", **COMMON_ARGS) self.solution = self.ode.solve(xs, np.array([1])) def test_ode_exposes_num_njtimes_evals(self): diff --git a/scikits/odes/tests/test_odeint.py b/scikits/odes/tests/test_odeint.py index 04174f61..c6202728 100644 --- a/scikits/odes/tests/test_odeint.py +++ b/scikits/odes/tests/test_odeint.py @@ -19,8 +19,9 @@ assert_, TestCase, run_module_suite, assert_array_almost_equal, assert_raises, assert_allclose, assert_array_equal, assert_equal) -from scikits.odes.odeint import odeint -from scikits.odes.sundials.common_defs import DTYPE +from ..odeint import odeint +from ..sundials import log_error_handler +from ..sundials.common_defs import DTYPE TEST_LAPACK = DTYPE == np.double diff --git a/scikits/odes/tests/test_on_funcs.py b/scikits/odes/tests/test_on_funcs.py index 326db18f..2eaa4a46 100644 --- a/scikits/odes/tests/test_on_funcs.py +++ b/scikits/odes/tests/test_on_funcs.py @@ -12,9 +12,15 @@ from numpy.testing import TestCase, run_module_suite -from scikits.odes import ode -from scikits.odes.sundials.cvode import StatusEnum -from scikits.odes.sundials.common_defs import DTYPE +from .. import ode +from ..sundials.cvode import StatusEnum +from ..sundials.common_defs import DTYPE +from ..sundials import log_error_handler, ontstop_stop, onroot_stop + +COMMON_ARGS = { + "old_api": False, + "err_handler": log_error_handler +} #data g = 9.81 # gravitational constant @@ -63,12 +69,6 @@ def onroot_va(t, y, solver): return 0 -def onroot_vb(t, y, solver): - """ - onroot function to stop solver when root is found - """ - return 1 - def onroot_vc(t, y, solver): """ onroot function to reset the solver back at the start, but keep the current @@ -103,12 +103,6 @@ def ontstop_va(t, y, solver): return 0 -def ontstop_vb(t, y, solver): - """ - ontstop function to stop solver when tstop is reached - """ - return 1 - def ontstop_vc(t, y, solver): """ ontstop function to reset the solver back at the start, but keep the current @@ -132,7 +126,7 @@ def test_cvode_rootfn_noroot(self): #test calling sequence. End is reached before root is found tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -143,7 +137,7 @@ def test_cvode_rootfn(self): #test root finding and stopping: End is reached at a root tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!" assert allclose([soln.roots.t[0], soln.roots.y[0,0], soln.roots.y[0,1]], @@ -155,7 +149,7 @@ def test_cvode_rootfnacc(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -170,8 +164,8 @@ def test_cvode_rootfn_stop(self): #test root finding and stopping: End is reached at a root with a function tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, - onroot=onroot_vb, - old_api=False) + onroot=onroot_stop, + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!" assert allclose([soln.roots.t[-1], soln.roots.y[-1,0], soln.roots.y[-1,1]], @@ -183,7 +177,7 @@ def test_cvode_rootfn_test(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -199,7 +193,7 @@ def test_cvode_rootfn_two(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=2, rootfn=root_fn2, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -215,7 +209,7 @@ def test_cvode_rootfn_end(self): tspan = np.arange(0, 30 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn3, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -232,7 +226,7 @@ def test_cvode_tstopfn_notstop(self): n = 0 tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, tstop=T1+1, ontstop=ontstop_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred" @@ -246,7 +240,7 @@ def test_cvode_tstopfn(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, tstop=T1, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Tstop not found!" assert allclose([soln.tstop.t[0], soln.tstop.y[0,0], soln.tstop.y[0,1]], @@ -262,7 +256,7 @@ def test_cvode_tstopfnacc(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -278,8 +272,8 @@ def test_cvode_tstopfn_stop(self): global n n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) - solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vb, - old_api=False) + solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_stop, + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Error occurred" @@ -299,7 +293,7 @@ def test_cvode_tstopfn_test(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0) assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Error occurred" diff --git a/scikits/odes/tests/test_on_funcs_ida.py b/scikits/odes/tests/test_on_funcs_ida.py index 7388428a..a9aee741 100644 --- a/scikits/odes/tests/test_on_funcs_ida.py +++ b/scikits/odes/tests/test_on_funcs_ida.py @@ -12,9 +12,15 @@ from numpy.testing import TestCase, run_module_suite -from scikits.odes import dae -from scikits.odes.sundials.ida import StatusEnumIDA -from scikits.odes.sundials.common_defs import DTYPE +from .. import dae +from ..sundials import log_error_handler, ontstop_stop, onroot_stop +from ..sundials.ida import StatusEnumIDA +from ..sundials.common_defs import DTYPE + +COMMON_ARGS = { + "old_api": False, + "err_handler": log_error_handler +} #data g = 9.81 # gravitational constant @@ -65,12 +71,6 @@ def onroot_va(t, y, ydot, solver): return 0 -def onroot_vb(t, y, ydot, solver): - """ - onroot function to stop solver when root is found - """ - return 1 - def onroot_vc(t, y, ydot, solver): """ onroot function to reset the solver back at the start, but keep the current @@ -105,12 +105,6 @@ def ontstop_va(t, y, ydot, solver): return 0 -def ontstop_vb(t, y, ydot, solver): - """ - ontstop function to stop solver when tstop is reached - """ - return 1 - def ontstop_vc(t, y, ydot, solver): """ ontstop function to reset the solver back at the start, but keep the current @@ -134,7 +128,7 @@ def test_ida_rootfn_noroot(self): #test calling sequence. End is reached before root is found tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -145,7 +139,7 @@ def test_ida_rootfn(self): #test root finding and stopping: End is reached at a root tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!" assert allclose([soln.roots.t[0], soln.roots.y[0,0], soln.roots.y[0,1]], @@ -157,7 +151,7 @@ def test_ida_rootfnacc(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -172,8 +166,8 @@ def test_ida_rootfn_stop(self): #test root finding and stopping: End is reached at a root with a function tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn, - onroot=onroot_vb, - old_api=False) + onroot=onroot_stop, + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!" assert allclose([soln.roots.t[-1], soln.roots.y[-1,0], soln.roots.y[-1,1]], @@ -185,7 +179,7 @@ def test_ida_rootfn_test(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -201,7 +195,7 @@ def test_ida_rootfn_two(self): tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=2, rootfn=root_fn2, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -217,7 +211,7 @@ def test_ida_rootfn_end(self): tspan = np.arange(0, 30 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn3, onroot=onroot_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Not sufficient root found" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -234,7 +228,7 @@ def test_ida_tstopfn_notstop(self): n = 0 tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, tstop=T1+1, ontstop=ontstop_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -247,7 +241,7 @@ def test_ida_tstopfn(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, tstop=T1, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Tstop not found!" assert allclose([soln.tstop.t[0], soln.tstop.y[0,0], soln.tstop.y[0,1]], @@ -263,7 +257,7 @@ def test_ida_tstopfnacc(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_va, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred" assert allclose([soln.values.t[-1], soln.values.y[-1,0], soln.values.y[-1,1]], @@ -279,8 +273,8 @@ def test_ida_tstopfn_stop(self): global n n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) - solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vb, - old_api=False) + solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_stop, + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Error occurred" @@ -300,7 +294,7 @@ def test_ida_tstopfn_test(self): n = 0 tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE) solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vc, - old_api=False) + **COMMON_ARGS) soln = solver.solve(tspan, y0, yp0) assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Error occurred" diff --git a/scikits/odes/tests/test_user_return_vals_cvode.py b/scikits/odes/tests/test_user_return_vals_cvode.py index 8af4b0e9..afe44845 100644 --- a/scikits/odes/tests/test_user_return_vals_cvode.py +++ b/scikits/odes/tests/test_user_return_vals_cvode.py @@ -2,8 +2,14 @@ from numpy.testing import TestCase, run_module_suite from .. import ode +from ..sundials import log_error_handler from ..sundials.cvode import StatusEnum +COMMON_ARGS = { + "old_api": False, + "err_handler": log_error_handler +} + def normal_rhs(t, y, ydot): ydot[0] = t @@ -102,7 +108,7 @@ def jac_vec_error_immediate(v, Jv, t, y): class TestCVodeReturn(TestCase): def test_normal_rhs(self): - solver = ode("cvode", normal_rhs, old_api=False) + solver = ode("cvode", normal_rhs, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -110,7 +116,7 @@ def test_normal_rhs(self): ) def test_rhs_with_return(self): - solver = ode("cvode", rhs_with_return, old_api=False) + solver = ode("cvode", rhs_with_return, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -118,7 +124,7 @@ def test_rhs_with_return(self): ) def test_rhs_problem_late(self): - solver = ode("cvode", rhs_problem_late, old_api=False) + solver = ode("cvode", rhs_problem_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.TOO_MUCH_WORK, @@ -126,7 +132,7 @@ def test_rhs_problem_late(self): ) def test_rhs_problem_immediate(self): - solver = ode("cvode", rhs_problem_immediate, old_api=False) + solver = ode("cvode", rhs_problem_immediate, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.FIRST_RHSFUNC_ERR, @@ -134,7 +140,7 @@ def test_rhs_problem_immediate(self): ) def test_rhs_error_late(self): - solver = ode("cvode", rhs_error_late, old_api=False) + solver = ode("cvode", rhs_error_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.RHSFUNC_FAIL, @@ -142,7 +148,7 @@ def test_rhs_error_late(self): ) def test_rhs_error_immediate(self): - solver = ode("cvode", rhs_error_immediate, old_api=False) + solver = ode("cvode", rhs_error_immediate, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.RHSFUNC_FAIL, @@ -151,7 +157,7 @@ def test_rhs_error_immediate(self): def test_normal_root(self): solver = ode("cvode", normal_rhs, rootfn=normal_root, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -160,7 +166,7 @@ def test_normal_root(self): def test_root_with_return(self): solver = ode("cvode", normal_rhs, rootfn=root_with_return, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -169,7 +175,7 @@ def test_root_with_return(self): def test_root_late(self): solver = ode("cvode", normal_rhs, rootfn=root_late, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.ROOT_RETURN, @@ -178,7 +184,7 @@ def test_root_late(self): def test_root_immediate(self): solver = ode("cvode", normal_rhs, rootfn=root_immediate, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -187,7 +193,7 @@ def test_root_immediate(self): def test_root_error_late(self): solver = ode("cvode", normal_rhs, rootfn=root_error_late, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.RTFUNC_FAIL, @@ -196,7 +202,7 @@ def test_root_error_late(self): def test_root_error_immediate(self): solver = ode("cvode", normal_rhs, rootfn=root_error_immediate, - nr_rootfns=1, old_api=False) + nr_rootfns=1, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.RTFUNC_FAIL, @@ -204,7 +210,7 @@ def test_root_error_immediate(self): ) def test_normal_jac(self): - solver = ode("cvode", normal_rhs, jacfn=normal_jac, old_api=False) + solver = ode("cvode", normal_rhs, jacfn=normal_jac, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -212,7 +218,7 @@ def test_normal_jac(self): ) def test_jac_with_return(self): - solver = ode("cvode", normal_rhs, jacfn=jac_with_return, old_api=False) + solver = ode("cvode", normal_rhs, jacfn=jac_with_return, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -220,7 +226,7 @@ def test_jac_with_return(self): ) def test_jac_problem_late(self): - solver = ode("cvode", complex_rhs, jacfn=jac_problem_late, old_api=False) + solver = ode("cvode", complex_rhs, jacfn=jac_problem_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.CONV_FAILURE, @@ -229,7 +235,7 @@ def test_jac_problem_late(self): def test_jac_problem_immediate(self): solver = ode("cvode", normal_rhs, jacfn=jac_problem_immediate, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.CONV_FAILURE, @@ -237,7 +243,7 @@ def test_jac_problem_immediate(self): ) def test_jac_error_late(self): - solver = ode("cvode", complex_rhs, jacfn=jac_error_late, old_api=False) + solver = ode("cvode", complex_rhs, jacfn=jac_error_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.LSETUP_FAIL, @@ -246,7 +252,7 @@ def test_jac_error_late(self): def test_jac_error_immediate(self): solver = ode("cvode", normal_rhs, jacfn=jac_error_immediate, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.LSETUP_FAIL, @@ -255,7 +261,7 @@ def test_jac_error_immediate(self): def test_normal_jac_vec(self): - solver = ode("cvode", normal_rhs, jac_times_vecfn=normal_jac_vec, old_api=False) + solver = ode("cvode", normal_rhs, jac_times_vecfn=normal_jac_vec, **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -263,7 +269,7 @@ def test_normal_jac_vec(self): ) def test_jac_vec_with_return(self): - solver = ode("cvode", normal_rhs, jac_times_vecfn=jac_vec_with_return, linsolver="spgmr", old_api=False) + solver = ode("cvode", normal_rhs, jac_times_vecfn=jac_vec_with_return, linsolver="spgmr", **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.SUCCESS, @@ -271,7 +277,7 @@ def test_jac_vec_with_return(self): ) def test_jac_vec_problem_late(self): - solver = ode("cvode", complex_rhs, jac_times_vecfn=jac_vec_problem_late, linsolver="spgmr", old_api=False) + solver = ode("cvode", complex_rhs, jac_times_vecfn=jac_vec_problem_late, linsolver="spgmr", **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.TOO_MUCH_WORK, @@ -281,7 +287,7 @@ def test_jac_vec_problem_late(self): def test_jac_vec_problem_immediate(self): solver = ode("cvode", normal_rhs, jac_times_vecfn=jac_vec_problem_immediate, linsolver="spgmr", - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.TOO_MUCH_WORK, @@ -290,7 +296,7 @@ def test_jac_vec_problem_immediate(self): def test_jac_vec_error_late(self): solver = ode("cvode", complex_rhs, jac_times_vecfn=jac_vec_error_late, - linsolver="spgmr", old_api=False) + linsolver="spgmr", **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.LSOLVE_FAIL, @@ -299,7 +305,7 @@ def test_jac_vec_error_late(self): def test_jac_vec_error_immediate(self): solver = ode("cvode", normal_rhs, jac_times_vecfn=jac_vec_error_immediate, linsolver="spgmr", - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1]) self.assertEqual( StatusEnum.LSOLVE_FAIL, diff --git a/scikits/odes/tests/test_user_return_vals_ida.py b/scikits/odes/tests/test_user_return_vals_ida.py index 3fc9fd85..a51343be 100644 --- a/scikits/odes/tests/test_user_return_vals_ida.py +++ b/scikits/odes/tests/test_user_return_vals_ida.py @@ -2,8 +2,14 @@ from numpy.testing import TestCase, run_module_suite from .. import dae +from ..sundials import log_error_handler from ..sundials.ida import StatusEnumIDA +COMMON_ARGS = { + "old_api": False, + "err_handler": log_error_handler +} + def normal_rhs(t, y, ydot, res): res[0] = ydot - t @@ -78,7 +84,7 @@ def jac_error_immediate(t, y, ydot, residual, cj, J): class TestIdaReturn(TestCase): def test_normal_rhs(self): - solver = dae("ida", normal_rhs, old_api=False) + solver = dae("ida", normal_rhs, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -86,7 +92,7 @@ def test_normal_rhs(self): ) def test_rhs_with_return(self): - solver = dae("ida", rhs_with_return, old_api=False) + solver = dae("ida", rhs_with_return, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -94,7 +100,7 @@ def test_rhs_with_return(self): ) def test_rhs_problem_late(self): - solver = dae("ida", rhs_problem_late, old_api=False) + solver = dae("ida", rhs_problem_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.TOO_MUCH_WORK, @@ -102,7 +108,7 @@ def test_rhs_problem_late(self): ) def test_rhs_problem_immediate(self): - solver = dae("ida", rhs_problem_immediate, old_api=False) + solver = dae("ida", rhs_problem_immediate, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.REP_RES_ERR, @@ -110,7 +116,7 @@ def test_rhs_problem_immediate(self): ) def test_rhs_error_late(self): - solver = dae("ida", rhs_error_late, old_api=False) + solver = dae("ida", rhs_error_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.RES_FAIL, @@ -118,7 +124,7 @@ def test_rhs_error_late(self): ) def test_rhs_error_immediate(self): - solver = dae("ida", rhs_error_immediate, old_api=False) + solver = dae("ida", rhs_error_immediate, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.RES_FAIL, @@ -127,7 +133,7 @@ def test_rhs_error_immediate(self): def test_normal_root(self): solver = dae("ida", normal_rhs, rootfn=normal_root, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -136,7 +142,7 @@ def test_normal_root(self): def test_root_with_return(self): solver = dae("ida", normal_rhs, rootfn=root_with_return, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -145,7 +151,7 @@ def test_root_with_return(self): def test_root_late(self): solver = dae("ida", normal_rhs, rootfn=root_late, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.ROOT_RETURN, @@ -154,7 +160,7 @@ def test_root_late(self): def test_root_immediate(self): solver = dae("ida", normal_rhs, rootfn=root_immediate, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -163,7 +169,7 @@ def test_root_immediate(self): def test_root_error_late(self): solver = dae("ida", normal_rhs, rootfn=root_error_late, nr_rootfns=1, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.RTFUNC_FAIL, @@ -172,7 +178,7 @@ def test_root_error_late(self): def test_root_error_immediate(self): solver = dae("ida", normal_rhs, rootfn=root_error_immediate, - nr_rootfns=1, old_api=False) + nr_rootfns=1, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.RTFUNC_FAIL, @@ -180,7 +186,7 @@ def test_root_error_immediate(self): ) def test_normal_jac(self): - solver = dae("ida", normal_rhs, jacfn=normal_jac, old_api=False) + solver = dae("ida", normal_rhs, jacfn=normal_jac, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -188,7 +194,7 @@ def test_normal_jac(self): ) def test_jac_with_return(self): - solver = dae("ida", normal_rhs, jacfn=jac_with_return, old_api=False) + solver = dae("ida", normal_rhs, jacfn=jac_with_return, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.SUCCESS, @@ -196,7 +202,7 @@ def test_jac_with_return(self): ) def test_jac_problem_late(self): - solver = dae("ida", complex_rhs, jacfn=jac_problem_late, old_api=False) + solver = dae("ida", complex_rhs, jacfn=jac_problem_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.CONV_FAIL, @@ -205,7 +211,7 @@ def test_jac_problem_late(self): def test_jac_problem_immediate(self): solver = dae("ida", normal_rhs, jacfn=jac_problem_immediate, - old_api=False) + **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.CONV_FAIL, @@ -213,7 +219,7 @@ def test_jac_problem_immediate(self): ) def test_jac_error_late(self): - solver = dae("ida", complex_rhs, jacfn=jac_error_late, old_api=False) + solver = dae("ida", complex_rhs, jacfn=jac_error_late, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.LSETUP_FAIL, @@ -221,7 +227,7 @@ def test_jac_error_late(self): ) def test_jac_error_immediate(self): - solver = dae("ida", normal_rhs, jacfn=jac_error_immediate, old_api=False) + solver = dae("ida", normal_rhs, jacfn=jac_error_immediate, **COMMON_ARGS) soln = solver.solve([0, 1], [1], [0]) self.assertEqual( StatusEnumIDA.LSETUP_FAIL, diff --git a/tox.ini b/tox.ini index 9b3efa9b..e807c01a 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,6 @@ passenv= deps = numpy cython - nose pytest wheel commands = From 3e2855cecc4dc0487fa3aa1bd7383869c2ec33b6 Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Sun, 26 May 2019 18:21:09 +1000 Subject: [PATCH 3/3] DOC: Use on(root,tstop)_stop in example --- docs/examples/ode/freefall.py | 13 ++++--------- docs/examples/ode/freefall_tstop.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/docs/examples/ode/freefall.py b/docs/examples/ode/freefall.py index 8746b49f..373a911b 100644 --- a/docs/examples/ode/freefall.py +++ b/docs/examples/ode/freefall.py @@ -11,6 +11,7 @@ from __future__ import print_function import numpy as np from scikits.odes import ode +from scikits.odes.sundials import onroot_stop #data g = 9.81 # gravitational constant @@ -31,7 +32,7 @@ # On the other hand experiments 1 and 3 don't use the 'onroot', # experiments 2 and 4 do and compute until the time t_end is reached # (function onroot_va()). Experiment 5 stops after the first interruption -# (function onroot_vb()) occurs, whereas experiment 6 stops after the +# (function onroot_stop()) occurs, whereas experiment 6 stops after the # first interruption at time t>28 (s) (function onroot_vc()). # Otherwise all experiments are the same. @@ -66,12 +67,6 @@ def onroot_va(t, y, solver): return 0 -def onroot_vb(t, y, solver): - """ - onroot function to stop solver when root is found - """ - return 1 - def onroot_vc(t, y, solver): """ onroot function to reset the solver back at the start, but keep the current @@ -157,10 +152,10 @@ def print_results(experiment_no, result, require_no_roots=False): print_results(4, solver.solve(tspan, y0)) # -# 5. Solve the problem with onroot function onroot_vb, which behaves similarly +# 5. Solve the problem with onroot function onroot_stop, which behaves similarly # to the default, which is to compute until a root is found. # -solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_vb, old_api=False) +solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn, onroot=onroot_stop, old_api=False) print_results(5, solver.solve(tspan, y0)) # diff --git a/docs/examples/ode/freefall_tstop.py b/docs/examples/ode/freefall_tstop.py index 4066ab48..13551b4a 100644 --- a/docs/examples/ode/freefall_tstop.py +++ b/docs/examples/ode/freefall_tstop.py @@ -11,6 +11,7 @@ from __future__ import print_function import numpy as np from scikits.odes import ode +from scikits.odes.sundials import ontstop_stop #data g = 9.81 # gravitational constant @@ -32,7 +33,7 @@ # On the other hand experiments 1 and 3 don't use the 'ontstop', # experiments 2 and 4 do and compute until the time t_end is reached # (function ontstop_va()). Experiment 5 stops after the first interruption -# (function ontstop_vb()) occurs, whereas experiment 6 stops after the +# (function ontstop_stop()) occurs, whereas experiment 6 stops after the # first interruption at time t>28 (s) (function ontstop_vc()). # Otherwise all experiments are the same. @@ -54,12 +55,6 @@ def ontstop_va(t, y, solver): return 0 -def ontstop_vb(t, y, solver): - """ - ontstop function to stop solver when tstop is found - """ - return 1 - def ontstop_vc(t, y, solver): """ ontstop function to reset the solver back at the start, but keep the current @@ -155,11 +150,11 @@ def print_results(experiment_no, result, require_no_tstop=False): print_results(4, solver.solve(tspan, y0)) # -# 5. Solve the problem with ontstop function ontstop_vb, which behaves similarly +# 5. Solve the problem with ontstop function ontstop_stop, which behaves similarly # to the default, which is to compute until a root is found. # n = 0 -solver = ode('cvode', rhs_fn, tstop=Y1, ontstop=ontstop_vb, old_api=False) +solver = ode('cvode', rhs_fn, tstop=Y1, ontstop=ontstop_stop, old_api=False) print_results(5, solver.solve(tspan, y0)) #