From 5fa4c21d9bfb5c15d4f2da7a1e91e28a31ddae67 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 20:55:34 +0100 Subject: [PATCH 1/9] Added generator for params_Maxwell.py --- ast_helpers.py | 47 +++++++++++++ generate_maxwell_ast.py | 147 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 ast_helpers.py create mode 100644 generate_maxwell_ast.py diff --git a/ast_helpers.py b/ast_helpers.py new file mode 100644 index 000000000..5416ab332 --- /dev/null +++ b/ast_helpers.py @@ -0,0 +1,47 @@ +import ast + + +def import_from(module, names): + return ast.ImportFrom( + module=module, + names=[ast.alias(name=n, asname=None) for n in names], + level=0, + ) + + +def assign_constructor(var, cls, **kwargs): + """Create AST for: var = cls(**kwargs)""" + + def ast_value(v): + if isinstance(v, tuple): + return ast.Tuple(elts=[ast_value(x) for x in v], ctx=ast.Load()) + return ast.Constant(v) + + return ast.Assign( + targets=[ast.Name(id=var, ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id=cls, ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword(arg=k, value=ast_value(v)) for k, v in kwargs.items() + ], + ), + ) + + +def call_attr(obj, attr, args=None, keywords=None): + return ast.Expr( + value=ast.Call( + func=ast.Attribute(value=obj, attr=attr, ctx=ast.Load()), + args=args or [], + keywords=keywords or [], + ) + ) + + +def attr_chain(names, ctx=ast.Load()): + """Create a nested Attribute node from a list: e.g., model.em_fields.b_field""" + node = ast.Name(id=names[0], ctx=ctx) + for name in names[1:]: + node = ast.Attribute(value=node, attr=name, ctx=ctx) + return node diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py new file mode 100644 index 000000000..4b9859bb3 --- /dev/null +++ b/generate_maxwell_ast.py @@ -0,0 +1,147 @@ +import ast + +from ast_helpers import assign_constructor, attr_chain, call_attr, import_from + +# Imports +imports = [ + import_from( + "struphy.io.options", + [ + "EnvironmentOptions", + "BaseUnits", + "Time", + "DerhamOptions", + "FieldsBackground", + ], + ), + import_from("struphy.geometry", ["domains"]), + import_from("struphy.fields_background", ["equils"]), + import_from("struphy.topology", ["grids"]), + import_from("struphy.initial", ["perturbations"]), + import_from("struphy.kinetic_background", ["maxwellians"]), + import_from( + "struphy.pic.utilities", + [ + "LoadingParameters", + "WeightsParameters", + "BoundaryParameters", + "BinningPlot", + "KernelDensityPlot", + ], + ), + import_from("struphy", ["main"]), + import_from("struphy.models.toy", ["Maxwell"]), +] + +# Assignments +assignments = [ + assign_constructor("env", "EnvironmentOptions"), + assign_constructor("base_units", "BaseUnits"), + assign_constructor("time_opts", "Time", dt=0.01, Tend=0.10), + assign_constructor("domain", "domains.Cuboid"), + assign_constructor("equil", "equils.HomogenSlab"), + assign_constructor("grid", "grids.TensorProductGrid"), + assign_constructor("derham_opts", "DerhamOptions"), + assign_constructor("model", "Maxwell"), +] + +# propagator options +prop_options_assign = ast.Assign( + targets=[ + attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store()) + ], + value=ast.Call( + func=attr_chain(["model", "propagators", "maxwell", "Options"]), + args=[], + keywords=[], + ), +) +assignments.append(prop_options_assign) + +# Perturbations +perturb_calls = [] +for comp in range(3): + perturb_calls.append( + call_attr( + attr_chain(["model", "em_fields", "b_field"]), + "add_perturbation", + args=[ + ast.Call( + func=ast.Attribute( + value=ast.Name(id="perturbations", ctx=ast.Load()), + attr="TorusModesCos", + ctx=ast.Load(), + ), + args=[], + keywords=[ + ast.keyword(arg="given_in_basis", value=ast.Constant("v")), + ast.keyword(arg="comp", value=ast.Constant(comp)), + ], + ) + ], + ) + ) + +# main +main_guard = ast.If( + test=ast.Compare( + left=ast.Name(id="__name__", ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Constant("__main__")], + ), + body=[ + ast.Assign( + targets=[ast.Name(id="verbose", ctx=ast.Store())], + value=ast.Constant(True), + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="main", ctx=ast.Load()), + attr="run", + ctx=ast.Load(), + ), + args=[ast.Name(id="model", ctx=ast.Load())], + keywords=[ + ast.keyword( + arg="params_path", value=ast.Name(id="__file__", ctx=ast.Load()) + ), + ast.keyword(arg="env", value=ast.Name(id="env", ctx=ast.Load())), + ast.keyword( + arg="base_units", + value=ast.Name(id="base_units", ctx=ast.Load()), + ), + ast.keyword( + arg="time_opts", value=ast.Name(id="time_opts", ctx=ast.Load()) + ), + ast.keyword( + arg="domain", value=ast.Name(id="domain", ctx=ast.Load()) + ), + ast.keyword( + arg="equil", value=ast.Name(id="equil", ctx=ast.Load()) + ), + ast.keyword(arg="grid", value=ast.Name(id="grid", ctx=ast.Load())), + ast.keyword( + arg="derham_opts", + value=ast.Name(id="derham_opts", ctx=ast.Load()), + ), + ast.keyword( + arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load()) + ), + ], + ) + ), + ], + orelse=[], +) + +# Assemble module +module = ast.Module( + body=imports + assignments + perturb_calls + [main_guard], type_ignores=[] +) + +ast.fix_missing_locations(module) + +# print source code +source = ast.unparse(module) +print(source) From 95b96b219dbc399256d21af8869286ce820dec70 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 21:38:11 +0100 Subject: [PATCH 2/9] Generate parameters with defaults --- ast_helpers.py | 73 +++++++++++++++++++++++++++++++++++++++++ generate_maxwell_ast.py | 47 ++++++++++++++++++++------ 2 files changed, 110 insertions(+), 10 deletions(-) diff --git a/ast_helpers.py b/ast_helpers.py index 5416ab332..5f8c95e6e 100644 --- a/ast_helpers.py +++ b/ast_helpers.py @@ -1,5 +1,7 @@ import ast +import inspect + def import_from(module, names): return ast.ImportFrom( @@ -29,6 +31,77 @@ def ast_value(v): ) +def assign_constructor_with_defaults( + var_name, cls_or_str, cls_name_for_ast=None, **overrides +): + """Create AST for: var_name = cls(**all_defaults_and_overrides) + + Args: + var_name: Variable name for the assignment + cls_or_str: Class object to introspect for defaults, OR a string class name if no introspection needed + cls_name_for_ast: Optional string for how to reference the class in AST (e.g., "domains.Cuboid") + If not provided, will use cls.__name__ or cls_or_str + **overrides: Keyword arguments that override the defaults + """ + keywords = [] + + # If we have a class object, introspect it for defaults + if not isinstance(cls_or_str, str): + cls = cls_or_str + sig = inspect.signature(cls.__init__) + + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Use override if provided, otherwise use default + if name in overrides: + value = overrides[name] + elif param.default is not param.empty: + value = param.default + else: + # No default and no override - skip or set to None + continue + + # Convert to AST constant or tuple if needed + if isinstance(value, tuple): + ast_value = ast.Tuple( + elts=[ast.Constant(v) for v in value], ctx=ast.Load() + ) + else: + ast_value = ast.Constant(value) + keywords.append(ast.keyword(arg=name, value=ast_value)) + + # Determine the class name for AST + if cls_name_for_ast is None: + cls_name_for_ast = cls.__name__ + else: + # It's just a string, use the overrides + cls_name_for_ast = cls_or_str + for name, value in overrides.items(): + if isinstance(value, tuple): + ast_value = ast.Tuple( + elts=[ast.Constant(v) for v in value], ctx=ast.Load() + ) + else: + ast_value = ast.Constant(value) + keywords.append(ast.keyword(arg=name, value=ast_value)) + + # Build the function AST node from the class name string + if "." in cls_name_for_ast: + parts = cls_name_for_ast.split(".") + func = ast.Name(id=parts[0], ctx=ast.Load()) + for part in parts[1:]: + func = ast.Attribute(value=func, attr=part, ctx=ast.Load()) + else: + func = ast.Name(id=cls_name_for_ast, ctx=ast.Load()) + + return ast.Assign( + targets=[ast.Name(id=var_name, ctx=ast.Store())], + value=ast.Call(func=func, args=[], keywords=keywords), + ) + + def call_attr(obj, attr, args=None, keywords=None): return ast.Expr( value=ast.Call( diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 4b9859bb3..8445b722c 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -1,6 +1,31 @@ import ast - -from ast_helpers import assign_constructor, attr_chain, call_attr, import_from +from struphy.io.options import ( + EnvironmentOptions, + BaseUnits, + Time, + DerhamOptions, + FieldsBackground, +) +from struphy.geometry import domains +from struphy.fields_background import equils +from struphy.topology import grids +from struphy.initial import perturbations +from struphy.kinetic_background import maxwellians +from struphy.pic.utilities import ( + LoadingParameters, + WeightsParameters, + BoundaryParameters, + BinningPlot, + KernelDensityPlot, +) +from struphy import main +from struphy.models.toy import Maxwell +from ast_helpers import ( + assign_constructor_with_defaults, + attr_chain, + call_attr, + import_from, +) # Imports imports = [ @@ -35,14 +60,16 @@ # Assignments assignments = [ - assign_constructor("env", "EnvironmentOptions"), - assign_constructor("base_units", "BaseUnits"), - assign_constructor("time_opts", "Time", dt=0.01, Tend=0.10), - assign_constructor("domain", "domains.Cuboid"), - assign_constructor("equil", "equils.HomogenSlab"), - assign_constructor("grid", "grids.TensorProductGrid"), - assign_constructor("derham_opts", "DerhamOptions"), - assign_constructor("model", "Maxwell"), + assign_constructor_with_defaults("env", EnvironmentOptions), + assign_constructor_with_defaults("base_units", BaseUnits), + assign_constructor_with_defaults("time_opts", Time, None, dt=0.01, Tend=0.10), + assign_constructor_with_defaults("domain", domains.Cuboid, "domains.Cuboid"), + assign_constructor_with_defaults("equil", equils.HomogenSlab, "equils.HomogenSlab"), + assign_constructor_with_defaults( + "grid", grids.TensorProductGrid, "grids.TensorProductGrid" + ), + assign_constructor_with_defaults("derham_opts", DerhamOptions), + assign_constructor_with_defaults("model", Maxwell), ] # propagator options From 6ec1957ec62478e6722a825c25accc100107b2f4 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 21:42:38 +0100 Subject: [PATCH 3/9] Added option to specify defaults --- ast_helpers.py | 25 +++++++++++++---- generate_maxwell_ast.py | 61 +++++++++++++++++++++++------------------ 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/ast_helpers.py b/ast_helpers.py index 5f8c95e6e..326c85676 100644 --- a/ast_helpers.py +++ b/ast_helpers.py @@ -1,5 +1,4 @@ import ast - import inspect @@ -11,21 +10,35 @@ def import_from(module, names): ) -def assign_constructor(var, cls, **kwargs): - """Create AST for: var = cls(**kwargs)""" +def assign_constructor(var_name, cls_or_str, cls_name_for_ast=None, **overrides): def ast_value(v): if isinstance(v, tuple): return ast.Tuple(elts=[ast_value(x) for x in v], ctx=ast.Load()) return ast.Constant(v) + # Determine the class name for AST + if isinstance(cls_or_str, str): + cls_name_for_ast = cls_name_for_ast or cls_or_str + else: + cls_name_for_ast = cls_name_for_ast or cls_or_str.__name__ + + # Build the function AST node from the class name string + if "." in cls_name_for_ast: + parts = cls_name_for_ast.split(".") + func = ast.Name(id=parts[0], ctx=ast.Load()) + for part in parts[1:]: + func = ast.Attribute(value=func, attr=part, ctx=ast.Load()) + else: + func = ast.Name(id=cls_name_for_ast, ctx=ast.Load()) + return ast.Assign( - targets=[ast.Name(id=var, ctx=ast.Store())], + targets=[ast.Name(id=var_name, ctx=ast.Store())], value=ast.Call( - func=ast.Name(id=cls, ctx=ast.Load()), + func=func, args=[], keywords=[ - ast.keyword(arg=k, value=ast_value(v)) for k, v in kwargs.items() + ast.keyword(arg=k, value=ast_value(v)) for k, v in overrides.items() ], ), ) diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 8445b722c..15e62fed5 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -1,31 +1,40 @@ import ast + +from ast_helpers import ( + assign_constructor, + assign_constructor_with_defaults, + attr_chain, + call_attr, + import_from, +) +from struphy import main +from struphy.fields_background import equils +from struphy.geometry import domains +from struphy.initial import perturbations from struphy.io.options import ( - EnvironmentOptions, BaseUnits, - Time, DerhamOptions, + EnvironmentOptions, FieldsBackground, + Time, ) -from struphy.geometry import domains -from struphy.fields_background import equils -from struphy.topology import grids -from struphy.initial import perturbations from struphy.kinetic_background import maxwellians +from struphy.models.toy import Maxwell from struphy.pic.utilities import ( - LoadingParameters, - WeightsParameters, - BoundaryParameters, BinningPlot, + BoundaryParameters, KernelDensityPlot, + LoadingParameters, + WeightsParameters, ) -from struphy import main -from struphy.models.toy import Maxwell -from ast_helpers import ( - assign_constructor_with_defaults, - attr_chain, - call_attr, - import_from, -) +from struphy.topology import grids + +specify_defaults: bool = False + +if specify_defaults: + assign_constructor_func = assign_constructor_with_defaults +else: + assign_constructor_func = assign_constructor # Imports imports = [ @@ -60,16 +69,14 @@ # Assignments assignments = [ - assign_constructor_with_defaults("env", EnvironmentOptions), - assign_constructor_with_defaults("base_units", BaseUnits), - assign_constructor_with_defaults("time_opts", Time, None, dt=0.01, Tend=0.10), - assign_constructor_with_defaults("domain", domains.Cuboid, "domains.Cuboid"), - assign_constructor_with_defaults("equil", equils.HomogenSlab, "equils.HomogenSlab"), - assign_constructor_with_defaults( - "grid", grids.TensorProductGrid, "grids.TensorProductGrid" - ), - assign_constructor_with_defaults("derham_opts", DerhamOptions), - assign_constructor_with_defaults("model", Maxwell), + assign_constructor_func("env", EnvironmentOptions), + assign_constructor_func("base_units", BaseUnits), + assign_constructor_func("time_opts", Time, None, dt=0.01, Tend=0.10), + assign_constructor_func("domain", domains.Cuboid, "domains.Cuboid"), + assign_constructor_func("equil", equils.HomogenSlab, "equils.HomogenSlab"), + assign_constructor_func("grid", grids.TensorProductGrid, "grids.TensorProductGrid"), + assign_constructor_func("derham_opts", DerhamOptions), + assign_constructor_func("model", Maxwell), ] # propagator options From 93367c109a4aaaf5aafeb6d7b3b4a1370e166e98 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 21:44:02 +0100 Subject: [PATCH 4/9] Added typehints --- ast_helpers.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/ast_helpers.py b/ast_helpers.py index 326c85676..55475ae70 100644 --- a/ast_helpers.py +++ b/ast_helpers.py @@ -1,8 +1,9 @@ import ast import inspect +from typing import Any, Optional, Union -def import_from(module, names): +def import_from(module: str, names: list[str]) -> ast.ImportFrom: return ast.ImportFrom( module=module, names=[ast.alias(name=n, asname=None) for n in names], @@ -10,7 +11,12 @@ def import_from(module, names): ) -def assign_constructor(var_name, cls_or_str, cls_name_for_ast=None, **overrides): +def assign_constructor( + var_name: str, + cls_or_str: Union[type, str], + cls_name_for_ast: Optional[str] = None, + **overrides: Any, +) -> ast.Assign: def ast_value(v): if isinstance(v, tuple): @@ -45,8 +51,11 @@ def ast_value(v): def assign_constructor_with_defaults( - var_name, cls_or_str, cls_name_for_ast=None, **overrides -): + var_name: str, + cls_or_str: Union[type, str], + cls_name_for_ast: Optional[str] = None, + **overrides: Any, +) -> ast.Assign: """Create AST for: var_name = cls(**all_defaults_and_overrides) Args: @@ -115,7 +124,12 @@ def assign_constructor_with_defaults( ) -def call_attr(obj, attr, args=None, keywords=None): +def call_attr( + obj: ast.expr, + attr: str, + args: Optional[list[ast.expr]] = None, + keywords: Optional[list[ast.keyword]] = None, +) -> ast.Expr: return ast.Expr( value=ast.Call( func=ast.Attribute(value=obj, attr=attr, ctx=ast.Load()), @@ -125,7 +139,7 @@ def call_attr(obj, attr, args=None, keywords=None): ) -def attr_chain(names, ctx=ast.Load()): +def attr_chain(names: list[str], ctx: ast.expr_context = ast.Load()) -> ast.expr: """Create a nested Attribute node from a list: e.g., model.em_fields.b_field""" node = ast.Name(id=names[0], ctx=ctx) for name in names[1:]: From 635625190f0ff3fa5c2b04b566aecb07c16b3dfb Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 21:45:30 +0100 Subject: [PATCH 5/9] Moved code into generate_maxwell_ast --- generate_maxwell_ast.py | 281 +++++++++++++++++++++------------------- 1 file changed, 147 insertions(+), 134 deletions(-) diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 15e62fed5..ce9d19ed6 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -29,153 +29,166 @@ ) from struphy.topology import grids -specify_defaults: bool = False -if specify_defaults: - assign_constructor_func = assign_constructor_with_defaults -else: - assign_constructor_func = assign_constructor +def generate_maxwell_ast(specify_defaults: bool = False): + if specify_defaults: + assign_constructor_func = assign_constructor_with_defaults + else: + assign_constructor_func = assign_constructor -# Imports -imports = [ - import_from( - "struphy.io.options", - [ - "EnvironmentOptions", - "BaseUnits", - "Time", - "DerhamOptions", - "FieldsBackground", - ], - ), - import_from("struphy.geometry", ["domains"]), - import_from("struphy.fields_background", ["equils"]), - import_from("struphy.topology", ["grids"]), - import_from("struphy.initial", ["perturbations"]), - import_from("struphy.kinetic_background", ["maxwellians"]), - import_from( - "struphy.pic.utilities", - [ - "LoadingParameters", - "WeightsParameters", - "BoundaryParameters", - "BinningPlot", - "KernelDensityPlot", - ], - ), - import_from("struphy", ["main"]), - import_from("struphy.models.toy", ["Maxwell"]), -] + # Imports + imports = [ + import_from( + "struphy.io.options", + [ + "EnvironmentOptions", + "BaseUnits", + "Time", + "DerhamOptions", + "FieldsBackground", + ], + ), + import_from("struphy.geometry", ["domains"]), + import_from("struphy.fields_background", ["equils"]), + import_from("struphy.topology", ["grids"]), + import_from("struphy.initial", ["perturbations"]), + import_from("struphy.kinetic_background", ["maxwellians"]), + import_from( + "struphy.pic.utilities", + [ + "LoadingParameters", + "WeightsParameters", + "BoundaryParameters", + "BinningPlot", + "KernelDensityPlot", + ], + ), + import_from("struphy", ["main"]), + import_from("struphy.models.toy", ["Maxwell"]), + ] -# Assignments -assignments = [ - assign_constructor_func("env", EnvironmentOptions), - assign_constructor_func("base_units", BaseUnits), - assign_constructor_func("time_opts", Time, None, dt=0.01, Tend=0.10), - assign_constructor_func("domain", domains.Cuboid, "domains.Cuboid"), - assign_constructor_func("equil", equils.HomogenSlab, "equils.HomogenSlab"), - assign_constructor_func("grid", grids.TensorProductGrid, "grids.TensorProductGrid"), - assign_constructor_func("derham_opts", DerhamOptions), - assign_constructor_func("model", Maxwell), -] + # Assignments + assignments = [ + assign_constructor_func("env", EnvironmentOptions), + assign_constructor_func("base_units", BaseUnits), + assign_constructor_func("time_opts", Time, None, dt=0.01, Tend=0.10), + assign_constructor_func("domain", domains.Cuboid, "domains.Cuboid"), + assign_constructor_func("equil", equils.HomogenSlab, "equils.HomogenSlab"), + assign_constructor_func( + "grid", grids.TensorProductGrid, "grids.TensorProductGrid" + ), + assign_constructor_func("derham_opts", DerhamOptions), + assign_constructor_func("model", Maxwell), + ] -# propagator options -prop_options_assign = ast.Assign( - targets=[ - attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store()) - ], - value=ast.Call( - func=attr_chain(["model", "propagators", "maxwell", "Options"]), - args=[], - keywords=[], - ), -) -assignments.append(prop_options_assign) + # propagator options + prop_options_assign = ast.Assign( + targets=[ + attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store()) + ], + value=ast.Call( + func=attr_chain(["model", "propagators", "maxwell", "Options"]), + args=[], + keywords=[], + ), + ) + assignments.append(prop_options_assign) + + # Perturbations + perturb_calls = [] + for comp in range(3): + perturb_calls.append( + call_attr( + attr_chain(["model", "em_fields", "b_field"]), + "add_perturbation", + args=[ + ast.Call( + func=ast.Attribute( + value=ast.Name(id="perturbations", ctx=ast.Load()), + attr="TorusModesCos", + ctx=ast.Load(), + ), + args=[], + keywords=[ + ast.keyword(arg="given_in_basis", value=ast.Constant("v")), + ast.keyword(arg="comp", value=ast.Constant(comp)), + ], + ) + ], + ) + ) -# Perturbations -perturb_calls = [] -for comp in range(3): - perturb_calls.append( - call_attr( - attr_chain(["model", "em_fields", "b_field"]), - "add_perturbation", - args=[ - ast.Call( + # main + main_guard = ast.If( + test=ast.Compare( + left=ast.Name(id="__name__", ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Constant("__main__")], + ), + body=[ + ast.Assign( + targets=[ast.Name(id="verbose", ctx=ast.Store())], + value=ast.Constant(True), + ), + ast.Expr( + value=ast.Call( func=ast.Attribute( - value=ast.Name(id="perturbations", ctx=ast.Load()), - attr="TorusModesCos", + value=ast.Name(id="main", ctx=ast.Load()), + attr="run", ctx=ast.Load(), ), - args=[], + args=[ast.Name(id="model", ctx=ast.Load())], keywords=[ - ast.keyword(arg="given_in_basis", value=ast.Constant("v")), - ast.keyword(arg="comp", value=ast.Constant(comp)), + ast.keyword( + arg="params_path", + value=ast.Name(id="__file__", ctx=ast.Load()), + ), + ast.keyword( + arg="env", value=ast.Name(id="env", ctx=ast.Load()) + ), + ast.keyword( + arg="base_units", + value=ast.Name(id="base_units", ctx=ast.Load()), + ), + ast.keyword( + arg="time_opts", + value=ast.Name(id="time_opts", ctx=ast.Load()), + ), + ast.keyword( + arg="domain", value=ast.Name(id="domain", ctx=ast.Load()) + ), + ast.keyword( + arg="equil", value=ast.Name(id="equil", ctx=ast.Load()) + ), + ast.keyword( + arg="grid", value=ast.Name(id="grid", ctx=ast.Load()) + ), + ast.keyword( + arg="derham_opts", + value=ast.Name(id="derham_opts", ctx=ast.Load()), + ), + ast.keyword( + arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load()) + ), ], ) - ], - ) + ), + ], + orelse=[], ) -# main -main_guard = ast.If( - test=ast.Compare( - left=ast.Name(id="__name__", ctx=ast.Load()), - ops=[ast.Eq()], - comparators=[ast.Constant("__main__")], - ), - body=[ - ast.Assign( - targets=[ast.Name(id="verbose", ctx=ast.Store())], - value=ast.Constant(True), - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="main", ctx=ast.Load()), - attr="run", - ctx=ast.Load(), - ), - args=[ast.Name(id="model", ctx=ast.Load())], - keywords=[ - ast.keyword( - arg="params_path", value=ast.Name(id="__file__", ctx=ast.Load()) - ), - ast.keyword(arg="env", value=ast.Name(id="env", ctx=ast.Load())), - ast.keyword( - arg="base_units", - value=ast.Name(id="base_units", ctx=ast.Load()), - ), - ast.keyword( - arg="time_opts", value=ast.Name(id="time_opts", ctx=ast.Load()) - ), - ast.keyword( - arg="domain", value=ast.Name(id="domain", ctx=ast.Load()) - ), - ast.keyword( - arg="equil", value=ast.Name(id="equil", ctx=ast.Load()) - ), - ast.keyword(arg="grid", value=ast.Name(id="grid", ctx=ast.Load())), - ast.keyword( - arg="derham_opts", - value=ast.Name(id="derham_opts", ctx=ast.Load()), - ), - ast.keyword( - arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load()) - ), - ], - ) - ), - ], - orelse=[], -) + # Assemble module + module = ast.Module( + body=imports + assignments + perturb_calls + [main_guard], type_ignores=[] + ) -# Assemble module -module = ast.Module( - body=imports + assignments + perturb_calls + [main_guard], type_ignores=[] -) + ast.fix_missing_locations(module) + + # return source code + return ast.unparse(module) -ast.fix_missing_locations(module) -# print source code -source = ast.unparse(module) -print(source) +if __name__ == "__main__": + source = generate_maxwell_ast(specify_defaults=False) + # print source code + print(source) From 703a61d43fde8d10f33fa0ed937d08830a660813 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 21:46:49 +0100 Subject: [PATCH 6/9] Added typehint --- generate_maxwell_ast.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index ce9d19ed6..1a757fb0e 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -30,7 +30,9 @@ from struphy.topology import grids -def generate_maxwell_ast(specify_defaults: bool = False): +def generate_maxwell_ast( + specify_defaults: bool = False, +) -> str: if specify_defaults: assign_constructor_func = assign_constructor_with_defaults else: From 392f1b0f6d447039a5a69d313fae2077685fd1af Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 22:24:29 +0100 Subject: [PATCH 7/9] Added dataclass-parameter generator --- ast_helpers.py | 58 +++++++++++++++ dataclass_generator.py | 152 ++++++++++++++++++++++++++++++++++++++++ generate_maxwell_ast.py | 8 +-- 3 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 dataclass_generator.py diff --git a/ast_helpers.py b/ast_helpers.py index 55475ae70..a50073913 100644 --- a/ast_helpers.py +++ b/ast_helpers.py @@ -145,3 +145,61 @@ def attr_chain(names: list[str], ctx: ast.expr_context = ast.Load()) -> ast.expr for name in names[1:]: node = ast.Attribute(value=node, attr=name, ctx=ctx) return node + + +def create_dataclass_from_class( + cls: type, dataclass_name: Optional[str] = None +) -> ast.ClassDef: + """Create a dataclass AST node with fields from a class's __init__ parameters. + + Args: + cls: Class to introspect for parameters + dataclass_name: Name for the dataclass. If None, uses cls.__name__ + "Params" + + Returns: + ast.ClassDef node for a dataclass + """ + if dataclass_name is None: + dataclass_name = f"{cls.__name__}Params" + + sig = inspect.signature(cls.__init__) + fields = [] + + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Create an annotated assignment for each field + # field_name: type = default_value + annotation = ast.Name(id="Any", ctx=ast.Load()) # Use Any for now + + if param.default is not param.empty: + # Has a default value + if isinstance(param.default, tuple): + value = ast.Tuple( + elts=[ast.Constant(v) for v in param.default], ctx=ast.Load() + ) + else: + value = ast.Constant(param.default) + else: + # No default - this is a required field + value = None + + ann_assign = ast.AnnAssign( + target=ast.Name(id=name, ctx=ast.Store()), + annotation=annotation, + value=value, + simple=1, + ) + fields.append(ann_assign) + + # Create the class definition with @dataclass decorator + class_def = ast.ClassDef( + name=dataclass_name, + bases=[], + keywords=[], + body=fields if fields else [ast.Pass()], + decorator_list=[ast.Name(id="dataclass", ctx=ast.Load())], + ) + + return class_def diff --git a/dataclass_generator.py b/dataclass_generator.py new file mode 100644 index 000000000..5d534047e --- /dev/null +++ b/dataclass_generator.py @@ -0,0 +1,152 @@ +from struphy.io.options import Time, EnvironmentOptions, BaseUnits, DerhamOptions, FieldsBackground + +import ast +import inspect + + +def generate_params_dataclass(cls): + """Generate a dataclass AST for parameters from a class's __init__.""" + # Get the signature and type hints + sig = inspect.signature(cls.__init__) + type_hints = {} + try: + import typing + type_hints = typing.get_type_hints(cls.__init__) + except Exception: + # If we can't get type hints, we'll fall back to Any + pass + + # Build dataclass fields + fields = [] + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Try to get the type annotation + if name in type_hints: + # Convert type hint to AST + type_hint = type_hints[name] + annotation = _type_to_ast(type_hint) + elif param.annotation is not param.empty: + # Try to use the annotation directly + annotation = _annotation_to_ast(param.annotation) + else: + # Fall back to Any + annotation = ast.Name(id="Any", ctx=ast.Load()) + + if param.default is not param.empty: + # Has a default value + if isinstance(param.default, tuple): + value = ast.Tuple( + elts=[ast.Constant(v) for v in param.default], ctx=ast.Load() + ) + else: + value = ast.Constant(param.default) + else: + # No default - required field + value = None + + ann_assign = ast.AnnAssign( + target=ast.Name(id=name, ctx=ast.Store()), + annotation=annotation, + value=value, + simple=1, + ) + fields.append(ann_assign) + + # Create the dataclass + class_def = ast.ClassDef( + name=f"{cls.__name__}Params", + bases=[], + keywords=[], + body=fields if fields else [ast.Pass()], + decorator_list=[ast.Name(id="dataclass", ctx=ast.Load())], + ) + + # Create a module with necessary imports and the dataclass + module = ast.Module( + body=[ + ast.ImportFrom( + module="dataclasses", names=[ast.alias(name="dataclass", asname=None)], level=0 + ), + ast.ImportFrom( + module="typing", names=[ast.alias(name="Any", asname=None)], level=0 + ), + class_def, + ], + type_ignores=[], + ) + + ast.fix_missing_locations(module) + return ast.unparse(module) + + +def _type_to_ast(type_hint): + """Convert a type hint to an AST node.""" + import typing + + # Handle basic types + if type_hint is type(None): + return ast.Constant(value=None) + elif hasattr(type_hint, "__name__"): + # Simple type like int, str, float, bool + return ast.Name(id=type_hint.__name__, ctx=ast.Load()) + elif hasattr(type_hint, "__origin__"): + # Generic types like Optional[int], list[str], etc. + origin = type_hint.__origin__ + + if origin is typing.Union: + # Union or Optional + args = type_hint.__args__ + if len(args) == 2 and type(None) in args: + # Optional[X] + inner_type = args[0] if args[1] is type(None) else args[1] + return ast.Subscript( + value=ast.Name(id="Optional", ctx=ast.Load()), + slice=_type_to_ast(inner_type), + ctx=ast.Load(), + ) + else: + # Union[X, Y, ...] + return ast.Subscript( + value=ast.Name(id="Union", ctx=ast.Load()), + slice=ast.Tuple( + elts=[_type_to_ast(arg) for arg in args], + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + elif hasattr(origin, "__name__"): + # list, tuple, dict, etc. + args = type_hint.__args__ + return ast.Subscript( + value=ast.Name(id=origin.__name__, ctx=ast.Load()), + slice=ast.Tuple( + elts=[_type_to_ast(arg) for arg in args], + ctx=ast.Load(), + ) if len(args) > 1 else _type_to_ast(args[0]), + ctx=ast.Load(), + ) + + # Fall back to Any + return ast.Name(id="Any", ctx=ast.Load()) + + +def _annotation_to_ast(annotation): + """Convert an annotation object to AST.""" + if isinstance(annotation, type): + return ast.Name(id=annotation.__name__, ctx=ast.Load()) + elif isinstance(annotation, str): + # Forward reference + return ast.Name(id=annotation, ctx=ast.Load()) + else: + # Try to extract from the annotation + try: + return _type_to_ast(annotation) + except Exception: + return ast.Name(id="Any", ctx=ast.Load()) + + +if __name__ == "__main__": + source = generate_params_dataclass(EnvironmentOptions) + print(source) \ No newline at end of file diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 1a757fb0e..53384fccd 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -189,8 +189,8 @@ def generate_maxwell_ast( # return source code return ast.unparse(module) - if __name__ == "__main__": - source = generate_maxwell_ast(specify_defaults=False) - # print source code - print(source) + + # source = generate_maxwell_ast(specify_defaults=False) + # # print source code + # print(source) From 90f7c13ff6e80be5e1c5ab57122f11da36023b77 Mon Sep 17 00:00:00 2001 From: Max Date: Sun, 11 Jan 2026 22:25:22 +0100 Subject: [PATCH 8/9] Formatting --- dataclass_generator.py | 35 ++++++++++++++++++++++++----------- generate_maxwell_ast.py | 9 +++++---- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/dataclass_generator.py b/dataclass_generator.py index 5d534047e..40136c7e9 100644 --- a/dataclass_generator.py +++ b/dataclass_generator.py @@ -1,8 +1,14 @@ -from struphy.io.options import Time, EnvironmentOptions, BaseUnits, DerhamOptions, FieldsBackground - import ast import inspect +from struphy.io.options import ( + BaseUnits, + DerhamOptions, + EnvironmentOptions, + FieldsBackground, + Time, +) + def generate_params_dataclass(cls): """Generate a dataclass AST for parameters from a class's __init__.""" @@ -11,6 +17,7 @@ def generate_params_dataclass(cls): type_hints = {} try: import typing + type_hints = typing.get_type_hints(cls.__init__) except Exception: # If we can't get type hints, we'll fall back to Any @@ -67,7 +74,9 @@ def generate_params_dataclass(cls): module = ast.Module( body=[ ast.ImportFrom( - module="dataclasses", names=[ast.alias(name="dataclass", asname=None)], level=0 + module="dataclasses", + names=[ast.alias(name="dataclass", asname=None)], + level=0, ), ast.ImportFrom( module="typing", names=[ast.alias(name="Any", asname=None)], level=0 @@ -84,7 +93,7 @@ def generate_params_dataclass(cls): def _type_to_ast(type_hint): """Convert a type hint to an AST node.""" import typing - + # Handle basic types if type_hint is type(None): return ast.Constant(value=None) @@ -94,7 +103,7 @@ def _type_to_ast(type_hint): elif hasattr(type_hint, "__origin__"): # Generic types like Optional[int], list[str], etc. origin = type_hint.__origin__ - + if origin is typing.Union: # Union or Optional args = type_hint.__args__ @@ -121,13 +130,17 @@ def _type_to_ast(type_hint): args = type_hint.__args__ return ast.Subscript( value=ast.Name(id=origin.__name__, ctx=ast.Load()), - slice=ast.Tuple( - elts=[_type_to_ast(arg) for arg in args], - ctx=ast.Load(), - ) if len(args) > 1 else _type_to_ast(args[0]), + slice=( + ast.Tuple( + elts=[_type_to_ast(arg) for arg in args], + ctx=ast.Load(), + ) + if len(args) > 1 + else _type_to_ast(args[0]) + ), ctx=ast.Load(), ) - + # Fall back to Any return ast.Name(id="Any", ctx=ast.Load()) @@ -149,4 +162,4 @@ def _annotation_to_ast(annotation): if __name__ == "__main__": source = generate_params_dataclass(EnvironmentOptions) - print(source) \ No newline at end of file + print(source) diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 53384fccd..94679a9ee 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -189,8 +189,9 @@ def generate_maxwell_ast( # return source code return ast.unparse(module) + if __name__ == "__main__": - - # source = generate_maxwell_ast(specify_defaults=False) - # # print source code - # print(source) + + source = generate_maxwell_ast(specify_defaults=False) + # print source code + print(source) From 1e4f7e3f1cc740fe77ccd4dce06781f03bbdd508 Mon Sep 17 00:00:00 2001 From: Max Lindqvist Date: Mon, 12 Jan 2026 10:30:48 +0100 Subject: [PATCH 9/9] ruff format --- ast_helpers.py | 21 +++++---------------- dataclass_generator.py | 8 ++------ generate_maxwell_ast.py | 33 ++++++++------------------------- 3 files changed, 15 insertions(+), 47 deletions(-) diff --git a/ast_helpers.py b/ast_helpers.py index a50073913..4940b604c 100644 --- a/ast_helpers.py +++ b/ast_helpers.py @@ -17,7 +17,6 @@ def assign_constructor( cls_name_for_ast: Optional[str] = None, **overrides: Any, ) -> ast.Assign: - def ast_value(v): if isinstance(v, tuple): return ast.Tuple(elts=[ast_value(x) for x in v], ctx=ast.Load()) @@ -43,9 +42,7 @@ def ast_value(v): value=ast.Call( func=func, args=[], - keywords=[ - ast.keyword(arg=k, value=ast_value(v)) for k, v in overrides.items() - ], + keywords=[ast.keyword(arg=k, value=ast_value(v)) for k, v in overrides.items()], ), ) @@ -87,9 +84,7 @@ def assign_constructor_with_defaults( # Convert to AST constant or tuple if needed if isinstance(value, tuple): - ast_value = ast.Tuple( - elts=[ast.Constant(v) for v in value], ctx=ast.Load() - ) + ast_value = ast.Tuple(elts=[ast.Constant(v) for v in value], ctx=ast.Load()) else: ast_value = ast.Constant(value) keywords.append(ast.keyword(arg=name, value=ast_value)) @@ -102,9 +97,7 @@ def assign_constructor_with_defaults( cls_name_for_ast = cls_or_str for name, value in overrides.items(): if isinstance(value, tuple): - ast_value = ast.Tuple( - elts=[ast.Constant(v) for v in value], ctx=ast.Load() - ) + ast_value = ast.Tuple(elts=[ast.Constant(v) for v in value], ctx=ast.Load()) else: ast_value = ast.Constant(value) keywords.append(ast.keyword(arg=name, value=ast_value)) @@ -147,9 +140,7 @@ def attr_chain(names: list[str], ctx: ast.expr_context = ast.Load()) -> ast.expr return node -def create_dataclass_from_class( - cls: type, dataclass_name: Optional[str] = None -) -> ast.ClassDef: +def create_dataclass_from_class(cls: type, dataclass_name: Optional[str] = None) -> ast.ClassDef: """Create a dataclass AST node with fields from a class's __init__ parameters. Args: @@ -176,9 +167,7 @@ def create_dataclass_from_class( if param.default is not param.empty: # Has a default value if isinstance(param.default, tuple): - value = ast.Tuple( - elts=[ast.Constant(v) for v in param.default], ctx=ast.Load() - ) + value = ast.Tuple(elts=[ast.Constant(v) for v in param.default], ctx=ast.Load()) else: value = ast.Constant(param.default) else: diff --git a/dataclass_generator.py b/dataclass_generator.py index 40136c7e9..8b8efa174 100644 --- a/dataclass_generator.py +++ b/dataclass_generator.py @@ -44,9 +44,7 @@ def generate_params_dataclass(cls): if param.default is not param.empty: # Has a default value if isinstance(param.default, tuple): - value = ast.Tuple( - elts=[ast.Constant(v) for v in param.default], ctx=ast.Load() - ) + value = ast.Tuple(elts=[ast.Constant(v) for v in param.default], ctx=ast.Load()) else: value = ast.Constant(param.default) else: @@ -78,9 +76,7 @@ def generate_params_dataclass(cls): names=[ast.alias(name="dataclass", asname=None)], level=0, ), - ast.ImportFrom( - module="typing", names=[ast.alias(name="Any", asname=None)], level=0 - ), + ast.ImportFrom(module="typing", names=[ast.alias(name="Any", asname=None)], level=0), class_def, ], type_ignores=[], diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py index 94679a9ee..991528558 100644 --- a/generate_maxwell_ast.py +++ b/generate_maxwell_ast.py @@ -76,18 +76,14 @@ def generate_maxwell_ast( assign_constructor_func("time_opts", Time, None, dt=0.01, Tend=0.10), assign_constructor_func("domain", domains.Cuboid, "domains.Cuboid"), assign_constructor_func("equil", equils.HomogenSlab, "equils.HomogenSlab"), - assign_constructor_func( - "grid", grids.TensorProductGrid, "grids.TensorProductGrid" - ), + assign_constructor_func("grid", grids.TensorProductGrid, "grids.TensorProductGrid"), assign_constructor_func("derham_opts", DerhamOptions), assign_constructor_func("model", Maxwell), ] # propagator options prop_options_assign = ast.Assign( - targets=[ - attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store()) - ], + targets=[attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store())], value=ast.Call( func=attr_chain(["model", "propagators", "maxwell", "Options"]), args=[], @@ -145,9 +141,7 @@ def generate_maxwell_ast( arg="params_path", value=ast.Name(id="__file__", ctx=ast.Load()), ), - ast.keyword( - arg="env", value=ast.Name(id="env", ctx=ast.Load()) - ), + ast.keyword(arg="env", value=ast.Name(id="env", ctx=ast.Load())), ast.keyword( arg="base_units", value=ast.Name(id="base_units", ctx=ast.Load()), @@ -156,22 +150,14 @@ def generate_maxwell_ast( arg="time_opts", value=ast.Name(id="time_opts", ctx=ast.Load()), ), - ast.keyword( - arg="domain", value=ast.Name(id="domain", ctx=ast.Load()) - ), - ast.keyword( - arg="equil", value=ast.Name(id="equil", ctx=ast.Load()) - ), - ast.keyword( - arg="grid", value=ast.Name(id="grid", ctx=ast.Load()) - ), + ast.keyword(arg="domain", value=ast.Name(id="domain", ctx=ast.Load())), + ast.keyword(arg="equil", value=ast.Name(id="equil", ctx=ast.Load())), + ast.keyword(arg="grid", value=ast.Name(id="grid", ctx=ast.Load())), ast.keyword( arg="derham_opts", value=ast.Name(id="derham_opts", ctx=ast.Load()), ), - ast.keyword( - arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load()) - ), + ast.keyword(arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load())), ], ) ), @@ -180,9 +166,7 @@ def generate_maxwell_ast( ) # Assemble module - module = ast.Module( - body=imports + assignments + perturb_calls + [main_guard], type_ignores=[] - ) + module = ast.Module(body=imports + assignments + perturb_calls + [main_guard], type_ignores=[]) ast.fix_missing_locations(module) @@ -191,7 +175,6 @@ def generate_maxwell_ast( if __name__ == "__main__": - source = generate_maxwell_ast(specify_defaults=False) # print source code print(source)