diff --git a/adept/_vlasov1d/helpers.py b/adept/_vlasov1d/helpers.py index 2a3b5c7c..a2b24d2a 100644 --- a/adept/_vlasov1d/helpers.py +++ b/adept/_vlasov1d/helpers.py @@ -10,7 +10,8 @@ from diffrax import Solution from jax import numpy as jnp from matplotlib import pyplot as plt -from scipy.special import gamma +from jax.scipy.special import gamma + from adept._base_ import get_envelope from adept._vlasov1d.storage import store_f, store_fields @@ -36,7 +37,7 @@ def _initialize_distribution_( m=2.0, T0=1.0, vmax=6.0, - n_prof=np.ones(1), + n_prof=jnp.ones(1), noise_val=0.0, noise_seed=42, noise_type="Uniform", @@ -60,23 +61,23 @@ def _initialize_distribution_( :return: """ - # noise_generator = np.random.default_rng(seed=noise_seed) + # noise_generator = jnp.random.default_rng(seed=noise_seed) dv = 2.0 * vmax / nv - vax = np.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv) + vax = jnp.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv) - alpha = np.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m)) - # cst = m / (4 * np.pi * alpha**3.0 * gamma(3.0 / m)) + alpha = jnp.sqrt(3.0 * gamma_3_over_m(m) / gamma_5_over_m(m)) + # cst = m / (4 * jnp.pi * alpha**3.0 * gamma(3.0 / m)) - single_dist = -(np.power(np.abs((vax[None, :] - v0) / alpha / np.sqrt(T0)), m)) + single_dist = -(jnp.power(jnp.abs((vax[None, :] - v0) / alpha / jnp.sqrt(T0)), m)) - single_dist = np.exp(single_dist) - # single_dist = np.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0) + single_dist = jnp.exp(single_dist) + # single_dist = jnp.exp(-(vaxs[0][None, None, :, None]**2.+vaxs[1][None, None, None, :]**2.)/2/T0) # for ix in range(nx): - f = np.repeat(single_dist, nx, axis=0) + f = jnp.repeat(single_dist, nx, axis=0) # normalize - f = f / np.trapz(f, dx=dv, axis=1)[:, None] + f = f / jnp.trapezoid(f, dx=dv, axis=1)[:, None] if n_prof.size > 1: # scale by density profile @@ -92,8 +93,8 @@ def _initialize_distribution_( def _initialize_total_distribution_(cfg, cfg_grid): params = cfg["density"] - n_prof_total = np.zeros([cfg_grid["nx"]]) - f = np.zeros([cfg_grid["nx"], cfg_grid["nv"]]) + n_prof_total = jnp.zeros([cfg_grid["nx"]]) + f = jnp.zeros([cfg_grid["nx"], cfg_grid["nv"]]) species_found = False for name, species_params in cfg["density"].items(): if name.startswith("species-"): @@ -111,7 +112,7 @@ def _initialize_total_distribution_(cfg, cfg_grid): m = params[name]["m"] if species_params["basis"] == "uniform": - nprof = np.ones_like(n_prof_total) + nprof = jnp.ones_like(n_prof_total) elif species_params["basis"] == "linear": left = species_params["center"] - species_params["width"] * 0.5 @@ -141,7 +142,7 @@ def _initialize_total_distribution_(cfg, cfg_grid): _Q(species_params["gradient scale length"]).to("nm").magnitude / cfg["units"]["derived"]["x0"].to("nm").magnitude ) - nprof = species_params["val at center"] * np.exp((cfg_grid["x"] - species_params["center"]) / L) + nprof = species_params["val at center"] * jnp.exp((cfg_grid["x"] - species_params["center"]) / L) nprof = mask * nprof elif species_params["basis"] == "tanh": diff --git a/tests/test_vlasov1d/configs/twostream_opt.yaml b/tests/test_vlasov1d/configs/twostream_opt.yaml new file mode 100644 index 00000000..11849a33 --- /dev/null +++ b/tests/test_vlasov1d/configs/twostream_opt.yaml @@ -0,0 +1,110 @@ +units: + laser_wavelength: 351nm + normalizing_temperature: 2000eV + normalizing_density: 1.5e21/cc + Z: 10 + Zp: 10 + + +density: + quasineutrality: true + species-electron1: + noise_seed: 420 + noise_type: gaussian + noise_val: 0.0 + v0: -1.5 + T0: 0.2 + m: 2.0 + basis: sine + baseline: 0.5 + amplitude: 1.0e-4 + wavenumber: 0.3 + species-electron2: + noise_seed: 420 + noise_type: gaussian + noise_val: 0.0 + v0: 1.5 + T0: 0.2 + m: 2.0 + basis: sine + baseline: 0.5 + amplitude: -1.0e-4 + wavenumber: 0.3 + +grid: + dt: 0.1 + nv: 4096 + nx: 64 + tmin: 0. + tmax: 100.0 + vmax: 6.4 + xmax: 20.94 + xmin: 0.0 + +save: + fields: + t: + tmin: 0.0 + tmax: 100.0 + nt: 51 + electron: + t: + tmin: 0.0 + tmax: 100.0 + nt: 51 + +solver: vlasov-1d + +mlflow: + experiment: twostream-optimize + run: opt-iter-0 + +drivers: + ex: {} + ey: {} + +diagnostics: + diag-vlasov-dfdt: False + diag-fp-dfdt: False + +terms: + field: poisson + edfdv: exponential + time: sixth + fokker_planck: + is_on: True + type: Dougherty + time: + baseline: 1.0e-5 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + space: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + krook: + is_on: True + time: + baseline: 1.0e-6 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + space: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 diff --git a/tests/test_vlasov1d/test_twostream_opt.py b/tests/test_vlasov1d/test_twostream_opt.py new file mode 100644 index 00000000..2f8b013a --- /dev/null +++ b/tests/test_vlasov1d/test_twostream_opt.py @@ -0,0 +1,140 @@ +import xarray as xr, numpy as np, os, sys +import yaml +import mlflow +from tqdm import tqdm +import time + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +sys.path.append(os.getcwd()) # To load adept + +from jax import config +import jax +import jax.numpy as jnp +config.update("jax_enable_x64", True) +import equinox as eqx +import optax + +from diffrax import diffeqsolve, SaveAt + +from adept import ergoExo +from adept._vlasov1d.modules import BaseVlasov1D +from adept._vlasov1d.helpers import _initialize_total_distribution_ + +import matplotlib +matplotlib.use("Agg") + + +def set_dict_leaves(src, dst, key=None): + for k, v in src.items(): + if isinstance(dst[k], dict): + set_dict_leaves(v, dst[k]) + else: + dst[k] = v + + +class OptVlasov1D(BaseVlasov1D): + def __init__(self, cfg): + super().__init__(cfg) + + def reinitialize_distribution(self, cfg, state): + # return super().init_state_and_args() + _, f = _initialize_total_distribution_(cfg, cfg["grid"]) + state["electron"] = f + + return state + + def __call__(self, params: dict, args: dict) -> dict: + if args is None: + args = self.args + # Overwrite cfg with passed args + cfg = self.cfg + set_dict_leaves(params, cfg) + # Reinitialize the distribution based on args + state = self.reinitialize_distribution(cfg, self.state) + # Solve the equations + solver_result = diffeqsolve( + terms=self.diffeqsolve_quants["terms"], + solver=self.diffeqsolve_quants["solver"], + t0=self.time_quantities["t0"], + t1=self.time_quantities["t1"], + max_steps=cfg["grid"]["max_steps"], + dt0=cfg["grid"]["dt"], + y0=state, + args=args, + saveat=SaveAt(**self.diffeqsolve_quants["saveat"]), + ) + # Compute the mean growth rate at the start of the simulation + opt_quantity = jnp.mean(jnp.log10(solver_result.ys['default']['mean_e2'][10:200])) + return_val = (opt_quantity, {"solver result": solver_result}) + return return_val + + def vg(self, params: dict, args: dict) -> tuple[float, dict, dict]: + return eqx.filter_value_and_grad(self.__call__, has_aux=True)(params, args) + + +if __name__ == "__main__": + + with open("tests/test_vlasov1d/configs/twostream_opt.yaml", 'r') as stream: + cfg = yaml.safe_load(stream) + cfg['mlflow']['experiment'] = "twostream-optimize" + mlflow.set_experiment("twostream-optimize") + + params = {"density": { + "species-electron1": { + "v0": jnp.array(cfg["density"]["species-electron1"]["v0"]), + "T0": jnp.array(cfg["density"]["species-electron1"]["T0"]), + }, + "species-electron2": { + "v0": jnp.array(cfg["density"]["species-electron2"]["v0"]), + "T0": jnp.array(cfg["density"]["species-electron2"]["T0"]), + } + } + } + + mlflow.log_metrics({ + "e1_v0": params["density"]["species-electron1"]["v0"].item(), + "e1_T0": params["density"]["species-electron1"]["T0"].item(), + "e2_v0": params["density"]["species-electron2"]["v0"].item(), + "e2_T0": params["density"]["species-electron2"]["T0"].item(), + }, step=0) + + optimizer = optax.adam(0.1) + opt_state = optimizer.init(params) + + loop_t0 = time.time() + mlflow.log_metrics({"time_loop": time.time() - loop_t0}, step=0) + for i in tqdm(range(5)): + iter_t0 = time.time() + cfg['mlflow']['run'] = f"opt-iter-{i}" + + exo = ergoExo(mlflow_nested=True) + exo.setup(cfg=cfg, adept_module=OptVlasov1D) + # Potential optimization to shift post-processing to another thread + val, grad, (sim_out, post_processed_output, mlflow_run_id) = exo.val_and_grad(params) + + mlflow.log_metrics({ + "gamma-e2": val.item(), + "grad_l2": jnp.linalg.norm(jnp.array(jax.tree.flatten(grad)[0])).item(), + "e1_v0": params["density"]["species-electron1"]["v0"].item(), + "e1_T0": params["density"]["species-electron1"]["T0"].item(), + "e2_v0": params["density"]["species-electron2"]["v0"].item(), + "e2_T0": params["density"]["species-electron2"]["T0"].item(), + }, step=i+1) + + updates, opt_state = optimizer.update(grad, opt_state, params) + params = optax.apply_updates(params, updates) + + print(f"Mean-log e2 growth rate : {val}") + + mlflow.log_metrics({ + "time_iter": time.time() - iter_t0, + "time_loop": time.time() - loop_t0, + }, step=i+1) + + # The final parameter values are not logged because they do not correspond to + # the final optimized quantity (the update step has been applied to them) + np.testing.assert_almost_equal(val, -8.572186655748087, decimal=2) + np.testing.assert_almost_equal(np.abs(params["density"]["species-electron1"]["v0"]), + np.abs(params["density"]["species-electron2"]["v0"]), decimal=2) + np.testing.assert_almost_equal(params["density"]["species-electron1"]["T0"], + params["density"]["species-electron2"]["T0"], decimal=2) diff --git a/two-stream_gridexplore.py b/two-stream_gridexplore.py new file mode 100644 index 00000000..3dd2f614 --- /dev/null +++ b/two-stream_gridexplore.py @@ -0,0 +1,139 @@ +import xarray as xr, numpy as np, os, sys +import yaml +import mlflow +from tqdm import tqdm +import time +import itertools +import copy + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["CUDA_VISIBLE_DEVIES"] = "1" +sys.path.append(os.getcwd()) # To load adept + +import jax +from jax import config +import jax.numpy as jnp +config.update("jax_enable_x64", True) +config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +config.update("jax_persistent_cache_min_entry_size_bytes", -1) +config.update("jax_persistent_cache_min_compile_time_secs", 0) +config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") +import equinox as eqx +import optax + +# os.chdir('../../adept') + +from adept import ergoExo +from adept._vlasov1d.modules import BaseVlasov1D +from adept._vlasov1d.helpers import _initialize_total_distribution_ + +import multiprocess as mp +from multiprocess import get_context +from mlflow.tracking import MlflowClient # for run_id-based logging + + +def set_dict_leaves(src, dst, key=None): + for k, v in src.items(): + if isinstance(dst[k], dict): + set_dict_leaves(v, dst[k]) + else: + dst[k] = v + + +class OptVlasov1D(BaseVlasov1D): + def __init__(self, cfg): + super().__init__(cfg) + + def __call__(self, params: dict, args: dict) -> dict: + if args is None: + args = self.args + solver_result = super().__call__(self.cfg, self.state, args) + return solver_result + + # def vg(self, params: dict, args: dict) -> tuple[float, dict, dict]: + # return eqx.filter_value_and_grad(self.__call__, has_aux=True)(params, args) + + +def run_one(args): + dt, nx, nv, base_cfg = args + + # Local imports (optional) to avoid large globals during fork + import time, copy + + t0 = time.time() + + # Deep copy the config so each worker can mutate independently + cfg = copy.deepcopy(base_cfg) + cfg["grid"]["nx"] = nx + cfg["grid"]["nv"] = nv + cfg["grid"]["dt"] = dt + cfg["mlflow"]["run"] = f"nx-{nx}_nv-{nv}_dt-{dt}" + + # Ensure MLflow experiment is set inside this process + mlflow.set_experiment(cfg["mlflow"].get("experiment", "twostream-gridsearch")) + + # Run the simulation + exo = ergoExo(mlflow_nested=True) + exo.setup(cfg=cfg, adept_module=OptVlasov1D) + sim_out, post_processed_output, mlflow_run_id = exo(exo.cfg) + + # Compute your value + val = float( + jnp.mean( + jnp.nan_to_num( + jnp.log10(sim_out["solver result"].ys["default"]["mean_e2"][1:]) + ) + ) + ) + + elapsed = time.time() - t0 + + # Log metrics against the run_id to avoid relying on implicit active run state + try: + client = MlflowClient() + client.log_metric(mlflow_run_id, key="time_sim", value=elapsed) + client.log_metric(mlflow_run_id, key="val", value=val) + except Exception: + # Fallback to implicit logging if needed (not ideal in multi-proc) + mlflow.log_metrics({"time_sim": elapsed, "val": val}) + + return { + "dt": dt, + "nx": nx, + "nv": nv, + "time_sim": elapsed, + "mlflow_run_id": mlflow_run_id, + "val": val, + } + + + +if __name__ == "__main__": + n_procs=20 + # Load base config + with open("tests/test_vlasov1d/configs/twostream_opt.yaml", "r") as stream: + cfg = yaml.safe_load(stream) + + # Set experiment (also done per worker) + cfg["mlflow"]["experiment"] = "twostream-gridsearch" + mlflow.set_experiment(cfg["mlflow"]["experiment"]) + + # Parameter grid + nx_list = [8, 16, 32, 64, 128] + nv_list = [16, 64, 256, 1024, 4096] + dt_list = [0.1, 1, 10] + + tasks = [(dt, nx, nv, cfg) for dt, nx, nv in itertools.product(dt_list, nx_list, nv_list)] + total = len(tasks) + + # Choose number of processes + if n_procs is None: + n_procs = os.cpu_count() or 1 + + # Use spawn context for JAX/MLflow friendliness + ctx = get_context("spawn") + with ctx.Pool(processes=n_procs) as pool: + # chunksize=1 ensures fair scheduling; increase for fewer, heavier tasks + results_iter = pool.imap_unordered(run_one, tasks, chunksize=1) + for res in tqdm(results_iter, total=total): + tqdm.write(f"time: {res['time_sim']:.3f} run_id: {res['mlflow_run_id']} val: {res['val']:.6g}") \ No newline at end of file