diff --git a/erfa/tests/test_ufunc.py.templ b/erfa/tests/test_ufunc.py.templ index d38e0bf..95c8c78 100644 --- a/erfa/tests/test_ufunc.py.templ +++ b/erfa/tests/test_ufunc.py.templ @@ -11,8 +11,6 @@ Basic tests of the ERFA library routines. These are just the tests bundled with ERFA itself, in ``t_erfa_c.c``, but translated to python to make sure that the code returns the same numbers. -The viv and vid functions emulate the corresponding functions in ``t_erfa_c.c``. - The tests are skipped if a system library is used that does not match in version number, since in that case the precise numbers may have changed (e.g., due to the earth orientation changes between 1.7.2 and 1.7.3). @@ -30,20 +28,6 @@ if not erfa.__version__.startswith(erfa.version.erfa_version): allow_module_level=True) -def viv(ival, ivalok, func, test, _): - """Validate an integer result.""" - assert ival == ivalok, f"{func} failed: {test} want {ivalok} got {ival}" - - -def vvd(val, valok, dval, func, test, _): - """Validate a double result.""" - a = val - valok - assert a == 0.0 or abs(a) <= abs(dval), ( - f"{func} failed: {test} want {valok:.20g} got {val:.20g} " - f"(1/{abs(valok / a):.3g})") - - -status = np.zeros((), dtype=int) # <--------------------------Actual test-wrapping code------------------------> {%- for test in test_funcs %} @@ -58,8 +42,4 @@ def test_{{ test.name }}(): {%- endfor %} - - -def test_status(): - assert status == 0, "Sanity check failed!" {# done! (note: this comment also ensures final new line!) #} diff --git a/erfa_generator.py b/erfa_generator.py index e2272bd..bad8386 100644 --- a/erfa_generator.py +++ b/erfa_generator.py @@ -478,12 +478,14 @@ def to_python(self): .replace("s, '+'", "s[0], b'+'") # Rather hacky... .strip()) - # Call of test function vvi or vvd. - if line.startswith('v'): - line = line.replace(era_name, self.name) - # Can call simple functions directly. Those need little modification. - if self.name + '(' in line: - line = line.replace(self.name + '(', f"erfa_ufunc.{self.name}(") + if m := re.match(r"viv ?\( ?([\w\[\]]+), +(.+?),", line): + line = f"assert {m.group(1)} == {m.group(2)}" + + elif m := re.match( + r"vvd\( ?(.+) ?, +([\d\.e-]+), *([\d\.e-]+), .+?, .+?, +status\)", line + ): + expr = m.group(1).replace(era_name, f"erfa_ufunc.{self.name}") + line = f"assert {expr} == pytest.approx({m.group(2)}, abs={m.group(3)})" # Call of function that is being tested. elif era_name in line: