Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 12 additions & 29 deletions erfa_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,29 +354,16 @@ def __init__(self, func: Function, t_erfa_c: str) -> None:
# Get lines that test the given erfa function: capture everything
# between a line starting with '{' after the test function definition
# and the first line starting with '}' or ' }'.
pattern = rf"^static void t_{func.pyname}\(" + r".+?(^\{.+?^\s?\})"
search = re.search(pattern, t_erfa_c, flags=re.DOTALL | re.MULTILINE)
self.lines = search.group(1).split('\n')
self.dt_pv_vars: Final = frozenset(
re.findall(r"(\w+)\[2\]\[3\]", search.group(1))
search = re.search(
rf"^static void t_{func.pyname}\(" + r".+?^\{(.+?)^\s?\}",
t_erfa_c,
re.DOTALL | re.MULTILINE,
)

def pre_process_lines(self):
"""Basic pre-processing.

Combine multi-part lines, strip braces, semi-colons, empty lines.
"""
lines = []
line = ''
for part in self.lines:
part = part.strip()
if part in ('', '{', '}'):
continue
line += part + ' '
if part.endswith(';'):
lines.append(line.strip()[:-1])
line = ''
return lines
if search is None:
raise RuntimeError(f"cannot find the test for {func.name}")
source = re.sub(r"\s\s+", " ", search.group(1))
self.lines: Final = re.findall(r"\s(.*?);", source, re.DOTALL)
self.dt_pv_vars: Final = frozenset(re.findall(r"(\w+)\[2\]\[3\]", source))

def define_arrays(self, line):
"""Check variable definition line for items also needed in python.
Expand Down Expand Up @@ -417,10 +404,8 @@ def to_python(self):
# TODO: this is quite hacky right now! Would be good to let function
# calls be understood by the Function class.

# Collect actual code lines, without ";", braces, etc.
lines = self.pre_process_lines()
out = []
for line in lines:
for line in self.lines:
# In ldn ufunc, the number of bodies is inferred from the array size,
# so no need to keep the definition.
if line == "n = 3" and self.func.pyname == "ldn":
Expand Down Expand Up @@ -453,11 +438,11 @@ def to_python(self):
.replace("s, '+'", "s[0], b'+'") # Rather hacky...
.strip())

if m := re.match(r"viv ?\( ?([\w\[\]]+), +(.+?),", line):
if m := re.match(r"viv ?\( ?([\w\[\]]+), (.+?),", line):
line = f"assert {m.group(1)} == {m.group(2)}"

elif m := re.match(
r"vvd\( ?(.+) ?, +([\d\.e-]+), *([\d\.e-]+), .+?, .+?, +status\)", line
r"vvd\( ?(.+) ?, ([\d\.e-]+), ?([\d\.e-]+), .+?, .+?, status\)", line
):
expr = m.group(1).replace(
self.func.name, f"erfa_ufunc.{self.func.pyname}"
Expand Down Expand Up @@ -509,8 +494,6 @@ def to_python(self):

# Input number setting.
elif '=' in line:
# Small clean-up.
line = line.replace('= ', '= ')
# Hack to make astrom element assignment work.
if line.startswith('astrom'):
out.append('astrom = np.zeros((), erfa_ufunc.dt_eraASTROM).view(np.recarray)')
Expand Down
Loading