Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 112 additions & 43 deletions src/struphy/feec/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from feectools.linalg.solvers import inverse
from feectools.linalg.stencil import StencilDiagonalMatrix, StencilMatrix, StencilVector

from struphy import domains, equils
from struphy.feec import mass_kernels
from struphy.feec.linear_operators import BoundaryOperator, LinOpWithTransp
from struphy.feec.psydac_derham import Derham, SplineFunction
from struphy.feec.utilities import LocalRotationMatrix, get_quad_grids
from struphy.feec.utilities import LocalProjectionMatrix, LocalRotationMatrix, get_quad_grids
from struphy.fields_background.base import MHDequilibrium
from struphy.fields_background.equils import set_defaults
from struphy.geometry.base import Domain
Expand Down Expand Up @@ -59,6 +60,11 @@ def __init__(
self._matrix_free = matrix_free
self._eq_mhd = eq_mhd

if self._eq_mhd is None:
self._eq_mhd = equils.HomogenSlab()
if not hasattr(self.eq_mhd, "_domain"):
self._eq_mhd.domain = self._domain

# only for M1 Mac users
PSYDAC_BACKEND_GPYCCEL["flags"] = "-O3 -march=native -mtune=native -ffast-math -ffree-line-length-none"

Expand Down Expand Up @@ -338,7 +344,7 @@ def M1ninv(self):
weights=(
"Ginv",
"sqrt_g",
"1/eq_n0",
lambda *etas: 1 / self.eq_mhd.n0(*etas),
),
name="M1ninv",
assemble=True,
Expand Down Expand Up @@ -693,14 +699,39 @@ def M1Bninv(self):
rot_B,
"Ginv",
"sqrt_g",
"1/eq_n0",
lambda *etas: 1 / self.eq_mhd.n0(*etas),
),
name="M1Bninv",
assemble=True,
)

return self._M1Bninv

@auto_convert_docstring
@property
def M1para(self):
r"""
Mass matrix

.. math::

\mathbb M^{1,\parallel}_{(\mu,ijk), (\nu,mno)} = \int \vec{\Lambda}^1_{\mu,ijk} b_0 b_0^\top \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
"""
if not hasattr(self, "_M1para"):
bb = LocalProjectionMatrix(self.eq_mhd.unit_bv_1, self.eq_mhd.unit_bv_2, self.eq_mhd.unit_bv_3)

self._M1para = self.create_weighted_mass(
"Hcurl",
"Hcurl",
weights=(
bb,
"sqrt_g",
),
name="M1para",
assemble=True,
)
return self._M1para

@auto_convert_docstring
@property
def M1perp(self):
Expand All @@ -709,25 +740,79 @@ def M1perp(self):

.. math::

\mathbb M^{1,\perp}_{(\mu,ijk), (\nu,mno)} = \int \vec{\Lambda}^1_{\mu,ijk} DF^{-1} \begin{pmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 0 \end{pmatrix} DF^{-\top} \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
\mathbb M^{1,\perp}_{(\mu,ijk), (\nu,mno)} = \int \vec{\Lambda}^1_{\mu,ijk} \left(G^{-1} - b_0 b_0^\top \right) \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
"""
if not hasattr(self, "_M1perp"):
D = [[1, 0, 0], [0, 1, 0], [0, 0, 0]]
self._M1perp = self.M1.copy()
self._M1perp -= self.M1para
return self._M1perp

self._M1perp = self.create_weighted_mass(
@auto_convert_docstring
@property
def M1para_MHDeq(self):
r"""
Mass matrix

.. math::

\mathbb M^{1,\parallel}_{(\mu,ijk), (\nu,mno)} = \int \frac{n^0_{\textnormal{eq}}(\boldsymbol{\eta})}{\|B_0(\boldsymbol{\eta})\|^2} \vec{\Lambda}^1_{\mu,ijk} b_0 b_0^\top \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
"""
if not hasattr(self, "_M1para_MHDeq"):
bb = LocalProjectionMatrix(self.eq_mhd.unit_bv_1, self.eq_mhd.unit_bv_2, self.eq_mhd.unit_bv_3)

self._M1para_MHDeq = self.create_weighted_mass(
"Hcurl",
"Hcurl",
weights=(
"DFinv",
D,
"DFinv",
bb,
lambda *etas: 1 / self.eq_mhd.absB0(*etas) ** 2,
Comment thread
emilegrivet marked this conversation as resolved.
"eq_n0",
"sqrt_g",
),
name="M1perp",
name="M1para_MHDeq",
assemble=True,
)
return self._M1para_MHDeq

return self._M1perp
@auto_convert_docstring
@property
def M1_MHDeq(self):
r"""
Mass matrix

.. math::

\mathbb M^{1}_{(\mu,ijk), (\nu,mno)} = \int \frac{n^0_{\textnormal{eq}}(\boldsymbol{\eta})}{\|B_0(\boldsymbol{\eta})\|^2} \vec{\Lambda}^1_{\mu,ijk} G^{-1} \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
"""
if not hasattr(self, "_M1_MHDeq"):
self._M1_MHDeq = self.create_weighted_mass(
"Hcurl",
"Hcurl",
weights=(
"Ginv",
lambda *etas: 1 / self.eq_mhd.absB0(*etas) ** 2,
"eq_n0",
"sqrt_g",
Comment thread
emilegrivet marked this conversation as resolved.
),
name="M1_MHDeq",
assemble=True,
)
return self._M1_MHDeq

@auto_convert_docstring
@property
def M1gyro(self):
r"""
Mass matrix

.. math::

\mathbb M^{1,\perp}_{(\mu,ijk), (\nu,mno)} = \int \frac{n^0_{\textnormal{eq}}(\boldsymbol{\eta})}{\|B_0(\boldsymbol{\eta})\|^2} \vec{\Lambda}^1_{\mu,ijk} \left(G^{-1} - b_0 b_0^\top \right) \vec{\Lambda}^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.
"""
if not hasattr(self, "_M1gyro"):
self._M1gyro = self.M1_MHDeq.copy()
self._M1gyro -= self.M1para_MHDeq
return self._M1gyro

@auto_convert_docstring
@property
Expand All @@ -746,7 +831,10 @@ def M0ad(self):
self._M0ad = self.create_weighted_mass(
"H1",
"H1",
weights=("eq_n0", "sqrt_g"),
weights=(
"eq_n0",
"sqrt_g",
),
name="M0ad",
assemble=True,
)
Expand All @@ -755,39 +843,30 @@ def M0ad(self):

@auto_convert_docstring
@property
def M1gyro(self):
def M0ad_withT(self):
r"""
Mass matrix

.. math::

\mathbb M^{1,n}_{(\mu,ijk), (\nu,mno)} = \int n^0_{\textnormal{eq}}(\boldsymbol{\eta}) \Lambda^1_{\mu,ijk} G^{-1}_{\mu,\nu} \Lambda^1_{\nu,mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta},
\mathbb M^0_{ijk, mno} = \int \frac{n^0_{\textnormal{eq}}(\boldsymbol{\eta})}{T^0_{\textnormal{eq}}(\boldsymbol{\eta})} \Lambda^0_{ijk} \Lambda^0_{mno} \sqrt{g} \textnormal{d}\boldsymbol{\eta}.

where :math:`n^0_{\textnormal{eq}}(\boldsymbol{\eta})` is an MHD equilibrium density (0-form).
where :math:`n^0_{\textnormal{eq}}(\boldsymbol{\eta})` and :math:`T^0_{\textnormal{eq}}(\boldsymbol{\eta})` are MHD equilibrium density and electron temperature (0-forms), respectively.
"""

if not hasattr(self, "_M1gyro"):
D = [[1, 0, 0], [0, 1, 0], [0, 0, 0]]

self._M1gyro = self.create_weighted_mass(
"Hcurl",
"Hcurl",
if not hasattr(self, "_M0ad_withT"):
self._M0ad_withT = self.create_weighted_mass(
"H1",
"H1",
weights=(
"eq_n0",
"1/eq_absB0",
"1/eq_absB0",
D,
"Ginv",
D,
lambda *etas: 1 / self.eq_mhd.t0(*etas),
"sqrt_g",
),
name="M1gyro",
name="M0ad_withT",
assemble=True,
)

# 1/eq_absB0**2 written twice instead of square

return self._M1gyro
return self._M0ad_withT

@property
def WMM(self):
Expand Down Expand Up @@ -917,18 +996,8 @@ def create_weighted_mass(
for n, f in enumerate(weights):
if isinstance(f, str):
# determine the callable
if "/" in f:
f_components = f.split("/")
if f_components[-1] == "sqrt_g":
f_call = lambda e1, e2, e3: 1.0 / abs(self.domain.jacobian_det(e1, e2, e3))
elif f_components[-1] == "eq_n0":
f_call = lambda e1, e2, e3: 1.0 / self.eq_mhd.n0(e1, e2, e3)
elif f_components[-1] == "eq_absB0":
f_call = lambda e1, e2, e3: 1.0 / self.eq_mhd.absB0(e1, e2, e3)
else:
raise NotImplementedError(
f"The option {f} is not available for division ('/') yet.",
)
if f == "1/sqrt_g":
f_call = lambda e1, e2, e3: 1.0 / abs(self.domain.jacobian_det(e1, e2, e3))
elif "eq_" in f:
f_components = f.split("q_")
f_call = getattr(self.eq_mhd, f_components[-1])
Expand Down
8 changes: 8 additions & 0 deletions src/struphy/feec/preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def fun(e):
): # this branch is only entered if comm exists (and thus subcomm has been initialized)
if subcomm != MPI.COMM_NULL:
subcomm.Allgather(local_fun, fun)
"""gathered = subcomm.gather(local_fun, root=selected_ranks[0])
if rank == selected_ranks[0]:
if gathered is None:
raise RuntimeError("MPI gather failed to return data on root rank")
fun[:] = xp.concatenate(gathered)
assert fun.size == npts, (
f"Gathered weight size {fun.size} does not match expected {npts}"
)"""
comm.Bcast(fun, root=selected_ranks[0])
else:
fun[:] = local_fun
Expand Down
15 changes: 10 additions & 5 deletions src/struphy/feec/tests/test_mass_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def test_mass(num_elements, degree, bcs, map_and_equil, matrix_free, show_plots=False):
"""Test weighted mass matrices by recovering projected functions from the DeRham complex.

For each mass operator in ``{M0, M1, M2, M3, Mv, M1n, M2n, Mvn, M1ninv, M0ad}``,
For each mass operator in ``{M0, M1, M2, M3, Mv, M1n, M2n, Mvn, M1ninv, M0ad, M0ad_withT}``,
the test:

1. Projects known trigonometric right-hand-side functions onto the
Expand All @@ -32,10 +32,12 @@ def test_mass(num_elements, degree, bcs, map_and_equil, matrix_free, show_plots=
it point-wise to the exact function.

The density-weighted operators (``M1n``, ``M2n``, ``Mvn``, ``M0ad``) are
tested against ``exact / n0``, and the inverse-density operator
tested against ``exact / n0``, ``M0ad_withT`` is tested against ``exact * t0 / n0``, and the inverse-density operator
(``M1ninv``) is tested against ``exact * n0``.
"""

from types import MethodType

import cunumpy as xp
from feectools.ddm.mpi import mpi as MPI
from feectools.linalg.solvers import inverse
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_mass(num_elements, degree, bcs, map_and_equil, matrix_free, show_plots=
# derham object
grid = TensorProductGrid(num_elements=num_elements)
derham_opts = DerhamOptions(degree=degree, bcs=bcs)
derham = Derham(grid, derham_opts, comm=mpi_comm)
derham = Derham(grid, derham_opts, comm=mpi_comm, domain=domain)

logger.debug(f"Rank {mpi_rank} | Local domain : " + str(derham.domain_array[mpi_rank]))

Expand Down Expand Up @@ -111,6 +113,7 @@ def rhs_2(e1, e2, e3):
rhs = {}
rhs["M0"] = l2proj_0.get_dofs(rhs_0, apply_bc=True)
rhs["M0ad"] = rhs["M0"]
rhs["M0ad_withT"] = rhs["M0"]
rhs["M1"] = l2proj_1.get_dofs((rhs_0, rhs_1, rhs_2), apply_bc=True)
rhs["M1n"] = rhs["M1"]
rhs["M1ninv"] = rhs["M1"]
Expand All @@ -134,7 +137,7 @@ def rhs_2(e1, e2, e3):
elif min(degree) == 2:
err_bound = 2.6e-2

names = ["M0", "M1", "M2", "M3", "Mv", "M1n", "M2n", "Mvn", "M1ninv", "M0ad", "WMM", "WMMnew"]
names = ["M0", "M1", "M2", "M3", "Mv", "M1n", "M2n", "Mvn", "M1ninv", "M0ad", "M0ad_withT", "WMM", "WMMnew"]
for name in names:
if name == "WMM":
intermediate = mass_ops.WMM
Expand All @@ -155,7 +158,9 @@ def rhs_2(e1, e2, e3):
exact = xp.array([rhs_0(ee1, ee2, ee3), rhs_1(ee1, ee2, ee3), rhs_2(ee1, ee2, ee3)])

solver = "cg"
if name in ["M1n", "M2n", "Mvn", "M0ad", "WMM", "WMMnew"]:
if name == "M0ad_withT":
exact *= equil.t0(e1, e2, e3)
if name in ["M1n", "M2n", "Mvn", "M0ad", "WMM", "WMMnew", "M0ad_withT"]:
# solve n0 * u = f, where n0 is the equilibrium density
exact /= equil.n0(e1, e2, e3)
elif name == "M1ninv":
Expand Down
32 changes: 32 additions & 0 deletions src/struphy/feec/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,38 @@ def get_quad_grids(
return tuple({q: gag} for q, gag in zip(nquads, space.get_assembly_grids(*nquads)))


class LocalProjectionMatrix:
"""For a given triple of callables representing the components of a normalized vector-valued function a(e1, e2, e3),
represents the local projection matrix P defined by P = a a^T at (e1, e2, e3).

LocalProjectionMatrix(e1, e2, e3) returns a five-dimensional array, with the 3x3 matrix in the last two indices.

This can then be used with the following numpy functions:
* matvec for matrix-vector multiplication in the last indices
* @ for matrix-matrix multiplication in the last two indices

Parameters
----------
*vec_fun : list
Three callables that represent the components of the vector-valued function a.
"""

def __init__(self, *vec_fun):
assert len(vec_fun) == 3
assert all([callable(fun) for fun in vec_fun])

self._funs = vec_fun

def __call__(self, e1, e2, e3):
# array from 2d list gives 3x3 array is in the first two indices
tmp = xp.array(
[[self._funs[m](e1, e2, e3) * self._funs[n](e1, e2, e3) for n in range(3)] for m in range(3)],
)

# numpy operates on the last two indices with @
return xp.transpose(tmp, axes=(2, 3, 4, 0, 1))


class LocalRotationMatrix:
"""For a given triple of callables representing the components of a vector-valued function a(e1, e2, e3),
represents the local rotation matrix R defined by Rv = a x v at (e1, e2, e3) for any vector v in R^3.
Expand Down
Loading
Loading