diff --git a/erfa_generator.py b/erfa_generator.py index 94bc80a..3c1315d 100644 --- a/erfa_generator.py +++ b/erfa_generator.py @@ -354,29 +354,16 @@ def __init__(self, func: Function, t_erfa_c: str) -> None: # Get lines that test the given erfa function: capture everything # between a line starting with '{' after the test function definition # and the first line starting with '}' or ' }'. - pattern = rf"^static void t_{func.pyname}\(" + r".+?(^\{.+?^\s?\})" - search = re.search(pattern, t_erfa_c, flags=re.DOTALL | re.MULTILINE) - self.lines = search.group(1).split('\n') - self.dt_pv_vars: Final = frozenset( - re.findall(r"(\w+)\[2\]\[3\]", search.group(1)) + search = re.search( + rf"^static void t_{func.pyname}\(" + r".+?^\{(.+?)^\s?\}", + t_erfa_c, + re.DOTALL | re.MULTILINE, ) - - def pre_process_lines(self): - """Basic pre-processing. - - Combine multi-part lines, strip braces, semi-colons, empty lines. - """ - lines = [] - line = '' - for part in self.lines: - part = part.strip() - if part in ('', '{', '}'): - continue - line += part + ' ' - if part.endswith(';'): - lines.append(line.strip()[:-1]) - line = '' - return lines + if search is None: + raise RuntimeError(f"cannot find the test for {func.name}") + source = re.sub(r"\s\s+", " ", search.group(1)) + self.lines: Final = re.findall(r"\s(.*?);", source, re.DOTALL) + self.dt_pv_vars: Final = frozenset(re.findall(r"(\w+)\[2\]\[3\]", source)) def define_arrays(self, line): """Check variable definition line for items also needed in python. @@ -417,10 +404,8 @@ def to_python(self): # TODO: this is quite hacky right now! Would be good to let function # calls be understood by the Function class. - # Collect actual code lines, without ";", braces, etc. - lines = self.pre_process_lines() out = [] - for line in lines: + for line in self.lines: # In ldn ufunc, the number of bodies is inferred from the array size, # so no need to keep the definition. if line == "n = 3" and self.func.pyname == "ldn": @@ -453,11 +438,11 @@ def to_python(self): .replace("s, '+'", "s[0], b'+'") # Rather hacky... .strip()) - if m := re.match(r"viv ?\( ?([\w\[\]]+), +(.+?),", line): + if m := re.match(r"viv ?\( ?([\w\[\]]+), (.+?),", line): line = f"assert {m.group(1)} == {m.group(2)}" elif m := re.match( - r"vvd\( ?(.+) ?, +([\d\.e-]+), *([\d\.e-]+), .+?, .+?, +status\)", line + r"vvd\( ?(.+) ?, ([\d\.e-]+), ?([\d\.e-]+), .+?, .+?, status\)", line ): expr = m.group(1).replace( self.func.name, f"erfa_ufunc.{self.func.pyname}" @@ -509,8 +494,6 @@ def to_python(self): # Input number setting. elif '=' in line: - # Small clean-up. - line = line.replace('= ', '= ') # Hack to make astrom element assignment work. if line.startswith('astrom'): out.append('astrom = np.zeros((), erfa_ufunc.dt_eraASTROM).view(np.recarray)')