Skip to content

Commit 5fa4c21

Browse files
committed
Added generator for params_Maxwell.py
1 parent 97dd3d3 commit 5fa4c21

2 files changed

Lines changed: 194 additions & 0 deletions

File tree

ast_helpers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import ast
2+
3+
4+
def import_from(module, names):
5+
return ast.ImportFrom(
6+
module=module,
7+
names=[ast.alias(name=n, asname=None) for n in names],
8+
level=0,
9+
)
10+
11+
12+
def assign_constructor(var, cls, **kwargs):
13+
"""Create AST for: var = cls(**kwargs)"""
14+
15+
def ast_value(v):
16+
if isinstance(v, tuple):
17+
return ast.Tuple(elts=[ast_value(x) for x in v], ctx=ast.Load())
18+
return ast.Constant(v)
19+
20+
return ast.Assign(
21+
targets=[ast.Name(id=var, ctx=ast.Store())],
22+
value=ast.Call(
23+
func=ast.Name(id=cls, ctx=ast.Load()),
24+
args=[],
25+
keywords=[
26+
ast.keyword(arg=k, value=ast_value(v)) for k, v in kwargs.items()
27+
],
28+
),
29+
)
30+
31+
32+
def call_attr(obj, attr, args=None, keywords=None):
33+
return ast.Expr(
34+
value=ast.Call(
35+
func=ast.Attribute(value=obj, attr=attr, ctx=ast.Load()),
36+
args=args or [],
37+
keywords=keywords or [],
38+
)
39+
)
40+
41+
42+
def attr_chain(names, ctx=ast.Load()):
43+
"""Create a nested Attribute node from a list: e.g., model.em_fields.b_field"""
44+
node = ast.Name(id=names[0], ctx=ctx)
45+
for name in names[1:]:
46+
node = ast.Attribute(value=node, attr=name, ctx=ctx)
47+
return node

generate_maxwell_ast.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import ast
2+
3+
from ast_helpers import assign_constructor, attr_chain, call_attr, import_from
4+
5+
# Imports
6+
imports = [
7+
import_from(
8+
"struphy.io.options",
9+
[
10+
"EnvironmentOptions",
11+
"BaseUnits",
12+
"Time",
13+
"DerhamOptions",
14+
"FieldsBackground",
15+
],
16+
),
17+
import_from("struphy.geometry", ["domains"]),
18+
import_from("struphy.fields_background", ["equils"]),
19+
import_from("struphy.topology", ["grids"]),
20+
import_from("struphy.initial", ["perturbations"]),
21+
import_from("struphy.kinetic_background", ["maxwellians"]),
22+
import_from(
23+
"struphy.pic.utilities",
24+
[
25+
"LoadingParameters",
26+
"WeightsParameters",
27+
"BoundaryParameters",
28+
"BinningPlot",
29+
"KernelDensityPlot",
30+
],
31+
),
32+
import_from("struphy", ["main"]),
33+
import_from("struphy.models.toy", ["Maxwell"]),
34+
]
35+
36+
# Assignments
37+
assignments = [
38+
assign_constructor("env", "EnvironmentOptions"),
39+
assign_constructor("base_units", "BaseUnits"),
40+
assign_constructor("time_opts", "Time", dt=0.01, Tend=0.10),
41+
assign_constructor("domain", "domains.Cuboid"),
42+
assign_constructor("equil", "equils.HomogenSlab"),
43+
assign_constructor("grid", "grids.TensorProductGrid"),
44+
assign_constructor("derham_opts", "DerhamOptions"),
45+
assign_constructor("model", "Maxwell"),
46+
]
47+
48+
# propagator options
49+
prop_options_assign = ast.Assign(
50+
targets=[
51+
attr_chain(["model", "propagators", "maxwell", "options"], ctx=ast.Store())
52+
],
53+
value=ast.Call(
54+
func=attr_chain(["model", "propagators", "maxwell", "Options"]),
55+
args=[],
56+
keywords=[],
57+
),
58+
)
59+
assignments.append(prop_options_assign)
60+
61+
# Perturbations
62+
perturb_calls = []
63+
for comp in range(3):
64+
perturb_calls.append(
65+
call_attr(
66+
attr_chain(["model", "em_fields", "b_field"]),
67+
"add_perturbation",
68+
args=[
69+
ast.Call(
70+
func=ast.Attribute(
71+
value=ast.Name(id="perturbations", ctx=ast.Load()),
72+
attr="TorusModesCos",
73+
ctx=ast.Load(),
74+
),
75+
args=[],
76+
keywords=[
77+
ast.keyword(arg="given_in_basis", value=ast.Constant("v")),
78+
ast.keyword(arg="comp", value=ast.Constant(comp)),
79+
],
80+
)
81+
],
82+
)
83+
)
84+
85+
# main
86+
main_guard = ast.If(
87+
test=ast.Compare(
88+
left=ast.Name(id="__name__", ctx=ast.Load()),
89+
ops=[ast.Eq()],
90+
comparators=[ast.Constant("__main__")],
91+
),
92+
body=[
93+
ast.Assign(
94+
targets=[ast.Name(id="verbose", ctx=ast.Store())],
95+
value=ast.Constant(True),
96+
),
97+
ast.Expr(
98+
value=ast.Call(
99+
func=ast.Attribute(
100+
value=ast.Name(id="main", ctx=ast.Load()),
101+
attr="run",
102+
ctx=ast.Load(),
103+
),
104+
args=[ast.Name(id="model", ctx=ast.Load())],
105+
keywords=[
106+
ast.keyword(
107+
arg="params_path", value=ast.Name(id="__file__", ctx=ast.Load())
108+
),
109+
ast.keyword(arg="env", value=ast.Name(id="env", ctx=ast.Load())),
110+
ast.keyword(
111+
arg="base_units",
112+
value=ast.Name(id="base_units", ctx=ast.Load()),
113+
),
114+
ast.keyword(
115+
arg="time_opts", value=ast.Name(id="time_opts", ctx=ast.Load())
116+
),
117+
ast.keyword(
118+
arg="domain", value=ast.Name(id="domain", ctx=ast.Load())
119+
),
120+
ast.keyword(
121+
arg="equil", value=ast.Name(id="equil", ctx=ast.Load())
122+
),
123+
ast.keyword(arg="grid", value=ast.Name(id="grid", ctx=ast.Load())),
124+
ast.keyword(
125+
arg="derham_opts",
126+
value=ast.Name(id="derham_opts", ctx=ast.Load()),
127+
),
128+
ast.keyword(
129+
arg="verbose", value=ast.Name(id="verbose", ctx=ast.Load())
130+
),
131+
],
132+
)
133+
),
134+
],
135+
orelse=[],
136+
)
137+
138+
# Assemble module
139+
module = ast.Module(
140+
body=imports + assignments + perturb_calls + [main_guard], type_ignores=[]
141+
)
142+
143+
ast.fix_missing_locations(module)
144+
145+
# print source code
146+
source = ast.unparse(module)
147+
print(source)

0 commit comments

Comments
 (0)