Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions erfa_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ def __init__(self, name, t_erfa_c, nin, ninout, nout):
self.nin = nin
self.ninout = ninout
self.nout = nout
# Dict of dtypes for variables, filled by define_arrays().
self.var_dtypes = {}
self.dt_pv_vars: Final = frozenset(
re.findall(r"(\w+)\[2\]\[3\]", search.group(1))
)

@classmethod
def from_function(cls, func, t_erfa_c):
Expand Down Expand Up @@ -429,7 +430,6 @@ def define_arrays(self, line):
v_dtype = v.dtype
v_shape = v.shape if v.signature_shape != '()' else '()'
extra = ""
self.var_dtypes[name] = v_dtype
v_dtype = "float" if v_dtype == "dt_double" else "erfa_ufunc." + v_dtype
defines.append(f"{name} = np.empty({v_shape}, {v_dtype}){extra}")

Expand Down Expand Up @@ -540,8 +540,7 @@ def to_python(self):
# that were not caught by the general replacement above (e.g.,
# with names not equal to 'pv')
name, _, rest = line.partition('[')
if (rest and rest[0] in '01' and name in self.var_dtypes
and self.var_dtypes[name] == 'dt_pv'):
if name in self.dt_pv_vars and rest.startswith(("0", "1")):
line = name + "[" + ("'p'" if rest[0] == "0" else "'v'") + rest[1:]

out.append(line)
Expand Down
Loading