diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..ea94d92 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,27 @@ + + +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need +build: + os: ubuntu-24.04 + tools: + python: "3.13" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally, but recommended, +# declare the Python requirements required to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt + - requirements: requirements.txt + + diff --git a/README.md b/README.md index 192ef3b..6a48159 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # humancompatible-train: a package for constrained machine learning -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![docs](https://app.readthedocs.org/projects/humancompatible-train/badge/?version=latest)](https://humancompatible-train.readthedocs.io/en/latest/?badge=latest) The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API. @@ -29,7 +29,7 @@ The only dependencies of this package are `numpy` and `torch`. ## Using the toolkit -The toolkit implements algorithms for constrained training of neural networks based on PyTorch. +The toolkit implements algorithms for constrained training of neural networks based on PyTorch. For the documentation, please visit [our Read the Docs page!](https://humancompatible-train.readthedocs.io?version=latest) The algorithms are intended for use in tandem with classic PyTorch optimizers, calculating the Lagrangian and keeping track of the dual variables. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..a92337a --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +sphinx +myst-nb +. +furo +sphinx_rtd_theme \ No newline at end of file diff --git a/docs/source/api_reference/dual_optimizers.rst b/docs/source/api_reference/dual_optimizers.rst new file mode 100644 index 0000000..bb31bf1 --- /dev/null +++ b/docs/source/api_reference/dual_optimizers.rst @@ -0,0 +1,8 @@ +Dual Optimizers +=============== + +.. toctree:: + :titlesonly: + :glob: + + dual_opts/* \ No newline at end of file diff --git a/docs/source/api_reference/dual_opts/alm.rst b/docs/source/api_reference/dual_opts/alm.rst new file mode 100644 index 0000000..74bb2f4 --- /dev/null +++ b/docs/source/api_reference/dual_opts/alm.rst @@ -0,0 +1,6 @@ +ALM +================= + + +.. autoclass:: humancompatible.train.dual_optim.ALM + :members: diff --git a/docs/source/api_reference/dual_opts/ialm.rst b/docs/source/api_reference/dual_opts/ialm.rst new file mode 100644 index 0000000..eb26e99 --- /dev/null +++ b/docs/source/api_reference/dual_opts/ialm.rst @@ -0,0 +1,6 @@ +iALM +================= + + +.. autoclass:: humancompatible.train.dual_optim.iALM + :members: diff --git a/docs/source/api_reference/dual_opts/pbm.rst b/docs/source/api_reference/dual_opts/pbm.rst new file mode 100644 index 0000000..49f0c72 --- /dev/null +++ b/docs/source/api_reference/dual_opts/pbm.rst @@ -0,0 +1,6 @@ +PBM +================= + + +.. autoclass:: humancompatible.train.dual_optim.PBM + :members: diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst new file mode 100644 index 0000000..3f22410 --- /dev/null +++ b/docs/source/api_reference/utils.rst @@ -0,0 +1,8 @@ +Utils +===== + +.. toctree:: + :titlesonly: + :glob: + + utils/* \ No newline at end of file diff --git a/docs/source/api_reference/utils/sampler.rst b/docs/source/api_reference/utils/sampler.rst new file mode 100644 index 0000000..adb3d29 --- /dev/null +++ b/docs/source/api_reference/utils/sampler.rst @@ -0,0 +1,6 @@ +BalancedBatchSampler +==================== + + +.. autoclass:: humancompatible.train.fairness.utils.BalancedBatchSampler + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..3ac9411 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,62 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +extensions = ["myst_nb", "sphinx.ext.autodoc"] + + +import os + + +project = 'humancompatible-train' +copyright = '2026, Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' +author = 'Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' +# release = '0.3.1' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +templates_path = ['_templates'] +exclude_patterns = [] + +nb_toctree = False +nb_number_headings = False +nb_execution_show_tb = False +nb_execution_mode = "cache" +nb_execution_timeout = 180 + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +# html_theme = 'furo' +html_theme = 'sphinx_rtd_theme' + +html_static_path = ['_static'] + +html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "/") + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", +] +myst_url_schemes = ("http", "https", "mailto") + +source_suffix = { + ".rst": "restructuredtext", + ".md": "myst-nb", + ".ipynb": "myst-nb", +} + +nb_execution_mode = "cache" + +import sys + +sys.path.insert(0, os.path.abspath('./../..')) \ No newline at end of file diff --git a/docs/source/examples/learn_DAG.rst b/docs/source/examples/learn_DAG.rst new file mode 100644 index 0000000..e7d072c --- /dev/null +++ b/docs/source/examples/learn_DAG.rst @@ -0,0 +1,209 @@ +Learning Directed Acyclic Graphs (DAGs) from Data +================================================== + +Overview +-------- + +This example demonstrates how to learn a **Directed Acyclic Graph (DAG)** from data using constrained optimization. We follow an approach inspired by the `Cooper `_ library's DAG learning example. + +In this example, we: + +1. Generate synthetic data from a linear structural equation model +2. Define a constrained optimization problem to recover the underlying graph +3. Use the Augmented Lagrangian Method (ALM) to solve the problem +4. Visualize both the learned and ground truth graphs + +What is a DAG? +-------------- + +A Directed Acyclic Graph (DAG) is a graph where: + +- Nodes represent variables or features +- Directed edges represent causal relationships +- There are no cycles (acyclic property) +- The acyclic property ensures a topological ordering exists + +DAG learning is useful in causal inference, discovering variable dependencies, and understanding structural relationships in data. + +Data Generation +--------------- + +We start by generating synthetic data from a linear structural equation model with Gaussian noise: + +.. code-block:: python + + import torch + import numpy as np + import math + + def generate_data(n, d, n_causes, noise_std, device): + """Generate data from a linear structural equation model with Gaussian noise. + + Args: + n: number of samples + d: number of features + n_causes: number of root nodes (nodes with no parents) + noise_std: standard deviation of the noise + device: torch.device + + Returns: + X: Data matrix of shape (n, d) + A: Adjacency matrix of shape (d, d) + """ + # Generate adjacency matrix + A = torch.zeros(d, d, device=device) + + for i in range(n_causes, d): + # Each node (except roots) has random parents from previous nodes + parents = 0 if i == 1 else torch.randperm(i)[:np.random.randint(1, i)] + A[i, parents] = 1 + + # Verify acyclic property + assert torch.trace(torch.linalg.matrix_exp(A)).item() == d, "A is not a DAG" + + # Generate data: X_i = sum(X_parents_i) + noise_i + noise = noise_std * torch.randn(n, d, device=device) + X = torch.zeros(n, d, device=device) + + for i in range(d): + parents = torch.nonzero(A[i]).flatten() + X[:, i] = X[:, parents].sum(dim=1) + noise[:, i] + + # Improve conditioning + X /= math.sqrt(d) + + return X, A + +**Parameters:** + +- ``n``: Number of samples (5,000 in this example) +- ``d``: Number of features/nodes (8 in this example) +- ``n_causes``: Number of root nodes with no parents (2 in this example) +- ``noise_std``: Standard deviation of Gaussian noise (0.01 in this example) + +Training Setup +-------------- + +We formulate the DAG learning problem as a constrained optimization problem: + +.. math:: + + \min_{A \in \{0, 1\}^{d \times d}} \left\| X - XA \right\|_F^2 + + \text{subject to:} \quad \text{tr}(e^A) = d + +The constraint ensures the adjacency matrix ``A`` represents a valid DAG: + +- The exponential matrix ``exp(A)`` has trace equal to ``d`` if and only if ``A`` is acyclic +- This is an algebraic constraint that replaces the combinatorial acyclicity check + +**Implementation:** + +.. code-block:: python + + from humancompatible.train.dual_optim import ALM + from torch.optim import AdamW + + # Initialize adjacency matrix as a learnable parameter + A = torch.nn.Parameter(torch.randn(D, D, device=DEVICE) / math.sqrt(D)) + + # Optimizer for the primal variable (adjacency matrix) + optimizer = AdamW(params=[A], lr=PRIMAL_LR) + + # Dual optimizer using Augmented Lagrangian Method + dual_opt = ALM(m=1) # m=1 constraint + + # Constraint function + constraint = lambda A: torch.trace(torch.linalg.matrix_exp(A)) - d + +Training Loop +------------- + +The training procedure alternates between: + +1. **Primal step**: Update ``A`` to minimize the Lagrangian +2. **Dual step**: Update Lagrange multipliers to enforce constraint satisfaction + +.. code-block:: python + + for i in range(N_STEPS): + # Project to valid range [0, 1] and remove diagonal + A.data.fill_diagonal_(0) + A.data.clamp_(min=0, max=1.0) + + # Compute loss: reconstruction error + loss = torch.square(torch.linalg.norm(X - X @ A.T, ord="fro")) + + # Compute constraint violation + cviol = constraint(A) + + # Update Lagrangian + lagrangian = dual_opt.forward_update(loss, cviol.unsqueeze(0)) + + # Gradient descent on primal variable + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + +**Key steps:** + +- **Diagonal removal**: No self-loops allowed (``A.fill_diagonal_(0)``) +- **Value clamping**: Adjacency values are bounded to [0, 1] (``A.clamp_(min=0, max=1.0)``) +- **Loss computation**: Measures how well ``A`` predicts the data +- **Constraint enforcement**: The ALM solver tracks dual variables to enforce the acyclicity constraint + +Results and Visualization +-------------------------- + +After training, we can visualize the learned adjacency matrix alongside the ground truth: + +.. code-block:: python + + import networkx as nx + import seaborn as sns + from matplotlib import pyplot as plt + + # Create network graph + G = nx.DiGraph() + G.add_nodes_from(range(D)) + + for i in range(D): + for j in range(D): + if A[i, j] != 0: + G.add_edge(j, i) + + # Visualize + pos = nx.shell_layout(G) + plt.figure(figsize=(5, 2)) + nx.draw(G, pos, with_labels=True, font_weight="bold") + plt.show() + +**Visualization outputs:** + +1. **Adjacency Heatmaps**: Compare learned, ground truth, and difference matrices +2. **Training Progress**: Track loss, constraint violation, and dual parameters over iterations + +The quality of recovery depends on: + +- **Dataset size**: Larger datasets improve recovery +- **Noise level**: Lower noise enables better recovery +- **Training iterations**: More iterations improve convergence +- **Graph density**: Sparser graphs are easier to recover + +Applications +----------- + +DAG learning is useful for: + +- **Causal discovery**: Inferring causal relationships from observational data +- **Biological networks**: Discovering gene regulatory networks +- **Financial systems**: Understanding dependencies between economic indicators +- **Knowledge graphs**: Learning structured relationships from data +- **Feature importance**: Understanding variable interactions + +See Also +-------- + +- :doc:`api_reference` for the ALM solver and optimization utilities +- `Cooper Documentation `_ for more constrained optimization examples +- The full notebook: ``examples/learn_DAG.ipynb`` diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst new file mode 100644 index 0000000..61ce3dd --- /dev/null +++ b/docs/source/getting_started.rst @@ -0,0 +1,49 @@ +Getting Started +=============== + +Quick Start +----------- + +After installing humancompatible-train, you can import it in your Python code: + +.. code-block:: python + + from humancompatible.train.dual_optim import * + +Basic Example +-------------- + +This is an abstract code sample; you can find runnable examples in the :doc:`tutorials/basic_usage` section. + +.. code-block:: python + + from humancompatible.train.dual_optim import ALM + + device = ... + num_constraints = ... + + optimizer = torch.optim.Adam(model.parameters(), ...) + dual_optimizer = ALM(m=num_constraints, ..., device=device) + + for inputs, labels in dataloader: + # evaluate objective + outputs = model(inputs) + loss = criterion(outputs, labels) + # evaluate tensor of constraints + constraints = evaluate_constraints(inputs, labels, ...) + # evaluate lagrangian and update dual variables + lagrangian = dual_optimizer.forward_update(loss, constraints) + # backward pass and step + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + +.. note:: + + For detailed examples (including inequality constraints), see the :doc:`tutorials/basic_usage` and :doc:`tutorials/inequality_constraints` sections. + +Next Steps +---------- + +- Read the :doc:`Basic Usage ` guide for a complete example +- If you encounter issues, visit the :doc:`Troubleshooting ` page diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..4805ce1 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,35 @@ +humancompatible-train documentation +=================================== + +Welcome to the **humancompatible-train** documentation. + +What is **humancompatible-train**? + +**humancompatible-train** is a PyTorch-based package for constrained optimization, aimed at constrained deep learning tasks. +We implement several first-order Lagrangian-based methods for constrained optimization with a PyTorch-based API that allow seamless integration of constraints into the training loop. + + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + :titlesonly: + + install + getting_started + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials + :titlesonly: + + Constrained Optimization Overview + Basic usage: Fairness + Handling inequality constraints + Tips and Tricks + +.. toctree:: + :caption: API reference + :titlesonly: + + Dual Optimizers + Utils \ No newline at end of file diff --git a/docs/source/install.rst b/docs/source/install.rst new file mode 100644 index 0000000..6e797db --- /dev/null +++ b/docs/source/install.rst @@ -0,0 +1,38 @@ +Installation +============ + +Prerequisites +------------- + +- Python 3.11 or higher +- pip (Python package manager) +- Virtual environment (recommended) + +Basic Installation +------------------ + +Install the package using pip: + +.. code-block:: bash + + pip install humancompatible-train + +Installation from Source +------------------------ + +To install the development version from source: + +.. code-block:: bash + + git clone https://github.com/humancompatible-train.git + cd humancompatible-train + pip install -e . + +Optional Dependencies +--------------------- + +For specific features, you may need additional packages: + +.. code-block:: bash + + pip install humancompatible-train[examples] # Example notebooks \ No newline at end of file diff --git a/docs/source/support.rst b/docs/source/support.rst new file mode 100644 index 0000000..a329fa2 --- /dev/null +++ b/docs/source/support.rst @@ -0,0 +1,53 @@ +Support +======= + +Getting Help +------------ + +If you need help with humancompatible-train, here are the recommended channels: + +GitHub Issues +~~~~~~~~~~~~~ + +For bug reports and feature requests, please open an issue on the GitHub repository: + +https://github.com/humancompatible-train + +When reporting an issue, please include: + +- Description of the problem +- Steps to reproduce the issue +- Python version and environment information +- Relevant code snippets or error messages + +Email Support +~~~~~~~~~~~~~ + +You can contact the maintainers at: + +kliacand@fel.cvut.cz + +Documentation +~~~~~~~~~~~~~ + +.. - Check the :doc:`getting_started` guide for basic information +.. - Review the :doc:`examples/basic_usage` for common usage patterns +.. - Consult the :doc:`troubleshooting` page for known issues and solutions +.. - See the :doc:`examples/api_reference` for API documentation + +Contributing +~~~~~~~~~~~~ + +We welcome contributions! If you'd like to contribute: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Submit a pull request + +Additional Resources +~~~~~~~~~~~~~~~~~~~~ + +- Project Homepage: https://github.com/humancompatible-train +- Documentation: https://humancompatible-train.readthedocs.io +- Issue Tracker: https://github.com/humancompatible-train/issues \ No newline at end of file diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst new file mode 100644 index 0000000..c622c76 --- /dev/null +++ b/docs/source/troubleshooting.rst @@ -0,0 +1,87 @@ +Troubleshooting +=============== + +Common Issues and Solutions +--------------------------- + +Installation Issues +~~~~~~~~~~~~~~~~~~~ + +**Problem: "ModuleNotFoundError" when importing humancompatible-train** + +Solution: + 1. Verify installation: ``pip list | grep humancompatible-train`` + 2. Reinstall the package: ``pip install --upgrade humancompatible-train`` + 3. Check your Python version: ``python --version`` (requires Python 3.8+) + +**Problem: "Permission denied" during installation** + +Solution: + Use a virtual environment (recommended): + + .. code-block:: bash + + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install humancompatible-train + +Training Issues +~~~~~~~~~~~~~~~ + +**Problem: Constraints are not being satisfied** + +Solutions: + 1. Verify constraint definitions are correct + 2. Check that constraints are compatible with your data + 3. Increase training time or adjust hyperparameters + 4. Review constraint priorities and weights + +**Problem: Out of memory during training** + +Solutions: + 1. Reduce batch size + 2. Use a smaller dataset for testing + 3. Enable gradient checkpointing if available + 4. Consider distributed training + +Performance Issues +~~~~~~~~~~~~~~~~~~ + +**Problem: Training is very slow** + +Solutions: + 1. Profile your code to identify bottlenecks + 2. Use fewer constraints if possible + 3. Optimize your data loading pipeline + 4. Consider using GPU acceleration + 5. Try reducing dataset size for experimentation + +Getting More Help +----------------- + +If you can't find a solution here: + +1. Check the :doc:`support` page for contact information +2. Review the :doc:`API Reference ` for function signatures +3. Open an issue on the GitHub repository +4. Consult the project's issue tracker for similar problems + +FAQ +--- + +**Q: Which Python versions are supported?** + +A: Python 3.8 and higher. We recommend using Python 3.9 or later. + +.. **Q: Can I use my custom model architecture?** + +.. A: Yes, see the :doc:`Advanced Usage ` section for details on custom constraints and models. + +**Q: How do I report a bug?** + +A: Please open an issue on GitHub with: + + - Description of the problem + - Steps to reproduce + - Python version and environment information + - Relevant code snippets or error messages diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md new file mode 100644 index 0000000..059bc28 --- /dev/null +++ b/docs/source/tutorials/basic_usage.md @@ -0,0 +1,212 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.19.3 +kernelspec: + display_name: hc-dev + language: python + name: python3 +--- + +# Basic Usage + +This page provides an overview of using humancompatible-train for constrained deep learning on a simple example. + ++++ + +## Idea + ++++ + +The core of the package is formed by Lagrangian-based **dual optimizers**, which are PyTorch Optimizer-like objects that handle the **constrained** part of **constrained deep learning**. + +They create, keep track of, and update the **dual parameters** of the constrained minimization problem, as well as calculate the Lagrangian that is then minimized by a standard PyTorch optimizer in place of a loss. + ++++ + +## Simple Example + ++++ + +Let us demonstrate using a **fairness-constrained learning** task, where we want to learn a classifier that is accurate but also satisfies a **demographic parity constraint** - i.e. we would like + +$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | \leq \epsilon $$ + +where $ Y $ is the prediction given by our model for sample $ X $, and $ \epsilon $ is some small threshold. + + +--- + ++++ + +To enforce demographic parity, we will define a **constraint function** (using the [fairret](https://github.com/aida-ugent/fairret) package) that measures the difference in positive prediction rates between two demographic groups. + +The **dual optimizer** will then update the Lagrange multipliers to enforce this constraint during training. + ++++ + +First, let us load and prepare the data. We will use the ACS dataset, containing U.S. Census data, provided by the [folktables](https://github.com/socialfoundations/folktables) package. Feel free to skip this section. + +```{code-cell} ipython3 +--- +tags: [hide-cell] +--- +# load data +import torch +import numpy as np +from sklearn.preprocessing import StandardScaler +from folktables import ACSDataSource, ACSIncome, generate_categories + +torch.set_default_dtype(torch.float32) + +# load folktables data +data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") +acs_data = data_source.get_data(states=["FL"], download=True) +definition_df = data_source.get_definitions(download=True) +categories = generate_categories( + features=ACSIncome.features, definition_df=definition_df +) +df_feat, df_labels, _ = ACSIncome.df_to_pandas( + acs_data, categories=categories, dummies=True +) +sens_cols = ["SEX_Female", "SEX_Male"] +features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32) +labels = df_labels.to_numpy(dtype=np.float32) +# one-hot encoding of the sensitive attribute (gender) +groups = df_feat[sens_cols].to_numpy(dtype=np.float32) + +# standardize features +scaler = StandardScaler() +features = scaler.fit_transform(features) +# convert to torch tensors +X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups) + +dataset_train = torch.utils.data.TensorDataset(X, groups, y) +loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True) +criterion = torch.nn.BCEWithLogitsLoss() +``` + +Initialize the model and optimizer. + +```{code-cell} ipython3 +from torch.nn import Sequential +from torch.optim import AdamW + +def setup_model(): + + model = Sequential( + torch.nn.Linear(features.shape[1], 64), + torch.nn.ReLU(), + torch.nn.Linear(64, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, 1), + ) + model.forward(torch.zeros(features.shape[1])).backward() # dummy forward/backward pass to construct torch graph for fair comparison + optimizer = AdamW(model.parameters()) + return model, optimizer +``` + +Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. \ +As a reminder, we expect our constraints to be of the form $ g(...) \leq 0 $ or $ h(...) = 0 $. We want $ g(...) \leq \epsilon $, so we will subtract $ \epsilon $ in the training loop. + +```{code-cell} ipython3 +from fairret.statistic import PositiveRate + +statistic = PositiveRate() + +def pr_diff(logit, groups): + preds = torch.sigmoid(logit) + stats = PositiveRate()(preds, groups) + stat_diff = torch.abs(stats[0] - stats[1]) + return stat_diff +``` + +As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables, and the **type** of constraint -- equality or inequality. In a following tutorial, we will see how to create *constraint groups* with different types and hyperparameters. + +```{code-cell} ipython3 +from humancompatible.train.dual_optim import ALM + +dual_optimizer = ALM(m=1, lr=0.01, is_ineq=True) +``` + +Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \epsilon $ threshold). + +Then, the `forward_update` step does two things: + +- Updates the dual variables based on the constraint violation, +- Calculates the Lagrangian based on loss and constraint violation. + +We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer. + +```{code-cell} ipython3 +model, optimizer = setup_model() +epochs = 10 +``` + +```{code-cell} ipython3 +for epoch in range(epochs): + # eval + model.eval() + logit = model(X) + train_loss = criterion(logit, y).item() + train_fair = pr_diff(logit, groups).item() + print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") + + # train + model.train() + for batch_feat, batch_groups, batch_label in loader: + optimizer.zero_grad() + logit = model(batch_feat) + loss = criterion(logit, batch_label) + + constraint = pr_diff(logit, batch_groups) - 0.05 + lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0)) + lagr.backward() + + optimizer.step() +``` + +We obtain a respectable loss value, while keeping the fairness violation below the threshold! + ++++ + +Just in case, let's check what happens if we train the model without constraints: + +```{code-cell} ipython3 +model, optimizer = setup_model() +``` + +```{code-cell} ipython3 +for epoch in range(epochs): + # eval + model.eval() + logit = model(X) + train_loss = criterion(logit, y).item() + train_fair = pr_diff(logit, groups).item() + print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") + + # train + model.train() + for batch_feat, batch_groups, batch_label in loader: + optimizer.zero_grad() + logit = model(batch_feat) + loss = criterion(logit, batch_label) + + loss.backward() + optimizer.step() +``` + +The absolute difference in positive rates is two times higher than what we wanted! + ++++ + +Further reading: + +```{code-cell} ipython3 + +``` diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst new file mode 100644 index 0000000..28b5d4d --- /dev/null +++ b/docs/source/tutorials/copt_overview.rst @@ -0,0 +1,76 @@ +Constrained Optimization Overview +================================= + +This tutorial provides an overview of constrained optimization problems, and how this relates to Deep Learning. We will cover problem formulation, .... + + +Formulation +--------------- + +In `humancompatible-train`, and in Constrained Machine Learning more generally, we are interested in solving problems of the form: + +.. math:: + \min_{x\in\mathbb{R}^n} \quad & \mathbb{E}[f(x,\xi)] \\ + \text{s.t.} \quad & \mathbb{E}[g(x,\xi)] \leq 0, \\ + & \mathbb{E}[h(x,\xi)] = 0, \\ + +where :math:`f` is the **objective function** we want to minimize, :math:`g` are the **inequality constraints**, and :math:`h` are the **equality constraints**. The expectation is taken over some random variable :math:`\xi`, which represents the data. + +You may recognize the first line of the above formula as the standard formulation of a machine learning problem, where we want to **minimize the expected loss** over the data. +We then introduce **constraints** -- they could express anything from some bounds on the weights of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. + + +.. note:: + - As is standard in the field, we adopt the convention of writing the constraints as :math:`g(x) \leq 0`, and :math:`h(x) = 0`. This is just a notational choice, and does not affect the generality of the formulation. It is trivial to transform :math:`g(x) \geq 0` into :math:`-g(x) \leq 0`, or :math:`h(x) = \epsilon` into :math:`g(x) - \epsilon = 0` for some :math:`\epsilon`. We refer to this :math:`\epsilon` as the constraint **bound**, or **threshold**. + - It is also easy to switch between equality and inequality constraints: to get :math:`g(x) = 0`, one can set :math:`-g(x) \leq 0` and :math:`g(x) \leq 0` simultaneously. In fact, different algorithms are designed to handle either equality or inequality constraints natively, but it is trivial to switch between the two. We shall see more concrete examples later on. + + +Solving Constrained Problems +-------------------------------- + +We know that to solve an unconstrained optimization problem, we can use gradient descent, or any of its myriad variants. But how do we solve a constrained optimization problem? + +The Constrained Machine Learning field seems to have converged on **Lagrangian-based methods**, which utilize the Lagrangian function to transform the **constrained** problem into an **unconstrained** one. + +Going forward in this tutorial, we will focus on the **deterministic case** to simplify notation; the stochastic case is more complex, but utilizes the same principles (think full-batch vs. mini-batch Gradient Descent). For more rigorous mathematical treatment of the stochastic case, see **TODO**, as well as the references included in the documentation for each of the algorithms. + +So, we have the following constrained problem: + +.. math:: + \min_{x\in\mathbb{R}^n} \quad & f(x,\xi) \\ + \text{s.t.} \quad & g(x,\xi) \leq 0, \\ + & h(x,\xi) = 0, \\ + +In a deterministic case, the Lagrangian function is defined as follows: + +.. math:: + \mathcal{L}(x, \lambda, \mu) = f(x) + \lambda^T g(x) + \mu^T h(x) + +where :math:`\lambda` is the Lagrange multiplier associated with the constraint :math:`g(x) \leq 0`, and :math:`\mu` is the Lagrange multiplier associated with the constraint :math:`h(x) = 0`. + +It is then possible to show that the original **constrained** problem is equivalent to the following **unconstrained** problem: + +.. math:: + \min_{x\in\mathbb{R}^n} \max_{\lambda \geq 0, \mu} \mathcal{L}(x, \lambda, \mu) + + +We refer to the original problem as the **primal problem**, with :math:`x` as the **primal variables**, and to the transformed problem as the **dual problem**, with :math:`\lambda` and :math:`\mu` as the **dual variables**. The dual problem is unconstrained, and can be solved using a clever application of standard optimization techniques. + +In particular, we can use **alternating updates**: fix the primal variables, and optimize the dual variables using gradient ascent; then fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. + +In `humancompatible-train`, we implement several variants of this approach, based on methods present in the literature. For more details, see the corresponding documentation; for now, it is important to understand that they are all based on the same principle of alternating updates to the primal and dual variables. + +In the simplest case of the Lagrangian method, this gives us the following update rules: + +.. math:: + \lambda_{t+1} & = \lambda_t + \beta \nabla_\lambda \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \lambda_t + \beta g(x_{t}) \\ + \mu_{t+1} & = \mu_t + \gamma \nabla_\mu \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \mu_t + \gamma h(x_{t}) \\ + x_{t+1} & = x_t - \alpha \nabla_x \mathcal{L}(x_t, \lambda_{t+1}, \mu_{t+1}) + +where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are the learning rates for the primal and dual variables, respectively. + +.. note:: + - The above update rules are for the simplest variant of the Lagrangian method. The methods implemented in this package are all more complex. Even beyond our implementation, one can (and sometimes should!) modify the update rules by e.g. tweaking the training loop code, as we show in the :doc:`tips` tutorial. + - The Lagrangian approach is by far not the only way to do constrained optimization. It is, however, the best fit for a large-scale deep learning setting thanks to its "alternating updates" interpretation, which allows one to use well-established first-order iterative algorithms. We are still looking to implement other families of methods in `humancompatible-train`, such as SQP-based solvers! + +In our package, the `dual optimizers` handle the updates to the dual variables, while the primal updates are handled by the standard PyTorch optimizers. This allows for seamless integration of constraints into the training loop, as we will see in the next tutorial. \ No newline at end of file diff --git a/docs/source/tutorials/tips.rst b/docs/source/tutorials/tips.rst new file mode 100644 index 0000000..ad56f73 --- /dev/null +++ b/docs/source/tutorials/tips.rst @@ -0,0 +1,38 @@ +Tips and Tricks +================================================== + +Here, we discuss some miscellaneous tricks and tips for using the package, which are not specific to any particular method, but can be useful in general when working with constrained optimization problems. + +Dealing with Noise +------------------ + +In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but this can be mitigated. + +**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In ``humancompatible-train``, the ``ALM``, ``iALM``, and ``nuPI`` dual optimizers support momentum, which can be enabled by setting the ``momentum`` parameter to a non-zero value. +Some dual update strategies, such as ``nuPI``, explicitly rely on momentum. + +Without momentum, the dual update at each step is a direct ascent step on the constraint values: + +.. math:: + + \pmb{\lambda}_{t+1} \leftarrow \text{clamp}\!\left(\pmb{\lambda}_t + \gamma\, \mathbf{c}_t(\theta_t),\; \lambda_{\min},\; \lambda_{\max}\right) + +With momentum enabled, a running buffer :math:`\mathbf{b}_t` accumulates a weighted history of past constraint values before being used for the dual update: + +.. math:: + + \mathbf{b}_{t+1} &\leftarrow \mu\, \mathbf{b}_t + (1 - \delta)\, \mathbf{c}_t(\theta_t) \\ + \pmb{\lambda}_{t+1} &\leftarrow \text{clamp}\!\left(\pmb{\lambda}_t + \gamma\, \mathbf{b}_{t+1},\; \lambda_{\min},\; \lambda_{\max}\right) + +where :math:`\mu` is the ``momentum`` coefficient, :math:`\delta` is the ``dampening`` coefficient, and :math:`\gamma` is the dual learning rate. + +.. note:: + + When ``momentum > 0`` and ``dampening`` is not explicitly provided, the library automatically sets ``dampening = momentum``. + This conservative choice prioritises stability: the buffer update becomes + + .. math:: + + \mathbf{b}_{t+1} \leftarrow \mu\, \mathbf{b}_t + (1 - \mu)\, \mathbf{c}_t(\theta_t) + + which is a standard exponential moving average of the constraint values with smoothing factor :math:`\mu`. \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py new file mode 100644 index 0000000..82beaa6 --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py @@ -0,0 +1,758 @@ +import numpy as np +from sympy import N +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +Nx = 10 #100 +Ny = 10 #100 +Nt = 10 #50 +N_collocation = Nx*Ny*Nt +d_in = 3 + +xMin = 0.0 +xMax = 8.0 +yMin = 0.0 +yMax = 8.0 +tMax = 50. +d_in = 2 + +def choose_width_depth(N_collocation=N_collocation, overparam_factor=3.0, d_in=d_in, d_out=1): + """ + Returns a single (width, depth) pair for a mildly overparameterized PINN. + + N_collocation : Nx * Ny * Nt + overparam_factor : how many times larger the model than data (default 3x) + """ + + N_target = int(overparam_factor * N_collocation) + + # For 3rd-order PDEs like ZK: + # depth 5–7 is stable. We fix depth=6 (good compromise). + depth = 8 + + # Parameter formula: + # N = (depth-1) w^2 + (d_in + depth + d_out) w + d_out + a = depth - 1 + b = d_in + depth + d_out + c = d_out - N_target + + # Solve quadratic for width + width = int((-b + np.sqrt(b*b - 4*a*c)) / (2*a)) + + return width, depth + +width, depth = choose_width_depth() +dx = (xMax - xMin) / (Nx - 1) +dy = (yMax - yMin) / (Ny - 1) +dt = tMax / (Nt - 1) +h = np.max([dx, dy, dt]) + +lambdas = [1., 1., 1.] #[1.7, 0.2, 1.4] #for cs #[1., 0.5, 1.] # [1.7, 0.2, 1.4] for cs # +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) + +cheb_par = tf.Variable(0.5, trainable=True, name='cheb_par', dtype=DTYPE) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +y = np.linspace(yMin, yMax, Ny).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) +x_grid, y_grid, t_grid = np.meshgrid(x, y, t, indexing='ij') +x_train = x_grid.flatten(); x_train = tf.convert_to_tensor(x_train); x_train = tf.expand_dims(x_train, axis=-1) +y_train = y_grid.flatten(); y_train = tf.convert_to_tensor(y_train); y_train = tf.expand_dims(y_train, axis=-1) +t_train = t_grid.flatten(); t_train = tf.convert_to_tensor(t_train); t_train = tf.expand_dims(t_train, axis=-1) +xyt_train = tf.concat([x_train, y_train, t_train], axis=-1) + +save_fig = True + +# Define the initial condition +def u_0(x, y): + ##1 + epsilon = 0.01 + theta = 0. + y1 = 0. + y2 = 0. + c1 = 0.45 + c2 = 0.25 + x1 = 2.5 + x2 = 3.3 + out = 3*c1/(tf.math.cosh(0.5*tf.sqrt(c1/epsilon)*((x-x1)*tf.math.cos(theta) + (y-y1)*tf.math.sin(theta))))**2 + + 3*c2/(tf.math.cosh(0.5*tf.sqrt(c2/epsilon)*((x-x2)*tf.math.cos(theta) + (y-y2)*tf.math.sin(theta))))**2 + ##2 + # epsilon = 0.01 + # theta = 0. + # y1 = 4. + # c1 = 1. + # x1 = 2.5 + # out = 3*c1/(tf.math.cosh(0.5*tf.sqrt(c1/epsilon)*((x-x1)*tf.math.cos(theta) + (y-y1)*tf.math.sin(theta))))**2 + + return out + # mpmath for sech + +# def periodic_boundary_conditions(model, Nbc=2000): +# x = tf.random.uniform((Nbc,1), xMin, xMax) +# y = tf.random.uniform((Nbc,1), yMin, yMax) +# t = tf.random.uniform((Nbc,1), 0, tMax) + +# xL = tf.ones_like(x)*xMin; xR = tf.ones_like(x)*xMax +# yL = tf.ones_like(y)*yMin; yR = tf.ones_like(y)*yMax + +# uLx = model(tf.concat([xL,y,t],1)) +# uRx = model(tf.concat([xR,y,t],1)) +# uLy = model(tf.concat([x,yL,t],1)) +# uRy = model(tf.concat([x,yR,t],1)) + +# return tf.reduce_mean((uLx-uRx)**2 + (uLy-uRy)**2) + + +def periodic_boundary_conditions(model, Nbc=2000): + + # Random boundary sampling (correct choice) + x = tf.random.uniform((Nbc,1), xMin, xMax) + y = tf.random.uniform((Nbc,1), yMin, yMax) + t = tf.random.uniform((Nbc,1), 0.0, tMax) + + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + yL = tf.ones_like(y) * yMin + yR = tf.ones_like(y) * yMax + + with tf.GradientTape(persistent=True) as tape: + tape.watch([xL, xR, yL, yR]) + + uLx = model(tf.concat([xL, y, t], 1)) + uRx = model(tf.concat([xR, y, t], 1)) + + uLy = model(tf.concat([x, yL, t], 1)) + uRy = model(tf.concat([x, yR, t], 1)) + + # First derivatives + uxL = tape.gradient(uLx, xL) + uxR = tape.gradient(uRx, xR) + + uyL = tape.gradient(uLy, yL) + uyR = tape.gradient(uRy, yR) + + del tape + + # Enforce periodicity of values AND derivatives + loss = tf.reduce_mean( + (uLx - uRx)**2 + + (uLy - uRy)**2 + + (uxL - uxR)**2 + + (uyL - uyR)**2 + ) + + return loss + + + +def H(u, u_x, u_y): + return tf.reduce_sum((tf.pow(u_x,2) + tf.pow(u_y,2))/2.0-tf.pow(u,3)/6.0, axis=[0,1]) * dx*dy + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + weights (list of tf.Tensor): List of weights for each tensor. + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +class FourierFeatures(tf.keras.layers.Layer): + def __init__(self, n_modes=5): + super().__init__() + self.n_modes = n_modes + + def call(self, inputs): + x = inputs[:, 0:1] + y = inputs[:, 1:2] + t = inputs[:, 2:3] + + features = [t] + + for k in range(1, self.n_modes + 1): + features.append(tf.sin(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(tf.cos(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(tf.sin(2*np.pi*k*(y - yMin)/(yMax-yMin))) + features.append(tf.cos(2*np.pi*k*(y - yMin)/(yMax-yMin))) + + return tf.concat(features, axis=1) + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,80 OK (# 8,40 # 10,40) + xyt_input = tf.keras.Input(shape=(3,)) + output_u = FourierFeatures(n_modes=4)(xyt_input) + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation='tanh', # tanh + kernel_initializer='glorot_uniform', # glorot_normal + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', # mish + kernel_initializer='glorot_uniform', # glorot_normal + )(output_u) + + return tf.keras.Model(inputs=xyt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +# def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,80 OK (# 8,40 # 10,40) +# xyt_input = tf.keras.Input(shape=(3,)) +# output_u = xyt_input +# for _ in range(num_hidden_layers): +# output_u = tf.keras.layers.Dense(num_neurons_per_layer, +# activation='tanh', # tanh +# kernel_initializer='glorot_uniform', # glorot_normal +# )(output_u) + +# output_u = tf.keras.layers.Dense(units=1, +# activation='linear', # mish +# kernel_initializer='glorot_uniform', # glorot_normal +# )(output_u) + +# # Define the initial condition +# # x_input = tf.reshape(xt_input[:, 0], shape=[-1, 1]) +# # t_input = tf.reshape(xt_input[:, 1], shape=[-1, 1]) +# # initial_u = u_0(x_input) +# # output_u = tf.where(tf.equal(t_input, 0), initial_u, output_u) + +# return tf.keras.Model(inputs=xyt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +@tf.function +def custom_loss(inputs, model): + xyt = inputs + x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3] + # zeros = tf.zeros_like(x) + + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + tape.watch(y) + with tf.GradientTape(persistent=True) as tape2: + tape2.watch(x) + tape2.watch(y) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(t) + tape3.watch(x) + tape3.watch(y) + u_model = model(tf.concat([x,y,t], axis=1)) + u_x = tape3.gradient(u_model, x) + u_y = tape3.gradient(u_model, y) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xy = tape2.gradient(u_x, y) + u_xxx = tape.gradient(u_xx, x) + u_xyy = tape.gradient(u_xy, y) + del tape, tape2, tape3 + + + # v = -nu*u_x + # phi_t = Vprime(u_model) - nu*u_xx - Vprime(u_model_0) + nu*u_0_xx + # w = -nu * u_xx + phi_t/2. - Vprime(u_model) + + # Compute the components of loss function + pde_loss = tf.reduce_mean((u_t + u_model * u_x + u_xxx + u_xyy) ** 2) + + # x_ic = tf.random.uniform((Nx*Ny,1), xMin, xMax) + # y_ic = tf.random.uniform((Nx*Ny,1), yMin, yMax) + x_ic = tf.expand_dims(tf.linspace(xMin, xMax, Nx*Ny), axis=-1) # For grid sampling + y_ic = tf.expand_dims(tf.linspace(yMin, yMax, Nx*Ny), axis=-1) # For grid sampling + t_ic = tf.zeros_like(x_ic) + u_ic = u_0(x_ic, y_ic) # Initial condition + t_ic = tf.zeros_like(x_ic) # t=0 for initial condition + u_ic_pred = model(tf.concat([x_ic, y_ic, t_ic], axis=1)) # Predicted initial condition + data_fitting_loss_0 = tf.reduce_mean((u_ic_pred - u_ic) ** 2) + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + # Combine the components of the loss functions + # loss, loss_type = linear_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], tf.exp(lambdas)) + # loss, loss_type = linear_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + # loss, loss_type = chebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], tf.exp(lambdas)) + # loss, loss_type = chebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + # loss, loss_type = smooth_chebyshev_loss_function(.1, [pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + loss, loss_type = augmentedChebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + + # S_loss = S(u_model, v, w) + H_loss = H(tf.reshape(u_model, shape=[Nx, Ny, Nt]), tf.reshape(u_x, shape=[Nx, Ny, Nt]), tf.reshape(u_y, shape=[Nx, Ny, Nt])) + # beta = 1e-3 + # data_fitting_loss = loss = beta*tf.math.log(tf.math.exp(data_fitting_loss_weight_0 * data_fitting_loss_0 / beta) + # + tf.math.exp(data_fitting_loss_weight_l * data_fitting_loss_l / beta) + # + tf.math.exp(data_fitting_loss_weight_r * data_fitting_loss_r / beta)) + # loss = beta*tf.math.log(tf.math.exp(pde_loss_weight * pde_loss / beta) + # + tf.math.exp(data_fitting_loss_weight_0 * data_fitting_loss_0 / beta) + # + tf.math.exp(data_fitting_loss_weight_l * data_fitting_loss_l / beta) + # + tf.math.exp(data_fitting_loss_weight_r * data_fitting_loss_r / beta)) + # data_fitting_loss = tf.math.reduce_max(tf.constant([data_fitting_loss_weight_0 * data_fitting_loss_0, + # data_fitting_loss_weight_l * data_fitting_loss_l, + # data_fitting_loss_weight_r * data_fitting_loss_r])) + # loss = tf.math.reduce_max(tf.constant([pde_loss_weight * pde_loss, + # data_fitting_loss_weight_0 * data_fitting_loss_0, + # data_fitting_loss_weight_l * data_fitting_loss_l, + # data_fitting_loss_weight_r * data_fitting_loss_r])) + + return loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss#, S_loss + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 500 # 5000 # 1000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=1e-2, + decay_steps=epochs, + end_learning_rate=1e-4, + power=3., + cycle=False, + name= 'PolynomialDecay' +) +# learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( +# initial_learning_rate=1e-3, +# decay_steps=50, # 100 +# decay_rate=0.9, +# staircase=False, +# name='ExponentialDecay' +# ) +# learning_rate = tf.keras.optimizers.schedules.CosineDecay( +# initial_learning_rate=1e-3, +# decay_steps=1000, +# alpha=0.0, +# warmup_target=None, +# warmup_steps=0, +# name='CosineDecay' +# ) +learning_rate_type = learning_rate.name + +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +# S_losses_min = [] +# S_losses_max = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +inputs = xyt_train +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + # Create a LearningRateScheduler to update the learning rate + # current_lr = scheduler(epoch, learning_rate) + # tf.keras.backend.set_value(optimizer.lr, current_lr) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + # param_values.append((trainable[-1]).numpy()) + # delta_gradients.append((gradients[-1]).numpy()) + # S_loss_min = tf.reduce_min(S_loss) + # S_loss_max = tf.reduce_max(S_loss) + # S_losses_min.append(S_loss_min.numpy()) + # S_losses_max.append(S_loss_max.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + # lambdas_values.append((trainable[-1]).numpy()) + + H0 = H_loss[0].numpy() + Hf = H_loss[-1].numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(H_abs_error.numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error.numpy()) + + # # Print S_loss, H_loss + # print(f"S_loss at epoch {epoch + 1}: {S_loss.numpy()}") + # print(f"H_loss at epoch {epoch + 1}: {H_loss.numpy()}") + + if len(losses) > 1 and not lambdas.trainable:# and False: + # SoftAdaptive weights update + # num1 = tf.math.exp(pde_losses[-1] - pde_losses[-2]) + # num2 = tf.math.exp(data_fitting_losses_0[-1] - data_fitting_losses_0[-2]) + # num3 = tf.math.exp(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]) + num = tf.nn.softmax([pde_losses[-1] - pde_losses[-2], data_fitting_losses_0[-1] - data_fitting_losses_0[-2], data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]]) + num1 = num[0] + num2 = num[1] + num3 = num[2] + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + # lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'loss.pdf', dpi=300) + +# # Evaluate the function +# x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# y_eval = np.linspace(y_train[0].numpy(), y_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# inputs_eval = [x_eval, y_eval, t_eval] + +# # Plot the parameters over epochs +# plt.plot(S_losses_min, label='S_loss_min') +# plt.plot(S_losses_max, label='S_loss_max') +# plt.xlabel('Epoch') +# plt.ylabel('Multisymplectic Constant') +# plt.title('Multisymplectic Constant over epochs') +# plt.legend() +# plt.grid() +# +# if save_fig: +# save_fig_string = generate_save_fig_string('S_loss', epochs, learning_rate_type, loss_type) +# # save png +# plt.savefig(save_fig_string, dpi=300) +# # # save pdf +# # plt.savefig('../results/' + 'S_loss.pdf', dpi=300) + + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss.pdf', dpi=300) + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_abs_error = np.array(H_losses_abs_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'cheb_par.pdf', dpi=300) + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/2D/training_history.csv', index=False) +# from mpl_toolkits.mplot3d import Axes3D + +# # Set up meshgrid +# N = 600 +# tspace = np.linspace(0, 2, N + 1) +# xspace = np.linspace(0, 2, N + 1) +# yspace = np.linspace(0, 2, N + 1) +# T, X , Y= np.meshgrid(tspace, xspace, yspace) +# XYTgrid = np.vstack([X.flatten(),Y.flatten(),T.flatten()]).T + +# # Determine predictions of u(t, x) +# u_pred = model(tf.cast(XYTgrid,DTYPE)) + +# # Reshape upred +# U = u_pred.numpy().reshape(N+1,N+1,N+1) + +# # Surface plot of solution u(t,x) +# fig = plt.figure(figsize=(9,6)) +# ax = fig.add_subplot(111, projection='3d') +# ax.plot_surface(X, Y, U, cmap='viridis') +# ax.view_init(35,35) +# ax.set_xlabel('$x$') +# ax.set_ylabel('$y$') +# ax.set_zlabel('$u_\\theta(x,y,t)$') +# ax.set_title('Solution to KdV equation') +# if save_fig: +# save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) +# # save png +# plt.savefig(save_fig_string, dpi=300) +# # # save pdf +# # plt.savefig('../results/' + 'solution.pdf', dpi=300) + +# # Extract the components of lambdas over epochs +# lambda_1 = [l[0] for l in lambdas_values] +# lambda_2 = [l[1] for l in lambdas_values] +# lambda_3 = [l[2] for l in lambdas_values] + +# # Plot the components of lambdas +# plt.figure(figsize=(10, 6)) +# plt.plot(lambda_1, label='$\lambda_1$', color='r') +# plt.plot(lambda_2, label='$\lambda_2$', color='g') +# plt.plot(lambda_3, label='$\lambda_3$', color='b') +# plt.xlabel('Epochs') +# plt.ylabel('Weights Values') +# plt.title('Evolution of weight components over training') +# plt.legend() +# plt.grid() +# + +# # Save the plot if required +# if save_fig: +# save_fig_string = generate_save_fig_string('lambdas', epochs, learning_rate_type, loss_type) +# plt.savefig(save_fig_string, dpi=300) + diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py new file mode 100644 index 0000000..8fc5ae0 --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py @@ -0,0 +1,291 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +Nx, Ny, Nt = 10, 10, 10 +N_collocation = Nx*Ny*Nt +d_in = 3 + +xMin, xMax = 0.0, 8.0 +yMin, yMax = 0.0, 8.0 +tMax = 50. + +def choose_width_depth(N_collocation=N_collocation, overparam_factor=3.0, d_in=d_in, d_out=1): + N_target = int(overparam_factor * N_collocation) + depth = 8 + a, b = depth - 1, d_in + depth + d_out + c = d_out - N_target + width = int((-b + np.sqrt(b*b - 4*a*c)) / (2*a)) + return width, depth + +width, depth = choose_width_depth() +dx = (xMax - xMin) / (Nx - 1) +dy = (yMax - yMin) / (Ny - 1) +dt = tMax / (Nt - 1) +h = np.max([dx, dy, dt]) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=True) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +y = torch.linspace(yMin, yMax, Ny, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +y_grid, x_grid, t_grid = torch.meshgrid(y.flatten(), x.flatten(), t.flatten(), indexing='ij') +x_train = x_grid.flatten().reshape(-1, 1) +y_train = y_grid.flatten().reshape(-1, 1) +t_train = t_grid.flatten().reshape(-1, 1) +xyt_train = torch.stack([x_train.flatten(), y_train.flatten(), t_train.flatten()], dim=1) + +save_fig = True + +def u_0(x, y): + epsilon = 0.01 + c1, c2 = 0.45, 0.25 + x1, x2 = 2.5, 3.3 + y1 = 0. + out = 3*c1/(torch.cosh(0.5*torch.sqrt(torch.tensor(c1/epsilon))*((x-x1)**2 + (y-y1)**2)**0.5))**2 + out += 3*c2/(torch.cosh(0.5*torch.sqrt(torch.tensor(c2/epsilon))*((x-x2)**2 + (y-y1)**2)**0.5))**2 + return out + +def periodic_boundary_conditions(model, Nbc=2000): + x = torch.rand(Nbc, 1, device=device) * (xMax - xMin) + xMin + y = torch.rand(Nbc, 1, device=device) * (yMax - yMin) + yMin + t = torch.rand(Nbc, 1, device=device) * tMax + + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + yL = torch.full_like(y, yMin) + yR = torch.full_like(y, yMax) + + uLx = model(torch.cat([xL, y, t], 1)) + uRx = model(torch.cat([xR, y, t], 1)) + uLy = model(torch.cat([x, yL, t], 1)) + uRy = model(torch.cat([x, yR, t], 1)) + + loss = torch.mean((uLx - uRx)**2 + (uLy - uRy)**2) + return loss + +def H(u, u_x, u_y): + return torch.sum((u_x**2 + u_y**2)/2 - u**3/6) * dx * dy + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +def augmentedChebyshev_loss_function(tensors, weights): + par = torch.sigmoid(cheb_par) + ls = linear_loss_function(tensors, weights)[0] + cs = chebyshev_loss_function(tensors, weights)[0] + return par*cs + (1-par)*ls, 'acs' + +class FourierFeatures(nn.Module): + def __init__(self, n_modes=5): + super().__init__() + self.n_modes = n_modes + + def forward(self, inputs): + x, y, t = inputs[:, 0:1], inputs[:, 1:2], inputs[:, 2:3] + features = [t] + for k in range(1, self.n_modes + 1): + features.append(torch.sin(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(torch.cos(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(torch.sin(2*np.pi*k*(y - yMin)/(yMax-yMin))) + features.append(torch.cos(2*np.pi*k*(y - yMin)/(yMax-yMin))) + return torch.cat(features, dim=1) + +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + self.ff = FourierFeatures(n_modes=4) + layers = [] + # input_dim = 3 + 4 * 2 * 4 + input_dim = 17 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(input_dim, num_neurons_per_layer)) + layers.append(nn.Tanh()) + input_dim = num_neurons_per_layer + layers.append(nn.Linear(input_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + x = self.ff(x) + return self.net(x) + +def custom_loss(inputs, model, dual_opt): + x, y, t = inputs[:, 0:1], inputs[:, 1:2], inputs[:, 2:3] + x.requires_grad_(True) + y.requires_grad_(True) + t.requires_grad_(True) + u_model = model(torch.cat([x, y, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + u_y = torch.autograd.grad(u_model.sum(), y, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_yy = torch.autograd.grad(u_y.sum(), y, create_graph=True)[0] + + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + u_xyy = torch.autograd.grad(u_y.sum(), y, create_graph=True)[0] + + pde_loss = torch.mean((u_t + u_model * u_x + u_xxx + u_xyy) ** 2) + + x_ic = torch.linspace(xMin, xMax, Nx*Ny).reshape(-1, 1).to(device) + y_ic = torch.linspace(yMin, yMax, Nx*Ny).reshape(-1, 1).to(device) + t_ic = torch.zeros_like(x_ic) + u_ic = u_0(x_ic, y_ic) + u_ic_pred = model(torch.cat([x_ic, y_ic, t_ic], dim=1)) + data_fitting_loss_0 = torch.mean((u_ic_pred - u_ic) ** 2) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + loss, loss_type = augmentedChebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + + H_loss = H(u_model.reshape(Nx, Ny, Nt), u_x.reshape(Nx, Ny, Nt), u_y.reshape(Nx, Ny, Nt)) + + # constraint + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + + Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + H_constraint = torch.abs(Hf - H0)/torch.abs(H0) + + eps = 5/(epoch+1) + H_constraint = torch.max(H_constraint - eps, torch.zeros_like(H_constraint)).unsqueeze(0) + + loss = dual_opt.forward_update(loss, H_constraint) + + + return loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + +model = PINNModel().to(device) +epochs = 1000 +lr_schedule = torch.optim.lr_scheduler.PolynomialLR( + torch.optim.Adam(model.parameters(), lr=1e-2), + total_iters=epochs, power=3.0 +) +optimizer = lr_schedule.optimizer + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + +for epoch in range(epochs): + optimizer.zero_grad() + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(xyt_train, model) + loss.backward() + optimizer.step() + lr_schedule.step() + + with torch.no_grad(): + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + +print(f'\nComputation time: {time() - t0:.2f}s') + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/2D_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/2D_H_minmax.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/2D_H_mean_std.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/2D_H_std.png', dpi=300) if save_fig else None +plt.show() + +if cheb_par.requires_grad: + plt.figure(figsize=(10, 6)) + plt.plot(torch.sigmoid(cheb_par).detach()) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + plt.savefig('./results/2D_cheb_par.png', dpi=300) if save_fig else None + plt.show() + +from mpl_toolkits.mplot3d import Axes3D +N = 100 +xspace = torch.linspace(xMin, xMax, N, dtype=DTYPE, device=device) +yspace = torch.linspace(yMin, yMax, N, dtype=DTYPE, device=device) +tspace_val = torch.tensor(tMax, dtype=DTYPE, device=device) +X_grid, Y_grid = torch.meshgrid(xspace, yspace, indexing='ij') +T_grid = torch.full_like(X_grid, tMax) +XYTgrid = torch.stack([X_grid.flatten(), Y_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XYTgrid) +U = u_pred.reshape(N, N) + +X_np = X_grid.cpu().numpy() +Y_np = Y_grid.cpu().numpy() +U_np = U.cpu().numpy() + +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, Y_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$y$') +ax.set_zlabel('$u(x,y,t)$') +ax.set_title('2D PDE Solution') +plt.savefig('./results/2D_solution.png', dpi=300) if save_fig else None +plt.show() diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py new file mode 100644 index 0000000..55ba1d3 --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py @@ -0,0 +1,729 @@ +import numpy as np +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +# Nx = 50 +# Nt = 50 +# N_collocation = Nx*Nt + +xMin = -np.pi +xMax = np.pi +tMax = 5. # 10. +d_in = 2 + +def Nx_from_arch(width, depth, fac=2.0, d_in=2, d_out=1): + """ + Given a PINN architecture (width, depth) and an overparam factor fac, + compute Nx = Nt such that: + + Nx * Nt ≈ N_params / fac, + Nx = Nt, + + where N_params is the number of trainable parameters. + + Parameters + ---------- + width : int + Number of neurons per hidden layer. + depth : int + Number of hidden layers. + fac : float + Over-parameterization factor. Typical values: fac = 2 or 3. + d_in : int + Input dimension (usually 2: x,t). + d_out : int + Output dimension (usually 1: u). + + Returns + ------- + Nx : int + Nt : int + Ntheta : int + Total number of trainable parameters. + Ncoll_target : int + Target number of collocation points = Ntheta/fac. + """ + + # Parameter count + Ntheta = (d_in + 1) * width \ + + (depth - 1) * (width * width + width) \ + + d_out * (width + 1) + + # Target collocation count + Ncoll_target = int(Ntheta / fac) + + # Square grid Nx = Nt + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + + return Nx, Nt, Ntheta, Ncoll_target + +width = 80 +depth = 4 + +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +def h_from_NxNt(Nx, Nt, xMin, xMax, tMax): + """ + Compute dx, dt, and h from Nx, Nt and the domain extents. + h is defined as max(dx, dt). + + Returns + ------- + dx : float + dt : float + h : float + """ + + Lx = xMax - xMin + Lt = tMax + + dx = Lx / (Nx - 1) + dt = Lt / (Nt - 1) + + h = max(dx, dt) + + return dx, dt, h + +dx, dt, h = h_from_NxNt(Nx, Nt, xMin, xMax, tMax) + +lambdas = [1., 1., 1.] +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) +do_training = True +cheb_par = tf.Variable(0.5, trainable=False, name='cheb_par', dtype=DTYPE) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) + +x_train = tf.expand_dims(tf.convert_to_tensor(x.flatten()), axis=-1) +t_train = tf.expand_dims(tf.convert_to_tensor(t.flatten()), axis=-1) + +save_fig = True + +# Define the initial condition +def u_0(x): + return 0.2+0.1*tf.math.cos(2 * x) + + +def u_0_x(x): + return -0.2*tf.math.sin(2 * x) + + +def periodic_boundary_conditions(model, Nbc=2000): + + # Random boundary sampling (correct choice) + x = tf.random.uniform((Nbc,1), xMin, xMax) + t = tf.random.uniform((Nbc,1), 0.0, tMax) + + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + + with tf.GradientTape(persistent=True) as tape: + tape.watch([xL, xR]) + + uLx = model(tf.concat([xL, t], 1)) + uRx = model(tf.concat([xR, t], 1)) + + # First derivatives + uxL = tape.gradient(uLx, xL) + uxR = tape.gradient(uRx, xR) + + del tape + + # Enforce periodicity of values AND derivatives + loss = tf.reduce_mean( + (uLx - uRx)**2 + + (uxL - uxR)**2 + ) + + return loss + + +# def H(u, u_x, dx): +# return tf.reduce_sum(tf.pow(u, 3)+u*tf.pow(u_x, 2), axis = -1) * dx +# # return tf.reduce_sum((tf.pow(u, 2)+tf.pow(u_x, 2))/2, axis = -1) * dx + +def ch_density(u, u_x): + return tf.pow(u, 3) + u * tf.pow(u_x, 2) + + +# @tf.function +def H(u, u_x, dx, density_fn=ch_density, axis=-1): + """ + Boole’s rule (8th order) along 'axis' for uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise uses Boole on the largest prefix and trapezoid on remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = tf.shape(f)[axis] + + # Trapezoid as a fallback on short tails + def _trap_rem(rem): + # rem: [..., M] contiguous tail; integrate with trapezoid + return tf.reduce_sum(0.5*(rem[..., 1:] + rem[..., :-1]), axis=-1) * tf.cast(dx, f.dtype) + + # Degenerate + if tf.less_equal(n, 1): + return tf.reduce_sum(f, axis=axis) * dx + + # Largest prefix with (n1-1) % 4 == 0 + n1 = n - ((n - 1) % 4) + # Boole constant for uniform spacing: 2*dx/45 + c = (2.0 * dx) / 45.0 + + # Indices for prefix + idx_prefix = tf.range(n1) + f0 = tf.gather(f, idx_prefix[0::4], axis=axis) # 0,4,8,... + f1 = tf.gather(f, idx_prefix[1::4], axis=axis) # 1,5,9,... + f2 = tf.gather(f, idx_prefix[2::4], axis=axis) # 2,6,10,... + f3 = tf.gather(f, idx_prefix[3::4], axis=axis) # 3,7,11,... + f4 = tf.gather(f, idx_prefix[4::4], axis=axis) # 4,8,12,... (last block end) + + # Weighted sum across blocks + # Boole's block weights per 5 nodes: [7, 32, 12, 32, 7] + # Aggregate across all blocks by summing slices + s = 7.0 * tf.reduce_sum(f0, axis=axis) + s += 32.0 * tf.reduce_sum(f1, axis=axis) + s += 12.0 * tf.reduce_sum(f2, axis=axis) + s += 32.0 * tf.reduce_sum(f3, axis=axis) + s += 7.0 * tf.reduce_sum(f4, axis=axis) + + boole_part = c * s + + # Tail remainder + if tf.equal(n1, n): + return boole_part + + rem = tf.gather(f, tf.range(n1-1, n), axis=axis) # nodes: n1-1 .. n-1 + tail = _trap_rem(rem) + return boole_part + tail + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +def sigmoid_centered(x): + return 2*tf.nn.sigmoid(.5*x) - 1 + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,40 + xt_input = tf.keras.Input(shape=(2,)) + output_u = xt_input + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation='gelu', # tanh + kernel_initializer='glorot_normal', #'glorot_uniform', # glorot_normal + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', + kernel_initializer='glorot_normal', #'glorot_uniform', # glorot_normal + )(output_u) + + return tf.keras.Model(inputs=xt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = tf.cast(epoch, tf.float32) + return lam_max * (1.0 - tf.exp(-kappa * tf.maximum(epoch - start, 0.0))) + + + +# @tf.function +def custom_loss(inputs, model): + x, t = inputs[:, 0:1], inputs[:, 1:2] + + with tf.GradientTape(persistent=True) as outerTape: + outerTape.watch(x) + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + with tf.GradientTape(persistent=False) as tape2: + tape2.watch(x) + tape2.watch(t) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(x) + tape3.watch(t) + u_model = model(tf.stack([x[:, 0], t[:, 0]], axis=1)) + u_x = tape3.gradient(u_model, x) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xxt = tape.gradient(u_xx, t) + u_xxx = tape.gradient(u_xx, x) + + # === Camassa–Holm residual === + r = ( + u_t + - u_xxt + + 3.0 * u_model * u_x + - 2.0 * u_x * u_xx + - u_model * u_xxx + ) + r_x = outerTape.gradient(r, x) + + # Clean up + del tape, tape2, tape3, outerTape + + lam = lambda_grad(epoch) + + # === H1 norm of residual === + pde_loss_L2 = tf.reduce_mean(tf.square(r)) + pde_loss_grad = tf.reduce_mean(tf.square(r_x)) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = tf.where(tf.abs(t) < 1e-6) + x_ic = tf.gather(x, ic_mask[:, 0]) + u_ic = u_0(x_ic) + t_ic = tf.zeros_like(x_ic) + u_ic_pred = model(tf.concat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = tf.reduce_mean(tf.square(u_ic_pred - u_ic)) + + # === Periodic boundary conditions === + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + # === Chebyshev aggregation === + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # === Hamiltonian (monitor only) === + H_loss = H( + tf.reshape(u_model, shape=[Nt, Nx]), + tf.reshape(u_x, shape=[Nt, Nx]), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 1000 # 3000 # 5000 # 2000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +# learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, +# end_learning_rate=1e-5, +# power=2., +# cycle=False, # True +# name='PolynomialDecay' +# ) +# learning_rate_type = 'polynomialDecay' +learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=1e-2, + decay_steps=epochs, # 100 + decay_rate=0.9, + staircase=False, + name='ExponentialDecay' +) +learning_rate_type = 'exponentialDecay' +# learning_rate = tf.keras.optimizers.schedules.CosineDecay( +# initial_learning_rate=1e-3, +# decay_steps=1000, +# alpha=0.0, +# name='CosineDecay', +# warmup_target=None, +# warmup_steps=0 +# ) +# learning_rate_type = 'cosineDecay' + +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +x_train = tf.convert_to_tensor(x_train) +t_train = tf.convert_to_tensor(t_train) +x_grid, t_grid = np.meshgrid(x.flatten(), t.flatten()) +inputs = tf.convert_to_tensor(np.vstack([x_grid.flatten(), t_grid.flatten()]).T) +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + # Create a LearningRateScheduler to update the learning rate + # current_lr = scheduler(epoch, learning_rate) + # tf.keras.backend.set_value(optimizer.lr, current_lr) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + + H0 = H(u_0(x_grid), u_0_x(x_grid), dx) # H0 = H_loss[0].numpy() + Hf = H_loss.numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(tf.reduce_max(H_abs_error).numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error[-1].numpy()) + + # lambdas_values.append((trainable[-1]).numpy()) + if len(losses) > 1 and not lambdas.trainable and do_training: + # SoftAdaptive weights update + num1 = tf.math.exp(pde_losses[-1] - pde_losses[-2]) + num2 = tf.math.exp(data_fitting_losses_0[-1] - data_fitting_losses_0[-2]) + num3 = tf.math.exp(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]) + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +# print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +# print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +print(f"Hamitonian relative error: {H_rel_error[-1].numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'loss.pdf', dpi=300) + +# Evaluate the function +x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +inputs_eval = [x_eval, t_eval] + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='min H') +plt.plot(H_losses_max, label='max H') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss.pdf', dpi=300) + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_rel_error = np.array(H_losses_rel_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'cheb_par.pdf', dpi=300) + + +from mpl_toolkits.mplot3d import Axes3D + +# Set up meshgrid +N = 600 +tspace = np.linspace(0, tMax, N + 1) +xspace = np.linspace(xMin, xMax, N + 1) +T, X = np.meshgrid(tspace, xspace) +XTgrid = np.vstack([X.flatten(),T.flatten()]).T + +# Determine predictions of u(t, x) +u_pred = model(tf.cast(XTgrid,DTYPE)) + +# Reshape upred +U = u_pred.numpy().reshape(N+1,N+1) + +# Surface plot of solution u(t,x) +fig = plt.figure(figsize=(9,6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X, T, U, cmap='viridis') +ax.view_init(35,35) +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u_\\theta(x,t)$') +ax.set_title('Solution to Camassa-Holm equation') +ax.set_box_aspect(None, zoom=0.85) + +if save_fig: + save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'solution.pdf', dpi=300) + + + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/camassa/training_history.csv', index=False) \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py new file mode 100644 index 0000000..c10d84f --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py @@ -0,0 +1,394 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt + +from humancompatible.train.dual_optim import ALM, iALM + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +xMin, xMax = -np.pi, np.pi +tMax = 5. +d_in = 2 + +def Nx_from_arch(width, depth, fac=2.0, d_in=2, d_out=1): + Ntheta = (d_in + 1) * width + (depth - 1) * (width * width + width) + d_out * (width + 1) + Ncoll_target = int(Ntheta / fac) + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + return Nx, Nt, Ntheta, Ncoll_target + +width, depth = 80, 4 +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +dx = (xMax - xMin) / (Nx - 1) +dt = tMax / (Nt - 1) +h = max(dx, dt) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +do_training = True +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=False) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +x_train = x.reshape(-1, 1) +t_train = t.reshape(-1, 1) +t_grid, x_grid = torch.meshgrid(t.flatten(), x.flatten(), indexing='ij') +inputs = torch.stack([x_grid.flatten(), t_grid.flatten()], dim=1) + +save_fig = True + +def u_0(x): + return 0.2 + 0.1 * torch.cos(2 * x) + +def u_0_x(x): + return -0.2 * torch.sin(2 * x) + +def periodic_boundary_conditions(model, Nbc=2000): + x = torch.rand(Nbc, 1, device=device) * (xMax - xMin) + xMin + t = torch.rand(Nbc, 1, device=device) * tMax + + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + + uLx = model(torch.cat([xL, t], 1)) + uRx = model(torch.cat([xR, t], 1)) + + loss = torch.mean((uLx - uRx)**2) + return loss + +def ch_density(u, u_x): + return u**3 + u * u_x**2 + +def H(u, u_x, dx, density_fn=ch_density): + f = density_fn(u, u_x) + return torch.sum(f) * dx + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +### MODEL ### + +# @torch.compile +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + layers = [] + in_dim = 2 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(in_dim, num_neurons_per_layer)) + layers.append(nn.GELU()) + in_dim = num_neurons_per_layer + layers.append(nn.Linear(in_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = float(epoch) + return lam_max * (1.0 - np.exp(-kappa * max(epoch - start, 0.0))) + + +##### UNCONSTRAINED LOSS FUNCTION WITH H1 REGULARIZATION ##### + + +# @torch.compile +def custom_loss(inputs, model, epoch): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxt = torch.autograd.grad(u_xx.sum(), t, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + r = u_t - u_xxt + 3.0 * u_model * u_x - 2.0 * u_x * u_xx - u_model * u_xxx + r_x = torch.autograd.grad(r.sum(), x, create_graph=True)[0] + + pde_loss_L2 = torch.mean(torch.square(r)) + pde_loss_grad = torch.mean(torch.square(r_x)) + + lam = lambda_grad(epoch) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean(torch.square(u_ic_pred - u_ic)) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + H_loss = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + + return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + + +#### LOSS FUNCTION WITH H1 CONSTRAINT #### + +def lagrangian_loss(inputs, model, dual_opt, epoch, H0=None): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + u_model_0 = model(torch.cat([x, torch.zeros_like(t)], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + u_x_0 = torch.autograd.grad(u_model_0.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxt = torch.autograd.grad(u_xx.sum(), t, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + r = u_t - u_xxt + 3.0 * u_model * u_x - 2.0 * u_x * u_xx - u_model * u_xxx + r_x = torch.autograd.grad(r.sum(), x, create_graph=True)[0] + + pde_loss_L2 = torch.mean(torch.square(r)) + pde_loss_grad = torch.mean(torch.square(r_x)) + + lam = lambda_grad(epoch) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean(torch.square(u_ic_pred - u_ic)) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # constraint + + + Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + H0 = H(u_model_0.reshape(Nt, Nx), u_x_0.reshape(Nt, Nx), dx) + + H_constraint = (torch.abs(Hf - H0)/torch.abs(H0)).unsqueeze(0) + + eps = 1/(epoch+1)**2 + H_constraint = H_constraint - eps + + loss = dual_opt.forward_update(loss, H_constraint) + + return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, Hf, H0 + + +####### TRAINING LOOP ####### + + +model = PINNModel().to(device) +epochs = 1000 + +optimizer = torch.optim.NAdam(model.parameters(), lr=1e-2, betas=(0.8, 0.9), eps=1e-07) +lr_schedule = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9**(1/epochs)) + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + + +# dual_opt = ALM(m=1, lr=1e-3, device=device, penalty=0., is_ineq=True) +dual_opt = iALM(m=1, beta=0.001, sigma=1.001, gamma=1., dual_range=(-10.,10.), is_ineq=True) + +H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + +for epoch in range(epochs): + optimizer.zero_grad() + # loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) + + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss, H0 = lagrangian_loss(inputs, model, dual_opt, epoch, H0) + loss.backward() + optimizer.step() + + lr_schedule.step() + + with torch.no_grad(): + + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + # H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + Hf = H_loss.detach() + H_abs_error = torch.abs(Hf - H0) + H_losses_abs_error.append(torch.max(H_abs_error).item()) + H_rel_error = H_abs_error / (torch.abs(H0) + 1e-16) + if isinstance(H_rel_error, torch.Tensor): + H_rel_error = H_rel_error.item() if H_rel_error.numel() == 1 else H_rel_error.max().item() + H_losses_rel_error.append(H_rel_error) + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + + # lambdas_values.append((trainable[-1]).numpy()) + if len(losses) > 1: + # SoftAdaptive weights update + num1 = np.exp(pde_losses[-1] - pde_losses[-2]) + num2 = np.exp(data_losses_0[-1] - data_losses_0[-2]) + num3 = np.exp(bc_losses[-1] - bc_losses[-2]) + den = num1 + num2 + num3 + + new_lambdas = torch.tensor([num1 / den, num2 / den, num3 / den]) + lambdas = new_lambdas + # lambdas_values.append((lambdas).numpy()) + +print(f'\nComputation time: {time() - t0:.2f}s') +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean}") +print(f"Hamiltonian std: {H_loss_std}") +print(f"Hamiltonian max: {H_loss_max}") +print(f"Hamiltonian min: {H_loss_min}") + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/ch_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='min H') +plt.plot(H_losses_max, label='max H') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/ch_H_loss.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/ch_H_loss_mean.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_std.png', dpi=300) if save_fig else None +plt.show() + +H_losses_abs_error = np.array(H_losses_abs_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_abs_error.png', dpi=300) if save_fig else None +plt.show() + +H_losses_rel_error = np.array(H_losses_rel_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_rel_error) +# plt.plot() +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_rel_error.png', dpi=300) if save_fig else None +plt.show() + +N = 600 +tspace = torch.linspace(0, tMax, N + 1, dtype=DTYPE, device=device) +xspace = torch.linspace(xMin, xMax, N + 1, dtype=DTYPE, device=device) +T_grid, X_grid = torch.meshgrid(tspace, xspace, indexing='ij') +XTgrid = torch.stack([X_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XTgrid) +U = u_pred.reshape(N+1, N+1) + +X_np = X_grid.cpu().numpy() +T_np = T_grid.cpu().numpy() +U_np = U.cpu().numpy() + +from mpl_toolkits.mplot3d import Axes3D +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, T_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u(x,t)$') +ax.set_title('Camassa-Holm equation') +plt.savefig('./results/ch_solution.png', dpi=300) if save_fig else None +plt.show() + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_losses_0 +df['data_fitting_loss_l_r'] = bc_losses +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/camassa/torch_training_history.csv', index=False) \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py new file mode 100644 index 0000000..781639e --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py @@ -0,0 +1,769 @@ +import numpy as np +from sympy import false, per +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +# Nx = 100 +# Nt = 100 +# N_collocation = Nx*Nt + +type = 1 # 1 for KdV, 2 + +if type == 1: + nu = -0.022**2 + alpha = -0.5 + rho = 0. + xMin = 0. + xMax = 2. + tMax = 5. # 10. +elif type == 2: + nu = -1. + alpha = -3. + rho = 0. + xMin = -20. + xMax = 20. + tMax = 100. # 4. + + +def Nx_from_arch(width, depth, fac=1.5, d_in=2, d_out=1): + """ + Given a PINN architecture (width, depth) and an overparam factor fac, + compute Nx = Nt such that: + + Nx * Nt ≈ N_params / fac, + Nx = Nt, + + where N_params is the number of trainable parameters. + + Parameters + ---------- + width : int + Number of neurons per hidden layer. + depth : int + Number of hidden layers. + fac : float + Over-parameterization factor. Typical values: fac = 2 or 3. + d_in : int + Input dimension (usually 2: x,t). + d_out : int + Output dimension (usually 1: u). + + Returns + ------- + Nx : int + Nt : int + Ntheta : int + Total number of trainable parameters. + Ncoll_target : int + Target number of collocation points = Ntheta/fac. + """ + + # Parameter count + Ntheta = (d_in + 1) * width \ + + (depth - 1) * (width * width + width) \ + + d_out * (width + 1) + + # Target collocation count + Ncoll_target = int(Ntheta / fac) + + # Square grid Nx = Nt + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + + return Nx, Nt, Ntheta, Ncoll_target + + +width = 80 +depth = 4 + +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +def h_from_NxNt(Nx, Nt, xMin, xMax, tMax): + """ + Compute dx, dt, and h from Nx, Nt and the domain extents. + h is defined as max(dx, dt). + + Returns + ------- + dx : float + dt : float + h : float + """ + + Lx = xMax - xMin + Lt = tMax + + dx = Lx / (Nx - 1) + dt = Lt / (Nt - 1) + + h = max(dx, dt) + + return dx, dt, h + +dx, dt, h = h_from_NxNt(Nx, Nt, xMin, xMax, tMax) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) + +x_train = tf.expand_dims(tf.convert_to_tensor(x.flatten()), axis=-1) +t_train = tf.expand_dims(tf.convert_to_tensor(t.flatten()), axis=-1) + +lambdas = [1., 1., 1.] +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) +do_training = False +cheb_par = tf.Variable(0.5, trainable=False, name='cheb_par', dtype=DTYPE) + +save_fig = True + +# Define the initial condition +def u_0(x): + if type == 1: + return tf.math.cos(np.pi * x) + elif type == 2: + return 6./(tf.math.cosh(x)**2) + + +def u_0_x(x): + if type == 1: + return -np.pi*tf.math.sin(np.pi * x) + elif type == 2: + return -12.*tf.math.sinh(x)/(tf.math.cosh(x)**3) + + +def periodic_bc(model, x, t): + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + uL = model(tf.concat([xL, t], axis=1)) + uR = model(tf.concat([xR, t], axis=1)) + return tf.reduce_mean((uL - uR)**2) + + +def V(u): + return alpha*tf.pow(u, 3)/3 + rho*tf.pow(u, 2)/2 + + +def kdv_density(u, u_x): + return V(u)-nu*tf.pow(u_x, 2)/2 + +# @tf.function +def H(u, u_x, dx, density_fn=kdv_density, axis=-1): + """ + Boole’s rule (8th order) along 'axis' for uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise uses Boole on the largest prefix and trapezoid on remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = tf.shape(f)[axis] + + # Trapezoid as a fallback on short tails + def _trap_rem(rem): + # rem: [..., M] contiguous tail; integrate with trapezoid + return tf.reduce_sum(0.5*(rem[..., 1:] + rem[..., :-1]), axis=-1) * tf.cast(dx, f.dtype) + + # Degenerate + if tf.less_equal(n, 1): + return tf.reduce_sum(f, axis=axis) * dx + + # Largest prefix with (n1-1) % 4 == 0 + n1 = n - ((n - 1) % 4) + # Boole constant for uniform spacing: 2*dx/45 + c = (2.0 * dx) / 45.0 + + # Indices for prefix + idx_prefix = tf.range(n1) + f0 = tf.gather(f, idx_prefix[0::4], axis=axis) # 0,4,8,... + f1 = tf.gather(f, idx_prefix[1::4], axis=axis) # 1,5,9,... + f2 = tf.gather(f, idx_prefix[2::4], axis=axis) # 2,6,10,... + f3 = tf.gather(f, idx_prefix[3::4], axis=axis) # 3,7,11,... + f4 = tf.gather(f, idx_prefix[4::4], axis=axis) # 4,8,12,... (last block end) + + # Weighted sum across blocks + # Boole's block weights per 5 nodes: [7, 32, 12, 32, 7] + # Aggregate across all blocks by summing slices + s = 7.0 * tf.reduce_sum(f0, axis=axis) + s += 32.0 * tf.reduce_sum(f1, axis=axis) + s += 12.0 * tf.reduce_sum(f2, axis=axis) + s += 32.0 * tf.reduce_sum(f3, axis=axis) + s += 7.0 * tf.reduce_sum(f4, axis=axis) + + boole_part = c * s + + # Tail remainder + if tf.equal(n1, n): + return boole_part + + rem = tf.gather(f, tf.range(n1-1, n), axis=axis) # nodes: n1-1 .. n-1 + tail = _trap_rem(rem) + return boole_part + tail + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +def sigmoid_centered(x): + return 2*tf.nn.sigmoid(x) - 1 + + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,40 + xt_input = tf.keras.Input(shape=(2,)) + output_u = xt_input + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation=sigmoid_centered, # mish + kernel_initializer='glorot_uniform', # glorot_normal + # kernel_constraint=tf.keras.constraints.UnitNorm(axis=0) + # lora_rank=10 + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', + kernel_initializer='glorot_uniform', # glorot_normal + # kernel_constraint=tf.keras.constraints.UnitNorm(axis=0) + # lora_rank=10 + )(output_u) + + return tf.keras.Model(inputs=xt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = tf.cast(epoch, tf.float32) + return lam_max * (1.0 - tf.exp(-kappa * tf.maximum(epoch - start, 0.0))) + + +@tf.function +def grad_L2_fft_batch(r, L): + """ + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = tf.cast(r, tf.complex64) + Nx = tf.shape(r)[-1] + + k_pos = tf.range(0, Nx//2 + 1, dtype=tf.float32) + k_neg = tf.range(-Nx//2 + 1, 0, dtype=tf.float32) + k = tf.concat([k_pos, k_neg], axis=0) + k = (2.0 * tf.constant(np.pi) / L) * k + k = tf.cast(k, tf.complex64) + + r_hat = tf.signal.fft(r) + + grad_energy = tf.reduce_sum(tf.abs(1j * k * r_hat)**2, axis=-1) + + dx = L / tf.cast(Nx, tf.float32) + return tf.math.real(grad_energy) * dx + + +@tf.function +def H1_norm_fft_batch(r, L): + """ + Compute ||r||_{H^1}^2 for each time slice. + + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = tf.cast(r, tf.complex64) + Nx = tf.shape(r)[-1] + + k_pos = tf.range(0, Nx//2 + 1, dtype=tf.float32) + k_neg = tf.range(-Nx//2 + 1, 0, dtype=tf.float32) + k = tf.concat([k_pos, k_neg], axis=0) + k = (2.0 * tf.constant(np.pi) / L) * k + k = tf.cast(k, tf.complex64) + + r_hat = tf.signal.fft(r) + weight = 1.0 + tf.abs(k)**2 + + H1_sq = tf.reduce_sum(weight * tf.abs(r_hat)**2, axis=-1) + + dx = L / tf.cast(Nx, tf.float32) + return tf.math.real(H1_sq) * dx + + +# @tf.function +def custom_loss(inputs, model, epoch): + x, t = inputs[:, 0:1], inputs[:, 1:2] + + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + with tf.GradientTape(persistent=True) as tape2: + tape2.watch(x) + tape2.watch(t) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(x) + tape3.watch(t) + u_model = model(tf.concat([x, t], axis=1)) + u_x = tape3.gradient(u_model, x) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xxx = tape.gradient(u_xx, x) + u_squared_x = 2*u_model*u_x + r = u_t - alpha * u_squared_x - rho*u_x - nu*u_xxx + del tape, tape2, tape3 + + # === PDE residual loss (stabilized, consistent) === + pde_loss_L2 = tf.reduce_mean(tf.square(r)) + + r_grid = tf.reshape(r, [Nt, Nx]) + L = xMax - xMin + + pde_loss_grad = tf.reduce_mean( + grad_L2_fft_batch(r_grid, L) + ) + + # mesh-scaled stabilization parameter + lam = 0.01 * (dx**2) * tf.minimum(1.0, epoch / 1000.0) + + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = tf.where(tf.abs(t) < 1e-6) + x_ic = tf.gather(x, ic_mask[:, 0]) + u_ic = u_0(x_ic) + t_ic = tf.zeros_like(x_ic) + u_ic_pred = model(tf.concat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = tf.reduce_mean(tf.square(u_ic_pred - u_ic)) + + # === Periodic BC === + data_fitting_loss_l_r = periodic_bc(model, x, t) + + # === Chebyshev aggregation === + # loss, loss_type = chebyshev_loss_function( + # [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + # lambdas + # ) + # loss, loss_type = augmentedChebyshev_loss_function( + # [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + # lambdas + # ) + loss, loss_type = linear_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # === Hamiltonian (monitor only) === + H_loss = H( + tf.reshape(u_model, shape=[Nt, Nx]), + tf.reshape(u_x, shape=[Nt, Nx]), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 5000 # 1000, 2000, 5000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +# learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, +# end_learning_rate=1e-5, +# power=2., +# cycle=False, +# name='PolynomialDecay' +# ) +# learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, # 100 +# decay_rate=0.9, # 0.9 +# staircase=False, +# name='ExponentialDecay' +# ) +# learning_rate_type = 'exponentialDecay' +learning_rate = tf.keras.optimizers.schedules.CosineDecay( + initial_learning_rate=1e-4, + decay_steps=1000, + alpha=0.5, + name='CosineDecay', + warmup_target=None, + warmup_steps=100 +) +learning_rate_type = 'cosineDecay' +# param_values = [delta.numpy()] +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.AdamW(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +# S_losses_min = [] +# S_losses_max = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +x_train = tf.convert_to_tensor(x_train) +t_train = tf.convert_to_tensor(t_train) +x_grid, t_grid = np.meshgrid(x.flatten(), t.flatten()) +inputs = tf.convert_to_tensor(np.vstack([x_grid.flatten(), t_grid.flatten()]).T) +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model, epoch) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + + H0 = H(u_0(x_grid), u_0_x(x_grid), dx) # H0 = H_loss[0].numpy() + Hf = H_loss.numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(tf.reduce_max(H_abs_error).numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error[-1].numpy()) + + # # Print S_loss, H_loss + # print(f"S_loss at epoch {epoch + 1}: {S_loss.numpy()}") + # print(f"H_loss at epoch {epoch + 1}: {H_loss.numpy()}") + + if len(losses) > 1 and not lambdas.trainable and do_training: + # SoftAdaptive weights update + # num1 = tf.math.exp(tf.experimental.numpy.cbrt(pde_losses[-1] - pde_losses[-2])) + # num2 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + # num3 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + num1 = tf.math.exp((pde_losses[-1] - pde_losses[-2])) + num2 = tf.math.exp((data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + num3 = tf.math.exp((data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + # if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + # stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +# print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +# print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +print(f"Hamitonian relative error: {H_rel_error[-1].numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + +import pandas as pd + +df = pd.DataFrame() +df['epoch'] = range(1, epochs + 1) + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/kdv/training_history.csv', index=False) + + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Evaluate the function +x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +inputs_eval = [x_eval, t_eval] + + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_abs_error = np.array(H_losses_abs_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +from mpl_toolkits.mplot3d import Axes3D + +# Set up meshgrid +N = 600 +tspace = np.linspace(0, tMax, N + 1) +xspace = np.linspace(xMin, xMax, N + 1) +T, X = np.meshgrid(tspace, xspace) +XTgrid = np.vstack([X.flatten(),T.flatten()]).T + +# Determine predictions of u(t, x) +u_pred = model(tf.cast(XTgrid,DTYPE)) + +# Reshape upred +U = u_pred.numpy().reshape(N+1,N+1) + +# Surface plot of solution u(t,x) +fig = plt.figure(figsize=(9,6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X, T, U, cmap='viridis') +ax.view_init(35,35) +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u_(x,t)$') +ax.set_title('Solution to KdV equation') +ax.set_box_aspect(None, zoom=0.85) + +if save_fig: + save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py new file mode 100644 index 0000000..4a028ea --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py @@ -0,0 +1,573 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt +from humancompatible.train.dual_optim import ALM + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +type_pde = 1 +if type_pde == 1: + nu, alpha, rho = -0.022**2, -0.5, 0. + xMin, xMax, tMax = 0., 2., 5. +elif type_pde == 2: + nu, alpha, rho = -1., -3., 0. + xMin, xMax, tMax = -20., 20., 100. + +def Nx_from_arch(width, depth, fac=1.5, d_in=2, d_out=1): + Ntheta = (d_in + 1) * width + (depth - 1) * (width * width + width) + d_out * (width + 1) + Ncoll_target = int(Ntheta / fac) + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + return Nx, Nt, Ntheta, Ncoll_target + +width, depth = 80, 4 +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +dx = (xMax - xMin) / (Nx - 1) +dt = tMax / (Nt - 1) +h = max(dx, dt) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +# do_training = False +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=False) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +x_train = x.reshape(-1, 1) +t_train = t.reshape(-1, 1) +# t_grid, x_grid = torch.meshgrid(t.flatten(), x.flatten(), indexing='ij') +x_grid, t_grid = torch.meshgrid(x.flatten(), t.flatten(), indexing='xy') +inputs = torch.stack([x_grid.flatten(), t_grid.flatten()], dim=1) + +save_fig = True + +def u_0(x): + if type_pde == 1: + return torch.cos(np.pi * x) + elif type_pde == 2: + return 6. / (torch.cosh(x)**2) + +def u_0_x(x): + if type_pde == 1: + return -np.pi * torch.sin(np.pi * x) + elif type_pde == 2: + return -12. * torch.sinh(x) / (torch.cosh(x)**3) + +def periodic_bc(model, x, t): + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + uL = model(torch.cat([xL, t], dim=1)) + uR = model(torch.cat([xR, t], dim=1)) + return torch.mean((uL - uR)**2) + +def V(u): + return alpha * (u**3) / 3 + (rho * u**2) / 2 + +def kdv_density(u, u_x): + return V(u) - nu * torch.pow(u_x, 2) / 2 + + +def H(u, u_x, dx, density_fn=kdv_density, axis=-1): + """ + Boole’s rule (8th order) along `axis` for a uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise applies Boole on the largest valid + prefix and trapezoid rule on the remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = f.shape[axis] + + # Normalize negative axis + axis = axis % f.ndim + + def _trap_rem(rem): + """ + rem: contiguous tail segment integrated with trapezoid rule + """ + left = rem.narrow(-1, 0, rem.shape[-1] - 1) + right = rem.narrow(-1, 1, rem.shape[-1] - 1) + return torch.sum(0.5 * (left + right), dim=-1) * dx + + # Degenerate case + if n <= 1: + return torch.sum(f, dim=axis) * dx + + # Largest prefix satisfying (n1 - 1) % 4 == 0 + n1 = n - ((n - 1) % 4) + + # Boole constant + c = (2.0 * dx) / 45.0 + + # Build index slices + idx_prefix = torch.arange(n1, device=f.device) + + f0 = torch.index_select(f, axis, idx_prefix[0::4]) + f1 = torch.index_select(f, axis, idx_prefix[1::4]) + f2 = torch.index_select(f, axis, idx_prefix[2::4]) + f3 = torch.index_select(f, axis, idx_prefix[3::4]) + f4 = torch.index_select(f, axis, idx_prefix[4::4]) + + # Weighted Boole sum + s = 7.0 * torch.sum(f0, dim=axis) + s += 32.0 * torch.sum(f1, dim=axis) + s += 12.0 * torch.sum(f2, dim=axis) + s += 32.0 * torch.sum(f3, dim=axis) + s += 7.0 * torch.sum(f4, dim=axis) + + boole_part = c * s + + # No remainder + if n1 == n: + return boole_part + + # Tail remainder + rem_idx = torch.arange(n1 - 1, n, device=f.device) + rem = torch.index_select(f, axis, rem_idx) + + tail = _trap_rem(rem) + + return boole_part + tail + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +def sigmoid_centered(x): + return 2 * torch.sigmoid(x) - 1 + +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + layers = [] + in_dim = 2 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(in_dim, num_neurons_per_layer)) + layers.append(nn.Tanh()) + in_dim = num_neurons_per_layer + layers.append(nn.Linear(in_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e0, + kappa=1e-3): + epoch = torch.as_tensor(epoch, dtype=torch.float32) + return lam_max * ( + 1.0 - torch.exp(-kappa * torch.clamp(epoch - start, min=0.0)) + ) + + +def grad_L2_fft_batch(r, L): + """ + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = r.to(torch.complex64) + Nx = r.shape[-1] + + device = r.device + + k_pos = torch.arange(0, Nx // 2 + 1, + dtype=torch.float32, + device=device) + + k_neg = torch.arange(-Nx // 2 + 1, 0, + dtype=torch.float32, + device=device) + + k = torch.cat([k_pos, k_neg], dim=0) + k = (2.0 * np.pi / L) * k + k = k.to(torch.complex64) + + r_hat = torch.fft.fft(r, dim=-1) + + grad_energy = torch.sum(torch.abs(1j * k * r_hat) ** 2, dim=-1) + + dx = L / float(Nx) + + return torch.real(grad_energy) * dx + + +def H1_norm_fft_batch(r, L): + """ + Compute ||r||_{H^1}^2 for each time slice. + + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = r.to(torch.complex64) + Nx = r.shape[-1] + + device = r.device + + k_pos = torch.arange(0, Nx // 2 + 1, + dtype=torch.float32, + device=device) + + k_neg = torch.arange(-Nx // 2 + 1, 0, + dtype=torch.float32, + device=device) + + k = torch.cat([k_pos, k_neg], dim=0) + k = (2.0 * np.pi / L) * k + k = k.to(torch.complex64) + + r_hat = torch.fft.fft(r, dim=-1) + + weight = 1.0 + torch.abs(k) ** 2 + + H1_sq = torch.sum(weight * torch.abs(r_hat) ** 2, dim=-1) + + dx = L / float(Nx) + + return torch.real(H1_sq) * dx + + +def custom_loss(inputs, model, epoch): + """ + Assumes the following globals/functions exist: + + alpha, rho, nu + Nt, Nx + dx, xMin, xMax + lambdas + + u_0(...) + periodic_bc(...) + linear_loss_function(...) + H(...) + """ + + x = inputs[:, 0:1].clone().detach().requires_grad_(True) + t = inputs[:, 1:2].clone().detach().requires_grad_(True) + + xt = torch.cat([x, t], dim=1) + + # Forward pass + u_model = model(xt) + + # First derivatives + u_x = torch.autograd.grad( + u_model, + x, + grad_outputs=torch.ones_like(u_model), + create_graph=True, + retain_graph=True, + )[0] + + u_t = torch.autograd.grad( + u_model, + t, + grad_outputs=torch.ones_like(u_model), + create_graph=True, + retain_graph=True, + )[0] + + # Second derivative + u_xx = torch.autograd.grad( + u_x, + x, + grad_outputs=torch.ones_like(u_x), + create_graph=True, + retain_graph=True, + )[0] + + # Third derivative + u_xxx = torch.autograd.grad( + u_xx, + x, + grad_outputs=torch.ones_like(u_xx), + create_graph=True, + retain_graph=True, + )[0] + + # PDE residual + u_squared_x = 2 * u_model * u_x + + r = ( + u_t + - alpha * u_squared_x + - rho * u_x + - nu * u_xxx + ) + + # === PDE residual loss === + pde_loss_L2 = torch.mean(r ** 2) + + r_grid = r.reshape(Nt, Nx) + + L = xMax - xMin + + pde_loss_grad = torch.mean( + grad_L2_fft_batch(r_grid, L) + ) + + # mesh-scaled stabilization parameter + lam = 0.01 * (dx ** 2) * min(1.0, float(epoch) / 1000.0) + + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = torch.where(torch.abs(t) < 1e-6)[0] + + x_ic = x[ic_mask] + + u_ic = u_0(x_ic) + + t_ic = torch.zeros_like(x_ic) + + u_ic_pred = model(torch.cat([x_ic, t_ic], dim=1)) + + data_fitting_loss_0 = torch.mean( + (u_ic_pred - u_ic) ** 2 + ) + + # === Periodic BC === + data_fitting_loss_l_r = periodic_bc(model, x, t) + + # === Aggregated loss === + loss, loss_type = linear_loss_function( + [ + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + ], + lambdas + ) + + # === Hamiltonian (monitor only) === + # breakpoint() + H_loss = H( + u_model.reshape(Nt, Nx), + u_x.reshape(Nt, Nx), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + +def lagrangian_loss(inputs, model, dual_opt): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + u_squared_x = 2 * u_model * u_x + r = u_t - alpha * u_squared_x - rho * u_x - nu * u_xxx + + pde_loss_L2 = torch.mean(torch.square(r)) + + # constraint: pde_loss_L2 = 0 or <= eps + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean((u_ic_pred - u_ic) ** 2) # IC loss + + data_fitting_loss_l_r = periodic_bc(model, x, t) # BC loss + + # ask: what's H_loss here, and should we use it in a constraint + H_loss = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + + data_fitting_loss = 0.5 * data_fitting_loss_0 + 0.5 * data_fitting_loss_l_r + + lagr = dual_opt.forward_update(data_fitting_loss, pde_loss_L2.unsqueeze(0)) + loss_type = 'ls' + + return lagr, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + + + +model = PINNModel().to(device) +epochs = 5000 + +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +dual_opt = ALM(m=1, lr=1e-3, dual_range=(0.,10.), device=device) + +lr_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=100, + T_mult=1, + eta_min=0.5*1e-4 +) + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + +for epoch in range(epochs): + optimizer.zero_grad() + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) + loss.backward() + optimizer.step() + lr_schedule.step() + + with torch.no_grad(): + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)).reshape(Nt, Nx), u_0_x(x_grid.flatten().reshape(-1, 1)).reshape(Nt, Nx), dx) + Hf = H_loss.detach() + # breakpoint() + H_abs_error = torch.abs(Hf - H0) + H_losses_abs_error.append(torch.max(H_abs_error).item()) + H_rel_error = H_abs_error / (torch.abs(H0) + 1e-16) + if isinstance(H_rel_error, torch.Tensor): + H_rel_error = H_rel_error.item() if H_rel_error.numel() == 1 else H_rel_error.max().item() + H_losses_rel_error.append(H_rel_error) + + if epoch > 1: + # SoftAdaptive weights update + # num1 = tf.math.exp(tf.experimental.numpy.cbrt(pde_losses[-1] - pde_losses[-2])) + # num2 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + # num3 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + num1 = np.exp((pde_losses[-1] - pde_losses[-2])) + num2 = np.exp((data_losses_0[-1] - data_losses_0[-2])) + num3 = np.exp((bc_losses[-1] - bc_losses[-2])) + den = num1 + num2 + num3 + + new_lambdas = torch.tensor([num1 / den, num2 / den, num3 / den]) + lambdas = new_lambdas + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + +print(f'\nComputation time: {time() - t0:.2f}s') +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean}") +print(f"Hamiltonian std: {H_loss_std}") +print(f"Hamiltonian max: {H_loss_max}") +print(f"Hamiltonian min: {H_loss_min}") + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/kdv_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/kdv_H_loss.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/kdv_H_loss_mean.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_std.png', dpi=300) if save_fig else None +plt.show() + +H_losses_abs_error = np.array(H_losses_abs_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_abs_error.png', dpi=300) if save_fig else None +plt.show() + +H_losses_rel_error = np.array(H_losses_rel_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_rel_error.png', dpi=300) if save_fig else None +plt.show() + +N = 600 +tspace = torch.linspace(0, tMax, N + 1, dtype=DTYPE, device=device) +xspace = torch.linspace(xMin, xMax, N + 1, dtype=DTYPE, device=device) +T_grid, X_grid = torch.meshgrid(tspace, xspace, indexing='ij') +XTgrid = torch.stack([X_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XTgrid) +U = u_pred.reshape(N+1, N+1) + +X_np = X_grid.cpu().numpy() +T_np = T_grid.cpu().numpy() +U_np = U.cpu().numpy() + +from mpl_toolkits.mplot3d import Axes3D +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, T_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u(x,t)$') +ax.set_title('KdV equation') +ax.set_box_aspect(None, zoom=0.85) +plt.savefig('./results/kdv_solution.png', dpi=300) if save_fig else None +plt.show() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a04d84f..dbbf5e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "humancompatible-train" -version = "0.3.0" +version = "0.3.2" dependencies = [ "torch", "numpy", ] -requires-python = ">= 3.11, <3.14" +requires-python = ">= 3.11" authors = [ {name = "Andrii Kliachkin", email = "kliacand@fel.cvut.cz"}, {name = "Gilles Bareilles"}, diff --git a/src/humancompatible/__init__.py b/src/humancompatible/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/humancompatible/train/dual_optim/__init__.py b/src/humancompatible/train/dual_optim/__init__.py index 5dafcbd..53d5796 100644 --- a/src/humancompatible/train/dual_optim/__init__.py +++ b/src/humancompatible/train/dual_optim/__init__.py @@ -1,4 +1,5 @@ from .alm import ALM from .ialm import iALM from .pbm import PBM +from .nupi import nuPI from .moreau import MoreauEnvelope diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 9cb4a87..c4717eb 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -1,7 +1,8 @@ import torch +import torch.distributed as dist from torch.nn import Parameter from torch.optim import Optimizer -from typing import Any, Tuple +from typing import Any, Optional, Tuple from torch import clamp_, Tensor # cite: Stochastic Smoothed Primal-Dual Algorithms for Nonconvex Optimization with Linear Inequality Constraints @@ -16,42 +17,23 @@ def __init__( init_duals: float | Tensor = None, penalty: float = 1.0, *, - dual_range: Tuple[float, float] = (0.0, 100.0), + dual_range: Tuple[float, float] = (-100.0, 100.0), momentum: float = 0.0, dampening: float = 0.0, + is_ineq: bool = False, + restart: bool = False, ctol: float = 0., device=None, + process_group: Optional[dist.ProcessGroup] = None, ) -> None: - """ - A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. - - :param m: Number of constraints (determines the number of dual variables to create) - :type m: int - :param lr: Dual variable update rate - :type lr: float - :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. - :type init_duals: float | Tensor - :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` - :type penalty: float - :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. - :type dual_range: Tuple[float, float] - :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. - :type momentum: float - :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. - :type dampening: float - :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. - :type ctol: float - """ if momentum > 0 and dampening == 0: dampening = momentum - self.dual_range = dual_range - self.ctol = ctol - self.penalty = penalty + self.process_group = process_group duals, defaults = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, device + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, restart, device ) super().__init__(duals, defaults) @@ -66,35 +48,61 @@ def duals(self) -> Tensor: def add_constraint_group( self, - m: int = None, + m: int, lr: float = None, momentum: float = None, dampening: float = None, init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + restart: bool = False, + device = None ) -> None: """ Allows to add a group of dual variables with separate initial values and learning rates. :param m: Size of group (number of dual variables to add) :type m: int - :param lr: Dual variable update rate + :param lr: Dual variable update rate. :type lr: float - :param init_duals: Initial values for the new dual variables + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param init_duals: Initial values for the new dual variables. Defaults to the value set when creating the optimizer. :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param restart: Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints. + :type restart: bool + + .. note:: + Parameters here will default to values set when initializing the dual optimizer. + """ duals, settings_dict = _init_constraint_group( - m, lr, momentum, dampening, init_duals, self.dual_range + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, restart, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: """Add penalty term to lagrangian in-place.""" - if self.penalty > 0: + if self.penalty == 0: + return + elif constraints.ndim > 0: + lagrangian.add_( + 0.5 + * self.penalty + * torch.dot(constraints, constraints) + ) + else: lagrangian.add_( 0.5 * self.penalty - * torch.dot(constraints - self.ctol, constraints - self.ctol) + * torch.square(constraints) ) @@ -102,6 +110,13 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ Calculates and returns the Augmented Lagrangian. + Computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + where `loss` is the objective value, `duals_i` are the dual variables, `constraints_i` are constraint values, + `penalty` is the penalty parameter, and the sum is over all constraint groups. + :param loss: Loss (objective function) value :type loss: Tensor :param constraints: Tensor of constraint values @@ -112,11 +127,11 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=False - ) + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=False) lagrangian.add_(duals @ group_constraints) + offset += len(duals) self._add_penalty_term(lagrangian, constraints) return lagrangian @@ -124,20 +139,53 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: def update(self, constraints: Tensor) -> None: """ - Updates the dual variables + Updates the dual variables using constrained gradient ascent with optional momentum. + + For each constraint group, performs the dual variable update. + + First, update the momentum buffer (if momentum > 0):: + + if momentum > 0: + buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i + else: + buffer_i = constraints_i + + Then, update the dual variables with clamping:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where `buffer_i` is the momentum buffer, `constraints_i` are constraint values, `duals_i` are dual variables, + and `clamp(x, lb, ub)` projects to the dual range. :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i in range(len(self.param_groups)): - _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True - ) + if self.process_group is not None: + with torch.no_grad(): + constraints = constraints.detach().clone() + dist.all_reduce(constraints, op=dist.ReduceOp.AVG, group=self.process_group) + offset = 0 + for group in self.param_groups: + _process_constraint_group(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) + + step = update # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ - Combines `forward` and `update`; slightly faster. + Combines `forward` and `update`; slightly faster than calling both separately. + + Updates dual variables:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + Then computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + + where the momentum buffer is updated as in :meth:`update`. :param loss: Loss (objective function) value :type loss: Tensor @@ -146,23 +194,33 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: :return: Lagrangian :rtype: Tensor """ + if self.process_group is not None: + with torch.no_grad(): + constraints_for_update = constraints.detach().clone() + dist.all_reduce(constraints_for_update, op=dist.ReduceOp.AVG, group=self.process_group) + else: + constraints_for_update = constraints + lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True - ) + offset = 0 + for group in self.param_groups: + duals, _ = _process_constraint_group(group, offset, constraints_for_update, update_duals=True) + # Always use the original (non-reduced) constraints for the Lagrangian term + # so that autograd can flow through ∂c/∂θ during backward(). + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) lagrangian.add_(duals @ group_constraints) + offset += n self._add_penalty_term(lagrangian, constraints) return lagrangian def state_dict(self) -> dict[str, Any]: - + """""" state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty - state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch for id_pg, pg in enumerate(state_dict["param_groups"]): pg["params"] = [ @@ -172,8 +230,9 @@ def state_dict(self) -> dict[str, Any]: return state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """""" self.penalty = state_dict["state"]["penalty"] - self.dual_range = state_dict["state"]["dual_range"] + # self.dual_range = state_dict["state"]["dual_range"] params = state_dict["param_groups"] self.param_groups = [] for param in params: @@ -182,94 +241,158 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def _process_constraint_group( group: dict[str, Any], - group_idx: int, + offset: int, constraints: Tensor, - ctol: float, - dual_range: Tuple[float, float], update_duals: bool = False, ) -> Tuple[Tensor, Tensor]: """ Process a single constraint group: extract duals/constraints and optionally update duals. :param group: The constraint group dictionary - :param group_idx: Index of the constraint group + :param offset: Start index of this group's slice within the full constraints tensor :param constraints: Full constraints tensor - :param ctol: Constraint tolerance - :param dual_range: Safeguarding range for dual variables :param update_duals: Whether to update dual variables :return: Tuple of (duals, group_constraints) """ duals = group["params"][0] - group_constraints = ( - constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ctol - ) - - if update_duals: - lr = group.get("lr") - momentum = group.get("momentum", 0.0) - dampening = group.get("dampening", 0.0) - momentum_buffer = group["momentum_buffer"] - - with torch.no_grad(): - _update_duals( - duals, group_constraints, lr, momentum, dampening, momentum_buffer - ) - clamp_(duals, min=dual_range[0], max=dual_range[1]) + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) + + lr = group.get("lr") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group["momentum_buffer"] + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + restart = group.get("restart") + + with torch.no_grad(): + if update_duals: + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr, restart) + clamp_(duals, min=dual_lb, max=dual_ub) return duals, group_constraints def _init_constraint_group( - m: int = None, - lr: float = None, - momentum: float = None, - dampening: float = None, - init_duals: float | Tensor = None, - dual_range: Tuple[float, float] = None, - device=None, - ): - ## checks ## - if init_duals is None and m is None: - raise ValueError("At least one of`m`,`init_duals` must be set") - - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") - - m = m if m is not None else len(init_duals) - - if init_duals is None: # initialize duals if not set or set to scalar - init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - ) - elif isinstance(init_duals, float): - init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals - - duals = Parameter(init_duals, requires_grad=False) - - settings_dict = { - "lr": lr, - "momentum": momentum, - "dampening": dampening, - "momentum_buffer": torch.zeros_like( - init_duals, requires_grad=False, device=device - ), - } - settings_dict = {k: v for k, v in settings_dict.items() if v is not None} - - param_group = ([duals], settings_dict) - return param_group - - -def _update_duals( - duals: Tensor, + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + restart: bool = None, + device = None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of m, init_duals must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"momentum must be within [0,1]; got {momentum}") + + if not isinstance(is_ineq, bool): + raise ValueError(f"Expected a Boolean value for is_ineq, got {type(is_ineq)}") + + if not isinstance(restart, bool): + raise ValueError(f"Expected a Boolean value for restart, got {type(restart)}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None and not is_ineq: + dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) + + settings_dict = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq, + "restart": restart + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + + +def _update_c_buffers( constraints: Tensor, - lr: float, momentum: float, dampening: float, buffer: Tensor, ) -> None: + """Update the constraint buffer with momentum.""" if momentum == 0: buffer = constraints else: buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + + +def _update_duals( + duals: Tensor, + buffer: Tensor, + lr: float, + restart: bool +) -> None: + """Update duals using the buffered constraint gradients.""" duals.add_(buffer, alpha=lr) + # Set duals to 0 where buffer < 0 + if restart: + duals[buffer < 0] = 0 + + + +ALM.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2504.07607 + + .. math:: + + \pmb{\lambda}_{t+1} & \leftarrow \pmb{\lambda}_t + \gamma \mathbf{c}_t(\theta_{t}) + + \mathcal{L}_{t+1} & \leftarrow f_t(\theta_{t}) + \pmb{\lambda}_{t+1}^T \mathbf{c}_t(\theta_{t}) + \frac{\rho}{2} \| \mathbf{c}_t(\theta_{t}) \|^2_2 + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param restart: Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints. + :type restart: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + :param process_group: Distributed process group for DDP. When set, constraint values are averaged across all workers via ``dist.all_reduce`` before each dual update, keeping dual variables consistent across replicas. Defaults to ``None`` (no synchronization). + :type process_group: dist.ProcessGroup, optional + """ +) \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 82d1965..233ba9c 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -12,16 +12,16 @@ class iALM(Optimizer): def __init__( self, m: int = None, - lr: float = 0.01, + beta: float = 1.0, + sigma: float = 1.0, + gamma: float = 1.0, init_duals: float | Tensor = None, penalty: float = 1.0, *, - dual_range: Tuple[float, float] = (0.0, 100.0), + dual_range: Tuple[float, float] = (-100., 100.), momentum: float = 0.0, dampening: float = 0.0, - beta: float = 1.0, - sigma: float = 1.0, - gamma: float = 1.0, + is_ineq: bool = False, ctol: float = 1e-4, device=None, ) -> None: @@ -30,8 +30,12 @@ def __init__( :param m: Number of constraints (determines the number of dual variables to create) :type m: int - :param lr: Dual variable update rate - :type lr: float + :param beta: Dual variable update rate. + :type beta: float + :param sigma: Multiplier for increasing`beta`. + :type sigma: float + :param gamma: Penalty update parameter. + :type gamma: float :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. :type init_duals: float | Tensor :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` @@ -42,12 +46,6 @@ def __init__( :type momentum: float :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. :type dampening: float - :param beta: Dual variable update rate - :type beta: float - :param sigma: Multiplier for increasing`beta`. - :type sigma: float - :param gamma: Penalty update parameter - :type gamma: float :param ctol: Constraint tolerance; value that allows tiny violations of constraints to account for noise. :type ctol: float """ @@ -55,61 +53,20 @@ def __init__( if momentum > 0 and dampening == 0: dampening = momentum - self.dual_range = dual_range + # self.dual_range = dual_range - self.beta = beta + # self.beta = beta self.penalty = penalty - self.gamma = gamma - self.sigma = sigma - self.ctol = ctol + # self.gamma = gamma + # self.sigma = sigma + # self.ctol = ctol - duals, defaults = self._init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, device + duals, defaults = _init_constraint_group( + m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, is_ineq, device ) super().__init__(duals, defaults) - @staticmethod - def _init_constraint_group( - m: int = None, - lr: float = None, - momentum: float = None, - dampening: float = None, - init_duals: float | Tensor = None, - dual_range: Tuple[float, float] = None, - device=None, - ): - ## checks ## - if init_duals is None and m is None: - raise ValueError("At least one of`m`,`init_duals` must be set") - - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") - - m = m if m is not None else len(init_duals) - - if init_duals is None: # initialize duals if not set or set to scalar - init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - ) - elif isinstance(init_duals, float): - init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals - - duals = Parameter(init_duals, requires_grad=False) - - settings_dict = { - "lr": lr, - "momentum": momentum, - "dampening": dampening, - "momentum_buffer": torch.zeros_like( - init_duals, requires_grad=False, device=device - ), - } - settings_dict = {k: v for k, v in settings_dict.items() if v is not None} - - param_group = ([duals], settings_dict) - return param_group - @property def duals(self) -> Tensor: """ @@ -121,23 +78,40 @@ def duals(self) -> Tensor: def add_constraint_group( self, m: int = None, - lr: float = None, + beta: float = 1.0, + sigma: float = 1.0, + gamma: float = 1.0, momentum: float = None, dampening: float = None, init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + device = None ) -> None: """ Allows to add a group of dual variables with separate initial values and learning rates. :param m: Size of group (number of dual variables to add) :type m: int - :param lr: Dual variable update rate - :type lr: float + :param beta: Dual variable update rate + :type beta: float + :param sigma: Multiplier for increasing `beta` + :type sigma: float + :param gamma: Penalty update parameter + :type gamma: float + :param momentum: Momentum for dual variable updates + :type momentum: float + :param dampening: Dampening for momentum + :type dampening: float :param init_duals: Initial values for the new dual variables :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool """ - duals, settings_dict = self._init_constraint_group( - m, lr, momentum, dampening, init_duals, self.dual_range + duals, settings_dict = _init_constraint_group( + m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, is_ineq, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -155,22 +129,16 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["lr"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol - ) - lagrangian.add_(duals @ group_constraints) - _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + offset = 0 + for group in self.param_groups: + duals, beta, group_constraints = _process_constraint_group_ialm(group, offset, constraints, update_duals=False) + lagrangian.add_(duals @ group_constraints) + offset += len(duals) - lagrangian.add_(0.5 * self.beta * torch.dot(constraints, constraints)) + # Use beta from first group for penalty term + beta = self.param_groups[0]["beta"] + lagrangian.add_(0.5 * beta * torch.dot(constraints, constraints)) return lagrangian @@ -181,31 +149,14 @@ def update(self, constraints: Tensor) -> None: :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["lr"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol - ) - with torch.no_grad(): - _update_duals( - duals, - group_constraints, - lr, - self.beta, - self.gamma, - momentum, - dampening, - momentum_buffer, - ) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - - self.beta *= self.sigma + offset = 0 + for group in self.param_groups: + _process_constraint_group_ialm(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) + + # Update beta by sigma for each group + for group in self.param_groups: + group["beta"].mul_(group["sigma"]) # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -221,42 +172,22 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["lr"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol - ) - with torch.no_grad(): - _update_c_buffers( - group_constraints, momentum, dampening, momentum_buffer - ) - _update_duals( - duals, - group_constraints, - lr, - self.beta, - self.gamma, - momentum, - dampening, - momentum_buffer, - ) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) + offset = 0 + for group in self.param_groups: + duals, beta, group_constraints = _process_constraint_group_ialm(group, offset, constraints, update_duals=True) lagrangian.add_(duals @ group_constraints) + offset += len(duals) + # Use beta from first group for penalty term + beta = self.param_groups[0]["beta"] lagrangian.add_( - 0.5 - * self.beta - * torch.dot(constraints - self.ctol, constraints - self.ctol) + 0.5 * beta * torch.dot(constraints, constraints) ) - self.beta *= self.sigma + # Update beta by sigma for each group + for group in self.param_groups: + group["beta"].mul_(group["sigma"]) return lagrangian @@ -264,7 +195,7 @@ def state_dict(self) -> dict[str, Any]: state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty - state_dict["state"]["dual_range"] = self.dual_range + # state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch for id_pg, pg in enumerate(state_dict["param_groups"]): pg["params"] = [ @@ -275,13 +206,101 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.penalty = state_dict["state"]["penalty"] - self.dual_range = state_dict["state"]["dual_range"] + # self.dual_range = state_dict["state"]["dual_range"] params = state_dict["param_groups"] self.param_groups = [] for param in params: self.param_groups.append(param) +def _process_constraint_group_ialm( + group: dict[str, Any], + offset: int, + constraints: Tensor, + update_duals: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Process a single constraint group: extract parameters and optionally update duals. + + :param group: The constraint group dictionary + :param offset: Start index of this group's slice within the full constraints tensor + :param constraints: Full constraints tensor + :param update_duals: Whether to update dual variables + :return: Tuple of (duals, beta, group_constraints) + """ + duals = group["params"][0] + n = len(duals) + beta = group.get("beta") + gamma = group.get("gamma") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group.get("momentum_buffer") + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) + + with torch.no_grad(): + if update_duals: + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + _update_duals(duals, beta, gamma, momentum_buffer if momentum > 0 else group_constraints) + clamp_(duals, min=dual_lb, max=dual_ub) + + return duals, beta, group_constraints + + +def _init_constraint_group( + m: int = None, + beta: float = None, + sigma: float = None, + gamma: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + device=None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of`m`,`init_duals` must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None and not is_ineq: + dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) + + settings_dict = { + "beta": Parameter(torch.tensor(beta), requires_grad=False), + "sigma": Parameter(torch.tensor(sigma), requires_grad=False), + "gamma": Parameter(torch.tensor(gamma), requires_grad=False), + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + def _update_c_buffers( constraints: Tensor, momentum: float, @@ -296,14 +315,49 @@ def _update_c_buffers( def _update_duals( duals: Tensor, - constraints: Tensor, - lr: float, beta: float, gamma: float, - momentum: float, - dampening: float, buffer: Tensor, ) -> None: - update_mult = min(beta, gamma / (buffer @ buffer)) + update_mult = torch.min(beta, gamma / torch.linalg.norm(buffer)) duals.add_(buffer, alpha=update_mult) + + +iALM.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule, with adaptive stepsize based on https://doi.org/10.1007/s10589-023-00521-z, Algorithm 1. Creates and updates dual variables. + + .. math:: + + \pmb{\lambda}_{t+1} & \leftarrow \pmb{\lambda}_t + \min\left\{ \beta_k, \frac{\gamma_k}{\|\mathbf{c}_t(\theta_t)\|} \right\} \mathbf{c}_t(\theta_{t}) + + \mathcal{L}_{t+1} & \leftarrow f_t(\theta_{t}) + \pmb{\lambda}_{t+1}^T \mathbf{c}_t(\theta_{t}) + \frac{\rho}{2} \| \mathbf{c}_t(\theta_{t}) \|^2_2 + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param beta: Dual variable update rate. + :type beta: float + :param sigma: Multiplier for increasing`beta`. + :type sigma: float + :param gamma: Penalty update parameter. + :type gamma: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ +) \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/nupi.py b/src/humancompatible/train/dual_optim/nupi.py new file mode 100644 index 0000000..dd013d5 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nupi.py @@ -0,0 +1,373 @@ +import torch +from torch.nn import Parameter +from torch.optim import Optimizer +from typing import Any, Tuple +from torch import clamp_, Tensor + +# cite: On PI Controllers for Updating Lagrange Multipliers in Constrained Optimization +# https://arxiv.org/pdf/2406.04558v1 + + +class nuPI(Optimizer): + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2504.07607 + + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param nu: Momentum parameter. + :type nu: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param ki: Momentum parameter. + :type ki: float + :param kp: Momentum parameter. + :type kp: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ + def __init__( + self, + m: int = None, + nu: float = 0.01, + init_duals: float | Tensor = None, + penalty: float = 1.0, + *, + dual_range: Tuple[float, float] = (-100.0, 100.0), + ki: float = 0.0, + kp: float = 0.0, + is_ineq: bool = False, + ctol: float = 0., + device=None, + ) -> None: + + # self.dual_range = dual_range + # self.ctol = ctol + + self.penalty = penalty + self._is_initialized = False + duals, defaults = _init_constraint_group( + m, nu, ki, kp, init_duals, dual_range, is_ineq, device + ) + + super().__init__(duals, defaults) + + @property + def duals(self) -> Tensor: + """ + :return: Dual variables, concatenated into a single tensor. + :rtype: Tensor + """ + return torch.cat([group["params"][0] for group in self.param_groups]) + + def add_constraint_group( + self, + m: int, + nu: float = None, + ki: float = None, + kp: float = None, + init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + device = None + ) -> None: + """ + Allows to add a group of dual variables with separate initial values and learning rates. + + :param m: Size of group (number of dual variables to add) + :type m: int + :param nu: Momentum parameter. + :type nu: float + :param ki: Momentum parameter. + :type ki: float + :param kp: Momentum parameter. + :type kp: float + :param init_duals: Initial values for the new dual variables. Defaults to the value set when creating the optimizer. + :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + + .. note:: + Parameters here will default to values set when initializing the dual optimizer. + + """ + duals, settings_dict = _init_constraint_group( + m, nu, ki, kp, init_duals, dual_range, is_ineq, device + ) + param_group_dict = {"params": duals, **settings_dict} + self.add_param_group(param_group_dict) + + def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: + """Add penalty term to lagrangian in-place.""" + if self.penalty == 0: + return + elif constraints.ndim > 0: + lagrangian.add_( + 0.5 + * self.penalty + * torch.dot(constraints, constraints) + ) + else: + lagrangian.add_( + 0.5 + * self.penalty + * torch.square(constraints) + ) + + + def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Calculates and returns the Augmented Lagrangian. + + Computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + where `loss` is the objective value, `duals_i` are the dual variables, `constraints_i` are constraint values, + `penalty` is the penalty parameter, and the sum is over all constraint groups. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=False) + lagrangian.add_(duals @ group_constraints) + offset += len(duals) + + self._add_penalty_term(lagrangian, constraints) + return lagrangian + + + def update(self, constraints: Tensor) -> None: + """ + Updates the dual variables using constrained gradient ascent with optional momentum. + + For each constraint group, performs the dual variable update. + + First, update the momentum buffer (if momentum > 0):: + + if momentum > 0: + buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i + else: + buffer_i = constraints_i + + Then, update the dual variables with clamping:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where `buffer_i` is the momentum buffer, `constraints_i` are constraint values, `duals_i` are dual variables, + and `clamp(x, lb, ub)` projects to the dual range. + + :param constraints: Tensor of constraint values + :type constraints: Tensor + """ + offset = 0 + for group in self.param_groups: + _process_constraint_group(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) + + # evaluate the Lagrangian and update the dual variables + def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Combines `forward` and `update`; slightly faster than calling both separately. + + Computes the augmented Lagrangian and updates dual variables in one pass:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + Then updates dual variables:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where the momentum buffer is updated as in :meth:`update`. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=True) + lagrangian.add_(duals @ group_constraints) + offset += len(duals) + + self._add_penalty_term(lagrangian, constraints) + return lagrangian + + def state_dict(self) -> dict[str, Any]: + """""" + state_dict = super().state_dict() + state_dict["state"]["penalty"] = self.penalty + # save params themselves in state_dict instead of param ID in default PyTorch + for id_pg, pg in enumerate(state_dict["param_groups"]): + pg["params"] = [ + self.param_groups[id_pg]["params"][param_id] + for param_id in pg["params"] + ] + return state_dict + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """""" + self.penalty = state_dict["state"]["penalty"] + # self.dual_range = state_dict["state"]["dual_range"] + params = state_dict["param_groups"] + self.param_groups = [] + for param in params: + self.param_groups.append(param) + + +def _process_constraint_group( + group: dict[str, Any], + offset: int, + constraints: Tensor, + update_duals: bool = False +) -> Tuple[Tensor, Tensor]: + """ + Process a single constraint group: extract duals/constraints and optionally update duals. + + :param group: The constraint group dictionary + :param offset: Start index of this group's slice within the full constraints tensor + :param constraints: Full constraints tensor + :param update_duals: Whether to update dual variables + :return: Tuple of (duals, group_constraints) + """ + duals = group["params"][0] + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) + + nu = group.get("nu") + ki = group.get("ki", 0.0) + kp = group.get("kp", 0.0) + momentum_buffer = group["momentum_buffer"] + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + + with torch.no_grad(): + if update_duals: + _update_duals(duals, momentum_buffer, group_constraints, nu, ki, kp) + clamp_(duals, min=dual_lb, max=dual_ub) + _update_c_buffers(group_constraints, nu, momentum_buffer) + + return duals, group_constraints + + +def _init_constraint_group( + m: int = None, + nu: float = None, + ki: float = None, + kp: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + device = None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of m, init_duals must be set") + + if not isinstance(is_ineq, bool): + raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None and not is_ineq: + dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) + + settings_dict = { + "nu": nu, + "ki": ki, + "kp": kp, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq, + "_momentum_initialized": False + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + + +def _update_c_buffers( + constraints: Tensor, + nu: float, + buffer: Tensor, +) -> None: + """Update the constraint buffer with momentum.""" + buffer.mul_(nu).add_(constraints, alpha=1 - nu) + + +def _update_duals( + duals: Tensor, + buffer: Tensor, + constraints: Tensor, + nu: float, + ki: float, + kp: float +) -> None: + """Update duals using the buffered constraint gradients.""" + # duals.add_(buffer, alpha=lr).add_() + duals.add_( constraints, alpha=ki + kp * (1-nu) ).add_( buffer, alpha = -kp * (1-nu) ) + + +nuPI.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the nuPI Augmented Lagrangian rule, based on https://doi.org/10.48550/arXiv.2406.04558. Creates and updates dual variables. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ +) \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/pbm.py b/src/humancompatible/train/dual_optim/pbm.py index 07c7665..c263062 100644 --- a/src/humancompatible/train/dual_optim/pbm.py +++ b/src/humancompatible/train/dual_optim/pbm.py @@ -23,28 +23,7 @@ def __init__( device=None, primal_update_process_length=1, # length of the primal update process - if =1, is the original algorithm ) -> None: - """ - A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Penalty-Barrier Method rule. Creates and updates dual variables. - - :param m: Number of constraints (determines the number of dual variables to create) - :type m: int - :param penalty_mult: Multiplier for penalty update (K1 or K2). For K2 (adaptive penalty update), values close to 1 correspond to a high "momentum". - :type penalty_mult: float - :param gamma: Multiplier for dual parameter update. Values close to 1 correspond to a high "momentum". - :type gamma: float - :param delta: Violation/satisfaction parameter for penalty update; values > 1 make the penalties decrease faster on violated constraints and vice versa. - :type delta: float - :param penalty_update: Penalty update strategy; must be one of `dimin`,`dimin_dual`,`dimin_adapt`,`const`. Defaults to`dimin_adapt`. - :type penalty_update: str - :param pbf: Penalty-Barrier Function to use. Must be one of `quadratic_logarithmic`,`quadratic_reciprocal` - :type pbf: str - :param init_duals: Initial values for the dual variables. Defaults to dual lower bound for all. - :type init_duals: float | Tensor - :param init_penalties: Initial values for the penalty variables. Defaults to the penalty upper bound for all. - :type init_penalties: float | Tensor - :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. - :type dual_range: Tuple[float, float] - """ + self.dual_range = dual_range self.penalty_range = penalty_range @@ -486,3 +465,39 @@ def _update_penalties_dimin_dual( "adapt": _update_penalties_adapt, "dimin_dual": _update_penalties_dimin_dual, } + + +PBM.__doc__ = ( + + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Penalty-Barrier Method rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2605.18618 + + .. note:: + + Natively, this method only supports inequality constraints (see reference). However, it is easy to transform one into the other: + + .. math:: + g(x) = |h(x)| \leq 0 + + We suggest using a small tolerance parameter on the right-hand side instead of 0. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param penalty_mult: Multiplier for penalty update (K1 or K2). For K2 (adaptive penalty update), values close to 1 correspond to a high "momentum". + :type penalty_mult: float + :param gamma: Multiplier for dual parameter update. Values close to 1 correspond to a high "momentum". + :type gamma: float + :param delta: Violation/satisfaction parameter for penalty update; values > 1 make the penalties decrease faster on violated constraints and vice versa. + :type delta: float + :param penalty_update: Penalty update strategy; must be one of `dimin`,`dimin_dual`,`dimin_adapt`,`const`. Defaults to`dimin_adapt`. + :type penalty_update: str + :param pbf: Penalty-Barrier Function to use. Must be one of `quadratic_logarithmic`,`quadratic_reciprocal` + :type pbf: str + :param init_duals: Initial values for the dual variables. Defaults to dual lower bound for all. + :type init_duals: float | Tensor + :param init_penalties: Initial values for the penalty variables. Defaults to the penalty upper bound for all. + :type init_penalties: float | Tensor + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + """ +) \ No newline at end of file diff --git a/tests/test_alm.py b/tests/test_alm.py index 4c4b138..4c3a22c 100644 --- a/tests/test_alm.py +++ b/tests/test_alm.py @@ -1,7 +1,9 @@ import unittest +from unittest.mock import patch from torch.optim import Optimizer from humancompatible.train.dual_optim import ALM import torch +import torch.distributed as dist # Unit tests class TestALM(unittest.TestCase): @@ -21,8 +23,6 @@ def setUp(self): def test_alm_initialization(self): # Test initialization with m self.assertEqual(len(self.alm_default.duals), 3) - self.assertEqual(self.alm_default.penalty, 1.0) - self.assertEqual(self.alm_default.dual_range, (0.0, 100.0)) # Test initialization with init_duals init_duals = torch.tensor([1.0, 2.0, 3.0]) @@ -40,6 +40,7 @@ def test_alm_forward(self): def test_alm_update(self): expected_duals = self.alm_default.duals + 0.1 * self.constraints + # breakpoint() self.alm_default.update(self.constraints) self.assertTrue(torch.allclose(self.alm_default.duals, expected_duals)) @@ -80,11 +81,164 @@ def test_alm_dual_range_clamping(self): self.assertTrue(torch.all(self.alm_custom_range.duals <= 1.0) and torch.all(self.alm_custom_range.duals >= -1.0)) + def test_step_is_update_alias(self): + alm = ALM(m=3, lr=0.1, penalty=1.0) + duals_before = alm.duals.clone() + alm.step(self.constraints) + alm2 = ALM(m=3, lr=0.1, penalty=1.0) + alm2.update(self.constraints) + self.assertTrue(torch.allclose(alm.duals, alm2.duals)) + self.assertFalse(torch.allclose(alm.duals, duals_before)) + def test_alm_state_dict(self): alm = ALM(m=3, lr=0.1, penalty=2.0, dual_range=(-1.0, 1.0)) state_dict = alm.state_dict() self.assertEqual(state_dict["state"]["penalty"], 2.0) - self.assertEqual(state_dict["state"]["dual_range"], (-1.0, 1.0)) + +class TestALMFixes(unittest.TestCase): + """Tests for fix 1 (momentum buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0, 10.0, 20.0, 30.0]) + + # --- Fix 1: forward() must not advance the momentum buffer --- + + def test_forward_does_not_corrupt_momentum_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_direct = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + alm_via_forward = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + + alm_direct.update(c) + + alm_via_forward.forward(self.loss, c) + alm_via_forward.update(c) + + self.assertTrue(torch.allclose(alm_direct.duals, alm_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + # forward_update() and forward() + update() must produce identical duals. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_combined = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + alm_separate = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + + alm_combined.forward_update(self.loss, c) + alm_separate.forward(self.loss, c) + alm_separate.update(c) + + self.assertTrue(torch.allclose(alm_combined.duals, alm_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + alm = ALM(m=2, lr=0.1, penalty=1.0) + alm.add_constraint_group(m=3, lr=0.2) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.update(c) + + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], 0.2 * c[2:])) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + alm = ALM(m=2, lr=0.1, penalty=1.0, init_duals=init0) + alm.add_constraint_group(m=3, lr=0.2, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = alm.forward(self.loss, c) + + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * alm.penalty * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_multi_group_forward_update_slices_correctly(self): + alm = ALM(m=2, lr=0.1, penalty=1.0) + alm.add_constraint_group(m=3, lr=0.2) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.forward_update(self.loss, c) + + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], 0.2 * c[2:])) + + +class TestALMDDP(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + self.pg = object() # sentinel; real value only matters to dist.all_reduce + + def test_no_process_group_skips_all_reduce(self): + alm = ALM(m=3, lr=0.1, penalty=1.0) + with patch('torch.distributed.all_reduce') as mock_ar: + alm.update(self.constraints) + alm.forward_update(self.loss, self.constraints) + mock_ar.assert_not_called() + + def test_update_calls_all_reduce_with_correct_args(self): + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce') as mock_ar: + alm.update(self.constraints) + mock_ar.assert_called_once() + _, kwargs = mock_ar.call_args + self.assertEqual(kwargs['op'], dist.ReduceOp.AVG) + self.assertEqual(kwargs['group'], self.pg) + + def test_update_uses_reduced_constraints(self): + # Simulate all_reduce replacing the tensor with worker-averaged values. + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + alm.update(self.constraints) + + self.assertTrue(torch.allclose(alm.duals, 0.1 * reduced)) + + def test_update_does_not_mutate_input(self): + # The all_reduce clone must be a detached copy; original tensor must be untouched. + original = self.constraints.clone() + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=lambda t, **kw: t.fill_(99.0)): + alm.update(self.constraints) + self.assertTrue(torch.allclose(self.constraints, original)) + + def test_forward_update_uses_reduced_constraints_for_dual(self): + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + alm.forward_update(self.loss, self.constraints) + + self.assertTrue(torch.allclose(alm.duals, 0.1 * reduced)) + + def test_forward_update_lagrangian_uses_original_constraints(self): + # Duals are updated with reduced constraints, but the Lagrangian must be + # computed with the original constraints so autograd flows through ∂c/∂θ. + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + lagrangian = alm.forward_update(self.loss, self.constraints) + + updated_duals = 0.1 * reduced + expected = ( + self.loss + + updated_duals @ self.constraints + + 0.5 * alm.penalty * torch.dot(self.constraints, self.constraints) + ) + self.assertTrue(torch.allclose(lagrangian, expected)) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/test_ialm.py b/tests/test_ialm.py new file mode 100644 index 0000000..2a50d22 --- /dev/null +++ b/tests/test_ialm.py @@ -0,0 +1,95 @@ +import unittest +import torch +from humancompatible.train.dual_optim import iALM + + +class TestiALM(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + + def test_ialm_initialization(self): + alm = iALM(m=3, beta=1.0, penalty=1.0) + self.assertEqual(len(alm.duals), 3) + + def test_ialm_forward(self): + alm = iALM(m=3, beta=1.0, penalty=1.0) + lagrangian = alm.forward(self.loss, self.constraints) + beta = alm.param_groups[0]["beta"] + expected = (self.loss + + alm.duals @ self.constraints + + 0.5 * beta * torch.dot(self.constraints, self.constraints)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_ialm_update(self): + # init_duals=zeros so the baseline is 0; duals should increase toward constraints + alm = iALM(m=3, beta=1.0, gamma=1.0, sigma=1.0, penalty=1.0, init_duals=torch.zeros(3)) + duals_before = alm.duals.clone() + alm.update(self.constraints) + self.assertTrue(torch.all(alm.duals > duals_before)) + + +class TestiALMFixes(unittest.TestCase): + """Tests for fix 1 (momentum buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + + # --- Fix 1: forward() must not advance the momentum buffer --- + + def test_forward_does_not_corrupt_momentum_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_direct = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + alm_via_forward = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + + alm_direct.update(c) + + alm_via_forward.forward(self.loss, c) + alm_via_forward.update(c) + + self.assertTrue(torch.allclose(alm_direct.duals, alm_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + c = torch.tensor([1.0, 2.0, 3.0]) + alm_combined = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + alm_separate = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + + alm_combined.forward_update(self.loss, c) + alm_separate.forward(self.loss, c) + alm_separate.update(c) + + self.assertTrue(torch.allclose(alm_combined.duals, alm_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + alm = iALM(m=2, beta=1.0, gamma=1e6, sigma=1.0, init_duals=torch.zeros(2)) + alm.add_constraint_group(m=3, beta=1.0, gamma=1e6, sigma=1.0, init_duals=torch.zeros(3)) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.update(c) + + # With large gamma step ≈ 1.0 for both groups, so duals ≈ c_slice + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], c[:2], atol=1e-4)) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], c[2:], atol=1e-4)) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + alm = iALM(m=2, beta=1.0, gamma=1.0, sigma=1.0, init_duals=init0) + alm.add_constraint_group(m=3, beta=1.0, gamma=1.0, sigma=1.0, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = alm.forward(self.loss, c) + + beta = alm.param_groups[0]["beta"] + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * beta * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nupi.py b/tests/test_nupi.py new file mode 100644 index 0000000..c70a17b --- /dev/null +++ b/tests/test_nupi.py @@ -0,0 +1,104 @@ +import unittest +import torch +from humancompatible.train.dual_optim import nuPI + + +class TestnuPI(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + + def test_nupi_initialization(self): + opt = nuPI(m=3, nu=0.9, ki=0.01, kp=0.01, penalty=1.0) + self.assertEqual(len(opt.duals), 3) + + def test_nupi_forward(self): + opt = nuPI(m=3, nu=0.9, ki=0.01, kp=0.01, penalty=1.0) + lagrangian = opt.forward(self.loss, self.constraints) + expected = (self.loss + + opt.duals @ self.constraints + + 0.5 * opt.penalty * torch.dot(self.constraints, self.constraints)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_nupi_update(self): + # With zero buffer (initial state) and kp=0, update is purely integral: λ += ki * c + opt = nuPI(m=3, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.update(self.constraints) + self.assertTrue(torch.allclose(opt.duals, 0.1 * self.constraints)) + + +class TestnuPIFixes(unittest.TestCase): + """Tests for fix 1 (buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + + # --- Fix 1: forward() must not advance the EMA buffer --- + # nuPI's buffer is unconditionally updated in the original code, making this + # the most severe instance of the bug: it fires even without any momentum setting. + + def test_forward_does_not_corrupt_ema_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + opt_direct = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + opt_via_forward = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + + opt_direct.update(c) + + opt_via_forward.forward(self.loss, c) + opt_via_forward.update(c) + + self.assertTrue(torch.allclose(opt_direct.duals, opt_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + c = torch.tensor([1.0, 2.0, 3.0]) + opt_combined = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + opt_separate = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + + opt_combined.forward_update(self.loss, c) + opt_separate.forward(self.loss, c) + opt_separate.update(c) + + self.assertTrue(torch.allclose(opt_combined.duals, opt_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + # kp=0 so update is purely λ += ki * c, easy to verify + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + opt.update(c) + + self.assertTrue(torch.allclose(opt.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(opt.param_groups[1]["params"][0], 0.2 * c[2:])) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0, init_duals=init0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = opt.forward(self.loss, c) + + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * opt.penalty * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_multi_group_forward_update_slices_correctly(self): + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + opt.forward_update(self.loss, c) + + self.assertTrue(torch.allclose(opt.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(opt.param_groups[1]["params"][0], 0.2 * c[2:])) + + +if __name__ == "__main__": + unittest.main()