Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions adept/_vlasov1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from diffrax import Solution
from jax import numpy as jnp
from matplotlib import pyplot as plt
from scipy.special import gamma
from jax.scipy.special import gamma


from adept._base_ import get_envelope
from adept._vlasov1d.storage import store_f, store_fields
Expand All @@ -36,7 +37,7 @@ def _initialize_distribution_(
m=2.0,
T0=1.0,
vmax=6.0,
n_prof=np.ones(1),
n_prof=jnp.ones(1),
noise_val=0.0,
noise_seed=42,
noise_type="Uniform",
Expand All @@ -60,23 +61,23 @@ def _initialize_distribution_(
:return:
"""

# noise_generator = np.random.default_rng(seed=noise_seed)
# noise_generator = jnp.random.default_rng(seed=noise_seed)

dv = 2.0 * vmax / nv
vax = np.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv)
vax = jnp.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv)

alpha = np.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m))
# cst = m / (4 * np.pi * alpha**3.0 * gamma(3.0 / m))
alpha = jnp.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m))
# cst = m / (4 * jnp.pi * alpha**3.0 * gamma(3.0 / m))

single_dist = -(np.power(np.abs((vax[None, :] - v0) / alpha / np.sqrt(T0)), m))
single_dist = -(jnp.power(jnp.abs((vax[None, :] - v0) / alpha / jnp.sqrt(T0)), m))

single_dist = np.exp(single_dist)
# single_dist = np.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0)
single_dist = jnp.exp(single_dist)
# single_dist = jnp.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0)

# for ix in range(nx):
f = np.repeat(single_dist, nx, axis=0)
f = jnp.repeat(single_dist, nx, axis=0)
# normalize
f = f / np.trapz(f, dx=dv, axis=1)[:, None]
f = f / jnp.trapezoid(f, dx=dv, axis=1)[:, None]

if n_prof.size > 1:
# scale by density profile
Expand All @@ -92,8 +93,8 @@ def _initialize_distribution_(

def _initialize_total_distribution_(cfg, cfg_grid):
params = cfg["density"]
n_prof_total = np.zeros([cfg_grid["nx"]])
f = np.zeros([cfg_grid["nx"], cfg_grid["nv"]])
n_prof_total = jnp.zeros([cfg_grid["nx"]])
f = jnp.zeros([cfg_grid["nx"], cfg_grid["nv"]])
species_found = False
for name, species_params in cfg["density"].items():
if name.startswith("species-"):
Expand All @@ -111,7 +112,7 @@ def _initialize_total_distribution_(cfg, cfg_grid):
m = params[name]["m"]

if species_params["basis"] == "uniform":
nprof = np.ones_like(n_prof_total)
nprof = jnp.ones_like(n_prof_total)

elif species_params["basis"] == "linear":
left = species_params["center"] - species_params["width"] * 0.5
Expand Down Expand Up @@ -141,7 +142,7 @@ def _initialize_total_distribution_(cfg, cfg_grid):
_Q(species_params["gradient scale length"]).to("nm").magnitude
/ cfg["units"]["derived"]["x0"].to("nm").magnitude
)
nprof = species_params["val at center"] * np.exp((cfg_grid["x"] - species_params["center"]) / L)
nprof = species_params["val at center"] * jnp.exp((cfg_grid["x"] - species_params["center"]) / L)
nprof = mask * nprof

elif species_params["basis"] == "tanh":
Expand Down
110 changes: 110 additions & 0 deletions tests/test_vlasov1d/configs/twostream_opt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
units:
laser_wavelength: 351nm
normalizing_temperature: 2000eV
normalizing_density: 1.5e21/cc
Z: 10
Zp: 10


density:
quasineutrality: true
species-electron1:
noise_seed: 420
noise_type: gaussian
noise_val: 0.0
v0: -1.5
T0: 0.2
m: 2.0
basis: sine
baseline: 0.5
amplitude: 1.0e-4
wavenumber: 0.3
species-electron2:
noise_seed: 420
noise_type: gaussian
noise_val: 0.0
v0: 1.5
T0: 0.2
m: 2.0
basis: sine
baseline: 0.5
amplitude: -1.0e-4
wavenumber: 0.3

grid:
dt: 0.1
nv: 4096
nx: 64
tmin: 0.
tmax: 100.0
vmax: 6.4
xmax: 20.94
xmin: 0.0

save:
fields:
t:
tmin: 0.0
tmax: 100.0
nt: 51
electron:
t:
tmin: 0.0
tmax: 100.0
nt: 51

solver: vlasov-1d

mlflow:
experiment: twostream-optimize
run: opt-iter-0

drivers:
ex: {}
ey: {}

diagnostics:
diag-vlasov-dfdt: False
diag-fp-dfdt: False

terms:
field: poisson
edfdv: exponential
time: sixth
fokker_planck:
is_on: True
type: Dougherty
time:
baseline: 1.0e-5
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
krook:
is_on: True
time:
baseline: 1.0e-6
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
140 changes: 140 additions & 0 deletions tests/test_vlasov1d/test_twostream_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import xarray as xr, numpy as np, os, sys
import yaml
import mlflow
from tqdm import tqdm
import time

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
sys.path.append(os.getcwd()) # To load adept

from jax import config
import jax
import jax.numpy as jnp
config.update("jax_enable_x64", True)
import equinox as eqx
import optax

from diffrax import diffeqsolve, SaveAt

from adept import ergoExo
from adept._vlasov1d.modules import BaseVlasov1D
from adept._vlasov1d.helpers import _initialize_total_distribution_

import matplotlib
matplotlib.use("Agg")


def set_dict_leaves(src, dst, key=None):
for k, v in src.items():
if isinstance(dst[k], dict):
set_dict_leaves(v, dst[k])
else:
dst[k] = v


class OptVlasov1D(BaseVlasov1D):
def __init__(self, cfg):
super().__init__(cfg)

def reinitialize_distribution(self, cfg, state):
# return super().init_state_and_args()
_, f = _initialize_total_distribution_(cfg, cfg["grid"])
state["electron"] = f

return state

def __call__(self, params: dict, args: dict) -> dict:
if args is None:
args = self.args
# Overwrite cfg with passed args
cfg = self.cfg
set_dict_leaves(params, cfg)
# Reinitialize the distribution based on args
state = self.reinitialize_distribution(cfg, self.state)
# Solve the equations
solver_result = diffeqsolve(
terms=self.diffeqsolve_quants["terms"],
solver=self.diffeqsolve_quants["solver"],
t0=self.time_quantities["t0"],
t1=self.time_quantities["t1"],
max_steps=cfg["grid"]["max_steps"],
dt0=cfg["grid"]["dt"],
y0=state,
args=args,
saveat=SaveAt(**self.diffeqsolve_quants["saveat"]),
)
# Compute the mean growth rate at the start of the simulation
opt_quantity = jnp.mean(jnp.log10(solver_result.ys['default']['mean_e2'][10:200]))
return_val = (opt_quantity, {"solver result": solver_result})
return return_val

def vg(self, params: dict, args: dict) -> tuple[float, dict, dict]:
return eqx.filter_value_and_grad(self.__call__, has_aux=True)(params, args)


if __name__ == "__main__":

with open("tests/test_vlasov1d/configs/twostream_opt.yaml", 'r') as stream:
cfg = yaml.safe_load(stream)
cfg['mlflow']['experiment'] = "twostream-optimize"
mlflow.set_experiment("twostream-optimize")

params = {"density": {
"species-electron1": {
"v0": jnp.array(cfg["density"]["species-electron1"]["v0"]),
"T0": jnp.array(cfg["density"]["species-electron1"]["T0"]),
},
"species-electron2": {
"v0": jnp.array(cfg["density"]["species-electron2"]["v0"]),
"T0": jnp.array(cfg["density"]["species-electron2"]["T0"]),
}
}
}

mlflow.log_metrics({
"e1_v0": params["density"]["species-electron1"]["v0"].item(),
"e1_T0": params["density"]["species-electron1"]["T0"].item(),
"e2_v0": params["density"]["species-electron2"]["v0"].item(),
"e2_T0": params["density"]["species-electron2"]["T0"].item(),
}, step=0)

optimizer = optax.adam(0.1)
opt_state = optimizer.init(params)

loop_t0 = time.time()
mlflow.log_metrics({"time_loop": time.time() - loop_t0}, step=0)
for i in tqdm(range(5)):
iter_t0 = time.time()
cfg['mlflow']['run'] = f"opt-iter-{i}"

exo = ergoExo(mlflow_nested=True)
exo.setup(cfg=cfg, adept_module=OptVlasov1D)
# Potential optimization to shift post-processing to another thread
val, grad, (sim_out, post_processed_output, mlflow_run_id) = exo.val_and_grad(params)

mlflow.log_metrics({
"gamma-e2": val.item(),
"grad_l2": jnp.linalg.norm(jnp.array(jax.tree.flatten(grad)[0])).item(),
"e1_v0": params["density"]["species-electron1"]["v0"].item(),
"e1_T0": params["density"]["species-electron1"]["T0"].item(),
"e2_v0": params["density"]["species-electron2"]["v0"].item(),
"e2_T0": params["density"]["species-electron2"]["T0"].item(),
}, step=i+1)

updates, opt_state = optimizer.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)

print(f"Mean-log e2 growth rate : {val}")

mlflow.log_metrics({
"time_iter": time.time() - iter_t0,
"time_loop": time.time() - loop_t0,
}, step=i+1)

# The final parameter values are not logged because they do not correspond to
# the final optimized quantity (the update step has been applied to them)
np.testing.assert_almost_equal(val, -8.572186655748087, decimal=2)
np.testing.assert_almost_equal(np.abs(params["density"]["species-electron1"]["v0"]),
np.abs(params["density"]["species-electron2"]["v0"]), decimal=2)
np.testing.assert_almost_equal(params["density"]["species-electron1"]["T0"],
params["density"]["species-electron2"]["T0"], decimal=2)
Loading