diff --git a/ast_helpers.py b/ast_helpers.py new file mode 100644 index 000000000..4940b604c --- /dev/null +++ b/ast_helpers.py @@ -0,0 +1,194 @@ +import ast +import inspect +from typing import Any, Optional, Union + + +def import_from(module: str, names: list[str]) -> ast.ImportFrom: + return ast.ImportFrom( + module=module, + names=[ast.alias(name=n, asname=None) for n in names], + level=0, + ) + + +def assign_constructor( + var_name: str, + cls_or_str: Union[type, str], + cls_name_for_ast: Optional[str] = None, + **overrides: Any, +) -> ast.Assign: + def ast_value(v): + if isinstance(v, tuple): + return ast.Tuple(elts=[ast_value(x) for x in v], ctx=ast.Load()) + return ast.Constant(v) + + # Determine the class name for AST + if isinstance(cls_or_str, str): + cls_name_for_ast = cls_name_for_ast or cls_or_str + else: + cls_name_for_ast = cls_name_for_ast or cls_or_str.__name__ + + # Build the function AST node from the class name string + if "." in cls_name_for_ast: + parts = cls_name_for_ast.split(".") + func = ast.Name(id=parts[0], ctx=ast.Load()) + for part in parts[1:]: + func = ast.Attribute(value=func, attr=part, ctx=ast.Load()) + else: + func = ast.Name(id=cls_name_for_ast, ctx=ast.Load()) + + return ast.Assign( + targets=[ast.Name(id=var_name, ctx=ast.Store())], + value=ast.Call( + func=func, + args=[], + keywords=[ast.keyword(arg=k, value=ast_value(v)) for k, v in overrides.items()], + ), + ) + + +def assign_constructor_with_defaults( + var_name: str, + cls_or_str: Union[type, str], + cls_name_for_ast: Optional[str] = None, + **overrides: Any, +) -> ast.Assign: + """Create AST for: var_name = cls(**all_defaults_and_overrides) + + Args: + var_name: Variable name for the assignment + cls_or_str: Class object to introspect for defaults, OR a string class name if no introspection needed + cls_name_for_ast: Optional string for how to reference the class in AST (e.g., "domains.Cuboid") + If not provided, will use cls.__name__ or cls_or_str + **overrides: Keyword arguments that override the defaults + """ + keywords = [] + + # If we have a class object, introspect it for defaults + if not isinstance(cls_or_str, str): + cls = cls_or_str + sig = inspect.signature(cls.__init__) + + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Use override if provided, otherwise use default + if name in overrides: + value = overrides[name] + elif param.default is not param.empty: + value = param.default + else: + # No default and no override - skip or set to None + continue + + # Convert to AST constant or tuple if needed + if isinstance(value, tuple): + ast_value = ast.Tuple(elts=[ast.Constant(v) for v in value], ctx=ast.Load()) + else: + ast_value = ast.Constant(value) + keywords.append(ast.keyword(arg=name, value=ast_value)) + + # Determine the class name for AST + if cls_name_for_ast is None: + cls_name_for_ast = cls.__name__ + else: + # It's just a string, use the overrides + cls_name_for_ast = cls_or_str + for name, value in overrides.items(): + if isinstance(value, tuple): + ast_value = ast.Tuple(elts=[ast.Constant(v) for v in value], ctx=ast.Load()) + else: + ast_value = ast.Constant(value) + keywords.append(ast.keyword(arg=name, value=ast_value)) + + # Build the function AST node from the class name string + if "." in cls_name_for_ast: + parts = cls_name_for_ast.split(".") + func = ast.Name(id=parts[0], ctx=ast.Load()) + for part in parts[1:]: + func = ast.Attribute(value=func, attr=part, ctx=ast.Load()) + else: + func = ast.Name(id=cls_name_for_ast, ctx=ast.Load()) + + return ast.Assign( + targets=[ast.Name(id=var_name, ctx=ast.Store())], + value=ast.Call(func=func, args=[], keywords=keywords), + ) + + +def call_attr( + obj: ast.expr, + attr: str, + args: Optional[list[ast.expr]] = None, + keywords: Optional[list[ast.keyword]] = None, +) -> ast.Expr: + return ast.Expr( + value=ast.Call( + func=ast.Attribute(value=obj, attr=attr, ctx=ast.Load()), + args=args or [], + keywords=keywords or [], + ) + ) + + +def attr_chain(names: list[str], ctx: ast.expr_context = ast.Load()) -> ast.expr: + """Create a nested Attribute node from a list: e.g., model.em_fields.b_field""" + node = ast.Name(id=names[0], ctx=ctx) + for name in names[1:]: + node = ast.Attribute(value=node, attr=name, ctx=ctx) + return node + + +def create_dataclass_from_class(cls: type, dataclass_name: Optional[str] = None) -> ast.ClassDef: + """Create a dataclass AST node with fields from a class's __init__ parameters. + + Args: + cls: Class to introspect for parameters + dataclass_name: Name for the dataclass. If None, uses cls.__name__ + "Params" + + Returns: + ast.ClassDef node for a dataclass + """ + if dataclass_name is None: + dataclass_name = f"{cls.__name__}Params" + + sig = inspect.signature(cls.__init__) + fields = [] + + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Create an annotated assignment for each field + # field_name: type = default_value + annotation = ast.Name(id="Any", ctx=ast.Load()) # Use Any for now + + if param.default is not param.empty: + # Has a default value + if isinstance(param.default, tuple): + value = ast.Tuple(elts=[ast.Constant(v) for v in param.default], ctx=ast.Load()) + else: + value = ast.Constant(param.default) + else: + # No default - this is a required field + value = None + + ann_assign = ast.AnnAssign( + target=ast.Name(id=name, ctx=ast.Store()), + annotation=annotation, + value=value, + simple=1, + ) + fields.append(ann_assign) + + # Create the class definition with @dataclass decorator + class_def = ast.ClassDef( + name=dataclass_name, + bases=[], + keywords=[], + body=fields if fields else [ast.Pass()], + decorator_list=[ast.Name(id="dataclass", ctx=ast.Load())], + ) + + return class_def diff --git a/dataclass_generator.py b/dataclass_generator.py new file mode 100644 index 000000000..8b8efa174 --- /dev/null +++ b/dataclass_generator.py @@ -0,0 +1,161 @@ +import ast +import inspect + +from struphy.io.options import ( + BaseUnits, + DerhamOptions, + EnvironmentOptions, + FieldsBackground, + Time, +) + + +def generate_params_dataclass(cls): + """Generate a dataclass AST for parameters from a class's __init__.""" + # Get the signature and type hints + sig = inspect.signature(cls.__init__) + type_hints = {} + try: + import typing + + type_hints = typing.get_type_hints(cls.__init__) + except Exception: + # If we can't get type hints, we'll fall back to Any + pass + + # Build dataclass fields + fields = [] + for name, param in sig.parameters.items(): + if name == "self": + continue + + # Try to get the type annotation + if name in type_hints: + # Convert type hint to AST + type_hint = type_hints[name] + annotation = _type_to_ast(type_hint) + elif param.annotation is not param.empty: + # Try to use the annotation directly + annotation = _annotation_to_ast(param.annotation) + else: + # Fall back to Any + annotation = ast.Name(id="Any", ctx=ast.Load()) + + if param.default is not param.empty: + # Has a default value + if isinstance(param.default, tuple): + value = ast.Tuple(elts=[ast.Constant(v) for v in param.default], ctx=ast.Load()) + else: + value = ast.Constant(param.default) + else: + # No default - required field + value = None + + ann_assign = ast.AnnAssign( + target=ast.Name(id=name, ctx=ast.Store()), + annotation=annotation, + value=value, + simple=1, + ) + fields.append(ann_assign) + + # Create the dataclass + class_def = ast.ClassDef( + name=f"{cls.__name__}Params", + bases=[], + keywords=[], + body=fields if fields else [ast.Pass()], + decorator_list=[ast.Name(id="dataclass", ctx=ast.Load())], + ) + + # Create a module with necessary imports and the dataclass + module = ast.Module( + body=[ + ast.ImportFrom( + module="dataclasses", + names=[ast.alias(name="dataclass", asname=None)], + level=0, + ), + ast.ImportFrom(module="typing", names=[ast.alias(name="Any", asname=None)], level=0), + class_def, + ], + type_ignores=[], + ) + + ast.fix_missing_locations(module) + return ast.unparse(module) + + +def _type_to_ast(type_hint): + """Convert a type hint to an AST node.""" + import typing + + # Handle basic types + if type_hint is type(None): + return ast.Constant(value=None) + elif hasattr(type_hint, "__name__"): + # Simple type like int, str, float, bool + return ast.Name(id=type_hint.__name__, ctx=ast.Load()) + elif hasattr(type_hint, "__origin__"): + # Generic types like Optional[int], list[str], etc. + origin = type_hint.__origin__ + + if origin is typing.Union: + # Union or Optional + args = type_hint.__args__ + if len(args) == 2 and type(None) in args: + # Optional[X] + inner_type = args[0] if args[1] is type(None) else args[1] + return ast.Subscript( + value=ast.Name(id="Optional", ctx=ast.Load()), + slice=_type_to_ast(inner_type), + ctx=ast.Load(), + ) + else: + # Union[X, Y, ...] + return ast.Subscript( + value=ast.Name(id="Union", ctx=ast.Load()), + slice=ast.Tuple( + elts=[_type_to_ast(arg) for arg in args], + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + elif hasattr(origin, "__name__"): + # list, tuple, dict, etc. + args = type_hint.__args__ + return ast.Subscript( + value=ast.Name(id=origin.__name__, ctx=ast.Load()), + slice=( + ast.Tuple( + elts=[_type_to_ast(arg) for arg in args], + ctx=ast.Load(), + ) + if len(args) > 1 + else _type_to_ast(args[0]) + ), + ctx=ast.Load(), + ) + + # Fall back to Any + return ast.Name(id="Any", ctx=ast.Load()) + + +def _annotation_to_ast(annotation): + """Convert an annotation object to AST.""" + if isinstance(annotation, type): + return ast.Name(id=annotation.__name__, ctx=ast.Load()) + elif isinstance(annotation, str): + # Forward reference + return ast.Name(id=annotation, ctx=ast.Load()) + else: + # Try to extract from the annotation + try: + return _type_to_ast(annotation) + except Exception: + return ast.Name(id="Any", ctx=ast.Load()) + + +if __name__ == "__main__": + source = generate_params_dataclass(EnvironmentOptions) + print(source) diff --git a/generate_maxwell_ast.py b/generate_maxwell_ast.py new file mode 100644 index 000000000..991528558 --- /dev/null +++ b/generate_maxwell_ast.py @@ -0,0 +1,180 @@ +import ast + +from ast_helpers import ( + assign_constructor, + assign_constructor_with_defaults, + attr_chain, + call_attr, + import_from, +) +from struphy import main +from struphy.fields_background import equils +from struphy.geometry import domains +from struphy.initial import perturbations +from struphy.io.options import ( + BaseUnits, + DerhamOptions, + EnvironmentOptions, + FieldsBackground, + Time, +) +from struphy.kinetic_background import maxwellians +from struphy.models.toy import Maxwell +from struphy.pic.utilities import ( + BinningPlot, + BoundaryParameters, + KernelDensityPlot, + LoadingParameters, + WeightsParameters, +) +from struphy.topology import grids + + +def generate_maxwell_ast( + specify_defaults: bool = False, +) -> str: + if specify_defaults: + assign_constructor_func = assign_constructor_with_defaults + else: + assign_constructor_func = assign_constructor + + # Imports + imports = [ + import_from( + "struphy.io.options", + [ + "EnvironmentOptions", + "BaseUnits", + "Time", + "DerhamOptions", + "FieldsBackground", + ], + ), + import_from("struphy.geometry", ["domains"]), + import_from("struphy.fields_background", ["equils"]), + import_from("struphy.topology", ["grids"]), + import_from("struphy.initial", ["perturbations"]), + import_from("struphy.kinetic_background", ["maxwellians"]), + import_from( + "struphy.pic.utilities", + [ + "LoadingParameters", + "WeightsParameters", + "BoundaryParameters", + "BinningPlot", + "KernelDensityPlot", + ], + ), + import_from("struphy", ["main"]), + import_from("struphy.models.toy", ["Maxwell"]), + ] + + # Assignments + assignments = [ + assign_constructor_func("env", EnvironmentOptions), + assign_constructor_func("base_units", BaseUnits), + assign_constructor_func("time_opts", Time, None, dt=0.01, Tend=0.10), + assign_constructor_func("domain", domains.Cuboid, "domains.Cuboid"), + assign_constructor_func("equil", equils.HomogenSlab, "equils.HomogenSlab"), + assign_constructor_func("grid", grids.TensorProductGrid, "grids.TensorProductGrid"), + assign_constructor_func("derham_opts", DerhamOptions), + assign_constructor_func("model", Maxwell), + ] + + # propagator options + prop_options_assign = ast.Assign( + targets=[attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store())], + value=ast.Call( + func=attr_chain(["model", "propagators", "maxwell", "Options"]), + args=[], + keywords=[], + ), + ) + assignments.append(prop_options_assign) + + # Perturbations + perturb_calls = [] + for comp in range(3): + perturb_calls.append( + call_attr( + attr_chain(["model", "em_fields", "b_field"]), + "add_perturbation", + args=[ + ast.Call( + func=ast.Attribute( + value=ast.Name(id="perturbations", ctx=ast.Load()), + attr="TorusModesCos", + ctx=ast.Load(), + ), + args=[], + keywords=[ + ast.keyword(arg="given_in_basis", value=ast.Constant("v")), + ast.keyword(arg="comp", value=ast.Constant(comp)), + ], + ) + ], + ) + ) + + # main + main_guard = ast.If( + test=ast.Compare( + left=ast.Name(id="__name__", ctx=ast.Load()), + ops=[ast.Eq()], + comparators=[ast.Constant("__main__")], + ), + body=[ + ast.Assign( + targets=[ast.Name(id="verbose", ctx=ast.Store())], + value=ast.Constant(True), + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="main", ctx=ast.Load()), + attr="run", + ctx=ast.Load(), + ), + args=[ast.Name(id="model", ctx=ast.Load())], + keywords=[ + ast.keyword( + arg="params_path", + value=ast.Name(id="__file__", ctx=ast.Load()), + ), + ast.keyword(arg="env", value=ast.Name(id="env", ctx=ast.Load())), + ast.keyword( + arg="base_units", + value=ast.Name(id="base_units", ctx=ast.Load()), + ), + ast.keyword( + arg="time_opts", + value=ast.Name(id="time_opts", ctx=ast.Load()), + ), + ast.keyword(arg="domain", value=ast.Name(id="domain", ctx=ast.Load())), + ast.keyword(arg="equil", value=ast.Name(id="equil", ctx=ast.Load())), + ast.keyword(arg="grid", value=ast.Name(id="grid", ctx=ast.Load())), + ast.keyword( + arg="derham_opts", + value=ast.Name(id="derham_opts", ctx=ast.Load()), + ), + ast.keyword(arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load())), + ], + ) + ), + ], + orelse=[], + ) + + # Assemble module + module = ast.Module(body=imports + assignments + perturb_calls + [main_guard], type_ignores=[]) + + ast.fix_missing_locations(module) + + # return source code + return ast.unparse(module) + + +if __name__ == "__main__": + source = generate_maxwell_ast(specify_defaults=False) + # print source code + print(source)