From 75be7e395a7621d24fb47417425aebd67fe6b3b7 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Mon, 25 May 2026 13:52:32 +0200 Subject: [PATCH 01/30] hamiltonian pinn experiments --- ...structurePreservingPINN_2D_trainableACS.py | 758 +++++++++++++++++ ...ePreservingPINN_2D_trainableACS_pytorch.py | 279 +++++++ ...ePreservingPINN_CamassaHolm_trainableH1.py | 729 +++++++++++++++++ ...ingPINN_CamassaHolm_trainableH1_pytorch.py | 389 +++++++++ ...structurePreservingPINN_kdv_trainableH1.py | 769 ++++++++++++++++++ ...ePreservingPINN_kdv_trainableH1_pytorch.py | 573 +++++++++++++ 6 files changed, 3497 insertions(+) create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py create mode 100644 experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py new file mode 100644 index 0000000..82beaa6 --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS.py @@ -0,0 +1,758 @@ +import numpy as np +from sympy import N +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +Nx = 10 #100 +Ny = 10 #100 +Nt = 10 #50 +N_collocation = Nx*Ny*Nt +d_in = 3 + +xMin = 0.0 +xMax = 8.0 +yMin = 0.0 +yMax = 8.0 +tMax = 50. +d_in = 2 + +def choose_width_depth(N_collocation=N_collocation, overparam_factor=3.0, d_in=d_in, d_out=1): + """ + Returns a single (width, depth) pair for a mildly overparameterized PINN. + + N_collocation : Nx * Ny * Nt + overparam_factor : how many times larger the model than data (default 3x) + """ + + N_target = int(overparam_factor * N_collocation) + + # For 3rd-order PDEs like ZK: + # depth 5–7 is stable. We fix depth=6 (good compromise). + depth = 8 + + # Parameter formula: + # N = (depth-1) w^2 + (d_in + depth + d_out) w + d_out + a = depth - 1 + b = d_in + depth + d_out + c = d_out - N_target + + # Solve quadratic for width + width = int((-b + np.sqrt(b*b - 4*a*c)) / (2*a)) + + return width, depth + +width, depth = choose_width_depth() +dx = (xMax - xMin) / (Nx - 1) +dy = (yMax - yMin) / (Ny - 1) +dt = tMax / (Nt - 1) +h = np.max([dx, dy, dt]) + +lambdas = [1., 1., 1.] #[1.7, 0.2, 1.4] #for cs #[1., 0.5, 1.] # [1.7, 0.2, 1.4] for cs # +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) + +cheb_par = tf.Variable(0.5, trainable=True, name='cheb_par', dtype=DTYPE) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +y = np.linspace(yMin, yMax, Ny).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) +x_grid, y_grid, t_grid = np.meshgrid(x, y, t, indexing='ij') +x_train = x_grid.flatten(); x_train = tf.convert_to_tensor(x_train); x_train = tf.expand_dims(x_train, axis=-1) +y_train = y_grid.flatten(); y_train = tf.convert_to_tensor(y_train); y_train = tf.expand_dims(y_train, axis=-1) +t_train = t_grid.flatten(); t_train = tf.convert_to_tensor(t_train); t_train = tf.expand_dims(t_train, axis=-1) +xyt_train = tf.concat([x_train, y_train, t_train], axis=-1) + +save_fig = True + +# Define the initial condition +def u_0(x, y): + ##1 + epsilon = 0.01 + theta = 0. + y1 = 0. + y2 = 0. + c1 = 0.45 + c2 = 0.25 + x1 = 2.5 + x2 = 3.3 + out = 3*c1/(tf.math.cosh(0.5*tf.sqrt(c1/epsilon)*((x-x1)*tf.math.cos(theta) + (y-y1)*tf.math.sin(theta))))**2 + + 3*c2/(tf.math.cosh(0.5*tf.sqrt(c2/epsilon)*((x-x2)*tf.math.cos(theta) + (y-y2)*tf.math.sin(theta))))**2 + ##2 + # epsilon = 0.01 + # theta = 0. + # y1 = 4. + # c1 = 1. + # x1 = 2.5 + # out = 3*c1/(tf.math.cosh(0.5*tf.sqrt(c1/epsilon)*((x-x1)*tf.math.cos(theta) + (y-y1)*tf.math.sin(theta))))**2 + + return out + # mpmath for sech + +# def periodic_boundary_conditions(model, Nbc=2000): +# x = tf.random.uniform((Nbc,1), xMin, xMax) +# y = tf.random.uniform((Nbc,1), yMin, yMax) +# t = tf.random.uniform((Nbc,1), 0, tMax) + +# xL = tf.ones_like(x)*xMin; xR = tf.ones_like(x)*xMax +# yL = tf.ones_like(y)*yMin; yR = tf.ones_like(y)*yMax + +# uLx = model(tf.concat([xL,y,t],1)) +# uRx = model(tf.concat([xR,y,t],1)) +# uLy = model(tf.concat([x,yL,t],1)) +# uRy = model(tf.concat([x,yR,t],1)) + +# return tf.reduce_mean((uLx-uRx)**2 + (uLy-uRy)**2) + + +def periodic_boundary_conditions(model, Nbc=2000): + + # Random boundary sampling (correct choice) + x = tf.random.uniform((Nbc,1), xMin, xMax) + y = tf.random.uniform((Nbc,1), yMin, yMax) + t = tf.random.uniform((Nbc,1), 0.0, tMax) + + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + yL = tf.ones_like(y) * yMin + yR = tf.ones_like(y) * yMax + + with tf.GradientTape(persistent=True) as tape: + tape.watch([xL, xR, yL, yR]) + + uLx = model(tf.concat([xL, y, t], 1)) + uRx = model(tf.concat([xR, y, t], 1)) + + uLy = model(tf.concat([x, yL, t], 1)) + uRy = model(tf.concat([x, yR, t], 1)) + + # First derivatives + uxL = tape.gradient(uLx, xL) + uxR = tape.gradient(uRx, xR) + + uyL = tape.gradient(uLy, yL) + uyR = tape.gradient(uRy, yR) + + del tape + + # Enforce periodicity of values AND derivatives + loss = tf.reduce_mean( + (uLx - uRx)**2 + + (uLy - uRy)**2 + + (uxL - uxR)**2 + + (uyL - uyR)**2 + ) + + return loss + + + +def H(u, u_x, u_y): + return tf.reduce_sum((tf.pow(u_x,2) + tf.pow(u_y,2))/2.0-tf.pow(u,3)/6.0, axis=[0,1]) * dx*dy + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + weights (list of tf.Tensor): List of weights for each tensor. + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +class FourierFeatures(tf.keras.layers.Layer): + def __init__(self, n_modes=5): + super().__init__() + self.n_modes = n_modes + + def call(self, inputs): + x = inputs[:, 0:1] + y = inputs[:, 1:2] + t = inputs[:, 2:3] + + features = [t] + + for k in range(1, self.n_modes + 1): + features.append(tf.sin(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(tf.cos(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(tf.sin(2*np.pi*k*(y - yMin)/(yMax-yMin))) + features.append(tf.cos(2*np.pi*k*(y - yMin)/(yMax-yMin))) + + return tf.concat(features, axis=1) + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,80 OK (# 8,40 # 10,40) + xyt_input = tf.keras.Input(shape=(3,)) + output_u = FourierFeatures(n_modes=4)(xyt_input) + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation='tanh', # tanh + kernel_initializer='glorot_uniform', # glorot_normal + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', # mish + kernel_initializer='glorot_uniform', # glorot_normal + )(output_u) + + return tf.keras.Model(inputs=xyt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +# def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,80 OK (# 8,40 # 10,40) +# xyt_input = tf.keras.Input(shape=(3,)) +# output_u = xyt_input +# for _ in range(num_hidden_layers): +# output_u = tf.keras.layers.Dense(num_neurons_per_layer, +# activation='tanh', # tanh +# kernel_initializer='glorot_uniform', # glorot_normal +# )(output_u) + +# output_u = tf.keras.layers.Dense(units=1, +# activation='linear', # mish +# kernel_initializer='glorot_uniform', # glorot_normal +# )(output_u) + +# # Define the initial condition +# # x_input = tf.reshape(xt_input[:, 0], shape=[-1, 1]) +# # t_input = tf.reshape(xt_input[:, 1], shape=[-1, 1]) +# # initial_u = u_0(x_input) +# # output_u = tf.where(tf.equal(t_input, 0), initial_u, output_u) + +# return tf.keras.Model(inputs=xyt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +@tf.function +def custom_loss(inputs, model): + xyt = inputs + x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3] + # zeros = tf.zeros_like(x) + + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + tape.watch(y) + with tf.GradientTape(persistent=True) as tape2: + tape2.watch(x) + tape2.watch(y) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(t) + tape3.watch(x) + tape3.watch(y) + u_model = model(tf.concat([x,y,t], axis=1)) + u_x = tape3.gradient(u_model, x) + u_y = tape3.gradient(u_model, y) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xy = tape2.gradient(u_x, y) + u_xxx = tape.gradient(u_xx, x) + u_xyy = tape.gradient(u_xy, y) + del tape, tape2, tape3 + + + # v = -nu*u_x + # phi_t = Vprime(u_model) - nu*u_xx - Vprime(u_model_0) + nu*u_0_xx + # w = -nu * u_xx + phi_t/2. - Vprime(u_model) + + # Compute the components of loss function + pde_loss = tf.reduce_mean((u_t + u_model * u_x + u_xxx + u_xyy) ** 2) + + # x_ic = tf.random.uniform((Nx*Ny,1), xMin, xMax) + # y_ic = tf.random.uniform((Nx*Ny,1), yMin, yMax) + x_ic = tf.expand_dims(tf.linspace(xMin, xMax, Nx*Ny), axis=-1) # For grid sampling + y_ic = tf.expand_dims(tf.linspace(yMin, yMax, Nx*Ny), axis=-1) # For grid sampling + t_ic = tf.zeros_like(x_ic) + u_ic = u_0(x_ic, y_ic) # Initial condition + t_ic = tf.zeros_like(x_ic) # t=0 for initial condition + u_ic_pred = model(tf.concat([x_ic, y_ic, t_ic], axis=1)) # Predicted initial condition + data_fitting_loss_0 = tf.reduce_mean((u_ic_pred - u_ic) ** 2) + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + # Combine the components of the loss functions + # loss, loss_type = linear_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], tf.exp(lambdas)) + # loss, loss_type = linear_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + # loss, loss_type = chebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], tf.exp(lambdas)) + # loss, loss_type = chebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + # loss, loss_type = smooth_chebyshev_loss_function(.1, [pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + loss, loss_type = augmentedChebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + + # S_loss = S(u_model, v, w) + H_loss = H(tf.reshape(u_model, shape=[Nx, Ny, Nt]), tf.reshape(u_x, shape=[Nx, Ny, Nt]), tf.reshape(u_y, shape=[Nx, Ny, Nt])) + # beta = 1e-3 + # data_fitting_loss = loss = beta*tf.math.log(tf.math.exp(data_fitting_loss_weight_0 * data_fitting_loss_0 / beta) + # + tf.math.exp(data_fitting_loss_weight_l * data_fitting_loss_l / beta) + # + tf.math.exp(data_fitting_loss_weight_r * data_fitting_loss_r / beta)) + # loss = beta*tf.math.log(tf.math.exp(pde_loss_weight * pde_loss / beta) + # + tf.math.exp(data_fitting_loss_weight_0 * data_fitting_loss_0 / beta) + # + tf.math.exp(data_fitting_loss_weight_l * data_fitting_loss_l / beta) + # + tf.math.exp(data_fitting_loss_weight_r * data_fitting_loss_r / beta)) + # data_fitting_loss = tf.math.reduce_max(tf.constant([data_fitting_loss_weight_0 * data_fitting_loss_0, + # data_fitting_loss_weight_l * data_fitting_loss_l, + # data_fitting_loss_weight_r * data_fitting_loss_r])) + # loss = tf.math.reduce_max(tf.constant([pde_loss_weight * pde_loss, + # data_fitting_loss_weight_0 * data_fitting_loss_0, + # data_fitting_loss_weight_l * data_fitting_loss_l, + # data_fitting_loss_weight_r * data_fitting_loss_r])) + + return loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss#, S_loss + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 500 # 5000 # 1000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=1e-2, + decay_steps=epochs, + end_learning_rate=1e-4, + power=3., + cycle=False, + name= 'PolynomialDecay' +) +# learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( +# initial_learning_rate=1e-3, +# decay_steps=50, # 100 +# decay_rate=0.9, +# staircase=False, +# name='ExponentialDecay' +# ) +# learning_rate = tf.keras.optimizers.schedules.CosineDecay( +# initial_learning_rate=1e-3, +# decay_steps=1000, +# alpha=0.0, +# warmup_target=None, +# warmup_steps=0, +# name='CosineDecay' +# ) +learning_rate_type = learning_rate.name + +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +# S_losses_min = [] +# S_losses_max = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +inputs = xyt_train +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + # Create a LearningRateScheduler to update the learning rate + # current_lr = scheduler(epoch, learning_rate) + # tf.keras.backend.set_value(optimizer.lr, current_lr) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + # param_values.append((trainable[-1]).numpy()) + # delta_gradients.append((gradients[-1]).numpy()) + # S_loss_min = tf.reduce_min(S_loss) + # S_loss_max = tf.reduce_max(S_loss) + # S_losses_min.append(S_loss_min.numpy()) + # S_losses_max.append(S_loss_max.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + # lambdas_values.append((trainable[-1]).numpy()) + + H0 = H_loss[0].numpy() + Hf = H_loss[-1].numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(H_abs_error.numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error.numpy()) + + # # Print S_loss, H_loss + # print(f"S_loss at epoch {epoch + 1}: {S_loss.numpy()}") + # print(f"H_loss at epoch {epoch + 1}: {H_loss.numpy()}") + + if len(losses) > 1 and not lambdas.trainable:# and False: + # SoftAdaptive weights update + # num1 = tf.math.exp(pde_losses[-1] - pde_losses[-2]) + # num2 = tf.math.exp(data_fitting_losses_0[-1] - data_fitting_losses_0[-2]) + # num3 = tf.math.exp(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]) + num = tf.nn.softmax([pde_losses[-1] - pde_losses[-2], data_fitting_losses_0[-1] - data_fitting_losses_0[-2], data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]]) + num1 = num[0] + num2 = num[1] + num3 = num[2] + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + # lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'loss.pdf', dpi=300) + +# # Evaluate the function +# x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# y_eval = np.linspace(y_train[0].numpy(), y_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +# inputs_eval = [x_eval, y_eval, t_eval] + +# # Plot the parameters over epochs +# plt.plot(S_losses_min, label='S_loss_min') +# plt.plot(S_losses_max, label='S_loss_max') +# plt.xlabel('Epoch') +# plt.ylabel('Multisymplectic Constant') +# plt.title('Multisymplectic Constant over epochs') +# plt.legend() +# plt.grid() +# +# if save_fig: +# save_fig_string = generate_save_fig_string('S_loss', epochs, learning_rate_type, loss_type) +# # save png +# plt.savefig(save_fig_string, dpi=300) +# # # save pdf +# # plt.savefig('../results/' + 'S_loss.pdf', dpi=300) + + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss.pdf', dpi=300) + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_abs_error = np.array(H_losses_abs_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'cheb_par.pdf', dpi=300) + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/2D/training_history.csv', index=False) +# from mpl_toolkits.mplot3d import Axes3D + +# # Set up meshgrid +# N = 600 +# tspace = np.linspace(0, 2, N + 1) +# xspace = np.linspace(0, 2, N + 1) +# yspace = np.linspace(0, 2, N + 1) +# T, X , Y= np.meshgrid(tspace, xspace, yspace) +# XYTgrid = np.vstack([X.flatten(),Y.flatten(),T.flatten()]).T + +# # Determine predictions of u(t, x) +# u_pred = model(tf.cast(XYTgrid,DTYPE)) + +# # Reshape upred +# U = u_pred.numpy().reshape(N+1,N+1,N+1) + +# # Surface plot of solution u(t,x) +# fig = plt.figure(figsize=(9,6)) +# ax = fig.add_subplot(111, projection='3d') +# ax.plot_surface(X, Y, U, cmap='viridis') +# ax.view_init(35,35) +# ax.set_xlabel('$x$') +# ax.set_ylabel('$y$') +# ax.set_zlabel('$u_\\theta(x,y,t)$') +# ax.set_title('Solution to KdV equation') +# if save_fig: +# save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) +# # save png +# plt.savefig(save_fig_string, dpi=300) +# # # save pdf +# # plt.savefig('../results/' + 'solution.pdf', dpi=300) + +# # Extract the components of lambdas over epochs +# lambda_1 = [l[0] for l in lambdas_values] +# lambda_2 = [l[1] for l in lambdas_values] +# lambda_3 = [l[2] for l in lambdas_values] + +# # Plot the components of lambdas +# plt.figure(figsize=(10, 6)) +# plt.plot(lambda_1, label='$\lambda_1$', color='r') +# plt.plot(lambda_2, label='$\lambda_2$', color='g') +# plt.plot(lambda_3, label='$\lambda_3$', color='b') +# plt.xlabel('Epochs') +# plt.ylabel('Weights Values') +# plt.title('Evolution of weight components over training') +# plt.legend() +# plt.grid() +# + +# # Save the plot if required +# if save_fig: +# save_fig_string = generate_save_fig_string('lambdas', epochs, learning_rate_type, loss_type) +# plt.savefig(save_fig_string, dpi=300) + diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py new file mode 100644 index 0000000..e2350db --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py @@ -0,0 +1,279 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +Nx, Ny, Nt = 10, 10, 10 +N_collocation = Nx*Ny*Nt +d_in = 3 + +xMin, xMax = 0.0, 8.0 +yMin, yMax = 0.0, 8.0 +tMax = 50. + +def choose_width_depth(N_collocation=N_collocation, overparam_factor=3.0, d_in=d_in, d_out=1): + N_target = int(overparam_factor * N_collocation) + depth = 8 + a, b = depth - 1, d_in + depth + d_out + c = d_out - N_target + width = int((-b + np.sqrt(b*b - 4*a*c)) / (2*a)) + return width, depth + +width, depth = choose_width_depth() +dx = (xMax - xMin) / (Nx - 1) +dy = (yMax - yMin) / (Ny - 1) +dt = tMax / (Nt - 1) +h = np.max([dx, dy, dt]) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=True) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +y = torch.linspace(yMin, yMax, Ny, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +y_grid, x_grid, t_grid = torch.meshgrid(y.flatten(), x.flatten(), t.flatten(), indexing='ij') +x_train = x_grid.flatten().reshape(-1, 1) +y_train = y_grid.flatten().reshape(-1, 1) +t_train = t_grid.flatten().reshape(-1, 1) +xyt_train = torch.stack([x_train.flatten(), y_train.flatten(), t_train.flatten()], dim=1) + +save_fig = True + +def u_0(x, y): + epsilon = 0.01 + c1, c2 = 0.45, 0.25 + x1, x2 = 2.5, 3.3 + y1 = 0. + out = 3*c1/(torch.cosh(0.5*torch.sqrt(torch.tensor(c1/epsilon))*((x-x1)**2 + (y-y1)**2)**0.5))**2 + out += 3*c2/(torch.cosh(0.5*torch.sqrt(torch.tensor(c2/epsilon))*((x-x2)**2 + (y-y1)**2)**0.5))**2 + return out + +def periodic_boundary_conditions(model, Nbc=2000): + x = torch.rand(Nbc, 1, device=device) * (xMax - xMin) + xMin + y = torch.rand(Nbc, 1, device=device) * (yMax - yMin) + yMin + t = torch.rand(Nbc, 1, device=device) * tMax + + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + yL = torch.full_like(y, yMin) + yR = torch.full_like(y, yMax) + + uLx = model(torch.cat([xL, y, t], 1)) + uRx = model(torch.cat([xR, y, t], 1)) + uLy = model(torch.cat([x, yL, t], 1)) + uRy = model(torch.cat([x, yR, t], 1)) + + loss = torch.mean((uLx - uRx)**2 + (uLy - uRy)**2) + return loss + +def H(u, u_x, u_y): + return torch.sum((u_x**2 + u_y**2)/2 - u**3/6) * dx * dy + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +def augmentedChebyshev_loss_function(tensors, weights): + par = torch.sigmoid(cheb_par) + ls = linear_loss_function(tensors, weights)[0] + cs = chebyshev_loss_function(tensors, weights)[0] + return par*cs + (1-par)*ls, 'acs' + +class FourierFeatures(nn.Module): + def __init__(self, n_modes=5): + super().__init__() + self.n_modes = n_modes + + def forward(self, inputs): + x, y, t = inputs[:, 0:1], inputs[:, 1:2], inputs[:, 2:3] + features = [t] + for k in range(1, self.n_modes + 1): + features.append(torch.sin(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(torch.cos(2*np.pi*k*(x - xMin)/(xMax-xMin))) + features.append(torch.sin(2*np.pi*k*(y - yMin)/(yMax-yMin))) + features.append(torch.cos(2*np.pi*k*(y - yMin)/(yMax-yMin))) + return torch.cat(features, dim=1) + +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + self.ff = FourierFeatures(n_modes=4) + layers = [] + # input_dim = 3 + 4 * 2 * 4 + input_dim = 17 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(input_dim, num_neurons_per_layer)) + layers.append(nn.Tanh()) + input_dim = num_neurons_per_layer + layers.append(nn.Linear(input_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + x = self.ff(x) + return self.net(x) + +def custom_loss(inputs, model): + x, y, t = inputs[:, 0:1], inputs[:, 1:2], inputs[:, 2:3] + x.requires_grad_(True) + y.requires_grad_(True) + t.requires_grad_(True) + u_model = model(torch.cat([x, y, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + u_y = torch.autograd.grad(u_model.sum(), y, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_yy = torch.autograd.grad(u_y.sum(), y, create_graph=True)[0] + + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + u_xyy = torch.autograd.grad(u_y.sum(), y, create_graph=True)[0] + + pde_loss = torch.mean((u_t + u_model * u_x + u_xxx + u_xyy) ** 2) + + x_ic = torch.linspace(xMin, xMax, Nx*Ny).reshape(-1, 1).to(device) + y_ic = torch.linspace(yMin, yMax, Nx*Ny).reshape(-1, 1).to(device) + t_ic = torch.zeros_like(x_ic) + u_ic = u_0(x_ic, y_ic) + u_ic_pred = model(torch.cat([x_ic, y_ic, t_ic], dim=1)) + data_fitting_loss_0 = torch.mean((u_ic_pred - u_ic) ** 2) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + loss, loss_type = augmentedChebyshev_loss_function([pde_loss, data_fitting_loss_0, data_fitting_loss_l_r], lambdas) + + H_loss = H(u_model.reshape(Nx, Ny, Nt), u_x.reshape(Nx, Ny, Nt), u_y.reshape(Nx, Ny, Nt)) + + return loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + +model = PINNModel().to(device) +epochs = 1000 +lr_schedule = torch.optim.lr_scheduler.PolynomialLR( + torch.optim.Adam(model.parameters(), lr=1e-2), + total_iters=epochs, power=3.0 +) +optimizer = lr_schedule.optimizer + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + +for epoch in range(epochs): + optimizer.zero_grad() + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(xyt_train, model) + loss.backward() + optimizer.step() + lr_schedule.step() + + with torch.no_grad(): + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + +print(f'\nComputation time: {time() - t0:.2f}s') + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/2D_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/2D_H_minmax.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/2D_H_mean_std.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/2D_H_std.png', dpi=300) if save_fig else None +plt.show() + +if cheb_par.requires_grad: + plt.figure(figsize=(10, 6)) + plt.plot(torch.sigmoid(cheb_par).detach()) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + plt.savefig('./results/2D_cheb_par.png', dpi=300) if save_fig else None + plt.show() + +from mpl_toolkits.mplot3d import Axes3D +N = 100 +xspace = torch.linspace(xMin, xMax, N, dtype=DTYPE, device=device) +yspace = torch.linspace(yMin, yMax, N, dtype=DTYPE, device=device) +tspace_val = torch.tensor(tMax, dtype=DTYPE, device=device) +X_grid, Y_grid = torch.meshgrid(xspace, yspace, indexing='ij') +T_grid = torch.full_like(X_grid, tMax) +XYTgrid = torch.stack([X_grid.flatten(), Y_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XYTgrid) +U = u_pred.reshape(N, N) + +X_np = X_grid.cpu().numpy() +Y_np = Y_grid.cpu().numpy() +U_np = U.cpu().numpy() + +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, Y_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$y$') +ax.set_zlabel('$u(x,y,t)$') +ax.set_title('2D PDE Solution') +plt.savefig('./results/2D_solution.png', dpi=300) if save_fig else None +plt.show() diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py new file mode 100644 index 0000000..55ba1d3 --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1.py @@ -0,0 +1,729 @@ +import numpy as np +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +# Nx = 50 +# Nt = 50 +# N_collocation = Nx*Nt + +xMin = -np.pi +xMax = np.pi +tMax = 5. # 10. +d_in = 2 + +def Nx_from_arch(width, depth, fac=2.0, d_in=2, d_out=1): + """ + Given a PINN architecture (width, depth) and an overparam factor fac, + compute Nx = Nt such that: + + Nx * Nt ≈ N_params / fac, + Nx = Nt, + + where N_params is the number of trainable parameters. + + Parameters + ---------- + width : int + Number of neurons per hidden layer. + depth : int + Number of hidden layers. + fac : float + Over-parameterization factor. Typical values: fac = 2 or 3. + d_in : int + Input dimension (usually 2: x,t). + d_out : int + Output dimension (usually 1: u). + + Returns + ------- + Nx : int + Nt : int + Ntheta : int + Total number of trainable parameters. + Ncoll_target : int + Target number of collocation points = Ntheta/fac. + """ + + # Parameter count + Ntheta = (d_in + 1) * width \ + + (depth - 1) * (width * width + width) \ + + d_out * (width + 1) + + # Target collocation count + Ncoll_target = int(Ntheta / fac) + + # Square grid Nx = Nt + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + + return Nx, Nt, Ntheta, Ncoll_target + +width = 80 +depth = 4 + +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +def h_from_NxNt(Nx, Nt, xMin, xMax, tMax): + """ + Compute dx, dt, and h from Nx, Nt and the domain extents. + h is defined as max(dx, dt). + + Returns + ------- + dx : float + dt : float + h : float + """ + + Lx = xMax - xMin + Lt = tMax + + dx = Lx / (Nx - 1) + dt = Lt / (Nt - 1) + + h = max(dx, dt) + + return dx, dt, h + +dx, dt, h = h_from_NxNt(Nx, Nt, xMin, xMax, tMax) + +lambdas = [1., 1., 1.] +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) +do_training = True +cheb_par = tf.Variable(0.5, trainable=False, name='cheb_par', dtype=DTYPE) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) + +x_train = tf.expand_dims(tf.convert_to_tensor(x.flatten()), axis=-1) +t_train = tf.expand_dims(tf.convert_to_tensor(t.flatten()), axis=-1) + +save_fig = True + +# Define the initial condition +def u_0(x): + return 0.2+0.1*tf.math.cos(2 * x) + + +def u_0_x(x): + return -0.2*tf.math.sin(2 * x) + + +def periodic_boundary_conditions(model, Nbc=2000): + + # Random boundary sampling (correct choice) + x = tf.random.uniform((Nbc,1), xMin, xMax) + t = tf.random.uniform((Nbc,1), 0.0, tMax) + + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + + with tf.GradientTape(persistent=True) as tape: + tape.watch([xL, xR]) + + uLx = model(tf.concat([xL, t], 1)) + uRx = model(tf.concat([xR, t], 1)) + + # First derivatives + uxL = tape.gradient(uLx, xL) + uxR = tape.gradient(uRx, xR) + + del tape + + # Enforce periodicity of values AND derivatives + loss = tf.reduce_mean( + (uLx - uRx)**2 + + (uxL - uxR)**2 + ) + + return loss + + +# def H(u, u_x, dx): +# return tf.reduce_sum(tf.pow(u, 3)+u*tf.pow(u_x, 2), axis = -1) * dx +# # return tf.reduce_sum((tf.pow(u, 2)+tf.pow(u_x, 2))/2, axis = -1) * dx + +def ch_density(u, u_x): + return tf.pow(u, 3) + u * tf.pow(u_x, 2) + + +# @tf.function +def H(u, u_x, dx, density_fn=ch_density, axis=-1): + """ + Boole’s rule (8th order) along 'axis' for uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise uses Boole on the largest prefix and trapezoid on remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = tf.shape(f)[axis] + + # Trapezoid as a fallback on short tails + def _trap_rem(rem): + # rem: [..., M] contiguous tail; integrate with trapezoid + return tf.reduce_sum(0.5*(rem[..., 1:] + rem[..., :-1]), axis=-1) * tf.cast(dx, f.dtype) + + # Degenerate + if tf.less_equal(n, 1): + return tf.reduce_sum(f, axis=axis) * dx + + # Largest prefix with (n1-1) % 4 == 0 + n1 = n - ((n - 1) % 4) + # Boole constant for uniform spacing: 2*dx/45 + c = (2.0 * dx) / 45.0 + + # Indices for prefix + idx_prefix = tf.range(n1) + f0 = tf.gather(f, idx_prefix[0::4], axis=axis) # 0,4,8,... + f1 = tf.gather(f, idx_prefix[1::4], axis=axis) # 1,5,9,... + f2 = tf.gather(f, idx_prefix[2::4], axis=axis) # 2,6,10,... + f3 = tf.gather(f, idx_prefix[3::4], axis=axis) # 3,7,11,... + f4 = tf.gather(f, idx_prefix[4::4], axis=axis) # 4,8,12,... (last block end) + + # Weighted sum across blocks + # Boole's block weights per 5 nodes: [7, 32, 12, 32, 7] + # Aggregate across all blocks by summing slices + s = 7.0 * tf.reduce_sum(f0, axis=axis) + s += 32.0 * tf.reduce_sum(f1, axis=axis) + s += 12.0 * tf.reduce_sum(f2, axis=axis) + s += 32.0 * tf.reduce_sum(f3, axis=axis) + s += 7.0 * tf.reduce_sum(f4, axis=axis) + + boole_part = c * s + + # Tail remainder + if tf.equal(n1, n): + return boole_part + + rem = tf.gather(f, tf.range(n1-1, n), axis=axis) # nodes: n1-1 .. n-1 + tail = _trap_rem(rem) + return boole_part + tail + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +def sigmoid_centered(x): + return 2*tf.nn.sigmoid(.5*x) - 1 + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,40 + xt_input = tf.keras.Input(shape=(2,)) + output_u = xt_input + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation='gelu', # tanh + kernel_initializer='glorot_normal', #'glorot_uniform', # glorot_normal + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', + kernel_initializer='glorot_normal', #'glorot_uniform', # glorot_normal + )(output_u) + + return tf.keras.Model(inputs=xt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = tf.cast(epoch, tf.float32) + return lam_max * (1.0 - tf.exp(-kappa * tf.maximum(epoch - start, 0.0))) + + + +# @tf.function +def custom_loss(inputs, model): + x, t = inputs[:, 0:1], inputs[:, 1:2] + + with tf.GradientTape(persistent=True) as outerTape: + outerTape.watch(x) + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + with tf.GradientTape(persistent=False) as tape2: + tape2.watch(x) + tape2.watch(t) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(x) + tape3.watch(t) + u_model = model(tf.stack([x[:, 0], t[:, 0]], axis=1)) + u_x = tape3.gradient(u_model, x) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xxt = tape.gradient(u_xx, t) + u_xxx = tape.gradient(u_xx, x) + + # === Camassa–Holm residual === + r = ( + u_t + - u_xxt + + 3.0 * u_model * u_x + - 2.0 * u_x * u_xx + - u_model * u_xxx + ) + r_x = outerTape.gradient(r, x) + + # Clean up + del tape, tape2, tape3, outerTape + + lam = lambda_grad(epoch) + + # === H1 norm of residual === + pde_loss_L2 = tf.reduce_mean(tf.square(r)) + pde_loss_grad = tf.reduce_mean(tf.square(r_x)) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = tf.where(tf.abs(t) < 1e-6) + x_ic = tf.gather(x, ic_mask[:, 0]) + u_ic = u_0(x_ic) + t_ic = tf.zeros_like(x_ic) + u_ic_pred = model(tf.concat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = tf.reduce_mean(tf.square(u_ic_pred - u_ic)) + + # === Periodic boundary conditions === + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + # === Chebyshev aggregation === + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # === Hamiltonian (monitor only) === + H_loss = H( + tf.reshape(u_model, shape=[Nt, Nx]), + tf.reshape(u_x, shape=[Nt, Nx]), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 1000 # 3000 # 5000 # 2000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +# learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, +# end_learning_rate=1e-5, +# power=2., +# cycle=False, # True +# name='PolynomialDecay' +# ) +# learning_rate_type = 'polynomialDecay' +learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=1e-2, + decay_steps=epochs, # 100 + decay_rate=0.9, + staircase=False, + name='ExponentialDecay' +) +learning_rate_type = 'exponentialDecay' +# learning_rate = tf.keras.optimizers.schedules.CosineDecay( +# initial_learning_rate=1e-3, +# decay_steps=1000, +# alpha=0.0, +# name='CosineDecay', +# warmup_target=None, +# warmup_steps=0 +# ) +# learning_rate_type = 'cosineDecay' + +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +x_train = tf.convert_to_tensor(x_train) +t_train = tf.convert_to_tensor(t_train) +x_grid, t_grid = np.meshgrid(x.flatten(), t.flatten()) +inputs = tf.convert_to_tensor(np.vstack([x_grid.flatten(), t_grid.flatten()]).T) +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + # Create a LearningRateScheduler to update the learning rate + # current_lr = scheduler(epoch, learning_rate) + # tf.keras.backend.set_value(optimizer.lr, current_lr) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + + H0 = H(u_0(x_grid), u_0_x(x_grid), dx) # H0 = H_loss[0].numpy() + Hf = H_loss.numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(tf.reduce_max(H_abs_error).numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error[-1].numpy()) + + # lambdas_values.append((trainable[-1]).numpy()) + if len(losses) > 1 and not lambdas.trainable and do_training: + # SoftAdaptive weights update + num1 = tf.math.exp(pde_losses[-1] - pde_losses[-2]) + num2 = tf.math.exp(data_fitting_losses_0[-1] - data_fitting_losses_0[-2]) + num3 = tf.math.exp(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2]) + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +# print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +# print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +print(f"Hamitonian relative error: {H_rel_error[-1].numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'loss.pdf', dpi=300) + +# Evaluate the function +x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +inputs_eval = [x_eval, t_eval] + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='min H') +plt.plot(H_losses_max, label='max H') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss.pdf', dpi=300) + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_rel_error = np.array(H_losses_rel_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_std.pdf', dpi=300) + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'H_loss_rel_error.pdf', dpi=300) + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'cheb_par.pdf', dpi=300) + + +from mpl_toolkits.mplot3d import Axes3D + +# Set up meshgrid +N = 600 +tspace = np.linspace(0, tMax, N + 1) +xspace = np.linspace(xMin, xMax, N + 1) +T, X = np.meshgrid(tspace, xspace) +XTgrid = np.vstack([X.flatten(),T.flatten()]).T + +# Determine predictions of u(t, x) +u_pred = model(tf.cast(XTgrid,DTYPE)) + +# Reshape upred +U = u_pred.numpy().reshape(N+1,N+1) + +# Surface plot of solution u(t,x) +fig = plt.figure(figsize=(9,6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X, T, U, cmap='viridis') +ax.view_init(35,35) +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u_\\theta(x,t)$') +ax.set_title('Solution to Camassa-Holm equation') +ax.set_box_aspect(None, zoom=0.85) + +if save_fig: + save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + # # save pdf + # plt.savefig('../results/' + 'solution.pdf', dpi=300) + + + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/camassa/training_history.csv', index=False) \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py new file mode 100644 index 0000000..915aedc --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py @@ -0,0 +1,389 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt + +from humancompatible.train.dual_optim import ALM, iALM + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +xMin, xMax = -np.pi, np.pi +tMax = 5. +d_in = 2 + +def Nx_from_arch(width, depth, fac=2.0, d_in=2, d_out=1): + Ntheta = (d_in + 1) * width + (depth - 1) * (width * width + width) + d_out * (width + 1) + Ncoll_target = int(Ntheta / fac) + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + return Nx, Nt, Ntheta, Ncoll_target + +width, depth = 80, 4 +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +dx = (xMax - xMin) / (Nx - 1) +dt = tMax / (Nt - 1) +h = max(dx, dt) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +do_training = True +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=False) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +x_train = x.reshape(-1, 1) +t_train = t.reshape(-1, 1) +t_grid, x_grid = torch.meshgrid(t.flatten(), x.flatten(), indexing='ij') +inputs = torch.stack([x_grid.flatten(), t_grid.flatten()], dim=1) + +save_fig = True + +def u_0(x): + return 0.2 + 0.1 * torch.cos(2 * x) + +def u_0_x(x): + return -0.2 * torch.sin(2 * x) + +def periodic_boundary_conditions(model, Nbc=2000): + x = torch.rand(Nbc, 1, device=device) * (xMax - xMin) + xMin + t = torch.rand(Nbc, 1, device=device) * tMax + + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + + uLx = model(torch.cat([xL, t], 1)) + uRx = model(torch.cat([xR, t], 1)) + + loss = torch.mean((uLx - uRx)**2) + return loss + +def ch_density(u, u_x): + return u**3 + u * u_x**2 + +def H(u, u_x, dx, density_fn=ch_density): + f = density_fn(u, u_x) + return torch.sum(f) * dx + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +### MODEL ### + +# @torch.compile +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + layers = [] + in_dim = 2 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(in_dim, num_neurons_per_layer)) + layers.append(nn.GELU()) + in_dim = num_neurons_per_layer + layers.append(nn.Linear(in_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = float(epoch) + return lam_max * (1.0 - np.exp(-kappa * max(epoch - start, 0.0))) + + +##### UNCONSTRAINED LOSS FUNCTION WITH H1 REGULARIZATION ##### + + +# @torch.compile +def custom_loss(inputs, model, epoch): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxt = torch.autograd.grad(u_xx.sum(), t, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + r = u_t - u_xxt + 3.0 * u_model * u_x - 2.0 * u_x * u_xx - u_model * u_xxx + r_x = torch.autograd.grad(r.sum(), x, create_graph=True)[0] + + pde_loss_L2 = torch.mean(torch.square(r)) + pde_loss_grad = torch.mean(torch.square(r_x)) + + lam = lambda_grad(epoch) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean(torch.square(u_ic_pred - u_ic)) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + H_loss = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + + return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + + +#### LOSS FUNCTION WITH H1 CONSTRAINT #### + +def lagrangian_loss(inputs, model, dual_opt, epoch): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxt = torch.autograd.grad(u_xx.sum(), t, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + r = u_t - u_xxt + 3.0 * u_model * u_x - 2.0 * u_x * u_xx - u_model * u_xxx + r_x = torch.autograd.grad(r.sum(), x, create_graph=True)[0] + + pde_loss_L2 = torch.mean(torch.square(r)) + pde_loss_grad = torch.mean(torch.square(r_x)) + + lam = lambda_grad(epoch) + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean(torch.square(u_ic_pred - u_ic)) + + data_fitting_loss_l_r = periodic_boundary_conditions(model) + + loss, loss_type = chebyshev_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # constraint + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + + Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + H_constraint = torch.abs(Hf - H0)/torch.abs(H0) + + eps = 5/(epoch+1) + H_constraint = torch.max(H_constraint - eps, torch.zeros_like(H_constraint)).unsqueeze(0) + + loss = dual_opt.forward_update(loss, H_constraint) + + return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, Hf + + +####### TRAINING LOOP ####### + + +model = PINNModel().to(device) +epochs = 1000 + +optimizer = torch.optim.NAdam(model.parameters(), lr=1e-2, betas=(0.8, 0.9), eps=1e-07) +lr_schedule = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9**(1/epochs)) + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + + +# dual_opt = ALM(m=1, lr=5e-5, dual_range=(0.,100.), device=device, ctol=1e-3, penalty=0.) +dual_opt = iALM(m=1, lr=0.1, beta=0.01, sigma=1.0001, gamma=1., dual_range=(0.,10.), ctol=1e-3) + +for epoch in range(epochs): + optimizer.zero_grad() + # loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) + + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = lagrangian_loss(inputs, model, dual_opt, epoch) + loss.backward() + optimizer.step() + + + if epoch % 1 == 0: + lr_schedule.step() + + with torch.no_grad(): + + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + Hf = H_loss.detach() + H_abs_error = torch.abs(Hf - H0) + H_losses_abs_error.append(torch.max(H_abs_error).item()) + H_rel_error = H_abs_error / (torch.abs(H0) + 1e-16) + if isinstance(H_rel_error, torch.Tensor): + H_rel_error = H_rel_error.item() if H_rel_error.numel() == 1 else H_rel_error.max().item() + H_losses_rel_error.append(H_rel_error) + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + + # lambdas_values.append((trainable[-1]).numpy()) + if len(losses) > 1: + # SoftAdaptive weights update + num1 = np.exp(pde_losses[-1] - pde_losses[-2]) + num2 = np.exp(data_losses_0[-1] - data_losses_0[-2]) + num3 = np.exp(bc_losses[-1] - bc_losses[-2]) + den = num1 + num2 + num3 + + new_lambdas = torch.tensor([num1 / den, num2 / den, num3 / den]) + lambdas = new_lambdas + # lambdas_values.append((lambdas).numpy()) + +print(f'\nComputation time: {time() - t0:.2f}s') +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean}") +print(f"Hamiltonian std: {H_loss_std}") +print(f"Hamiltonian max: {H_loss_max}") +print(f"Hamiltonian min: {H_loss_min}") + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/ch_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='min H') +plt.plot(H_losses_max, label='max H') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/ch_H_loss.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/ch_H_loss_mean.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_std.png', dpi=300) if save_fig else None +plt.show() + +H_losses_abs_error = np.array(H_losses_abs_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_abs_error.png', dpi=300) if save_fig else None +plt.show() + +H_losses_rel_error = np.array(H_losses_rel_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +plt.grid() +plt.savefig('./results/ch_H_loss_rel_error.png', dpi=300) if save_fig else None +plt.show() + +N = 600 +tspace = torch.linspace(0, tMax, N + 1, dtype=DTYPE, device=device) +xspace = torch.linspace(xMin, xMax, N + 1, dtype=DTYPE, device=device) +T_grid, X_grid = torch.meshgrid(tspace, xspace, indexing='ij') +XTgrid = torch.stack([X_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XTgrid) +U = u_pred.reshape(N+1, N+1) + +X_np = X_grid.cpu().numpy() +T_np = T_grid.cpu().numpy() +U_np = U.cpu().numpy() + +from mpl_toolkits.mplot3d import Axes3D +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, T_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u(x,t)$') +ax.set_title('Camassa-Holm equation') +plt.savefig('./results/ch_solution.png', dpi=300) if save_fig else None +plt.show() + + +import pandas as pd +df = pd.DataFrame() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_losses_0 +df['data_fitting_loss_l_r'] = bc_losses +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/camassa/torch_training_history.csv', index=False) \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py new file mode 100644 index 0000000..781639e --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1.py @@ -0,0 +1,769 @@ +import numpy as np +from sympy import false, per +import tensorflow as tf +from time import time +import matplotlib.pyplot as plt + +# TF_USE_LEGACY_KERAS=True + +DTYPE = np.float32 +# Nx = 100 +# Nt = 100 +# N_collocation = Nx*Nt + +type = 1 # 1 for KdV, 2 + +if type == 1: + nu = -0.022**2 + alpha = -0.5 + rho = 0. + xMin = 0. + xMax = 2. + tMax = 5. # 10. +elif type == 2: + nu = -1. + alpha = -3. + rho = 0. + xMin = -20. + xMax = 20. + tMax = 100. # 4. + + +def Nx_from_arch(width, depth, fac=1.5, d_in=2, d_out=1): + """ + Given a PINN architecture (width, depth) and an overparam factor fac, + compute Nx = Nt such that: + + Nx * Nt ≈ N_params / fac, + Nx = Nt, + + where N_params is the number of trainable parameters. + + Parameters + ---------- + width : int + Number of neurons per hidden layer. + depth : int + Number of hidden layers. + fac : float + Over-parameterization factor. Typical values: fac = 2 or 3. + d_in : int + Input dimension (usually 2: x,t). + d_out : int + Output dimension (usually 1: u). + + Returns + ------- + Nx : int + Nt : int + Ntheta : int + Total number of trainable parameters. + Ncoll_target : int + Target number of collocation points = Ntheta/fac. + """ + + # Parameter count + Ntheta = (d_in + 1) * width \ + + (depth - 1) * (width * width + width) \ + + d_out * (width + 1) + + # Target collocation count + Ncoll_target = int(Ntheta / fac) + + # Square grid Nx = Nt + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + + return Nx, Nt, Ntheta, Ncoll_target + + +width = 80 +depth = 4 + +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +def h_from_NxNt(Nx, Nt, xMin, xMax, tMax): + """ + Compute dx, dt, and h from Nx, Nt and the domain extents. + h is defined as max(dx, dt). + + Returns + ------- + dx : float + dt : float + h : float + """ + + Lx = xMax - xMin + Lt = tMax + + dx = Lx / (Nx - 1) + dt = Lt / (Nt - 1) + + h = max(dx, dt) + + return dx, dt, h + +dx, dt, h = h_from_NxNt(Nx, Nt, xMin, xMax, tMax) + +x = np.linspace(xMin, xMax, Nx).reshape((-1, 1)).astype(DTYPE) +t = np.linspace(0, tMax, Nt).reshape((-1, 1)).astype(DTYPE) + +x_train = tf.expand_dims(tf.convert_to_tensor(x.flatten()), axis=-1) +t_train = tf.expand_dims(tf.convert_to_tensor(t.flatten()), axis=-1) + +lambdas = [1., 1., 1.] +lambdas = tf.Variable(lambdas, trainable=False, name='lambdas', dtype=DTYPE) +do_training = False +cheb_par = tf.Variable(0.5, trainable=False, name='cheb_par', dtype=DTYPE) + +save_fig = True + +# Define the initial condition +def u_0(x): + if type == 1: + return tf.math.cos(np.pi * x) + elif type == 2: + return 6./(tf.math.cosh(x)**2) + + +def u_0_x(x): + if type == 1: + return -np.pi*tf.math.sin(np.pi * x) + elif type == 2: + return -12.*tf.math.sinh(x)/(tf.math.cosh(x)**3) + + +def periodic_bc(model, x, t): + xL = tf.ones_like(x) * xMin + xR = tf.ones_like(x) * xMax + uL = model(tf.concat([xL, t], axis=1)) + uR = model(tf.concat([xR, t], axis=1)) + return tf.reduce_mean((uL - uR)**2) + + +def V(u): + return alpha*tf.pow(u, 3)/3 + rho*tf.pow(u, 2)/2 + + +def kdv_density(u, u_x): + return V(u)-nu*tf.pow(u_x, 2)/2 + +# @tf.function +def H(u, u_x, dx, density_fn=kdv_density, axis=-1): + """ + Boole’s rule (8th order) along 'axis' for uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise uses Boole on the largest prefix and trapezoid on remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = tf.shape(f)[axis] + + # Trapezoid as a fallback on short tails + def _trap_rem(rem): + # rem: [..., M] contiguous tail; integrate with trapezoid + return tf.reduce_sum(0.5*(rem[..., 1:] + rem[..., :-1]), axis=-1) * tf.cast(dx, f.dtype) + + # Degenerate + if tf.less_equal(n, 1): + return tf.reduce_sum(f, axis=axis) * dx + + # Largest prefix with (n1-1) % 4 == 0 + n1 = n - ((n - 1) % 4) + # Boole constant for uniform spacing: 2*dx/45 + c = (2.0 * dx) / 45.0 + + # Indices for prefix + idx_prefix = tf.range(n1) + f0 = tf.gather(f, idx_prefix[0::4], axis=axis) # 0,4,8,... + f1 = tf.gather(f, idx_prefix[1::4], axis=axis) # 1,5,9,... + f2 = tf.gather(f, idx_prefix[2::4], axis=axis) # 2,6,10,... + f3 = tf.gather(f, idx_prefix[3::4], axis=axis) # 3,7,11,... + f4 = tf.gather(f, idx_prefix[4::4], axis=axis) # 4,8,12,... (last block end) + + # Weighted sum across blocks + # Boole's block weights per 5 nodes: [7, 32, 12, 32, 7] + # Aggregate across all blocks by summing slices + s = 7.0 * tf.reduce_sum(f0, axis=axis) + s += 32.0 * tf.reduce_sum(f1, axis=axis) + s += 12.0 * tf.reduce_sum(f2, axis=axis) + s += 32.0 * tf.reduce_sum(f3, axis=axis) + s += 7.0 * tf.reduce_sum(f4, axis=axis) + + boole_part = c * s + + # Tail remainder + if tf.equal(n1, n): + return boole_part + + rem = tf.gather(f, tf.range(n1-1, n), axis=axis) # nodes: n1-1 .. n-1 + tail = _trap_rem(rem) + return boole_part + tail + + +def linear_loss_function(tensors, weights): + """ + Computes the sum of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the sum. + + Returns: + tf.Tensor: The sum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) # shape (n_losses,) + weights = weights / tf.reduce_sum(weights) + loss = tf.reduce_sum(weights * stacked) + loss_type = 'ls' + return loss, loss_type + + +def chebyshev_loss_function(tensors, weights): + """ + Computes the max of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The maximum of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + loss = tf.reduce_max(weights*stacked) + loss_type = 'cs' + return loss, loss_type + + +def smooth_chebyshev_loss_function(mu, tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + weights = weights / tf.reduce_sum(weights) # Normalize weights + # stack_tensor = tf.stack(tf.multiply(tensors, weights)) + # stack_tensor = tf.stack([w * t for w, t in zip(weights, tensors)], axis=0) + stacked = tf.stack(tensors, axis=0) + exp_sum = tf.reduce_sum(tf.math.exp(stacked/mu), axis=0) + loss = mu*tf.math.log(exp_sum) + loss_type = 'scs' + return loss, loss_type + + +def augmentedChebyshev_loss_function(tensors, weights): + """ + Computes the log of the sum of the exponentials of the input tensors. + + Args: + tensors (list of tf.Tensor): List of tensors to compute the log-sum-exp. + + Returns: + tf.Tensor: The log-sum-exp of the input tensors. + """ + # weights = weights / tf.reduce_sum(weights) # Normalize weights + loss_type = 'acs' + par = tf.sigmoid(cheb_par) # par is between 0 and 1 + return par*chebyshev_loss_function(tensors, weights)[0] + (1-par)*linear_loss_function(tensors, weights)[0], loss_type + + +def sigmoid_centered(x): + return 2*tf.nn.sigmoid(x) - 1 + + +def PINNModel(num_hidden_layers=depth, num_neurons_per_layer=width): # 8,40 + xt_input = tf.keras.Input(shape=(2,)) + output_u = xt_input + for _ in range(num_hidden_layers): + output_u = tf.keras.layers.Dense(num_neurons_per_layer, + activation=sigmoid_centered, # mish + kernel_initializer='glorot_uniform', # glorot_normal + # kernel_constraint=tf.keras.constraints.UnitNorm(axis=0) + # lora_rank=10 + )(output_u) + + output_u = tf.keras.layers.Dense(units=1, + activation='linear', + kernel_initializer='glorot_uniform', # glorot_normal + # kernel_constraint=tf.keras.constraints.UnitNorm(axis=0) + # lora_rank=10 + )(output_u) + + return tf.keras.Model(inputs=xt_input, outputs=output_u) #tf.keras.Model(inputs=[x_input, t_input], outputs=output_u) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e-0, + kappa=1e-3): + epoch = tf.cast(epoch, tf.float32) + return lam_max * (1.0 - tf.exp(-kappa * tf.maximum(epoch - start, 0.0))) + + +@tf.function +def grad_L2_fft_batch(r, L): + """ + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = tf.cast(r, tf.complex64) + Nx = tf.shape(r)[-1] + + k_pos = tf.range(0, Nx//2 + 1, dtype=tf.float32) + k_neg = tf.range(-Nx//2 + 1, 0, dtype=tf.float32) + k = tf.concat([k_pos, k_neg], axis=0) + k = (2.0 * tf.constant(np.pi) / L) * k + k = tf.cast(k, tf.complex64) + + r_hat = tf.signal.fft(r) + + grad_energy = tf.reduce_sum(tf.abs(1j * k * r_hat)**2, axis=-1) + + dx = L / tf.cast(Nx, tf.float32) + return tf.math.real(grad_energy) * dx + + +@tf.function +def H1_norm_fft_batch(r, L): + """ + Compute ||r||_{H^1}^2 for each time slice. + + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = tf.cast(r, tf.complex64) + Nx = tf.shape(r)[-1] + + k_pos = tf.range(0, Nx//2 + 1, dtype=tf.float32) + k_neg = tf.range(-Nx//2 + 1, 0, dtype=tf.float32) + k = tf.concat([k_pos, k_neg], axis=0) + k = (2.0 * tf.constant(np.pi) / L) * k + k = tf.cast(k, tf.complex64) + + r_hat = tf.signal.fft(r) + weight = 1.0 + tf.abs(k)**2 + + H1_sq = tf.reduce_sum(weight * tf.abs(r_hat)**2, axis=-1) + + dx = L / tf.cast(Nx, tf.float32) + return tf.math.real(H1_sq) * dx + + +# @tf.function +def custom_loss(inputs, model, epoch): + x, t = inputs[:, 0:1], inputs[:, 1:2] + + with tf.GradientTape(persistent=True) as tape: + tape.watch(t) + tape.watch(x) + with tf.GradientTape(persistent=True) as tape2: + tape2.watch(x) + tape2.watch(t) + with tf.GradientTape(persistent=True) as tape3: + tape3.watch(x) + tape3.watch(t) + u_model = model(tf.concat([x, t], axis=1)) + u_x = tape3.gradient(u_model, x) + u_t = tape3.gradient(u_model, t) + u_xx = tape2.gradient(u_x, x) + u_xxx = tape.gradient(u_xx, x) + u_squared_x = 2*u_model*u_x + r = u_t - alpha * u_squared_x - rho*u_x - nu*u_xxx + del tape, tape2, tape3 + + # === PDE residual loss (stabilized, consistent) === + pde_loss_L2 = tf.reduce_mean(tf.square(r)) + + r_grid = tf.reshape(r, [Nt, Nx]) + L = xMax - xMin + + pde_loss_grad = tf.reduce_mean( + grad_L2_fft_batch(r_grid, L) + ) + + # mesh-scaled stabilization parameter + lam = 0.01 * (dx**2) * tf.minimum(1.0, epoch / 1000.0) + + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = tf.where(tf.abs(t) < 1e-6) + x_ic = tf.gather(x, ic_mask[:, 0]) + u_ic = u_0(x_ic) + t_ic = tf.zeros_like(x_ic) + u_ic_pred = model(tf.concat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = tf.reduce_mean(tf.square(u_ic_pred - u_ic)) + + # === Periodic BC === + data_fitting_loss_l_r = periodic_bc(model, x, t) + + # === Chebyshev aggregation === + # loss, loss_type = chebyshev_loss_function( + # [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + # lambdas + # ) + # loss, loss_type = augmentedChebyshev_loss_function( + # [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + # lambdas + # ) + loss, loss_type = linear_loss_function( + [pde_loss_H1, data_fitting_loss_0, data_fitting_loss_l_r], + lambdas + ) + + # === Hamiltonian (monitor only) === + H_loss = H( + tf.reshape(u_model, shape=[Nt, Nx]), + tf.reshape(u_x, shape=[Nt, Nx]), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + + +# Create the PINN model +model = PINNModel() +model.summary() + +epochs = 5000 # 1000, 2000, 5000 +# # Compile the model +# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), +# loss=lambda y_true, y_pred: custom_loss([x_train, t_train, theta_train], model)[1]) + +# Create the optimizer with a smaller learning rate +# learning_rate = 1e-3 # 1e-4 +# learning_rate_type = 'constant' +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([10, 100], [1e-1, 5e-2, 1e-2]) #OK +# learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([100, 300], [1e-2, 1e-3, 1e-4]) +# learning_rate = tf.keras.optimizers.schedules.PolynomialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, +# end_learning_rate=1e-5, +# power=2., +# cycle=False, +# name='PolynomialDecay' +# ) +# learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( +# initial_learning_rate=1e-3, +# decay_steps=epochs, # 100 +# decay_rate=0.9, # 0.9 +# staircase=False, +# name='ExponentialDecay' +# ) +# learning_rate_type = 'exponentialDecay' +learning_rate = tf.keras.optimizers.schedules.CosineDecay( + initial_learning_rate=1e-4, + decay_steps=1000, + alpha=0.5, + name='CosineDecay', + warmup_target=None, + warmup_steps=100 +) +learning_rate_type = 'cosineDecay' +# param_values = [delta.numpy()] +trainable = model.trainable_variables +if lambdas.trainable: + trainable += [lambdas] + +if cheb_par.trainable: + trainable += [cheb_par] + +# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, amsgrad=True) +# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True) +optimizer = tf.keras.optimizers.AdamW(learning_rate=learning_rate, beta_1=0.8, beta_2=0.9, epsilon=1e-07) +# optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False) + +# Training loop +losses = [] +pde_losses = [] +data_fitting_losses_0 = [] +data_fitting_losses_l_r = [] +delta_gradients = [] +# S_losses_min = [] +# S_losses_max = [] +H_losses_min = [] +H_losses_max = [] +H_losses_mean = [] +H_losses_std = [] +H_losses_abs_error = [] +H_losses_rel_error = [] +lambdas_values = [] +lambdas_values.append(lambdas.numpy()) +cheb_par_values = [] +cheb_par_values.append(cheb_par.numpy()) + +# Convert data to tensor because tf.GradientTape() can only watch tensor and not numpy arrays +x_train = tf.convert_to_tensor(x_train) +t_train = tf.convert_to_tensor(t_train) +x_grid, t_grid = np.meshgrid(x.flatten(), t.flatten()) +inputs = tf.convert_to_tensor(np.vstack([x_grid.flatten(), t_grid.flatten()]).T) +stop = False +# Start timer +t0 = time() +for epoch in range(epochs): + if not stop: + # print("# STARTING EPOCH", epoch + 1) + + with tf.GradientTape() as tape: + loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss = custom_loss(inputs, model, epoch) + + # print("Computing gradients") + gradients = tape.gradient(loss, trainable) + # print(gradients[-1]) + # print("Applying gradients") + optimizer.apply_gradients(zip(gradients, trainable)) + # print("Appending losses") + losses.append(loss.numpy()) + pde_losses.append(pde_loss.numpy()) + data_fitting_losses_0.append(data_fitting_loss_0.numpy()) + data_fitting_losses_l_r.append(data_fitting_loss_l_r.numpy()) + H_loss_min = tf.reduce_min(H_loss) + H_loss_max = tf.reduce_max(H_loss) + H_losses_min.append(H_loss_min.numpy()) + H_losses_max.append(H_loss_max.numpy()) + H_loss_mean = tf.reduce_mean(H_loss) + H_loss_std = tf.math.reduce_std(H_loss) + H_losses_mean.append(H_loss_mean.numpy()) + H_losses_std.append(H_loss_std.numpy()) + + H0 = H(u_0(x_grid), u_0_x(x_grid), dx) # H0 = H_loss[0].numpy() + Hf = H_loss.numpy() + H_abs_error = tf.abs(Hf - H0) + H_losses_abs_error.append(tf.reduce_max(H_abs_error).numpy()) + H_rel_error = H_abs_error / tf.abs((H0 + 1e-16)) + H_losses_rel_error.append(H_rel_error[-1].numpy()) + + # # Print S_loss, H_loss + # print(f"S_loss at epoch {epoch + 1}: {S_loss.numpy()}") + # print(f"H_loss at epoch {epoch + 1}: {H_loss.numpy()}") + + if len(losses) > 1 and not lambdas.trainable and do_training: + # SoftAdaptive weights update + # num1 = tf.math.exp(tf.experimental.numpy.cbrt(pde_losses[-1] - pde_losses[-2])) + # num2 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + # num3 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + num1 = tf.math.exp((pde_losses[-1] - pde_losses[-2])) + num2 = tf.math.exp((data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + num3 = tf.math.exp((data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + den = num1 + num2 + num3 + + new_lambdas = tf.stack([num1 / den, num2 / den, num3 / den]) + lambdas.assign(new_lambdas) + lambdas_values.append((lambdas).numpy()) + + if cheb_par.trainable: + cheb_par_values.append(cheb_par.numpy()) + + del tape + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}") + + # if len(losses) > 2 and np.abs(losses[-1] - losses[-2]) / np.abs(losses[-2]) < 1e-8: + # stop = True + +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean.numpy()}") +print(f"Hamiltonian standard deviation: {H_loss_std.numpy()}") +print(f"Hamiltonian maximum: {H_loss_max.numpy()}") +print(f"Hamiltonian minimum: {H_loss_min.numpy()}") +# print(f"Hamiltonian absolute error: {H_abs_error.numpy()}") +# print(f"Hamiltonian relative error: {H_rel_error.numpy()}") +print(f"Hamitonian relative error: {H_rel_error[-1].numpy()}") +# Print computation time +print('\nComputation time: {} seconds'.format(time() - t0)) + +import pandas as pd + +df = pd.DataFrame() +df['epoch'] = range(1, epochs + 1) + +def generate_save_fig_string(type, epochs, learning_rate_type, loss_type): + """ + Generates a string for saving figures that includes the number of epochs and the type of learning rate. + + Args: + epochs (int): The number of epochs. + learning_rate_type (str): The type of learning rate. + + Returns: + str: The generated string for saving figures. + """ + return f"./results/{type}_epochs_{epochs}_lr_{learning_rate_type}_{loss_type}.png" + +# Plot the loss history +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_fitting_losses_0, label='Initial Conditions Loss') +plt.semilogy(data_fitting_losses_l_r, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() + +df['total_loss'] = losses +df['pde_loss'] = pde_losses +df['data_fitting_loss_0'] = data_fitting_losses_0 +df['data_fitting_loss_l_r'] = data_fitting_losses_l_r +df['H_loss_min'] = H_losses_min +df['H_loss_max'] = H_losses_max +df['H_loss_mean'] = H_losses_mean +df['H_loss_std'] = H_losses_std +df['H_loss_abs_error'] = H_losses_abs_error +df['H_loss_rel_error'] = H_losses_rel_error +# df['cheb_par'] = cheb_par_values + +df.to_csv('./results/kdv/training_history.csv', index=False) + + +if save_fig: + save_fig_string = generate_save_fig_string('loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Evaluate the function +x_eval = np.linspace(x_train[0].numpy(), x_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +t_eval = np.linspace(t_train[0].numpy(), t_train[-1].numpy(), 100).reshape((-1, 1)).astype(np.float32) +inputs_eval = [x_eval, t_eval] + + +# Plot the Hamiltonian over epochs +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the average Hamiltonian over epochs with standard deviation +H_losses_mean = np.array(H_losses_mean) +H_losses_std = np.array(H_losses_std) +H_losses_abs_error = np.array(H_losses_abs_error) +H_losses_rel_error = np.array(H_losses_rel_error) + +plt.plot(H_losses_mean) +plt.fill_between(range(len(H_losses_mean)), H_losses_mean - H_losses_std, H_losses_mean + H_losses_std, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_mean', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the standard deviation of the Hamiltonian over epochs +plt.plot(H_losses_std) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_std', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the absolute error of the Hamiltonian over epochs +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_abs_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +# Plot the relative error of the Hamiltonian over epochs +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +# plt.legend() +plt.grid() + +if save_fig: + save_fig_string = generate_save_fig_string('H_loss_rel_error', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + + +# Plot the Chebyshev parameter over epochs +if cheb_par.trainable: + plt.plot(tf.sigmoid(cheb_par_values)) + plt.xlabel('Epoch') + plt.ylabel('Chebyshev parameter') + plt.title('Chebyshev parameter over epochs') + plt.grid() + + if save_fig: + save_fig_string = generate_save_fig_string('cheb_par', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() + +from mpl_toolkits.mplot3d import Axes3D + +# Set up meshgrid +N = 600 +tspace = np.linspace(0, tMax, N + 1) +xspace = np.linspace(xMin, xMax, N + 1) +T, X = np.meshgrid(tspace, xspace) +XTgrid = np.vstack([X.flatten(),T.flatten()]).T + +# Determine predictions of u(t, x) +u_pred = model(tf.cast(XTgrid,DTYPE)) + +# Reshape upred +U = u_pred.numpy().reshape(N+1,N+1) + +# Surface plot of solution u(t,x) +fig = plt.figure(figsize=(9,6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X, T, U, cmap='viridis') +ax.view_init(35,35) +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u_(x,t)$') +ax.set_title('Solution to KdV equation') +ax.set_box_aspect(None, zoom=0.85) + +if save_fig: + save_fig_string = generate_save_fig_string('sol', epochs, learning_rate_type, loss_type) + # save png + plt.savefig(save_fig_string, dpi=300) +plt.show() \ No newline at end of file diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py new file mode 100644 index 0000000..4a028ea --- /dev/null +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_kdv_trainableH1_pytorch.py @@ -0,0 +1,573 @@ +import numpy as np +import torch +import torch.nn as nn +from time import time +import matplotlib.pyplot as plt +from humancompatible.train.dual_optim import ALM + +DTYPE = torch.float32 +torch.set_default_dtype(DTYPE) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +type_pde = 1 +if type_pde == 1: + nu, alpha, rho = -0.022**2, -0.5, 0. + xMin, xMax, tMax = 0., 2., 5. +elif type_pde == 2: + nu, alpha, rho = -1., -3., 0. + xMin, xMax, tMax = -20., 20., 100. + +def Nx_from_arch(width, depth, fac=1.5, d_in=2, d_out=1): + Ntheta = (d_in + 1) * width + (depth - 1) * (width * width + width) + d_out * (width + 1) + Ncoll_target = int(Ntheta / fac) + Nx = int(np.sqrt(Ncoll_target)) + Nt = Nx + return Nx, Nt, Ntheta, Ncoll_target + +width, depth = 80, 4 +Nx, Nt, Ntheta, Ncoll = Nx_from_arch(width=width, depth=depth, fac=10.) + +dx = (xMax - xMin) / (Nx - 1) +dt = tMax / (Nt - 1) +h = max(dx, dt) + +lambdas = torch.tensor([1., 1., 1.], dtype=DTYPE, device=device, requires_grad=False) +# do_training = False +cheb_par = torch.tensor(0.5, dtype=DTYPE, device=device, requires_grad=False) + +x = torch.linspace(xMin, xMax, Nx, dtype=DTYPE, device=device).reshape(-1, 1) +t = torch.linspace(0, tMax, Nt, dtype=DTYPE, device=device).reshape(-1, 1) +x_train = x.reshape(-1, 1) +t_train = t.reshape(-1, 1) +# t_grid, x_grid = torch.meshgrid(t.flatten(), x.flatten(), indexing='ij') +x_grid, t_grid = torch.meshgrid(x.flatten(), t.flatten(), indexing='xy') +inputs = torch.stack([x_grid.flatten(), t_grid.flatten()], dim=1) + +save_fig = True + +def u_0(x): + if type_pde == 1: + return torch.cos(np.pi * x) + elif type_pde == 2: + return 6. / (torch.cosh(x)**2) + +def u_0_x(x): + if type_pde == 1: + return -np.pi * torch.sin(np.pi * x) + elif type_pde == 2: + return -12. * torch.sinh(x) / (torch.cosh(x)**3) + +def periodic_bc(model, x, t): + xL = torch.full_like(x, xMin) + xR = torch.full_like(x, xMax) + uL = model(torch.cat([xL, t], dim=1)) + uR = model(torch.cat([xR, t], dim=1)) + return torch.mean((uL - uR)**2) + +def V(u): + return alpha * (u**3) / 3 + (rho * u**2) / 2 + +def kdv_density(u, u_x): + return V(u) - nu * torch.pow(u_x, 2) / 2 + + +def H(u, u_x, dx, density_fn=kdv_density, axis=-1): + """ + Boole’s rule (8th order) along `axis` for a uniform grid with spacing dx. + Requires (N-1) % 4 == 0. Otherwise applies Boole on the largest valid + prefix and trapezoid rule on the remainder. + """ + f = density_fn(u, u_x) # [..., N] + n = f.shape[axis] + + # Normalize negative axis + axis = axis % f.ndim + + def _trap_rem(rem): + """ + rem: contiguous tail segment integrated with trapezoid rule + """ + left = rem.narrow(-1, 0, rem.shape[-1] - 1) + right = rem.narrow(-1, 1, rem.shape[-1] - 1) + return torch.sum(0.5 * (left + right), dim=-1) * dx + + # Degenerate case + if n <= 1: + return torch.sum(f, dim=axis) * dx + + # Largest prefix satisfying (n1 - 1) % 4 == 0 + n1 = n - ((n - 1) % 4) + + # Boole constant + c = (2.0 * dx) / 45.0 + + # Build index slices + idx_prefix = torch.arange(n1, device=f.device) + + f0 = torch.index_select(f, axis, idx_prefix[0::4]) + f1 = torch.index_select(f, axis, idx_prefix[1::4]) + f2 = torch.index_select(f, axis, idx_prefix[2::4]) + f3 = torch.index_select(f, axis, idx_prefix[3::4]) + f4 = torch.index_select(f, axis, idx_prefix[4::4]) + + # Weighted Boole sum + s = 7.0 * torch.sum(f0, dim=axis) + s += 32.0 * torch.sum(f1, dim=axis) + s += 12.0 * torch.sum(f2, dim=axis) + s += 32.0 * torch.sum(f3, dim=axis) + s += 7.0 * torch.sum(f4, dim=axis) + + boole_part = c * s + + # No remainder + if n1 == n: + return boole_part + + # Tail remainder + rem_idx = torch.arange(n1 - 1, n, device=f.device) + rem = torch.index_select(f, axis, rem_idx) + + tail = _trap_rem(rem) + + return boole_part + tail + +def linear_loss_function(tensors, weights): + stacked = torch.stack(tensors) + weights = weights / torch.sum(weights) + loss = torch.sum(weights * stacked) + return loss, 'ls' + +def chebyshev_loss_function(tensors, weights): + stacked = torch.stack(tensors) + loss = torch.max(weights * stacked) + return loss, 'cs' + +def sigmoid_centered(x): + return 2 * torch.sigmoid(x) - 1 + +class PINNModel(nn.Module): + def __init__(self, num_hidden_layers=depth, num_neurons_per_layer=width): + super().__init__() + layers = [] + in_dim = 2 + for _ in range(num_hidden_layers): + layers.append(nn.Linear(in_dim, num_neurons_per_layer)) + layers.append(nn.Tanh()) + in_dim = num_neurons_per_layer + layers.append(nn.Linear(in_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +def lambda_grad(epoch, + start=1000, + lam_max=1e0, + kappa=1e-3): + epoch = torch.as_tensor(epoch, dtype=torch.float32) + return lam_max * ( + 1.0 - torch.exp(-kappa * torch.clamp(epoch - start, min=0.0)) + ) + + +def grad_L2_fft_batch(r, L): + """ + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = r.to(torch.complex64) + Nx = r.shape[-1] + + device = r.device + + k_pos = torch.arange(0, Nx // 2 + 1, + dtype=torch.float32, + device=device) + + k_neg = torch.arange(-Nx // 2 + 1, 0, + dtype=torch.float32, + device=device) + + k = torch.cat([k_pos, k_neg], dim=0) + k = (2.0 * np.pi / L) * k + k = k.to(torch.complex64) + + r_hat = torch.fft.fft(r, dim=-1) + + grad_energy = torch.sum(torch.abs(1j * k * r_hat) ** 2, dim=-1) + + dx = L / float(Nx) + + return torch.real(grad_energy) * dx + + +def H1_norm_fft_batch(r, L): + """ + Compute ||r||_{H^1}^2 for each time slice. + + r : shape (Nt, Nx) + returns : shape (Nt,) + """ + r = r.to(torch.complex64) + Nx = r.shape[-1] + + device = r.device + + k_pos = torch.arange(0, Nx // 2 + 1, + dtype=torch.float32, + device=device) + + k_neg = torch.arange(-Nx // 2 + 1, 0, + dtype=torch.float32, + device=device) + + k = torch.cat([k_pos, k_neg], dim=0) + k = (2.0 * np.pi / L) * k + k = k.to(torch.complex64) + + r_hat = torch.fft.fft(r, dim=-1) + + weight = 1.0 + torch.abs(k) ** 2 + + H1_sq = torch.sum(weight * torch.abs(r_hat) ** 2, dim=-1) + + dx = L / float(Nx) + + return torch.real(H1_sq) * dx + + +def custom_loss(inputs, model, epoch): + """ + Assumes the following globals/functions exist: + + alpha, rho, nu + Nt, Nx + dx, xMin, xMax + lambdas + + u_0(...) + periodic_bc(...) + linear_loss_function(...) + H(...) + """ + + x = inputs[:, 0:1].clone().detach().requires_grad_(True) + t = inputs[:, 1:2].clone().detach().requires_grad_(True) + + xt = torch.cat([x, t], dim=1) + + # Forward pass + u_model = model(xt) + + # First derivatives + u_x = torch.autograd.grad( + u_model, + x, + grad_outputs=torch.ones_like(u_model), + create_graph=True, + retain_graph=True, + )[0] + + u_t = torch.autograd.grad( + u_model, + t, + grad_outputs=torch.ones_like(u_model), + create_graph=True, + retain_graph=True, + )[0] + + # Second derivative + u_xx = torch.autograd.grad( + u_x, + x, + grad_outputs=torch.ones_like(u_x), + create_graph=True, + retain_graph=True, + )[0] + + # Third derivative + u_xxx = torch.autograd.grad( + u_xx, + x, + grad_outputs=torch.ones_like(u_xx), + create_graph=True, + retain_graph=True, + )[0] + + # PDE residual + u_squared_x = 2 * u_model * u_x + + r = ( + u_t + - alpha * u_squared_x + - rho * u_x + - nu * u_xxx + ) + + # === PDE residual loss === + pde_loss_L2 = torch.mean(r ** 2) + + r_grid = r.reshape(Nt, Nx) + + L = xMax - xMin + + pde_loss_grad = torch.mean( + grad_L2_fft_batch(r_grid, L) + ) + + # mesh-scaled stabilization parameter + lam = 0.01 * (dx ** 2) * min(1.0, float(epoch) / 1000.0) + + pde_loss_H1 = pde_loss_L2 + lam * pde_loss_grad + + # === Initial condition === + ic_mask = torch.where(torch.abs(t) < 1e-6)[0] + + x_ic = x[ic_mask] + + u_ic = u_0(x_ic) + + t_ic = torch.zeros_like(x_ic) + + u_ic_pred = model(torch.cat([x_ic, t_ic], dim=1)) + + data_fitting_loss_0 = torch.mean( + (u_ic_pred - u_ic) ** 2 + ) + + # === Periodic BC === + data_fitting_loss_l_r = periodic_bc(model, x, t) + + # === Aggregated loss === + loss, loss_type = linear_loss_function( + [ + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + ], + lambdas + ) + + # === Hamiltonian (monitor only) === + # breakpoint() + H_loss = H( + u_model.reshape(Nt, Nx), + u_x.reshape(Nt, Nx), + dx + ) + + return ( + loss, + loss_type, + pde_loss_H1, + data_fitting_loss_0, + data_fitting_loss_l_r, + H_loss, + ) + +def lagrangian_loss(inputs, model, dual_opt): + x, t = inputs[:, 0:1], inputs[:, 1:2] + x.requires_grad_(True) + t.requires_grad_(True) + + u_model = model(torch.cat([x, t], dim=1)) + + u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] + u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + + u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] + u_xxx = torch.autograd.grad(u_xx.sum(), x, create_graph=True)[0] + + u_squared_x = 2 * u_model * u_x + r = u_t - alpha * u_squared_x - rho * u_x - nu * u_xxx + + pde_loss_L2 = torch.mean(torch.square(r)) + + # constraint: pde_loss_L2 = 0 or <= eps + + ic_mask = torch.abs(t) < 1e-6 + x_ic = x[ic_mask[:, 0]] + u_ic = u_0(x_ic) + t_ic = torch.zeros_like(x_ic) + u_ic_pred = model(torch.cat([x_ic, t_ic], axis=1)) + data_fitting_loss_0 = torch.mean((u_ic_pred - u_ic) ** 2) # IC loss + + data_fitting_loss_l_r = periodic_bc(model, x, t) # BC loss + + # ask: what's H_loss here, and should we use it in a constraint + H_loss = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + + data_fitting_loss = 0.5 * data_fitting_loss_0 + 0.5 * data_fitting_loss_l_r + + lagr = dual_opt.forward_update(data_fitting_loss, pde_loss_L2.unsqueeze(0)) + loss_type = 'ls' + + return lagr, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, H_loss + + + +model = PINNModel().to(device) +epochs = 5000 + +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +dual_opt = ALM(m=1, lr=1e-3, dual_range=(0.,10.), device=device) + +lr_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=100, + T_mult=1, + eta_min=0.5*1e-4 +) + +losses, pde_losses, data_losses_0, bc_losses = [], [], [], [] +H_losses_min, H_losses_max, H_losses_mean, H_losses_std = [], [], [], [] +H_losses_abs_error, H_losses_rel_error = [], [] +t0 = time() + +for epoch in range(epochs): + optimizer.zero_grad() + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) + loss.backward() + optimizer.step() + lr_schedule.step() + + with torch.no_grad(): + losses.append(loss.item()) + pde_losses.append(pde_loss.item()) + data_losses_0.append(data_loss_0.item()) + bc_losses.append(bc_loss.item()) + + H_loss_min = torch.min(H_loss).item() + H_loss_max = torch.max(H_loss).item() + H_losses_min.append(H_loss_min) + H_losses_max.append(H_loss_max) + H_loss_mean = torch.mean(H_loss).item() + H_loss_std = torch.std(H_loss).item() + H_losses_mean.append(H_loss_mean) + H_losses_std.append(H_loss_std) + + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)).reshape(Nt, Nx), u_0_x(x_grid.flatten().reshape(-1, 1)).reshape(Nt, Nx), dx) + Hf = H_loss.detach() + # breakpoint() + H_abs_error = torch.abs(Hf - H0) + H_losses_abs_error.append(torch.max(H_abs_error).item()) + H_rel_error = H_abs_error / (torch.abs(H0) + 1e-16) + if isinstance(H_rel_error, torch.Tensor): + H_rel_error = H_rel_error.item() if H_rel_error.numel() == 1 else H_rel_error.max().item() + H_losses_rel_error.append(H_rel_error) + + if epoch > 1: + # SoftAdaptive weights update + # num1 = tf.math.exp(tf.experimental.numpy.cbrt(pde_losses[-1] - pde_losses[-2])) + # num2 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_0[-1] - data_fitting_losses_0[-2])) + # num3 = tf.math.exp(tf.experimental.numpy.cbrt(data_fitting_losses_l_r[-1] - data_fitting_losses_l_r[-2])) + num1 = np.exp((pde_losses[-1] - pde_losses[-2])) + num2 = np.exp((data_losses_0[-1] - data_losses_0[-2])) + num3 = np.exp((bc_losses[-1] - bc_losses[-2])) + den = num1 + num2 + num3 + + new_lambdas = torch.tensor([num1 / den, num2 / den, num3 / den]) + lambdas = new_lambdas + + if epoch % 100 == 0 or epoch == epochs - 1: + print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.6e}") + +print(f'\nComputation time: {time() - t0:.2f}s') +print(f"Loss type: {loss_type}") +print(f"Hamiltonian mean: {H_loss_mean}") +print(f"Hamiltonian std: {H_loss_std}") +print(f"Hamiltonian max: {H_loss_max}") +print(f"Hamiltonian min: {H_loss_min}") + +plt.figure(figsize=(10, 6)) +plt.semilogy(losses, label='Total Loss') +plt.semilogy(pde_losses, label='PDE Loss') +plt.semilogy(data_losses_0, label='Initial Conditions Loss') +plt.semilogy(bc_losses, label='Periodic Boundary Conditions Loss') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.title('Loss Contributions') +plt.legend() +plt.grid() +plt.savefig('./results/kdv_loss.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_min, label='H_loss_min') +plt.plot(H_losses_max, label='H_loss_max') +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian') +plt.title('Hamiltonian over epochs') +plt.legend() +plt.grid() +plt.savefig('./results/kdv_H_loss.png', dpi=300) if save_fig else None +plt.show() + +H_losses_mean_arr = np.array(H_losses_mean) +H_losses_std_arr = np.array(H_losses_std) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_mean_arr) +plt.fill_between(range(len(H_losses_mean_arr)), H_losses_mean_arr - H_losses_std_arr, H_losses_mean_arr + H_losses_std_arr, alpha=0.2) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian mean') +plt.title('Hamiltonian mean over epochs with standard deviation') +plt.grid() +plt.savefig('./results/kdv_H_loss_mean.png', dpi=300) if save_fig else None +plt.show() + +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_std_arr) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian std') +plt.title('Hamiltonian standard deviation over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_std.png', dpi=300) if save_fig else None +plt.show() + +H_losses_abs_error = np.array(H_losses_abs_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_abs_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian absolute error') +plt.title('Hamiltonian absolute error over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_abs_error.png', dpi=300) if save_fig else None +plt.show() + +H_losses_rel_error = np.array(H_losses_rel_error) +plt.figure(figsize=(10, 6)) +plt.plot(H_losses_rel_error) +plt.xlabel('Epoch') +plt.ylabel('Hamiltonian relative error') +plt.title('Hamiltonian relative error over epochs') +plt.grid() +plt.savefig('./results/kdv_H_loss_rel_error.png', dpi=300) if save_fig else None +plt.show() + +N = 600 +tspace = torch.linspace(0, tMax, N + 1, dtype=DTYPE, device=device) +xspace = torch.linspace(xMin, xMax, N + 1, dtype=DTYPE, device=device) +T_grid, X_grid = torch.meshgrid(tspace, xspace, indexing='ij') +XTgrid = torch.stack([X_grid.flatten(), T_grid.flatten()], dim=1) + +with torch.no_grad(): + u_pred = model(XTgrid) +U = u_pred.reshape(N+1, N+1) + +X_np = X_grid.cpu().numpy() +T_np = T_grid.cpu().numpy() +U_np = U.cpu().numpy() + +from mpl_toolkits.mplot3d import Axes3D +fig = plt.figure(figsize=(9, 6)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X_np, T_np, U_np, cmap='viridis') +ax.set_xlabel('$x$') +ax.set_ylabel('$t$') +ax.set_zlabel('$u(x,t)$') +ax.set_title('KdV equation') +ax.set_box_aspect(None, zoom=0.85) +plt.savefig('./results/kdv_solution.png', dpi=300) if save_fig else None +plt.show() \ No newline at end of file From b6a06d6dabfbde8213f1ee1f57a4d7e78d780ec0 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 26 May 2026 13:44:21 +0200 Subject: [PATCH 02/30] move h0 in camassa --- ...eservingPINN_CamassaHolm_trainableH1_pytorch.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py index 915aedc..e366a3d 100644 --- a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py @@ -152,7 +152,7 @@ def custom_loss(inputs, model, epoch): #### LOSS FUNCTION WITH H1 CONSTRAINT #### -def lagrangian_loss(inputs, model, dual_opt, epoch): +def lagrangian_loss(inputs, model, dual_opt, epoch, H0): x, t = inputs[:, 0:1], inputs[:, 1:2] x.requires_grad_(True) t.requires_grad_(True) @@ -190,12 +190,12 @@ def lagrangian_loss(inputs, model, dual_opt, epoch): ) # constraint - H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) H_constraint = torch.abs(Hf - H0)/torch.abs(H0) - eps = 5/(epoch+1) + eps = 1/(epoch+1) H_constraint = torch.max(H_constraint - eps, torch.zeros_like(H_constraint)).unsqueeze(0) loss = dual_opt.forward_update(loss, H_constraint) @@ -219,13 +219,15 @@ def lagrangian_loss(inputs, model, dual_opt, epoch): # dual_opt = ALM(m=1, lr=5e-5, dual_range=(0.,100.), device=device, ctol=1e-3, penalty=0.) -dual_opt = iALM(m=1, lr=0.1, beta=0.01, sigma=1.0001, gamma=1., dual_range=(0.,10.), ctol=1e-3) +dual_opt = iALM(m=1, beta=0.01, sigma=1.0001, gamma=1., dual_range=(0.,10.), ctol=1e-3) + +H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) for epoch in range(epochs): optimizer.zero_grad() # loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) - loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = lagrangian_loss(inputs, model, dual_opt, epoch) + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = lagrangian_loss(inputs, model, dual_opt, epoch, H0) loss.backward() optimizer.step() @@ -249,7 +251,7 @@ def lagrangian_loss(inputs, model, dual_opt, epoch): H_losses_mean.append(H_loss_mean) H_losses_std.append(H_loss_std) - H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + # H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) Hf = H_loss.detach() H_abs_error = torch.abs(Hf - H0) H_losses_abs_error.append(torch.max(H_abs_error).item()) From 13f1bb96975a1a156229a1568b929c23fab7dfef Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 26 May 2026 13:52:52 +0200 Subject: [PATCH 03/30] docs --- .readthedocs.yaml | 26 ++++++++++++++++++++++++++ docs/Makefile | 20 ++++++++++++++++++++ docs/make.bat | 35 +++++++++++++++++++++++++++++++++++ docs/source/conf.py | 28 ++++++++++++++++++++++++++++ docs/source/index.rst | 17 +++++++++++++++++ 5 files changed, 126 insertions(+) create mode 100644 .readthedocs.yaml create mode 100644 docs/Makefile create mode 100644 docs/make.bat create mode 100644 docs/source/conf.py create mode 100644 docs/source/index.rst diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..74e0648 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,26 @@ + + +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need +build: + os: ubuntu-24.04 + tools: + python: "3.13" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally, but recommended, +# declare the Python requirements required to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +# python: +# install: +# - requirements: docs/requirements.txt + + diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..c1921d4 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,28 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'humancompatible-train' +copyright = '2026, Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' +author = 'Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' +release = '0.3.1' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [] + +templates_path = ['_templates'] +exclude_patterns = [] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'alabaster' +html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..2d3c295 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,17 @@ +.. humancompatible-train documentation master file, created by + sphinx-quickstart on Tue May 26 14:11:39 2026. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +humancompatible-train documentation +=================================== + +Add your content using ``reStructuredText`` syntax. See the +`reStructuredText `_ +documentation for details. + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + From 9f50604e923929fce0b579409762c9b4bfddec04 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 27 May 2026 11:45:56 +0200 Subject: [PATCH 04/30] docs update --- .readthedocs.yaml | 7 +- docs/requirements.txt | 3 + docs/source/conf.py | 35 +- docs/source/examples/learn_DAG.rst | 209 +++++++++++ docs/source/getting_started.rst | 49 +++ docs/source/index.rst | 22 +- docs/source/install.rst | 60 ++++ docs/source/support.rst | 53 +++ docs/source/troubleshooting.rst | 87 +++++ docs/source/tutorials/basic_usage.ipynb | 339 ++++++++++++++++++ .../tutorials/inequality_constraints.rst | 6 + 11 files changed, 860 insertions(+), 10 deletions(-) create mode 100644 docs/requirements.txt create mode 100644 docs/source/examples/learn_DAG.rst create mode 100644 docs/source/getting_started.rst create mode 100644 docs/source/install.rst create mode 100644 docs/source/support.rst create mode 100644 docs/source/troubleshooting.rst create mode 100644 docs/source/tutorials/basic_usage.ipynb create mode 100644 docs/source/tutorials/inequality_constraints.rst diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 74e0648..ea94d92 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -19,8 +19,9 @@ sphinx: # Optionally, but recommended, # declare the Python requirements required to build your documentation # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html -# python: -# install: -# - requirements: docs/requirements.txt +python: + install: + - requirements: docs/requirements.txt + - requirements: requirements.txt diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..81556cc --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,3 @@ +sphinx +myst-nb +. \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index c1921d4..85a9577 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,6 +6,12 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information +extensions = ["myst_nb"] + + +import os + + project = 'humancompatible-train' copyright = '2026, Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' author = 'Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' @@ -14,15 +20,38 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [] templates_path = ['_templates'] exclude_patterns = [] - +nb_toctree = False +nb_number_headings = False +nb_execution_show_tb = False +nb_execution_mode = "cache" +nb_execution_timeout = 60 # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'alabaster' +html_theme = 'furo' + html_static_path = ['_static'] + +html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "/") + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", +] +myst_url_schemes = ("http", "https", "mailto") + +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", + ".ipynb": "myst-nb", +} + +nb_execution_mode = "cache" \ No newline at end of file diff --git a/docs/source/examples/learn_DAG.rst b/docs/source/examples/learn_DAG.rst new file mode 100644 index 0000000..e7d072c --- /dev/null +++ b/docs/source/examples/learn_DAG.rst @@ -0,0 +1,209 @@ +Learning Directed Acyclic Graphs (DAGs) from Data +================================================== + +Overview +-------- + +This example demonstrates how to learn a **Directed Acyclic Graph (DAG)** from data using constrained optimization. We follow an approach inspired by the `Cooper `_ library's DAG learning example. + +In this example, we: + +1. Generate synthetic data from a linear structural equation model +2. Define a constrained optimization problem to recover the underlying graph +3. Use the Augmented Lagrangian Method (ALM) to solve the problem +4. Visualize both the learned and ground truth graphs + +What is a DAG? +-------------- + +A Directed Acyclic Graph (DAG) is a graph where: + +- Nodes represent variables or features +- Directed edges represent causal relationships +- There are no cycles (acyclic property) +- The acyclic property ensures a topological ordering exists + +DAG learning is useful in causal inference, discovering variable dependencies, and understanding structural relationships in data. + +Data Generation +--------------- + +We start by generating synthetic data from a linear structural equation model with Gaussian noise: + +.. code-block:: python + + import torch + import numpy as np + import math + + def generate_data(n, d, n_causes, noise_std, device): + """Generate data from a linear structural equation model with Gaussian noise. + + Args: + n: number of samples + d: number of features + n_causes: number of root nodes (nodes with no parents) + noise_std: standard deviation of the noise + device: torch.device + + Returns: + X: Data matrix of shape (n, d) + A: Adjacency matrix of shape (d, d) + """ + # Generate adjacency matrix + A = torch.zeros(d, d, device=device) + + for i in range(n_causes, d): + # Each node (except roots) has random parents from previous nodes + parents = 0 if i == 1 else torch.randperm(i)[:np.random.randint(1, i)] + A[i, parents] = 1 + + # Verify acyclic property + assert torch.trace(torch.linalg.matrix_exp(A)).item() == d, "A is not a DAG" + + # Generate data: X_i = sum(X_parents_i) + noise_i + noise = noise_std * torch.randn(n, d, device=device) + X = torch.zeros(n, d, device=device) + + for i in range(d): + parents = torch.nonzero(A[i]).flatten() + X[:, i] = X[:, parents].sum(dim=1) + noise[:, i] + + # Improve conditioning + X /= math.sqrt(d) + + return X, A + +**Parameters:** + +- ``n``: Number of samples (5,000 in this example) +- ``d``: Number of features/nodes (8 in this example) +- ``n_causes``: Number of root nodes with no parents (2 in this example) +- ``noise_std``: Standard deviation of Gaussian noise (0.01 in this example) + +Training Setup +-------------- + +We formulate the DAG learning problem as a constrained optimization problem: + +.. math:: + + \min_{A \in \{0, 1\}^{d \times d}} \left\| X - XA \right\|_F^2 + + \text{subject to:} \quad \text{tr}(e^A) = d + +The constraint ensures the adjacency matrix ``A`` represents a valid DAG: + +- The exponential matrix ``exp(A)`` has trace equal to ``d`` if and only if ``A`` is acyclic +- This is an algebraic constraint that replaces the combinatorial acyclicity check + +**Implementation:** + +.. code-block:: python + + from humancompatible.train.dual_optim import ALM + from torch.optim import AdamW + + # Initialize adjacency matrix as a learnable parameter + A = torch.nn.Parameter(torch.randn(D, D, device=DEVICE) / math.sqrt(D)) + + # Optimizer for the primal variable (adjacency matrix) + optimizer = AdamW(params=[A], lr=PRIMAL_LR) + + # Dual optimizer using Augmented Lagrangian Method + dual_opt = ALM(m=1) # m=1 constraint + + # Constraint function + constraint = lambda A: torch.trace(torch.linalg.matrix_exp(A)) - d + +Training Loop +------------- + +The training procedure alternates between: + +1. **Primal step**: Update ``A`` to minimize the Lagrangian +2. **Dual step**: Update Lagrange multipliers to enforce constraint satisfaction + +.. code-block:: python + + for i in range(N_STEPS): + # Project to valid range [0, 1] and remove diagonal + A.data.fill_diagonal_(0) + A.data.clamp_(min=0, max=1.0) + + # Compute loss: reconstruction error + loss = torch.square(torch.linalg.norm(X - X @ A.T, ord="fro")) + + # Compute constraint violation + cviol = constraint(A) + + # Update Lagrangian + lagrangian = dual_opt.forward_update(loss, cviol.unsqueeze(0)) + + # Gradient descent on primal variable + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + +**Key steps:** + +- **Diagonal removal**: No self-loops allowed (``A.fill_diagonal_(0)``) +- **Value clamping**: Adjacency values are bounded to [0, 1] (``A.clamp_(min=0, max=1.0)``) +- **Loss computation**: Measures how well ``A`` predicts the data +- **Constraint enforcement**: The ALM solver tracks dual variables to enforce the acyclicity constraint + +Results and Visualization +-------------------------- + +After training, we can visualize the learned adjacency matrix alongside the ground truth: + +.. code-block:: python + + import networkx as nx + import seaborn as sns + from matplotlib import pyplot as plt + + # Create network graph + G = nx.DiGraph() + G.add_nodes_from(range(D)) + + for i in range(D): + for j in range(D): + if A[i, j] != 0: + G.add_edge(j, i) + + # Visualize + pos = nx.shell_layout(G) + plt.figure(figsize=(5, 2)) + nx.draw(G, pos, with_labels=True, font_weight="bold") + plt.show() + +**Visualization outputs:** + +1. **Adjacency Heatmaps**: Compare learned, ground truth, and difference matrices +2. **Training Progress**: Track loss, constraint violation, and dual parameters over iterations + +The quality of recovery depends on: + +- **Dataset size**: Larger datasets improve recovery +- **Noise level**: Lower noise enables better recovery +- **Training iterations**: More iterations improve convergence +- **Graph density**: Sparser graphs are easier to recover + +Applications +----------- + +DAG learning is useful for: + +- **Causal discovery**: Inferring causal relationships from observational data +- **Biological networks**: Discovering gene regulatory networks +- **Financial systems**: Understanding dependencies between economic indicators +- **Knowledge graphs**: Learning structured relationships from data +- **Feature importance**: Understanding variable interactions + +See Also +-------- + +- :doc:`api_reference` for the ALM solver and optimization utilities +- `Cooper Documentation `_ for more constrained optimization examples +- The full notebook: ``examples/learn_DAG.ipynb`` diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst new file mode 100644 index 0000000..61ce3dd --- /dev/null +++ b/docs/source/getting_started.rst @@ -0,0 +1,49 @@ +Getting Started +=============== + +Quick Start +----------- + +After installing humancompatible-train, you can import it in your Python code: + +.. code-block:: python + + from humancompatible.train.dual_optim import * + +Basic Example +-------------- + +This is an abstract code sample; you can find runnable examples in the :doc:`tutorials/basic_usage` section. + +.. code-block:: python + + from humancompatible.train.dual_optim import ALM + + device = ... + num_constraints = ... + + optimizer = torch.optim.Adam(model.parameters(), ...) + dual_optimizer = ALM(m=num_constraints, ..., device=device) + + for inputs, labels in dataloader: + # evaluate objective + outputs = model(inputs) + loss = criterion(outputs, labels) + # evaluate tensor of constraints + constraints = evaluate_constraints(inputs, labels, ...) + # evaluate lagrangian and update dual variables + lagrangian = dual_optimizer.forward_update(loss, constraints) + # backward pass and step + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + +.. note:: + + For detailed examples (including inequality constraints), see the :doc:`tutorials/basic_usage` and :doc:`tutorials/inequality_constraints` sections. + +Next Steps +---------- + +- Read the :doc:`Basic Usage ` guide for a complete example +- If you encounter issues, visit the :doc:`Troubleshooting ` page diff --git a/docs/source/index.rst b/docs/source/index.rst index 2d3c295..c26c647 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,12 +6,26 @@ humancompatible-train documentation =================================== -Add your content using ``reStructuredText`` syntax. See the -`reStructuredText `_ -documentation for details. +Welcome to the **humancompatible-train** documentation. + +What is **humancompatible-train**? + +**humancompatible-train** is a PyTorch-based package for constrained optimization, aimed at constrained deep learning tasks. +We implement several first-order Lagrangian-based methods for constrained optimization with a PyTorch-based API that allow seamless integration of constraints into the training loop. .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Getting Started + :titlesonly: + + install + getting_started + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials + :titlesonly: + Basic usage: Fairness + Handling inequality constraints diff --git a/docs/source/install.rst b/docs/source/install.rst new file mode 100644 index 0000000..0e0ac8b --- /dev/null +++ b/docs/source/install.rst @@ -0,0 +1,60 @@ +Installation +============ + +Prerequisites +------------- + +- Python 3.8 or higher +- pip (Python package manager) +- Virtual environment (recommended) + +Basic Installation +------------------ + +Install the package using pip: + +.. code-block:: bash + + pip install humancompatible-train + +Installation from Source +------------------------ + +To install the development version from source: + +.. code-block:: bash + + git clone https://github.com/humancompatible-train.git + cd humancompatible-train + pip install -e . + +Using Virtual Environment (Recommended) +---------------------------------------- + +It's recommended to install in a virtual environment: + +.. code-block:: bash + + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install humancompatible-train + +Verifying Installation +---------------------- + +To verify your installation was successful: + +.. code-block:: python + + import humancompatible_train + print(humancompatible_train.__version__) + +Optional Dependencies +--------------------- + +For specific features, you may need additional packages: + +.. code-block:: bash + + pip install humancompatible-train[dev] # Development tools + pip install humancompatible-train[docs] # Documentation building \ No newline at end of file diff --git a/docs/source/support.rst b/docs/source/support.rst new file mode 100644 index 0000000..75c0981 --- /dev/null +++ b/docs/source/support.rst @@ -0,0 +1,53 @@ +Support +======= + +Getting Help +------------ + +If you need help with humancompatible-train, here are the recommended channels: + +GitHub Issues +~~~~~~~~~~~~~ + +For bug reports and feature requests, please open an issue on the GitHub repository: + +https://github.com/humancompatible-train + +When reporting an issue, please include: + +- Description of the problem +- Steps to reproduce the issue +- Python version and environment information +- Relevant code snippets or error messages + +Email Support +~~~~~~~~~~~~~ + +You can contact the maintainers at: + +kliacand@fel.cvut.cz + +Documentation +~~~~~~~~~~~~~ + +- Check the :doc:`getting_started` guide for basic information +- Review the :doc:`examples/basic_usage` for common usage patterns +- Consult the :doc:`troubleshooting` page for known issues and solutions +- See the :doc:`examples/api_reference` for API documentation + +Contributing +~~~~~~~~~~~~ + +We welcome contributions! If you'd like to contribute: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Submit a pull request + +Additional Resources +~~~~~~~~~~~~~~~~~~~~ + +- Project Homepage: https://github.com/humancompatible-train +- Documentation: https://humancompatible-train.readthedocs.io +- Issue Tracker: https://github.com/humancompatible-train/issues \ No newline at end of file diff --git a/docs/source/troubleshooting.rst b/docs/source/troubleshooting.rst new file mode 100644 index 0000000..c622c76 --- /dev/null +++ b/docs/source/troubleshooting.rst @@ -0,0 +1,87 @@ +Troubleshooting +=============== + +Common Issues and Solutions +--------------------------- + +Installation Issues +~~~~~~~~~~~~~~~~~~~ + +**Problem: "ModuleNotFoundError" when importing humancompatible-train** + +Solution: + 1. Verify installation: ``pip list | grep humancompatible-train`` + 2. Reinstall the package: ``pip install --upgrade humancompatible-train`` + 3. Check your Python version: ``python --version`` (requires Python 3.8+) + +**Problem: "Permission denied" during installation** + +Solution: + Use a virtual environment (recommended): + + .. code-block:: bash + + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install humancompatible-train + +Training Issues +~~~~~~~~~~~~~~~ + +**Problem: Constraints are not being satisfied** + +Solutions: + 1. Verify constraint definitions are correct + 2. Check that constraints are compatible with your data + 3. Increase training time or adjust hyperparameters + 4. Review constraint priorities and weights + +**Problem: Out of memory during training** + +Solutions: + 1. Reduce batch size + 2. Use a smaller dataset for testing + 3. Enable gradient checkpointing if available + 4. Consider distributed training + +Performance Issues +~~~~~~~~~~~~~~~~~~ + +**Problem: Training is very slow** + +Solutions: + 1. Profile your code to identify bottlenecks + 2. Use fewer constraints if possible + 3. Optimize your data loading pipeline + 4. Consider using GPU acceleration + 5. Try reducing dataset size for experimentation + +Getting More Help +----------------- + +If you can't find a solution here: + +1. Check the :doc:`support` page for contact information +2. Review the :doc:`API Reference ` for function signatures +3. Open an issue on the GitHub repository +4. Consult the project's issue tracker for similar problems + +FAQ +--- + +**Q: Which Python versions are supported?** + +A: Python 3.8 and higher. We recommend using Python 3.9 or later. + +.. **Q: Can I use my custom model architecture?** + +.. A: Yes, see the :doc:`Advanced Usage ` section for details on custom constraints and models. + +**Q: How do I report a bug?** + +A: Please open an issue on GitHub with: + + - Description of the problem + - Steps to reproduce + - Python version and environment information + - Relevant code snippets or error messages diff --git a/docs/source/tutorials/basic_usage.ipynb b/docs/source/tutorials/basic_usage.ipynb new file mode 100644 index 0000000..204b7d5 --- /dev/null +++ b/docs/source/tutorials/basic_usage.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c6eb906f", + "metadata": {}, + "source": [ + "# Basic Usage\n", + "\n", + "This page provides an overview of using humancompatible-train for constrained deep learning on a simple example." + ] + }, + { + "cell_type": "markdown", + "id": "99f5c31d", + "metadata": {}, + "source": [ + "## Idea" + ] + }, + { + "cell_type": "markdown", + "id": "c44ad563", + "metadata": {}, + "source": [ + "\n", + "The core of the package is formed by Lagrangian-based **dual optimizers**, which are PyTorch Optimizer-like objects that handle the **constrained** part of **constrained deep learning**.\n", + "\n", + "They create, keep track of, and update the **dual parameters** of the constrained minimization problem, as well as calculate the Lagrangian that is then minimized by a standard PyTorch optimizer in place of a loss." + ] + }, + { + "cell_type": "markdown", + "id": "a882bebf", + "metadata": {}, + "source": [ + "Formally, our constrained learning task looks as follows:\n", + "\n", + "$$ \\min_{x\\in\\mathbb{R}^n} \\mathbb{E}[f(x,\\xi)] \\quad \\text{s.t.} \\quad \\mathbb{E}[g(x,\\xi)] \\leq 0 $$\n", + "\n", + "where $ \\xi\\sim\\Xi $ is a random variable, $ f(x,\\xi) $ is the loss function, and $ g(x,\\xi) $ is the constraint function. Of course, the loss and constraint functions need not depend on one random variable; $ \\xi $ can be treated as a concatenation of several of them.\n", + "\n", + "In this notebook, we shall start with the simpler case of **equality constraints**, i.e. $ \\mathbb{E}[g(x,\\xi)] = 0 $; later, are welcome to check out the *using inequality constraints* tutorial.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "caf7c1f1", + "metadata": {}, + "source": [ + "## Simple Example" + ] + }, + { + "cell_type": "markdown", + "id": "0d282892", + "metadata": {}, + "source": [ + "Let us demonstrate using a **fairness-constrained learning** task, where we want to learn a classifier that is accurate but also satisfies a **demographic parity constraint** - i.e. we would like\n", + "\n", + "$$ | P( Y = 1 | \\text{X is Male}) - P ( Y = 1 | \\text{X is Female} ) | = \\epsilon $$\n", + "\n", + "where $ Y $ is the prediction given by our model for sample $ X $, and $ \\epsilon $ is some small threshold.\n", + "\n", + "*Note: Normally, this would be an inequality constraint, but for the sake of this example let us handle this case first.*\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "86f459d9", + "metadata": {}, + "source": [ + "To enforce demographic parity, we will define a **constraint function** (using the [fairret](https://github.com/aida-ugent/fairret) package) that measures the difference in positive prediction rates between two demographic groups.\n", + "\n", + "The **dual optimizer** will then update the Lagrange multipliers to enforce this constraint during training." + ] + }, + { + "cell_type": "markdown", + "id": "ef078073", + "metadata": {}, + "source": [ + "First, let us load and prepare the data. We will use the ACS dataset, containing U.S. Census data, provided by the [folktables](https://github.com/socialfoundations/folktables) package. Feel free to skip this section." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cfe37f7", + "metadata": {}, + "outputs": [], + "source": [ + "# load data\n", + "import torch\n", + "import numpy as np\n", + "from sklearn.preprocessing import StandardScaler\n", + "from folktables import ACSDataSource, ACSIncome, generate_categories\n", + "\n", + "torch.set_default_dtype(torch.float32)\n", + "\n", + "# load folktables data\n", + "data_source = ACSDataSource(survey_year=\"2018\", horizon=\"1-Year\", survey=\"person\")\n", + "acs_data = data_source.get_data(states=[\"FL\"], download=True)\n", + "definition_df = data_source.get_definitions(download=True)\n", + "categories = generate_categories(\n", + " features=ACSIncome.features, definition_df=definition_df\n", + ")\n", + "df_feat, df_labels, _ = ACSIncome.df_to_pandas(\n", + " acs_data, categories=categories, dummies=True\n", + ")\n", + "sens_cols = [\"SEX_Female\", \"SEX_Male\"]\n", + "features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32)\n", + "labels = df_labels.to_numpy(dtype=np.float32)\n", + "# one-hot encoding of the sensitive attribute (gender)\n", + "groups = df_feat[sens_cols].to_numpy(dtype=np.float32)\n", + "\n", + "# standardize features\n", + "scaler = StandardScaler()\n", + "features = scaler.fit_transform(features)\n", + "# convert to torch tensors\n", + "X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups)\n", + "\n", + "dataset_train = torch.utils.data.TensorDataset(X, groups, y)\n", + "loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True)\n", + "criterion = torch.nn.BCEWithLogitsLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "d62fd731", + "metadata": {}, + "source": [ + "Initialize the model and optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "014c8597", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import Sequential\n", + "from torch.optim import AdamW\n", + "\n", + "def setup_model():\n", + "\n", + " model = Sequential(\n", + " torch.nn.Linear(features.shape[1], 64),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(64, 32),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(32, 1),\n", + " )\n", + "\n", + " optimizer = AdamW(model.parameters())\n", + " return model, optimizer" + ] + }, + { + "cell_type": "markdown", + "id": "d189818f", + "metadata": {}, + "source": [ + "Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a3e5e31", + "metadata": {}, + "outputs": [], + "source": [ + "from fairret.statistic import PositiveRate\n", + "\n", + "statistic = PositiveRate()\n", + "\n", + "def pr_diff(logit, groups):\n", + " preds = torch.sigmoid(logit)\n", + " stats = PositiveRate()(preds, groups)\n", + " stat_diff = torch.abs(stats[0] - stats[1])\n", + " return stat_diff" + ] + }, + { + "cell_type": "markdown", + "id": "fad51b80", + "metadata": {}, + "source": [ + "As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e97715f3", + "metadata": {}, + "outputs": [], + "source": [ + "from humancompatible.train.dual_optim import ALM\n", + "\n", + "dual_optimizer = ALM(m=1, lr=0.01)" + ] + }, + { + "cell_type": "markdown", + "id": "9d74bed3", + "metadata": {}, + "source": [ + "Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \\epsilon $ threshold).\n", + "\n", + "Then, the `forward_update` step does two things:\n", + "- Updates the dual variables based on the constraint violation,\n", + "- Calculates the Lagrangian based on loss and constraint violation.\n", + "\n", + "We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07e7f2cd", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10\n", + "\n", + "model, optimizer = setup_model()\n", + "\n", + "for epoch in range(epochs):\n", + " # eval\n", + " model.eval()\n", + " logit = model(X)\n", + " train_loss = criterion(logit, y).item()\n", + " train_fair = pr_diff(logit, groups).item()\n", + " print(f\"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}\")\n", + " \n", + " # train\n", + " model.train()\n", + " for batch_feat, batch_groups, batch_label in loader:\n", + " optimizer.zero_grad()\n", + " logit = model(batch_feat)\n", + " loss = criterion(logit, batch_label)\n", + " \n", + " constraint = pr_diff(logit, batch_groups) - 0.05\n", + " lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0))\n", + " lagr.backward()\n", + " \n", + " optimizer.step()\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "1fc9dd6e", + "metadata": {}, + "source": [ + "Due to noise, it is difficult to obtain the exact correct value, but the method attempts to keep the constraint around 0.05!" + ] + }, + { + "cell_type": "markdown", + "id": "5898dc39", + "metadata": {}, + "source": [ + "Just in case, let's check what happens if we train the model without constraints:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a506bb87", + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer = setup_model()\n", + "\n", + "for epoch in range(epochs):\n", + " # eval\n", + " model.eval()\n", + " logit = model(X)\n", + " train_loss = criterion(logit, y).item()\n", + " train_fair = pr_diff(logit, groups).item()\n", + " print(f\"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}\")\n", + " \n", + " # train\n", + " model.train()\n", + " for batch_feat, batch_groups, batch_label in loader:\n", + " optimizer.zero_grad()\n", + " logit = model(batch_feat)\n", + " loss = criterion(logit, batch_label)\n", + "\n", + " loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "markdown", + "id": "c40628c4", + "metadata": {}, + "source": [ + "The absolute difference in positive rates is two times higher than what we wanted!" + ] + }, + { + "cell_type": "markdown", + "id": "cb735dea", + "metadata": {}, + "source": [ + "Further reading:" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hc-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/inequality_constraints.rst b/docs/source/tutorials/inequality_constraints.rst new file mode 100644 index 0000000..e1fadac --- /dev/null +++ b/docs/source/tutorials/inequality_constraints.rst @@ -0,0 +1,6 @@ +Handling Inequality Constraints +================================================== + +Here, we'll demonstrate how to handle inequality constraints. +TODO. +weight constraint exanmple \ No newline at end of file From a28b8bc6900186be4d04606c416ce0313b49f174 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Mon, 1 Jun 2026 16:31:52 +0200 Subject: [PATCH 05/30] skeleton for other tutorials --- docs/requirements.txt | 3 +- .../tutorials/inequality_constraints.ipynb | 40 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 docs/source/tutorials/inequality_constraints.ipynb diff --git a/docs/requirements.txt b/docs/requirements.txt index 81556cc..5059523 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ sphinx myst-nb -. \ No newline at end of file +. +furo \ No newline at end of file diff --git a/docs/source/tutorials/inequality_constraints.ipynb b/docs/source/tutorials/inequality_constraints.ipynb new file mode 100644 index 0000000..bd85f3d --- /dev/null +++ b/docs/source/tutorials/inequality_constraints.ipynb @@ -0,0 +1,40 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e5f39da5", + "metadata": {}, + "source": [ + "# Handling inequality constraints" + ] + }, + { + "cell_type": "markdown", + "id": "c79e284f", + "metadata": {}, + "source": [ + "This notebook will demonstrate how to work with inequality constraints." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hc-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From dd5bd7d07f5ceac92945285cfadbd28576149c95 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 2 Jun 2026 11:35:47 +0200 Subject: [PATCH 06/30] change dual_range in alm and ialm to allow negative dual values --- docs/source/install.rst | 26 +---- src/humancompatible/train/dual_optim/alm.py | 2 +- src/humancompatible/train/dual_optim/ialm.py | 104 ++++++++++--------- 3 files changed, 58 insertions(+), 74 deletions(-) diff --git a/docs/source/install.rst b/docs/source/install.rst index 0e0ac8b..6e797db 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -4,7 +4,7 @@ Installation Prerequisites ------------- -- Python 3.8 or higher +- Python 3.11 or higher - pip (Python package manager) - Virtual environment (recommended) @@ -28,27 +28,6 @@ To install the development version from source: cd humancompatible-train pip install -e . -Using Virtual Environment (Recommended) ----------------------------------------- - -It's recommended to install in a virtual environment: - -.. code-block:: bash - - python -m venv venv - source venv/bin/activate # On Windows: venv\Scripts\activate - pip install humancompatible-train - -Verifying Installation ----------------------- - -To verify your installation was successful: - -.. code-block:: python - - import humancompatible_train - print(humancompatible_train.__version__) - Optional Dependencies --------------------- @@ -56,5 +35,4 @@ For specific features, you may need additional packages: .. code-block:: bash - pip install humancompatible-train[dev] # Development tools - pip install humancompatible-train[docs] # Documentation building \ No newline at end of file + pip install humancompatible-train[examples] # Example notebooks \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 9cb4a87..41236c5 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -16,7 +16,7 @@ def __init__( init_duals: float | Tensor = None, penalty: float = 1.0, *, - dual_range: Tuple[float, float] = (0.0, 100.0), + dual_range: Tuple[float, float] = (-100.0, 100.0), momentum: float = 0.0, dampening: float = 0.0, ctol: float = 0., diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 82d1965..5d9028c 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -12,16 +12,15 @@ class iALM(Optimizer): def __init__( self, m: int = None, - lr: float = 0.01, + beta: float = 1.0, + sigma: float = 1.0, + gamma: float = 1.0, init_duals: float | Tensor = None, penalty: float = 1.0, *, - dual_range: Tuple[float, float] = (0.0, 100.0), + dual_range: Tuple[float, float] = (-100., 100.), momentum: float = 0.0, dampening: float = 0.0, - beta: float = 1.0, - sigma: float = 1.0, - gamma: float = 1.0, ctol: float = 1e-4, device=None, ) -> None: @@ -30,8 +29,12 @@ def __init__( :param m: Number of constraints (determines the number of dual variables to create) :type m: int - :param lr: Dual variable update rate - :type lr: float + :param beta: Dual variable update rate. + :type beta: float + :param sigma: Multiplier for increasing`beta`. + :type sigma: float + :param gamma: Penalty update parameter. + :type gamma: float :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. :type init_duals: float | Tensor :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` @@ -42,12 +45,6 @@ def __init__( :type momentum: float :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. :type dampening: float - :param beta: Dual variable update rate - :type beta: float - :param sigma: Multiplier for increasing`beta`. - :type sigma: float - :param gamma: Penalty update parameter - :type gamma: float :param ctol: Constraint tolerance; value that allows tiny violations of constraints to account for noise. :type ctol: float """ @@ -57,14 +54,14 @@ def __init__( self.dual_range = dual_range - self.beta = beta + # self.beta = beta self.penalty = penalty - self.gamma = gamma - self.sigma = sigma + # self.gamma = gamma + # self.sigma = sigma self.ctol = ctol duals, defaults = self._init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, device + m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, device ) super().__init__(duals, defaults) @@ -72,7 +69,9 @@ def __init__( @staticmethod def _init_constraint_group( m: int = None, - lr: float = None, + beta: float = None, + sigma: float = None, + gamma: float = None, momentum: float = None, dampening: float = None, init_duals: float | Tensor = None, @@ -98,7 +97,9 @@ def _init_constraint_group( duals = Parameter(init_duals, requires_grad=False) settings_dict = { - "lr": lr, + "beta": Parameter(torch.tensor(beta), requires_grad=False), + "sigma": Parameter(torch.tensor(sigma), requires_grad=False), + "gamma": Parameter(torch.tensor(gamma), requires_grad=False), "momentum": momentum, "dampening": dampening, "momentum_buffer": torch.zeros_like( @@ -121,7 +122,9 @@ def duals(self) -> Tensor: def add_constraint_group( self, m: int = None, - lr: float = None, + beta: float = 1.0, + sigma: float = 1.0, + gamma: float = 1.0, momentum: float = None, dampening: float = None, init_duals: Tensor = None, @@ -131,13 +134,21 @@ def add_constraint_group( :param m: Size of group (number of dual variables to add) :type m: int - :param lr: Dual variable update rate - :type lr: float + :param beta: Dual variable update rate + :type beta: float + :param sigma: Multiplier for increasing `beta` + :type sigma: float + :param gamma: Penalty update parameter + :type gamma: float + :param momentum: Momentum for dual variable updates + :type momentum: float + :param dampening: Dampening for momentum + :type dampening: float :param init_duals: Initial values for the new dual variables :type init_duals: Tensor """ duals, settings_dict = self._init_constraint_group( - m, lr, momentum, dampening, init_duals, self.dual_range + m, beta, sigma, gamma, momentum, dampening, init_duals, self.dual_range, self.device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -156,9 +167,11 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( + duals, beta, _, _, momentum, dampening, momentum_buffer = ( group["params"][0], - group["lr"], + group["beta"], + group["sigma"], + group["gamma"], group["momentum"], group["dampening"], group["momentum_buffer"], @@ -170,7 +183,7 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) - lagrangian.add_(0.5 * self.beta * torch.dot(constraints, constraints)) + lagrangian.add_(0.5 * beta * torch.dot(constraints, constraints)) return lagrangian @@ -182,9 +195,11 @@ def update(self, constraints: Tensor) -> None: :type constraints: Tensor """ for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( + duals, beta, sigma, gamma, momentum, dampening, momentum_buffer = ( group["params"][0], - group["lr"], + group["beta"], + group["sigma"], + group["gamma"], group["momentum"], group["dampening"], group["momentum_buffer"], @@ -196,16 +211,13 @@ def update(self, constraints: Tensor) -> None: _update_duals( duals, group_constraints, - lr, - self.beta, - self.gamma, - momentum, - dampening, + beta, + gamma, momentum_buffer, ) clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - self.beta *= self.sigma + beta.mul_(sigma) # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -222,9 +234,11 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = ( + duals, beta, sigma, gamma, momentum, dampening, momentum_buffer = ( group["params"][0], - group["lr"], + group["beta"], + group["sigma"], + group["gamma"], group["momentum"], group["dampening"], group["momentum_buffer"], @@ -238,12 +252,8 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: ) _update_duals( duals, - group_constraints, - lr, - self.beta, - self.gamma, - momentum, - dampening, + beta, + gamma, momentum_buffer, ) clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) @@ -252,11 +262,11 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian.add_( 0.5 - * self.beta + * beta * torch.dot(constraints - self.ctol, constraints - self.ctol) ) - self.beta *= self.sigma + beta.mul_(sigma) return lagrangian @@ -296,14 +306,10 @@ def _update_c_buffers( def _update_duals( duals: Tensor, - constraints: Tensor, - lr: float, beta: float, gamma: float, - momentum: float, - dampening: float, buffer: Tensor, ) -> None: - update_mult = min(beta, gamma / (buffer @ buffer)) + update_mult = torch.min(beta, gamma / (buffer @ buffer)) duals.add_(buffer, alpha=update_mult) From 708a36976685b110d6d59620d132c3b6fb3127e8 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 2 Jun 2026 12:36:08 +0200 Subject: [PATCH 07/30] further docs --- docs/source/index.rst | 2 + docs/source/tutorials/basic_usage.ipynb | 42 ++++++++++++++-- docs/source/tutorials/copt_overview.rst | 65 +++++++++++++++++++++++++ docs/source/tutorials/tips.rst | 4 ++ 4 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 docs/source/tutorials/copt_overview.rst create mode 100644 docs/source/tutorials/tips.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index c26c647..907cadf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,5 +27,7 @@ We implement several first-order Lagrangian-based methods for constrained optimi :caption: Tutorials :titlesonly: + Constrained Optimization Overview Basic usage: Fairness Handling inequality constraints + Tips and Tricks \ No newline at end of file diff --git a/docs/source/tutorials/basic_usage.ipynb b/docs/source/tutorials/basic_usage.ipynb index 204b7d5..a9da905 100644 --- a/docs/source/tutorials/basic_usage.ipynb +++ b/docs/source/tutorials/basic_usage.ipynb @@ -40,7 +40,7 @@ "\n", "where $ \\xi\\sim\\Xi $ is a random variable, $ f(x,\\xi) $ is the loss function, and $ g(x,\\xi) $ is the constraint function. Of course, the loss and constraint functions need not depend on one random variable; $ \\xi $ can be treated as a concatenation of several of them.\n", "\n", - "In this notebook, we shall start with the simpler case of **equality constraints**, i.e. $ \\mathbb{E}[g(x,\\xi)] = 0 $; later, are welcome to check out the *using inequality constraints* tutorial.\n", + "In this notebook, we shall start with the simpler case of **equality constraints**, i.e. $ \\mathbb{E}[g(x,\\xi)] = 0 $. You are also welcome to check out the c tutorial.\n", "\n", "---" ] @@ -204,7 +204,7 @@ "source": [ "from humancompatible.train.dual_optim import ALM\n", "\n", - "dual_optimizer = ALM(m=1, lr=0.01)" + "dual_optimizer = ALM(m=1, lr=0.01, dual_range=(-100, 100), init_duals=0.)" ] }, { @@ -226,7 +226,24 @@ "execution_count": null, "id": "07e7f2cd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0, loss: 0.7124184966087341, constraint: 0.0017570257186889648\n", + "Epoch: 1, loss: 0.39874470233917236, constraint: 0.06620797514915466\n", + "Epoch: 2, loss: 0.39147812128067017, constraint: 0.056648969650268555\n", + "Epoch: 3, loss: 0.3884408175945282, constraint: 0.04519534111022949\n", + "Epoch: 4, loss: 0.38256096839904785, constraint: 0.046931684017181396\n", + "Epoch: 5, loss: 0.37441322207450867, constraint: 0.04030010104179382\n", + "Epoch: 6, loss: 0.3695501685142517, constraint: 0.036804407835006714\n", + "Epoch: 7, loss: 0.35965678095817566, constraint: 0.03822091221809387\n", + "Epoch: 8, loss: 0.354665070772171, constraint: 0.03923347592353821\n", + "Epoch: 9, loss: 0.34533509612083435, constraint: 0.03696507215499878\n" + ] + } + ], "source": [ "epochs = 10\n", "\n", @@ -276,7 +293,24 @@ "execution_count": null, "id": "a506bb87", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0, loss: 0.6904992461204529, constraint: 0.0005057156085968018\n", + "Epoch: 1, loss: 0.39664506912231445, constraint: 0.08427131175994873\n", + "Epoch: 2, loss: 0.38788166642189026, constraint: 0.09328612685203552\n", + "Epoch: 3, loss: 0.3807773292064667, constraint: 0.0948910117149353\n", + "Epoch: 4, loss: 0.3751002252101898, constraint: 0.09216496348381042\n", + "Epoch: 5, loss: 0.3642417788505554, constraint: 0.09508025646209717\n", + "Epoch: 6, loss: 0.35605600476264954, constraint: 0.09377643465995789\n", + "Epoch: 7, loss: 0.3492068350315094, constraint: 0.09720361232757568\n", + "Epoch: 8, loss: 0.3390880823135376, constraint: 0.09501060843467712\n", + "Epoch: 9, loss: 0.32931336760520935, constraint: 0.09538483619689941\n" + ] + } + ], "source": [ "model, optimizer = setup_model()\n", "\n", diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst new file mode 100644 index 0000000..a2091ae --- /dev/null +++ b/docs/source/tutorials/copt_overview.rst @@ -0,0 +1,65 @@ +Constrained Optimization Overview +================================= + +This tutorial provides an overview of constrained optimization problems, and how this relates to Deep Learning. We will cover problem formulation, .... + + +Formulation +--------------- + +In `humancompatible-train`, and in Constrained Machine Learning more generally, we are interested in solving problems of the form: + +.. math:: + \min_{x\in\mathbb{R}^n} \quad & \mathbb{E}[f(x,\xi)] \\ + \text{s.t.} \quad & \mathbb{E}[g(x,\xi)] \leq 0, \\ + & \mathbb{E}[h(x,\xi)] = 0, \\ + +where :math:`f` is the **objective function** we want to minimize, :math:`g` are the **inequality constraints**, and :math:`h` are the **equality constraints**. The expectation is taken over some random variable :math:`\xi`, which represents the data. + +You may recognize the first line of the above formula as the standard formulation of an optimization problem in machine learning, where we want to minimize the expected loss over the data. \ +The second line introduces a constraint -- this could be anything from some bound on the parameters of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. + + +.. note:: + - As is standard in the field, we adopt the convention of writing the constraints as :math:`g(x) \leq 0`, and :math:`h(x) = 0`. This is just a notational choice, and does not affect the generality of the formulation. It is trivial to transform :math:`g(x) \geq 0` into :math:`-g(x) \leq 0`, or :math:`g(x) \leq \epsilon` into :math:`g(x) - \epsilon \leq 0` for some non-zero bound. + - It is also easy to switch between equality and inequality constraints: to achieve :math:`g(x) = 0`, one can set :math:`-g(x) \leq 0` and :math:`g(x) \leq 0` simultaneously. In fact, different algorithms are designed to handle either equality or inequality constraints natively, but, again, it is trivial to switch between the two. We shall see more concrete examples later on. + + +Solving Constrained Problems +-------------------------------- + +We all know how to solve an unconstrained optimization problem -- we can use gradient descent, or any of its variants. But how do we solve a constrained optimization problem? +The Constrained Machine Learning field, including us, seems to have converged on **Lagrangian-based methods**, which utilize the Lagrangian function to transform the **constrained** problem into an **unconstrained** one. + +Going forward in this tutorial, we will focus on the **deterministic case** to simplify notation; the stochastic case is more complex, but utilizes the same principles (imagine Gradient Descent vs. SGD). For more rigorous mathematical treatment of the stochastic case, see **TODO**, as well as the references included in the documentation for each of the algorithms in the package. + +In a deterministic case, the Lagrangian function is defined as follows: + +.. math:: + \mathcal{L}(x, \lambda, \mu) = f(x) + \lambda^T g(x) + \mu^T h(x) + +where :math:`\lambda` is the Lagrange multiplier associated with the constraint :math:`g(x) \leq 0`, and :math:`\mu` is the Lagrange multiplier associated with the constraint :math:`h(x) = 0`. + +It is then possible to show that the original constrained optimization problem is equivalent to the following unconstrained optimization problem: + +.. math:: + \min_{x\in\mathbb{R}^n} \max_{\lambda \geq 0, \mu} \mathcal{L}(x, \lambda, \mu) + + +We refer to the original problem as the **primal problem**, with :math:`x` as the **primal variables**, and to the transformed problem as the **dual problem**, with :math:`\lambda` and :math:`\mu` as the **dual variables**. The dual problem is unconstrained, and can be solved using a clever application of standard optimization techniques. + +In particular, the most common approach is to use **alternating optimization**: we fix the primal variables, and optimize the dual variables using gradient ascent; then we fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. + +In `humancompatible-train`, we implement several variants of this approach: the Augmented Lagrangian Method (ALM), the Inexact Augmented Lagrangian Method (iALM), and the Penalty-Barrier Method (PBM). For more details, see the corresponding documentation; for now, it is important to understand that they are all based on the same principle of alternating optimization of the primal and dual variables. + +In the simplest case of the Lagrangian method, this gives us the following update rules: + +.. math:: + \lambda_{t+1} & = \lambda_t + \beta \nabla_\lambda \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \lambda_t + \beta g(x_{t}) \\ + \mu_{t+1} & = \mu_t + \gamma \nabla_\mu \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \mu_t + \gamma h(x_{t}) \\ + x_{t+1} & = x_t - \alpha \nabla_x \mathcal{L}(x_t, \lambda_t, \mu_t) + +where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are the learning rates for the primal and dual variables, respectively. + +.. note:: + - The above update rules are for the simplest variant of the Lagrangian method. The methods implemented in this package are all more complex. Even beyond our implementation, one can (and sometimes should!) modify the update rules by e.g. tweaking the training loop code, as we show in the :doc:`tips` tutorial. \ No newline at end of file diff --git a/docs/source/tutorials/tips.rst b/docs/source/tutorials/tips.rst new file mode 100644 index 0000000..a514784 --- /dev/null +++ b/docs/source/tutorials/tips.rst @@ -0,0 +1,4 @@ +Tips and Tricks +================================================== + +Here, we discuss some miscellaneous tricks you can use to improve your experience. \ No newline at end of file From 5ecd9f6c9fb9b7af0a8ca49ffdbe52c240e86edf Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 2 Jun 2026 16:23:39 +0200 Subject: [PATCH 08/30] move doc notebooks to myst --- docs/source/conf.py | 2 +- docs/source/index.rst | 2 +- docs/source/tutorials/basic_usage.ipynb | 373 ------------------ docs/source/tutorials/basic_usage.md | 210 ++++++++++ docs/source/tutorials/copt_overview.rst | 11 +- .../tutorials/inequality_constraints.ipynb | 40 -- .../tutorials/inequality_constraints.md | 144 +++++++ docs/source/tutorials/tips.rst | 9 +- 8 files changed, 371 insertions(+), 420 deletions(-) delete mode 100644 docs/source/tutorials/basic_usage.ipynb create mode 100644 docs/source/tutorials/basic_usage.md delete mode 100644 docs/source/tutorials/inequality_constraints.ipynb create mode 100644 docs/source/tutorials/inequality_constraints.md diff --git a/docs/source/conf.py b/docs/source/conf.py index 85a9577..705d2c4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -50,7 +50,7 @@ source_suffix = { ".rst": "restructuredtext", - ".md": "markdown", + ".md": "myst-nb", ".ipynb": "myst-nb", } diff --git a/docs/source/index.rst b/docs/source/index.rst index 907cadf..7c32d9c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,4 +30,4 @@ We implement several first-order Lagrangian-based methods for constrained optimi Constrained Optimization Overview Basic usage: Fairness Handling inequality constraints - Tips and Tricks \ No newline at end of file + Tips and Tricks \ No newline at end of file diff --git a/docs/source/tutorials/basic_usage.ipynb b/docs/source/tutorials/basic_usage.ipynb deleted file mode 100644 index a9da905..0000000 --- a/docs/source/tutorials/basic_usage.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c6eb906f", - "metadata": {}, - "source": [ - "# Basic Usage\n", - "\n", - "This page provides an overview of using humancompatible-train for constrained deep learning on a simple example." - ] - }, - { - "cell_type": "markdown", - "id": "99f5c31d", - "metadata": {}, - "source": [ - "## Idea" - ] - }, - { - "cell_type": "markdown", - "id": "c44ad563", - "metadata": {}, - "source": [ - "\n", - "The core of the package is formed by Lagrangian-based **dual optimizers**, which are PyTorch Optimizer-like objects that handle the **constrained** part of **constrained deep learning**.\n", - "\n", - "They create, keep track of, and update the **dual parameters** of the constrained minimization problem, as well as calculate the Lagrangian that is then minimized by a standard PyTorch optimizer in place of a loss." - ] - }, - { - "cell_type": "markdown", - "id": "a882bebf", - "metadata": {}, - "source": [ - "Formally, our constrained learning task looks as follows:\n", - "\n", - "$$ \\min_{x\\in\\mathbb{R}^n} \\mathbb{E}[f(x,\\xi)] \\quad \\text{s.t.} \\quad \\mathbb{E}[g(x,\\xi)] \\leq 0 $$\n", - "\n", - "where $ \\xi\\sim\\Xi $ is a random variable, $ f(x,\\xi) $ is the loss function, and $ g(x,\\xi) $ is the constraint function. Of course, the loss and constraint functions need not depend on one random variable; $ \\xi $ can be treated as a concatenation of several of them.\n", - "\n", - "In this notebook, we shall start with the simpler case of **equality constraints**, i.e. $ \\mathbb{E}[g(x,\\xi)] = 0 $. You are also welcome to check out the c tutorial.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "id": "caf7c1f1", - "metadata": {}, - "source": [ - "## Simple Example" - ] - }, - { - "cell_type": "markdown", - "id": "0d282892", - "metadata": {}, - "source": [ - "Let us demonstrate using a **fairness-constrained learning** task, where we want to learn a classifier that is accurate but also satisfies a **demographic parity constraint** - i.e. we would like\n", - "\n", - "$$ | P( Y = 1 | \\text{X is Male}) - P ( Y = 1 | \\text{X is Female} ) | = \\epsilon $$\n", - "\n", - "where $ Y $ is the prediction given by our model for sample $ X $, and $ \\epsilon $ is some small threshold.\n", - "\n", - "*Note: Normally, this would be an inequality constraint, but for the sake of this example let us handle this case first.*\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "id": "86f459d9", - "metadata": {}, - "source": [ - "To enforce demographic parity, we will define a **constraint function** (using the [fairret](https://github.com/aida-ugent/fairret) package) that measures the difference in positive prediction rates between two demographic groups.\n", - "\n", - "The **dual optimizer** will then update the Lagrange multipliers to enforce this constraint during training." - ] - }, - { - "cell_type": "markdown", - "id": "ef078073", - "metadata": {}, - "source": [ - "First, let us load and prepare the data. We will use the ACS dataset, containing U.S. Census data, provided by the [folktables](https://github.com/socialfoundations/folktables) package. Feel free to skip this section." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7cfe37f7", - "metadata": {}, - "outputs": [], - "source": [ - "# load data\n", - "import torch\n", - "import numpy as np\n", - "from sklearn.preprocessing import StandardScaler\n", - "from folktables import ACSDataSource, ACSIncome, generate_categories\n", - "\n", - "torch.set_default_dtype(torch.float32)\n", - "\n", - "# load folktables data\n", - "data_source = ACSDataSource(survey_year=\"2018\", horizon=\"1-Year\", survey=\"person\")\n", - "acs_data = data_source.get_data(states=[\"FL\"], download=True)\n", - "definition_df = data_source.get_definitions(download=True)\n", - "categories = generate_categories(\n", - " features=ACSIncome.features, definition_df=definition_df\n", - ")\n", - "df_feat, df_labels, _ = ACSIncome.df_to_pandas(\n", - " acs_data, categories=categories, dummies=True\n", - ")\n", - "sens_cols = [\"SEX_Female\", \"SEX_Male\"]\n", - "features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32)\n", - "labels = df_labels.to_numpy(dtype=np.float32)\n", - "# one-hot encoding of the sensitive attribute (gender)\n", - "groups = df_feat[sens_cols].to_numpy(dtype=np.float32)\n", - "\n", - "# standardize features\n", - "scaler = StandardScaler()\n", - "features = scaler.fit_transform(features)\n", - "# convert to torch tensors\n", - "X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups)\n", - "\n", - "dataset_train = torch.utils.data.TensorDataset(X, groups, y)\n", - "loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True)\n", - "criterion = torch.nn.BCEWithLogitsLoss()" - ] - }, - { - "cell_type": "markdown", - "id": "d62fd731", - "metadata": {}, - "source": [ - "Initialize the model and optimizer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "014c8597", - "metadata": {}, - "outputs": [], - "source": [ - "from torch.nn import Sequential\n", - "from torch.optim import AdamW\n", - "\n", - "def setup_model():\n", - "\n", - " model = Sequential(\n", - " torch.nn.Linear(features.shape[1], 64),\n", - " torch.nn.ReLU(),\n", - " torch.nn.Linear(64, 32),\n", - " torch.nn.ReLU(),\n", - " torch.nn.Linear(32, 1),\n", - " )\n", - "\n", - " optimizer = AdamW(model.parameters())\n", - " return model, optimizer" - ] - }, - { - "cell_type": "markdown", - "id": "d189818f", - "metadata": {}, - "source": [ - "Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a3e5e31", - "metadata": {}, - "outputs": [], - "source": [ - "from fairret.statistic import PositiveRate\n", - "\n", - "statistic = PositiveRate()\n", - "\n", - "def pr_diff(logit, groups):\n", - " preds = torch.sigmoid(logit)\n", - " stats = PositiveRate()(preds, groups)\n", - " stat_diff = torch.abs(stats[0] - stats[1])\n", - " return stat_diff" - ] - }, - { - "cell_type": "markdown", - "id": "fad51b80", - "metadata": {}, - "source": [ - "As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e97715f3", - "metadata": {}, - "outputs": [], - "source": [ - "from humancompatible.train.dual_optim import ALM\n", - "\n", - "dual_optimizer = ALM(m=1, lr=0.01, dual_range=(-100, 100), init_duals=0.)" - ] - }, - { - "cell_type": "markdown", - "id": "9d74bed3", - "metadata": {}, - "source": [ - "Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \\epsilon $ threshold).\n", - "\n", - "Then, the `forward_update` step does two things:\n", - "- Updates the dual variables based on the constraint violation,\n", - "- Calculates the Lagrangian based on loss and constraint violation.\n", - "\n", - "We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07e7f2cd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, loss: 0.7124184966087341, constraint: 0.0017570257186889648\n", - "Epoch: 1, loss: 0.39874470233917236, constraint: 0.06620797514915466\n", - "Epoch: 2, loss: 0.39147812128067017, constraint: 0.056648969650268555\n", - "Epoch: 3, loss: 0.3884408175945282, constraint: 0.04519534111022949\n", - "Epoch: 4, loss: 0.38256096839904785, constraint: 0.046931684017181396\n", - "Epoch: 5, loss: 0.37441322207450867, constraint: 0.04030010104179382\n", - "Epoch: 6, loss: 0.3695501685142517, constraint: 0.036804407835006714\n", - "Epoch: 7, loss: 0.35965678095817566, constraint: 0.03822091221809387\n", - "Epoch: 8, loss: 0.354665070772171, constraint: 0.03923347592353821\n", - "Epoch: 9, loss: 0.34533509612083435, constraint: 0.03696507215499878\n" - ] - } - ], - "source": [ - "epochs = 10\n", - "\n", - "model, optimizer = setup_model()\n", - "\n", - "for epoch in range(epochs):\n", - " # eval\n", - " model.eval()\n", - " logit = model(X)\n", - " train_loss = criterion(logit, y).item()\n", - " train_fair = pr_diff(logit, groups).item()\n", - " print(f\"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}\")\n", - " \n", - " # train\n", - " model.train()\n", - " for batch_feat, batch_groups, batch_label in loader:\n", - " optimizer.zero_grad()\n", - " logit = model(batch_feat)\n", - " loss = criterion(logit, batch_label)\n", - " \n", - " constraint = pr_diff(logit, batch_groups) - 0.05\n", - " lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0))\n", - " lagr.backward()\n", - " \n", - " optimizer.step()\n", - " " - ] - }, - { - "cell_type": "markdown", - "id": "1fc9dd6e", - "metadata": {}, - "source": [ - "Due to noise, it is difficult to obtain the exact correct value, but the method attempts to keep the constraint around 0.05!" - ] - }, - { - "cell_type": "markdown", - "id": "5898dc39", - "metadata": {}, - "source": [ - "Just in case, let's check what happens if we train the model without constraints:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a506bb87", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, loss: 0.6904992461204529, constraint: 0.0005057156085968018\n", - "Epoch: 1, loss: 0.39664506912231445, constraint: 0.08427131175994873\n", - "Epoch: 2, loss: 0.38788166642189026, constraint: 0.09328612685203552\n", - "Epoch: 3, loss: 0.3807773292064667, constraint: 0.0948910117149353\n", - "Epoch: 4, loss: 0.3751002252101898, constraint: 0.09216496348381042\n", - "Epoch: 5, loss: 0.3642417788505554, constraint: 0.09508025646209717\n", - "Epoch: 6, loss: 0.35605600476264954, constraint: 0.09377643465995789\n", - "Epoch: 7, loss: 0.3492068350315094, constraint: 0.09720361232757568\n", - "Epoch: 8, loss: 0.3390880823135376, constraint: 0.09501060843467712\n", - "Epoch: 9, loss: 0.32931336760520935, constraint: 0.09538483619689941\n" - ] - } - ], - "source": [ - "model, optimizer = setup_model()\n", - "\n", - "for epoch in range(epochs):\n", - " # eval\n", - " model.eval()\n", - " logit = model(X)\n", - " train_loss = criterion(logit, y).item()\n", - " train_fair = pr_diff(logit, groups).item()\n", - " print(f\"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}\")\n", - " \n", - " # train\n", - " model.train()\n", - " for batch_feat, batch_groups, batch_label in loader:\n", - " optimizer.zero_grad()\n", - " logit = model(batch_feat)\n", - " loss = criterion(logit, batch_label)\n", - "\n", - " loss.backward()\n", - " optimizer.step()" - ] - }, - { - "cell_type": "markdown", - "id": "c40628c4", - "metadata": {}, - "source": [ - "The absolute difference in positive rates is two times higher than what we wanted!" - ] - }, - { - "cell_type": "markdown", - "id": "cb735dea", - "metadata": {}, - "source": [ - "Further reading:" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hc-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md new file mode 100644 index 0000000..f1292ff --- /dev/null +++ b/docs/source/tutorials/basic_usage.md @@ -0,0 +1,210 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.19.3 +kernelspec: + display_name: hc-dev + language: python + name: python3 +--- + +# Basic Usage + +This page provides an overview of using humancompatible-train for constrained deep learning on a simple example. + ++++ + +## Idea + ++++ + +The core of the package is formed by Lagrangian-based **dual optimizers**, which are PyTorch Optimizer-like objects that handle the **constrained** part of **constrained deep learning**. + +They create, keep track of, and update the **dual parameters** of the constrained minimization problem, as well as calculate the Lagrangian that is then minimized by a standard PyTorch optimizer in place of a loss. + ++++ + +## Simple Example + ++++ + +Let us demonstrate using a **fairness-constrained learning** task, where we want to learn a classifier that is accurate but also satisfies a **demographic parity constraint** - i.e. we would like + +$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | = \epsilon $$ + +where $ Y $ is the prediction given by our model for sample $ X $, and $ \epsilon $ is some small threshold. + +:::{note} +Note: Normally, this would be an inequality constraint, but for the sake of this example let us handle this case first. +::: +--- + ++++ + +To enforce demographic parity, we will define a **constraint function** (using the [fairret](https://github.com/aida-ugent/fairret) package) that measures the difference in positive prediction rates between two demographic groups. + +The **dual optimizer** will then update the Lagrange multipliers to enforce this constraint during training. + ++++ + +First, let us load and prepare the data. We will use the ACS dataset, containing U.S. Census data, provided by the [folktables](https://github.com/socialfoundations/folktables) package. Feel free to skip this section. + +```{code-cell} ipython3 +--- +tags: [hide-cell] +--- +# load data +import torch +import numpy as np +from sklearn.preprocessing import StandardScaler +from folktables import ACSDataSource, ACSIncome, generate_categories + +torch.set_default_dtype(torch.float32) + +# load folktables data +data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") +acs_data = data_source.get_data(states=["FL"], download=True) +definition_df = data_source.get_definitions(download=True) +categories = generate_categories( + features=ACSIncome.features, definition_df=definition_df +) +df_feat, df_labels, _ = ACSIncome.df_to_pandas( + acs_data, categories=categories, dummies=True +) +sens_cols = ["SEX_Female", "SEX_Male"] +features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32) +labels = df_labels.to_numpy(dtype=np.float32) +# one-hot encoding of the sensitive attribute (gender) +groups = df_feat[sens_cols].to_numpy(dtype=np.float32) + +# standardize features +scaler = StandardScaler() +features = scaler.fit_transform(features) +# convert to torch tensors +X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups) + +dataset_train = torch.utils.data.TensorDataset(X, groups, y) +loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True) +criterion = torch.nn.BCEWithLogitsLoss() +``` + +Initialize the model and optimizer. + +```{code-cell} ipython3 +from torch.nn import Sequential +from torch.optim import AdamW + +def setup_model(): + + model = Sequential( + torch.nn.Linear(features.shape[1], 64), + torch.nn.ReLU(), + torch.nn.Linear(64, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, 1), + ) + model.forward(torch.zeros(features.shape[1])).backward() # dummy forward/backward pass to construct torch graph for fair comparison + optimizer = AdamW(model.parameters()) + return model, optimizer +``` + +Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. + +```{code-cell} ipython3 +from fairret.statistic import PositiveRate + +statistic = PositiveRate() + +def pr_diff(logit, groups): + preds = torch.sigmoid(logit) + stats = PositiveRate()(preds, groups) + stat_diff = torch.abs(stats[0] - stats[1]) + return stat_diff +``` + +As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables. + +```{code-cell} ipython3 +from humancompatible.train.dual_optim import ALM + +dual_optimizer = ALM(m=1, lr=0.01) +``` + +Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \epsilon $ threshold). + +Then, the `forward_update` step does two things: +- Updates the dual variables based on the constraint violation, +- Calculates the Lagrangian based on loss and constraint violation. + +We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer. + +```{code-cell} ipython3 +model, optimizer = setup_model() +epochs = 10 +``` + +```{code-cell} ipython3 +for epoch in range(epochs): + # eval + model.eval() + logit = model(X) + train_loss = criterion(logit, y).item() + train_fair = pr_diff(logit, groups).item() + print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") + + # train + model.train() + for batch_feat, batch_groups, batch_label in loader: + optimizer.zero_grad() + logit = model(batch_feat) + loss = criterion(logit, batch_label) + + constraint = pr_diff(logit, batch_groups) - 0.05 + lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0)) + lagr.backward() + + optimizer.step() +``` + +Due to noise, it is difficult to obtain the exact correct value, but the method attempts to keep the constraint around 0.05! + ++++ + +Just in case, let's check what happens if we train the model without constraints: + +```{code-cell} ipython3 +model, optimizer = setup_model() +``` + +```{code-cell} ipython3 +for epoch in range(epochs): + # eval + model.eval() + logit = model(X) + train_loss = criterion(logit, y).item() + train_fair = pr_diff(logit, groups).item() + print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") + + # train + model.train() + for batch_feat, batch_groups, batch_label in loader: + optimizer.zero_grad() + logit = model(batch_feat) + loss = criterion(logit, batch_label) + + loss.backward() + optimizer.step() +``` + +The absolute difference in positive rates is two times higher than what we wanted! + ++++ + +Further reading: + +```{code-cell} ipython3 + +``` diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst index a2091ae..a0bd734 100644 --- a/docs/source/tutorials/copt_overview.rst +++ b/docs/source/tutorials/copt_overview.rst @@ -17,7 +17,7 @@ In `humancompatible-train`, and in Constrained Machine Learning more generally, where :math:`f` is the **objective function** we want to minimize, :math:`g` are the **inequality constraints**, and :math:`h` are the **equality constraints**. The expectation is taken over some random variable :math:`\xi`, which represents the data. You may recognize the first line of the above formula as the standard formulation of an optimization problem in machine learning, where we want to minimize the expected loss over the data. \ -The second line introduces a constraint -- this could be anything from some bound on the parameters of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. +We then introduce constraints -- they could express anything from some bounds on the parameters of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. .. note:: @@ -48,9 +48,9 @@ It is then possible to show that the original constrained optimization problem i We refer to the original problem as the **primal problem**, with :math:`x` as the **primal variables**, and to the transformed problem as the **dual problem**, with :math:`\lambda` and :math:`\mu` as the **dual variables**. The dual problem is unconstrained, and can be solved using a clever application of standard optimization techniques. -In particular, the most common approach is to use **alternating optimization**: we fix the primal variables, and optimize the dual variables using gradient ascent; then we fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. +In particular, the most common approach is to use **alternating updates**: we fix the primal variables, and optimize the dual variables using gradient ascent; then we fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. -In `humancompatible-train`, we implement several variants of this approach: the Augmented Lagrangian Method (ALM), the Inexact Augmented Lagrangian Method (iALM), and the Penalty-Barrier Method (PBM). For more details, see the corresponding documentation; for now, it is important to understand that they are all based on the same principle of alternating optimization of the primal and dual variables. +In `humancompatible-train`, we implement several variants of this approach, based on methods present in the literature. For more details, see the corresponding documentation; for now, it is important to understand that they are all based on the same principle of alternating updates to the primal and dual variables. In the simplest case of the Lagrangian method, this gives us the following update rules: @@ -62,4 +62,7 @@ In the simplest case of the Lagrangian method, this gives us the following updat where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are the learning rates for the primal and dual variables, respectively. .. note:: - - The above update rules are for the simplest variant of the Lagrangian method. The methods implemented in this package are all more complex. Even beyond our implementation, one can (and sometimes should!) modify the update rules by e.g. tweaking the training loop code, as we show in the :doc:`tips` tutorial. \ No newline at end of file + - The above update rules are for the simplest variant of the Lagrangian method. The methods implemented in this package are all more complex. Even beyond our implementation, one can (and sometimes should!) modify the update rules by e.g. tweaking the training loop code, as we show in the :doc:`tips` tutorial. + - The above update rules are for the deterministic case. In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but we have some tricks up our sleeves, such as momentum, LR scheduling, and so on. + +In our package, the `dual optimizers` handle the updates to the dual variables, while the primal updates are handled by the standard PyTorch optimizers. This allows for seamless integration of constraints into the training loop, as we will see in the next tutorial. diff --git a/docs/source/tutorials/inequality_constraints.ipynb b/docs/source/tutorials/inequality_constraints.ipynb deleted file mode 100644 index bd85f3d..0000000 --- a/docs/source/tutorials/inequality_constraints.ipynb +++ /dev/null @@ -1,40 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "e5f39da5", - "metadata": {}, - "source": [ - "# Handling inequality constraints" - ] - }, - { - "cell_type": "markdown", - "id": "c79e284f", - "metadata": {}, - "source": [ - "This notebook will demonstrate how to work with inequality constraints." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hc-dev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorials/inequality_constraints.md b/docs/source/tutorials/inequality_constraints.md new file mode 100644 index 0000000..50214a9 --- /dev/null +++ b/docs/source/tutorials/inequality_constraints.md @@ -0,0 +1,144 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.19.3 +kernelspec: + display_name: hc-dev + language: python + name: python3 +--- + +# Inequality Constraints + +In the [Basic Usage tutorial](#basic_usage), we learned how to set up the constrained optimization problem with **equality constraints**. Here, we will generalize this to **inequality constraints**. For simplicity, we will keep the same setup, but our constraint will look like: + +$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | \leq \epsilon $$ + +where $ Y $ is the prediction given by our model for sample $ X $, and $ \epsilon $ is some small threshold. + ++++ + +Prepare [folktables](https://github.com/socialfoundations/folktables) data: + +```{code-cell} ipython3 +--- +tags: [hide-cell] +--- +# load data +import torch +import numpy as np +from sklearn.preprocessing import StandardScaler +from folktables import ACSDataSource, ACSIncome, generate_categories + +torch.set_default_dtype(torch.float32) + +# load folktables data +data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") +acs_data = data_source.get_data(states=["FL"], download=True) +definition_df = data_source.get_definitions(download=True) +categories = generate_categories( + features=ACSIncome.features, definition_df=definition_df +) +df_feat, df_labels, _ = ACSIncome.df_to_pandas( + acs_data, categories=categories, dummies=True +) +sens_cols = ["SEX_Female", "SEX_Male"] +features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32) +labels = df_labels.to_numpy(dtype=np.float32) +# one-hot encoding of the sensitive attribute (gender) +groups = df_feat[sens_cols].to_numpy(dtype=np.float32) + +# standardize features +scaler = StandardScaler() +features = scaler.fit_transform(features) +# convert to torch tensors +X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups) + +dataset_train = torch.utils.data.TensorDataset(X, groups, y) +loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True) +criterion = torch.nn.BCEWithLogitsLoss() +``` + +Initialize the model and optimizer. + +```{code-cell} ipython3 +--- +tags: [hide-cell] +--- +from torch.nn import Sequential +from torch.optim import AdamW + +def setup_model(): + + model = Sequential( + torch.nn.Linear(features.shape[1], 64), + torch.nn.ReLU(), + torch.nn.Linear(64, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, 1), + ) + model.forward(torch.zeros(features.shape[1])).backward() # dummy forward/backward pass to construct torch graph for fair comparison + optimizer = AdamW(model.parameters()) + return model, optimizer +``` + +Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. + +```{code-cell} ipython3 +from fairret.statistic import PositiveRate + +statistic = PositiveRate() + +def pr_diff(logit, groups): + preds = torch.sigmoid(logit) + stats = PositiveRate()(preds, groups) + stat_diff = torch.abs(stats[0] - stats[1]) + return stat_diff +``` + +As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables. + +```{code-cell} ipython3 +from humancompatible.train.dual_optim import ALM + +dual_optimizer = ALM(m=1, lr=0.01) +``` + +Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \epsilon $ threshold). + +Then, the `forward_update` step does two things: +- Updates the dual variables based on the constraint violation, +- Calculates the Lagrangian based on loss and constraint violation. + +We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer. + +```{code-cell} ipython3 +model, optimizer = setup_model() +epochs = 10 +``` + +```{code-cell} ipython3 +for epoch in range(epochs): + # eval + model.eval() + logit = model(X) + train_loss = criterion(logit, y).item() + train_fair = pr_diff(logit, groups).item() + print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") + + # train + model.train() + for batch_feat, batch_groups, batch_label in loader: + optimizer.zero_grad() + logit = model(batch_feat) + loss = criterion(logit, batch_label) + + constraint = pr_diff(logit, batch_groups) - 0.05 + lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0)) + lagr.backward() + + optimizer.step() +``` diff --git a/docs/source/tutorials/tips.rst b/docs/source/tutorials/tips.rst index a514784..c0b10ce 100644 --- a/docs/source/tutorials/tips.rst +++ b/docs/source/tutorials/tips.rst @@ -1,4 +1,11 @@ Tips and Tricks ================================================== -Here, we discuss some miscellaneous tricks you can use to improve your experience. \ No newline at end of file +Here, we discuss some miscellaneous tricks and tips for using the package, which are not specific to any particular method, but can be useful in general when working with constrained optimization problems. + +Dealing with Noise +------------------ + +In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but this can be mitigated. + +**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In `humancompatible-train`, the dual optimizers support momentum, which can be enabled by setting the `momentum` parameter to a non-zero value. From 715af2d8c3dd61729d189541f99e192650073445 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Tue, 2 Jun 2026 18:30:56 +0200 Subject: [PATCH 09/30] alm: add and fix built-in inequality constraint handling; relax on strict satisfaction --- src/humancompatible/train/dual_optim/alm.py | 90 +++++++++++++-------- 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 41236c5..8c1626f 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -19,6 +19,7 @@ def __init__( dual_range: Tuple[float, float] = (-100.0, 100.0), momentum: float = 0.0, dampening: float = 0.0, + is_ineq: bool = False, ctol: float = 0., device=None, ) -> None: @@ -39,6 +40,8 @@ def __init__( :type momentum: float :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. :type ctol: float """ @@ -46,12 +49,12 @@ def __init__( if momentum > 0 and dampening == 0: dampening = momentum - self.dual_range = dual_range - self.ctol = ctol + # self.dual_range = dual_range + # self.ctol = ctol self.penalty = penalty duals, defaults = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, device + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device ) super().__init__(duals, defaults) @@ -71,6 +74,9 @@ def add_constraint_group( momentum: float = None, dampening: float = None, init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + device = None ) -> None: """ Allows to add a group of dual variables with separate initial values and learning rates. @@ -81,9 +87,13 @@ def add_constraint_group( :type lr: float :param init_duals: Initial values for the new dual variables :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool """ duals, settings_dict = _init_constraint_group( - m, lr, momentum, dampening, init_duals, self.dual_range + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -94,7 +104,7 @@ def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: lagrangian.add_( 0.5 * self.penalty - * torch.dot(constraints - self.ctol, constraints - self.ctol) + * torch.dot(constraints, constraints) ) @@ -114,7 +124,7 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: for i in range(len(self.param_groups)): duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=False + self.param_groups[i], i, constraints, update_duals=False ) lagrangian.add_(duals @ group_constraints) @@ -131,7 +141,7 @@ def update(self, constraints: Tensor) -> None: """ for i in range(len(self.param_groups)): _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True + self.param_groups[i], i, constraints, update_duals=True ) # evaluate the Lagrangian and update the dual variables @@ -151,7 +161,7 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: for i in range(len(self.param_groups)): duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True + self.param_groups[i], i, constraints, update_duals=True ) lagrangian.add_(duals @ group_constraints) @@ -162,7 +172,6 @@ def state_dict(self) -> dict[str, Any]: state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty - state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch for id_pg, pg in enumerate(state_dict["param_groups"]): pg["params"] = [ @@ -184,8 +193,6 @@ def _process_constraint_group( group: dict[str, Any], group_idx: int, constraints: Tensor, - ctol: float, - dual_range: Tuple[float, float], update_duals: bool = False, ) -> Tuple[Tensor, Tensor]: """ @@ -194,27 +201,29 @@ def _process_constraint_group( :param group: The constraint group dictionary :param group_idx: Index of the constraint group :param constraints: Full constraints tensor - :param ctol: Constraint tolerance - :param dual_range: Safeguarding range for dual variables :param update_duals: Whether to update dual variables :return: Tuple of (duals, group_constraints) """ duals = group["params"][0] group_constraints = ( - constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ctol + constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] ) - if update_duals: - lr = group.get("lr") - momentum = group.get("momentum", 0.0) - dampening = group.get("dampening", 0.0) - momentum_buffer = group["momentum_buffer"] + lr = group.get("lr") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group["momentum_buffer"] + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + is_ineq = group.get("is_ineq") + + with torch.no_grad(): + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + if update_duals: + _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr) + clamp_(duals, min=dual_lb, max=dual_ub) - with torch.no_grad(): - _update_duals( - duals, group_constraints, lr, momentum, dampening, momentum_buffer - ) - clamp_(duals, min=dual_range[0], max=dual_range[1]) return duals, group_constraints @@ -226,26 +235,31 @@ def _init_constraint_group( dampening: float = None, init_duals: float | Tensor = None, dual_range: Tuple[float, float] = None, - device=None, + is_ineq: bool = False, + device = None, ): ## checks ## if init_duals is None and m is None: - raise ValueError("At least one of`m`,`init_duals` must be set") + raise ValueError("At least one of m, init_duals must be set") if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") + raise ValueError(f"momentum must be within [0,1]; got {momentum}") + + if not isinstance(is_ineq, bool): + raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") m = m if m is not None else len(init_duals) if init_duals is None: # initialize duals if not set or set to scalar - init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - ) + init_duals = torch.zeros(m, requires_grad=False, device=device) elif isinstance(init_duals, float): init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals duals = Parameter(init_duals, requires_grad=False) + if dual_range is None: + dual_range = (None, None) + settings_dict = { "lr": lr, "momentum": momentum, @@ -253,6 +267,9 @@ def _init_constraint_group( "momentum_buffer": torch.zeros_like( init_duals, requires_grad=False, device=device ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq } settings_dict = {k: v for k, v in settings_dict.items() if v is not None} @@ -260,16 +277,23 @@ def _init_constraint_group( return param_group -def _update_duals( - duals: Tensor, +def _update_c_buffers( constraints: Tensor, - lr: float, momentum: float, dampening: float, buffer: Tensor, ) -> None: + """Update the constraint buffer with momentum.""" if momentum == 0: buffer = constraints else: buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + + +def _update_duals( + duals: Tensor, + buffer: Tensor, + lr: float, +) -> None: + """Update duals using the buffered constraint gradients.""" duals.add_(buffer, alpha=lr) From fa69883488de71a486cf234f58b27d1fc5f2dd54 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 10:52:06 +0200 Subject: [PATCH 10/30] alm and ialm ineq constraint handling --- src/humancompatible/train/dual_optim/alm.py | 94 +++---- src/humancompatible/train/dual_optim/ialm.py | 251 ++++++++++--------- 2 files changed, 179 insertions(+), 166 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 8c1626f..9586464 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -182,7 +182,7 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.penalty = state_dict["state"]["penalty"] - self.dual_range = state_dict["state"]["dual_range"] + # self.dual_range = state_dict["state"]["dual_range"] params = state_dict["param_groups"] self.param_groups = [] for param in params: @@ -229,52 +229,52 @@ def _process_constraint_group( def _init_constraint_group( - m: int = None, - lr: float = None, - momentum: float = None, - dampening: float = None, - init_duals: float | Tensor = None, - dual_range: Tuple[float, float] = None, - is_ineq: bool = False, - device = None, - ): - ## checks ## - if init_duals is None and m is None: - raise ValueError("At least one of m, init_duals must be set") - - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"momentum must be within [0,1]; got {momentum}") - - if not isinstance(is_ineq, bool): - raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") - - m = m if m is not None else len(init_duals) - - if init_duals is None: # initialize duals if not set or set to scalar - init_duals = torch.zeros(m, requires_grad=False, device=device) - elif isinstance(init_duals, float): - init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals - - duals = Parameter(init_duals, requires_grad=False) - - if dual_range is None: - dual_range = (None, None) - - settings_dict = { - "lr": lr, - "momentum": momentum, - "dampening": dampening, - "momentum_buffer": torch.zeros_like( - init_duals, requires_grad=False, device=device - ), - "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], - "upper_bound": dual_range[1], - "is_ineq": is_ineq - } - settings_dict = {k: v for k, v in settings_dict.items() if v is not None} - - param_group = ([duals], settings_dict) - return param_group + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + device = None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of m, init_duals must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"momentum must be within [0,1]; got {momentum}") + + if not isinstance(is_ineq, bool): + raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None: + dual_range = (None, None) + + settings_dict = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group def _update_c_buffers( diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 5d9028c..66be076 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -21,6 +21,7 @@ def __init__( dual_range: Tuple[float, float] = (-100., 100.), momentum: float = 0.0, dampening: float = 0.0, + is_ineq: bool = False, ctol: float = 1e-4, device=None, ) -> None: @@ -52,65 +53,20 @@ def __init__( if momentum > 0 and dampening == 0: dampening = momentum - self.dual_range = dual_range + # self.dual_range = dual_range # self.beta = beta self.penalty = penalty # self.gamma = gamma # self.sigma = sigma - self.ctol = ctol + # self.ctol = ctol - duals, defaults = self._init_constraint_group( - m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, device + duals, defaults = _init_constraint_group( + m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, is_ineq, device ) super().__init__(duals, defaults) - @staticmethod - def _init_constraint_group( - m: int = None, - beta: float = None, - sigma: float = None, - gamma: float = None, - momentum: float = None, - dampening: float = None, - init_duals: float | Tensor = None, - dual_range: Tuple[float, float] = None, - device=None, - ): - ## checks ## - if init_duals is None and m is None: - raise ValueError("At least one of`m`,`init_duals` must be set") - - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") - - m = m if m is not None else len(init_duals) - - if init_duals is None: # initialize duals if not set or set to scalar - init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - ) - elif isinstance(init_duals, float): - init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals - - duals = Parameter(init_duals, requires_grad=False) - - settings_dict = { - "beta": Parameter(torch.tensor(beta), requires_grad=False), - "sigma": Parameter(torch.tensor(sigma), requires_grad=False), - "gamma": Parameter(torch.tensor(gamma), requires_grad=False), - "momentum": momentum, - "dampening": dampening, - "momentum_buffer": torch.zeros_like( - init_duals, requires_grad=False, device=device - ), - } - settings_dict = {k: v for k, v in settings_dict.items() if v is not None} - - param_group = ([duals], settings_dict) - return param_group - @property def duals(self) -> Tensor: """ @@ -128,6 +84,9 @@ def add_constraint_group( momentum: float = None, dampening: float = None, init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + device = None ) -> None: """ Allows to add a group of dual variables with separate initial values and learning rates. @@ -146,9 +105,13 @@ def add_constraint_group( :type dampening: float :param init_duals: Initial values for the new dual variables :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool """ - duals, settings_dict = self._init_constraint_group( - m, beta, sigma, gamma, momentum, dampening, init_duals, self.dual_range, self.device + duals, settings_dict = _init_constraint_group( + m, beta, sigma, gamma, momentum, dampening, init_duals, dual_range, is_ineq, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -166,23 +129,15 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals, beta, _, _, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["beta"], - group["sigma"], - group["gamma"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol + + for i in range(len(self.param_groups)): + duals, beta, group_constraints = _process_constraint_group_ialm( + self.param_groups, i, constraints, update_duals=False ) lagrangian.add_(duals @ group_constraints) - _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) - + # Use beta from first group for penalty term + beta = self.param_groups[0]["beta"] lagrangian.add_(0.5 * beta * torch.dot(constraints, constraints)) return lagrangian @@ -194,30 +149,14 @@ def update(self, constraints: Tensor) -> None: :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i, group in enumerate(self.param_groups): - duals, beta, sigma, gamma, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["beta"], - group["sigma"], - group["gamma"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol + for i in range(len(self.param_groups)): + _process_constraint_group_ialm( + self.param_groups, i, constraints, update_duals=True ) - with torch.no_grad(): - _update_duals( - duals, - group_constraints, - beta, - gamma, - momentum_buffer, - ) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - - beta.mul_(sigma) + + # Update beta by sigma for each group + for group in self.param_groups: + group["beta"].mul_(group["sigma"]) # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -233,40 +172,22 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals, beta, sigma, gamma, momentum, dampening, momentum_buffer = ( - group["params"][0], - group["beta"], - group["sigma"], - group["gamma"], - group["momentum"], - group["dampening"], - group["momentum_buffer"], - ) - group_constraints = ( - constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol - ) - with torch.no_grad(): - _update_c_buffers( - group_constraints, momentum, dampening, momentum_buffer - ) - _update_duals( - duals, - beta, - gamma, - momentum_buffer, - ) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) + for i in range(len(self.param_groups)): + duals, beta, group_constraints = _process_constraint_group_ialm( + self.param_groups, i, constraints, update_duals=True + ) lagrangian.add_(duals @ group_constraints) + # Use beta from first group for penalty term + beta = self.param_groups[0]["beta"] lagrangian.add_( - 0.5 - * beta - * torch.dot(constraints - self.ctol, constraints - self.ctol) + 0.5 * beta * torch.dot(constraints, constraints) ) - beta.mul_(sigma) + # Update beta by sigma for each group + for group in self.param_groups: + group["beta"].mul_(group["sigma"]) return lagrangian @@ -274,7 +195,7 @@ def state_dict(self) -> dict[str, Any]: state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty - state_dict["state"]["dual_range"] = self.dual_range + # state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch for id_pg, pg in enumerate(state_dict["param_groups"]): pg["params"] = [ @@ -285,13 +206,105 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.penalty = state_dict["state"]["penalty"] - self.dual_range = state_dict["state"]["dual_range"] + # self.dual_range = state_dict["state"]["dual_range"] params = state_dict["param_groups"] self.param_groups = [] for param in params: self.param_groups.append(param) +def _process_constraint_group_ialm( + param_groups: list, + group_idx: int, + constraints: Tensor, + update_duals: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Process a single constraint group: extract parameters and optionally update duals. + + :param param_groups: List of parameter groups from optimizer + :param group_idx: Index of the constraint group + :param constraints: Full constraints tensor + :param ctol: Constraint tolerance + :param dual_range: Safeguarding range for dual variables + :param update_duals: Whether to update dual variables + :return: Tuple of (duals, beta, group_constraints) + """ + group = param_groups[group_idx] + duals = group["params"][0] + beta = group.get("beta") + sigma = group.get("sigma") + gamma = group.get("gamma") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group.get("momentum_buffer") + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + is_ineq = group.get("is_ineq") + + group_constraints = constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] + + with torch.no_grad(): + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + if update_duals: + _update_duals(duals, beta, gamma, momentum_buffer if momentum > 0 else group_constraints) + clamp_(duals, min=dual_lb, max=dual_ub) + + return duals, beta, group_constraints + + +def _init_constraint_group( + m: int = None, + beta: float = None, + sigma: float = None, + gamma: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + device=None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of`m`,`init_duals` must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = ( + torch.zeros(m, requires_grad=False, device=device) + dual_range[0] + ) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None: + dual_range = (None, None) + + settings_dict = { + "beta": Parameter(torch.tensor(beta), requires_grad=False), + "sigma": Parameter(torch.tensor(sigma), requires_grad=False), + "gamma": Parameter(torch.tensor(gamma), requires_grad=False), + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + def _update_c_buffers( constraints: Tensor, momentum: float, @@ -311,5 +324,5 @@ def _update_duals( buffer: Tensor, ) -> None: - update_mult = torch.min(beta, gamma / (buffer @ buffer)) + update_mult = torch.min(beta, gamma / torch.linalg.norm(buffer)) duals.add_(buffer, alpha=update_mult) From 7206c063791f5c4cb82074e8504ab127d614d9d0 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 10:53:21 +0200 Subject: [PATCH 11/30] fix bug with neg dual init --- src/humancompatible/train/dual_optim/alm.py | 2 +- src/humancompatible/train/dual_optim/ialm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 41236c5..d11d2e2 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -239,7 +239,7 @@ def _init_constraint_group( if init_duals is None: # initialize duals if not set or set to scalar init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] + torch.zeros(m, requires_grad=False, device=device) ) elif isinstance(init_duals, float): init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 5d9028c..320d291 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -89,7 +89,7 @@ def _init_constraint_group( if init_duals is None: # initialize duals if not set or set to scalar init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] + torch.zeros(m, requires_grad=False, device=device) ) elif isinstance(init_duals, float): init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals From ce19362ecab899f5d0df06380e28180fb1767bbe Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 12:31:13 +0200 Subject: [PATCH 12/30] doc update --- docs/requirements.txt | 3 ++- docs/source/api.rst | 6 ++++++ docs/source/conf.py | 11 ++++++++--- docs/source/index.rst | 9 ++++++++- docs/source/support.rst | 8 ++++---- docs/source/tutorials/copt_overview.rst | 19 ++++++++++--------- docs/source/tutorials/tips.rst | 3 ++- 7 files changed, 40 insertions(+), 19 deletions(-) create mode 100644 docs/source/api.rst diff --git a/docs/requirements.txt b/docs/requirements.txt index 5059523..a92337a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ sphinx myst-nb . -furo \ No newline at end of file +furo +sphinx_rtd_theme \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 0000000..8b72ac4 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,6 @@ +Dual Optimizers +================= + + +.. autoclass:: humancompatible.train.dual_optim.ALM + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 705d2c4..8f1c53f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,7 +6,7 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -extensions = ["myst_nb"] +extensions = ["myst_nb", "sphinx.ext.autodoc"] import os @@ -33,7 +33,8 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'furo' +# html_theme = 'furo' +html_theme = 'sphinx_rtd_theme' html_static_path = ['_static'] @@ -54,4 +55,8 @@ ".ipynb": "myst-nb", } -nb_execution_mode = "cache" \ No newline at end of file +nb_execution_mode = "cache" + +import sys + +sys.path.insert(0, os.path.abspath('./../..')) \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 7c32d9c..2bc0c9d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,4 +30,11 @@ We implement several first-order Lagrangian-based methods for constrained optimi Constrained Optimization Overview Basic usage: Fairness Handling inequality constraints - Tips and Tricks \ No newline at end of file + Tips and Tricks + +.. toctree:: + :maxdepth: 1 + :caption: API reference + :titlesonly: + + api.rst diff --git a/docs/source/support.rst b/docs/source/support.rst index 75c0981..a329fa2 100644 --- a/docs/source/support.rst +++ b/docs/source/support.rst @@ -30,10 +30,10 @@ kliacand@fel.cvut.cz Documentation ~~~~~~~~~~~~~ -- Check the :doc:`getting_started` guide for basic information -- Review the :doc:`examples/basic_usage` for common usage patterns -- Consult the :doc:`troubleshooting` page for known issues and solutions -- See the :doc:`examples/api_reference` for API documentation +.. - Check the :doc:`getting_started` guide for basic information +.. - Review the :doc:`examples/basic_usage` for common usage patterns +.. - Consult the :doc:`troubleshooting` page for known issues and solutions +.. - See the :doc:`examples/api_reference` for API documentation Contributing ~~~~~~~~~~~~ diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst index a0bd734..330b70f 100644 --- a/docs/source/tutorials/copt_overview.rst +++ b/docs/source/tutorials/copt_overview.rst @@ -16,22 +16,23 @@ In `humancompatible-train`, and in Constrained Machine Learning more generally, where :math:`f` is the **objective function** we want to minimize, :math:`g` are the **inequality constraints**, and :math:`h` are the **equality constraints**. The expectation is taken over some random variable :math:`\xi`, which represents the data. -You may recognize the first line of the above formula as the standard formulation of an optimization problem in machine learning, where we want to minimize the expected loss over the data. \ -We then introduce constraints -- they could express anything from some bounds on the parameters of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. +You may recognize the first line of the above formula as the standard formulation of a machine learning problem, where we want to **minimize the expected loss** over the data. +We then introduce **constraints** -- they could express anything from some bounds on the weights of the model, or a requirement on the model's predictions to satisfy some fairness criterion, to the boundary conditions of a physical system. .. note:: - As is standard in the field, we adopt the convention of writing the constraints as :math:`g(x) \leq 0`, and :math:`h(x) = 0`. This is just a notational choice, and does not affect the generality of the formulation. It is trivial to transform :math:`g(x) \geq 0` into :math:`-g(x) \leq 0`, or :math:`g(x) \leq \epsilon` into :math:`g(x) - \epsilon \leq 0` for some non-zero bound. - - It is also easy to switch between equality and inequality constraints: to achieve :math:`g(x) = 0`, one can set :math:`-g(x) \leq 0` and :math:`g(x) \leq 0` simultaneously. In fact, different algorithms are designed to handle either equality or inequality constraints natively, but, again, it is trivial to switch between the two. We shall see more concrete examples later on. + - It is also easy to switch between equality and inequality constraints: to get :math:`g(x) = 0`, one can set :math:`-g(x) \leq 0` and :math:`g(x) \leq 0` simultaneously. In fact, different algorithms are designed to handle either equality or inequality constraints natively, but it is trivial to switch between the two. We shall see more concrete examples later on. Solving Constrained Problems -------------------------------- -We all know how to solve an unconstrained optimization problem -- we can use gradient descent, or any of its variants. But how do we solve a constrained optimization problem? -The Constrained Machine Learning field, including us, seems to have converged on **Lagrangian-based methods**, which utilize the Lagrangian function to transform the **constrained** problem into an **unconstrained** one. +We know that to solve an unconstrained optimization problem, we can use gradient descent, or any of its myriad variants. But how do we solve a constrained optimization problem? -Going forward in this tutorial, we will focus on the **deterministic case** to simplify notation; the stochastic case is more complex, but utilizes the same principles (imagine Gradient Descent vs. SGD). For more rigorous mathematical treatment of the stochastic case, see **TODO**, as well as the references included in the documentation for each of the algorithms in the package. +The Constrained Machine Learning field seems to have converged on **Lagrangian-based methods**, which utilize the Lagrangian function to transform the **constrained** problem into an **unconstrained** one. + +Going forward in this tutorial, we will focus on the **deterministic case** to simplify notation; the stochastic case is more complex, but utilizes the same principles (think full-batch vs. mini-batch Gradient Descent). For more rigorous mathematical treatment of the stochastic case, see **TODO**, as well as the references included in the documentation for each of the algorithms. In a deterministic case, the Lagrangian function is defined as follows: @@ -48,7 +49,7 @@ It is then possible to show that the original constrained optimization problem i We refer to the original problem as the **primal problem**, with :math:`x` as the **primal variables**, and to the transformed problem as the **dual problem**, with :math:`\lambda` and :math:`\mu` as the **dual variables**. The dual problem is unconstrained, and can be solved using a clever application of standard optimization techniques. -In particular, the most common approach is to use **alternating updates**: we fix the primal variables, and optimize the dual variables using gradient ascent; then we fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. +In particular, we can use **alternating updates**: fix the primal variables, and optimize the dual variables using gradient ascent; then fix the dual variables, and optimize the primal variables using gradient descent. This process is repeated until convergence. In `humancompatible-train`, we implement several variants of this approach, based on methods present in the literature. For more details, see the corresponding documentation; for now, it is important to understand that they are all based on the same principle of alternating updates to the primal and dual variables. @@ -63,6 +64,6 @@ where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are the learning rates f .. note:: - The above update rules are for the simplest variant of the Lagrangian method. The methods implemented in this package are all more complex. Even beyond our implementation, one can (and sometimes should!) modify the update rules by e.g. tweaking the training loop code, as we show in the :doc:`tips` tutorial. - - The above update rules are for the deterministic case. In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but we have some tricks up our sleeves, such as momentum, LR scheduling, and so on. + - The Lagrangian approach is by far not the only way to do constrained optimization. It is, however, the best fit for a large-scale deep learning setting thanks to its "alternating updates" interpretation, which allows one to use well-established first-order iterative algorithms. We are still looking to implement other families of methods in `humancompatible-train`, such as SQP-based solvers! -In our package, the `dual optimizers` handle the updates to the dual variables, while the primal updates are handled by the standard PyTorch optimizers. This allows for seamless integration of constraints into the training loop, as we will see in the next tutorial. +In our package, the `dual optimizers` handle the updates to the dual variables, while the primal updates are handled by the standard PyTorch optimizers. This allows for seamless integration of constraints into the training loop, as we will see in the next tutorial. \ No newline at end of file diff --git a/docs/source/tutorials/tips.rst b/docs/source/tutorials/tips.rst index c0b10ce..89bf410 100644 --- a/docs/source/tutorials/tips.rst +++ b/docs/source/tutorials/tips.rst @@ -8,4 +8,5 @@ Dealing with Noise In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but this can be mitigated. -**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In `humancompatible-train`, the dual optimizers support momentum, which can be enabled by setting the `momentum` parameter to a non-zero value. +**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In `humancompatible-train`, the dual optimizers support momentum, which can be enabled by setting the `momentum` parameter to a non-zero value. \ +Some dual update strategies, such as nuPI, explicitly rely on momentum. From 38cead9e57abd63bc6ca008ad38ea0146512d641 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 14:29:28 +0200 Subject: [PATCH 13/30] bug in ineq init when range none --- src/humancompatible/train/dual_optim/alm.py | 86 ++++++++++++-------- src/humancompatible/train/dual_optim/ialm.py | 4 +- 2 files changed, 57 insertions(+), 33 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index 9586464..e0ac355 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -9,6 +9,28 @@ class ALM(Optimizer): + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ def __init__( self, m: int = None, @@ -23,28 +45,6 @@ def __init__( ctol: float = 0., device=None, ) -> None: - """ - A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. - - :param m: Number of constraints (determines the number of dual variables to create) - :type m: int - :param lr: Dual variable update rate - :type lr: float - :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. - :type init_duals: float | Tensor - :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` - :type penalty: float - :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. - :type dual_range: Tuple[float, float] - :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. - :type momentum: float - :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. - :type dampening: float - :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. - :type is_ineq: bool - :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. - :type ctol: float - """ if momentum > 0 and dampening == 0: dampening = momentum @@ -69,7 +69,7 @@ def duals(self) -> Tensor: def add_constraint_group( self, - m: int = None, + m: int, lr: float = None, momentum: float = None, dampening: float = None, @@ -83,14 +83,22 @@ def add_constraint_group( :param m: Size of group (number of dual variables to add) :type m: int - :param lr: Dual variable update rate + :param lr: Dual variable update rate. :type lr: float - :param init_duals: Initial values for the new dual variables + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param init_duals: Initial values for the new dual variables. Defaults to the value set when creating the optimizer. :type init_duals: Tensor :param dual_range: After each dual update, the dual variables will be clamped to this range. :type dual_range: Tuple[float, float] :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. :type is_ineq: bool + + .. note:: + Parameters here will default to values set when initializing the dual optimizer. + """ duals, settings_dict = _init_constraint_group( m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device @@ -100,12 +108,20 @@ def add_constraint_group( def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: """Add penalty term to lagrangian in-place.""" - if self.penalty > 0: + if self.penalty == 0: + return + elif constraints.ndim > 0: lagrangian.add_( 0.5 * self.penalty * torch.dot(constraints, constraints) ) + else: + lagrangian.add_( + 0.5 + * self.penalty + * torch.square(constraints) + ) def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -169,7 +185,7 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: return lagrangian def state_dict(self) -> dict[str, Any]: - + """""" state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty # save params themselves in state_dict instead of param ID in default PyTorch @@ -181,6 +197,7 @@ def state_dict(self) -> dict[str, Any]: return state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """""" self.penalty = state_dict["state"]["penalty"] # self.dual_range = state_dict["state"]["dual_range"] params = state_dict["param_groups"] @@ -205,10 +222,13 @@ def _process_constraint_group( :return: Tuple of (duals, group_constraints) """ duals = group["params"][0] - group_constraints = ( - constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ) - + if constraints.ndim > 0: + group_constraints = ( + constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] + ) + else: + group_constraints = constraints.unsqueeze(0) + lr = group.get("lr") momentum = group.get("momentum", 0.0) dampening = group.get("dampening", 0.0) @@ -257,8 +277,10 @@ def _init_constraint_group( duals = Parameter(init_duals, requires_grad=False) - if dual_range is None: + if dual_range is None and not is_ineq: dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) settings_dict = { "lr": lr, diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 66be076..a9ef9a2 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -284,8 +284,10 @@ def _init_constraint_group( duals = Parameter(init_duals, requires_grad=False) - if dual_range is None: + if dual_range is None and not is_ineq: dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) settings_dict = { "beta": Parameter(torch.tensor(beta), requires_grad=False), From 3c431695280b28be52b3d3bcac2fbf79ef1e4029 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 14:44:58 +0200 Subject: [PATCH 14/30] little updates to tests, structure, toml --- pyproject.toml | 4 ++-- src/humancompatible/__init__.py | 0 tests/test_alm.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 src/humancompatible/__init__.py diff --git a/pyproject.toml b/pyproject.toml index a04d84f..dbbf5e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "humancompatible-train" -version = "0.3.0" +version = "0.3.2" dependencies = [ "torch", "numpy", ] -requires-python = ">= 3.11, <3.14" +requires-python = ">= 3.11" authors = [ {name = "Andrii Kliachkin", email = "kliacand@fel.cvut.cz"}, {name = "Gilles Bareilles"}, diff --git a/src/humancompatible/__init__.py b/src/humancompatible/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_alm.py b/tests/test_alm.py index 4c4b138..98cb0dd 100644 --- a/tests/test_alm.py +++ b/tests/test_alm.py @@ -21,8 +21,6 @@ def setUp(self): def test_alm_initialization(self): # Test initialization with m self.assertEqual(len(self.alm_default.duals), 3) - self.assertEqual(self.alm_default.penalty, 1.0) - self.assertEqual(self.alm_default.dual_range, (0.0, 100.0)) # Test initialization with init_duals init_duals = torch.tensor([1.0, 2.0, 3.0]) @@ -40,7 +38,10 @@ def test_alm_forward(self): def test_alm_update(self): expected_duals = self.alm_default.duals + 0.1 * self.constraints + # breakpoint() self.alm_default.update(self.constraints) + print(self.alm_default.duals) + print(expected_duals) self.assertTrue(torch.allclose(self.alm_default.duals, expected_duals)) def test_alm_momentum_update(self): @@ -84,7 +85,6 @@ def test_alm_state_dict(self): alm = ALM(m=3, lr=0.1, penalty=2.0, dual_range=(-1.0, 1.0)) state_dict = alm.state_dict() self.assertEqual(state_dict["state"]["penalty"], 2.0) - self.assertEqual(state_dict["state"]["dual_range"], (-1.0, 1.0)) if __name__ == "__main__": unittest.main() \ No newline at end of file From 188a2b3ffbdb6e7d2a9d4b36f18047b2d30370b1 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 14:45:52 +0200 Subject: [PATCH 15/30] docs update --- docs/source/api_reference/dual_optimizers.rst | 8 + .../dual_opts/alm.rst} | 2 +- docs/source/api_reference/dual_opts/pbm.rst | 6 + docs/source/api_reference/utils.rst | 2 + docs/source/conf.py | 2 +- docs/source/index.rst | 9 +- docs/source/tutorials/basic_usage.md | 16 +- docs/source/tutorials/copt_overview.rst | 11 +- .../tutorials/inequality_constraints.md | 144 ------------------ .../tutorials/inequality_constraints.rst | 6 - 10 files changed, 38 insertions(+), 168 deletions(-) create mode 100644 docs/source/api_reference/dual_optimizers.rst rename docs/source/{api.rst => api_reference/dual_opts/alm.rst} (84%) create mode 100644 docs/source/api_reference/dual_opts/pbm.rst create mode 100644 docs/source/api_reference/utils.rst delete mode 100644 docs/source/tutorials/inequality_constraints.md delete mode 100644 docs/source/tutorials/inequality_constraints.rst diff --git a/docs/source/api_reference/dual_optimizers.rst b/docs/source/api_reference/dual_optimizers.rst new file mode 100644 index 0000000..bb31bf1 --- /dev/null +++ b/docs/source/api_reference/dual_optimizers.rst @@ -0,0 +1,8 @@ +Dual Optimizers +=============== + +.. toctree:: + :titlesonly: + :glob: + + dual_opts/* \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api_reference/dual_opts/alm.rst similarity index 84% rename from docs/source/api.rst rename to docs/source/api_reference/dual_opts/alm.rst index 8b72ac4..74bb2f4 100644 --- a/docs/source/api.rst +++ b/docs/source/api_reference/dual_opts/alm.rst @@ -1,4 +1,4 @@ -Dual Optimizers +ALM ================= diff --git a/docs/source/api_reference/dual_opts/pbm.rst b/docs/source/api_reference/dual_opts/pbm.rst new file mode 100644 index 0000000..49f0c72 --- /dev/null +++ b/docs/source/api_reference/dual_opts/pbm.rst @@ -0,0 +1,6 @@ +PBM +================= + + +.. autoclass:: humancompatible.train.dual_optim.PBM + :members: diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst new file mode 100644 index 0000000..2085850 --- /dev/null +++ b/docs/source/api_reference/utils.rst @@ -0,0 +1,2 @@ +Utils +===== \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 8f1c53f..5ae81c7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,7 +28,7 @@ nb_number_headings = False nb_execution_show_tb = False nb_execution_mode = "cache" -nb_execution_timeout = 60 +nb_execution_timeout = 180 # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/index.rst b/docs/source/index.rst index 2bc0c9d..4805ce1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,8 +1,3 @@ -.. humancompatible-train documentation master file, created by - sphinx-quickstart on Tue May 26 14:11:39 2026. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - humancompatible-train documentation =================================== @@ -33,8 +28,8 @@ We implement several first-order Lagrangian-based methods for constrained optimi Tips and Tricks .. toctree:: - :maxdepth: 1 :caption: API reference :titlesonly: - api.rst + Dual Optimizers + Utils \ No newline at end of file diff --git a/docs/source/tutorials/basic_usage.md b/docs/source/tutorials/basic_usage.md index f1292ff..059bc28 100644 --- a/docs/source/tutorials/basic_usage.md +++ b/docs/source/tutorials/basic_usage.md @@ -33,13 +33,13 @@ They create, keep track of, and update the **dual parameters** of the constraine Let us demonstrate using a **fairness-constrained learning** task, where we want to learn a classifier that is accurate but also satisfies a **demographic parity constraint** - i.e. we would like -$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | = \epsilon $$ +$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | \leq \epsilon $$ where $ Y $ is the prediction given by our model for sample $ X $, and $ \epsilon $ is some small threshold. -:::{note} + --- +++ @@ -111,7 +111,8 @@ def setup_model(): return model, optimizer ``` -Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. +Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. \ +As a reminder, we expect our constraints to be of the form $ g(...) \leq 0 $ or $ h(...) = 0 $. We want $ g(...) \leq \epsilon $, so we will subtract $ \epsilon $ in the training loop. ```{code-cell} ipython3 from fairret.statistic import PositiveRate @@ -125,17 +126,18 @@ def pr_diff(logit, groups): return stat_diff ``` -As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables. +As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables, and the **type** of constraint -- equality or inequality. In a following tutorial, we will see how to create *constraint groups* with different types and hyperparameters. ```{code-cell} ipython3 from humancompatible.train.dual_optim import ALM -dual_optimizer = ALM(m=1, lr=0.01) +dual_optimizer = ALM(m=1, lr=0.01, is_ineq=True) ``` Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \epsilon $ threshold). Then, the `forward_update` step does two things: + - Updates the dual variables based on the constraint violation, - Calculates the Lagrangian based on loss and constraint violation. @@ -169,7 +171,7 @@ for epoch in range(epochs): optimizer.step() ``` -Due to noise, it is difficult to obtain the exact correct value, but the method attempts to keep the constraint around 0.05! +We obtain a respectable loss value, while keeping the fairness violation below the threshold! +++ diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst index 330b70f..dbc7f42 100644 --- a/docs/source/tutorials/copt_overview.rst +++ b/docs/source/tutorials/copt_overview.rst @@ -34,6 +34,13 @@ The Constrained Machine Learning field seems to have converged on **Lagrangian-b Going forward in this tutorial, we will focus on the **deterministic case** to simplify notation; the stochastic case is more complex, but utilizes the same principles (think full-batch vs. mini-batch Gradient Descent). For more rigorous mathematical treatment of the stochastic case, see **TODO**, as well as the references included in the documentation for each of the algorithms. +So, we have the following constrained problem: + +.. math:: + \min_{x\in\mathbb{R}^n} \quad & f(x,\xi) \\ + \text{s.t.} \quad & g(x,\xi) \leq 0, \\ + & h(x,\xi) = 0, \\ + In a deterministic case, the Lagrangian function is defined as follows: .. math:: @@ -41,7 +48,7 @@ In a deterministic case, the Lagrangian function is defined as follows: where :math:`\lambda` is the Lagrange multiplier associated with the constraint :math:`g(x) \leq 0`, and :math:`\mu` is the Lagrange multiplier associated with the constraint :math:`h(x) = 0`. -It is then possible to show that the original constrained optimization problem is equivalent to the following unconstrained optimization problem: +It is then possible to show that the original **constrained** problem is equivalent to the following **unconstrained** problem: .. math:: \min_{x\in\mathbb{R}^n} \max_{\lambda \geq 0, \mu} \mathcal{L}(x, \lambda, \mu) @@ -58,7 +65,7 @@ In the simplest case of the Lagrangian method, this gives us the following updat .. math:: \lambda_{t+1} & = \lambda_t + \beta \nabla_\lambda \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \lambda_t + \beta g(x_{t}) \\ \mu_{t+1} & = \mu_t + \gamma \nabla_\mu \mathcal{L}(x_{t}, \lambda_t, \mu_t) = \mu_t + \gamma h(x_{t}) \\ - x_{t+1} & = x_t - \alpha \nabla_x \mathcal{L}(x_t, \lambda_t, \mu_t) + x_{t+1} & = x_t - \alpha \nabla_x \mathcal{L}(x_t, \lambda_{t+1}, \mu_{t+1}) where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are the learning rates for the primal and dual variables, respectively. diff --git a/docs/source/tutorials/inequality_constraints.md b/docs/source/tutorials/inequality_constraints.md deleted file mode 100644 index 50214a9..0000000 --- a/docs/source/tutorials/inequality_constraints.md +++ /dev/null @@ -1,144 +0,0 @@ ---- -jupytext: - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.19.3 -kernelspec: - display_name: hc-dev - language: python - name: python3 ---- - -# Inequality Constraints - -In the [Basic Usage tutorial](#basic_usage), we learned how to set up the constrained optimization problem with **equality constraints**. Here, we will generalize this to **inequality constraints**. For simplicity, we will keep the same setup, but our constraint will look like: - -$$ | P( Y = 1 | \text{X is Male}) - P ( Y = 1 | \text{X is Female} ) | \leq \epsilon $$ - -where $ Y $ is the prediction given by our model for sample $ X $, and $ \epsilon $ is some small threshold. - -+++ - -Prepare [folktables](https://github.com/socialfoundations/folktables) data: - -```{code-cell} ipython3 ---- -tags: [hide-cell] ---- -# load data -import torch -import numpy as np -from sklearn.preprocessing import StandardScaler -from folktables import ACSDataSource, ACSIncome, generate_categories - -torch.set_default_dtype(torch.float32) - -# load folktables data -data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") -acs_data = data_source.get_data(states=["FL"], download=True) -definition_df = data_source.get_definitions(download=True) -categories = generate_categories( - features=ACSIncome.features, definition_df=definition_df -) -df_feat, df_labels, _ = ACSIncome.df_to_pandas( - acs_data, categories=categories, dummies=True -) -sens_cols = ["SEX_Female", "SEX_Male"] -features = df_feat.drop(columns=sens_cols).to_numpy(dtype=np.float32) -labels = df_labels.to_numpy(dtype=np.float32) -# one-hot encoding of the sensitive attribute (gender) -groups = df_feat[sens_cols].to_numpy(dtype=np.float32) - -# standardize features -scaler = StandardScaler() -features = scaler.fit_transform(features) -# convert to torch tensors -X = torch.tensor(features) ; y = torch.tensor(labels) ; groups = torch.tensor(groups) - -dataset_train = torch.utils.data.TensorDataset(X, groups, y) -loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True) -criterion = torch.nn.BCEWithLogitsLoss() -``` - -Initialize the model and optimizer. - -```{code-cell} ipython3 ---- -tags: [hide-cell] ---- -from torch.nn import Sequential -from torch.optim import AdamW - -def setup_model(): - - model = Sequential( - torch.nn.Linear(features.shape[1], 64), - torch.nn.ReLU(), - torch.nn.Linear(64, 32), - torch.nn.ReLU(), - torch.nn.Linear(32, 1), - ) - model.forward(torch.zeros(features.shape[1])).backward() # dummy forward/backward pass to construct torch graph for fair comparison - optimizer = AdamW(model.parameters()) - return model, optimizer -``` - -Next, we define the **constraint function** for demographic parity, which uses the `fairret.statistic.PositiveRate` class to evaluate positive rates for both groups. - -```{code-cell} ipython3 -from fairret.statistic import PositiveRate - -statistic = PositiveRate() - -def pr_diff(logit, groups): - preds = torch.sigmoid(logit) - stats = PositiveRate()(preds, groups) - stat_diff = torch.abs(stats[0] - stats[1]) - return stat_diff -``` - -As a last step, we define our **dual optimizer**. To set it up, we only need to define the **number of constraints** -- in our case, it is 1 -- so it can create the corresponding dual variables. - -```{code-cell} ipython3 -from humancompatible.train.dual_optim import ALM - -dual_optimizer = ALM(m=1, lr=0.01) -``` - -Finally, we write our training loop. In addition to the forward pass and loss calculation, we add a constraint calculation step (0.05 is our $ \epsilon $ threshold). - -Then, the `forward_update` step does two things: -- Updates the dual variables based on the constraint violation, -- Calculates the Lagrangian based on loss and constraint violation. - -We then perform a backward pass on the Lagrangian and minimize it using a normal PyTorch optimizer. - -```{code-cell} ipython3 -model, optimizer = setup_model() -epochs = 10 -``` - -```{code-cell} ipython3 -for epoch in range(epochs): - # eval - model.eval() - logit = model(X) - train_loss = criterion(logit, y).item() - train_fair = pr_diff(logit, groups).item() - print(f"Epoch: {epoch}, loss: {train_loss}, constraint: {train_fair}") - - # train - model.train() - for batch_feat, batch_groups, batch_label in loader: - optimizer.zero_grad() - logit = model(batch_feat) - loss = criterion(logit, batch_label) - - constraint = pr_diff(logit, batch_groups) - 0.05 - lagr = dual_optimizer.forward_update(loss, constraint.unsqueeze(0)) - lagr.backward() - - optimizer.step() -``` diff --git a/docs/source/tutorials/inequality_constraints.rst b/docs/source/tutorials/inequality_constraints.rst deleted file mode 100644 index e1fadac..0000000 --- a/docs/source/tutorials/inequality_constraints.rst +++ /dev/null @@ -1,6 +0,0 @@ -Handling Inequality Constraints -================================================== - -Here, we'll demonstrate how to handle inequality constraints. -TODO. -weight constraint exanmple \ No newline at end of file From 80ae8e0ff86882aee32de745405b84f9fe4774cc Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 17:18:03 +0200 Subject: [PATCH 16/30] alm, pbm docstrings --- src/humancompatible/train/dual_optim/alm.py | 96 +++++++++++++++------ src/humancompatible/train/dual_optim/pbm.py | 59 ++++++++----- 2 files changed, 109 insertions(+), 46 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index e0ac355..b291e88 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -9,28 +9,6 @@ class ALM(Optimizer): - r""" - A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. - - :param m: Number of constraints (determines the number of dual variables to create) - :type m: int - :param lr: Dual variable update rate. - :type lr: float - :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. - :type init_duals: float | Tensor - :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` - :type penalty: float - :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. - :type dual_range: Tuple[float, float] - :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. - :type momentum: float - :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. - :type dampening: float - :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. - :type is_ineq: bool - :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. - :type ctol: float - """ def __init__( self, m: int = None, @@ -128,6 +106,13 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ Calculates and returns the Augmented Lagrangian. + Computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + where `loss` is the objective value, `duals_i` are the dual variables, `constraints_i` are constraint values, + `penalty` is the penalty parameter, and the sum is over all constraint groups. + :param loss: Loss (objective function) value :type loss: Tensor :param constraints: Tensor of constraint values @@ -150,7 +135,23 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: def update(self, constraints: Tensor) -> None: """ - Updates the dual variables + Updates the dual variables using constrained gradient ascent with optional momentum. + + For each constraint group, performs the dual variable update. + + First, update the momentum buffer (if momentum > 0):: + + if momentum > 0: + buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i + else: + buffer_i = constraints_i + + Then, update the dual variables with clamping:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where `buffer_i` is the momentum buffer, `constraints_i` are constraint values, `duals_i` are dual variables, + and `clamp(x, lb, ub)` projects to the dual range. :param constraints: Tensor of constraint values :type constraints: Tensor @@ -163,7 +164,18 @@ def update(self, constraints: Tensor) -> None: # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ - Combines `forward` and `update`; slightly faster. + Combines `forward` and `update`; slightly faster than calling both separately. + + Updates dual variables:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + Then computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + + where the momentum buffer is updated as in :meth:`update`. :param loss: Loss (objective function) value :type loss: Tensor @@ -319,3 +331,39 @@ def _update_duals( ) -> None: """Update duals using the buffered constraint gradients.""" duals.add_(buffer, alpha=lr) + + + +ALM.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2504.07607 + + .. math:: + + \pmb{\lambda}_{t+1} & \leftarrow \pmb{\lambda}_t + \gamma \mathbf{c}_t(\theta_{t}) + + \mathcal{L}_{t+1} & \leftarrow f_t(\theta_{t}) + \pmb{\lambda}_{t+1}^T \mathbf{c}_t(\theta_{t}) + \frac{\rho}{2} \| \mathbf{c}_t(\theta_{t}) \|^2_2 + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ +) \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/pbm.py b/src/humancompatible/train/dual_optim/pbm.py index 07c7665..c263062 100644 --- a/src/humancompatible/train/dual_optim/pbm.py +++ b/src/humancompatible/train/dual_optim/pbm.py @@ -23,28 +23,7 @@ def __init__( device=None, primal_update_process_length=1, # length of the primal update process - if =1, is the original algorithm ) -> None: - """ - A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Penalty-Barrier Method rule. Creates and updates dual variables. - - :param m: Number of constraints (determines the number of dual variables to create) - :type m: int - :param penalty_mult: Multiplier for penalty update (K1 or K2). For K2 (adaptive penalty update), values close to 1 correspond to a high "momentum". - :type penalty_mult: float - :param gamma: Multiplier for dual parameter update. Values close to 1 correspond to a high "momentum". - :type gamma: float - :param delta: Violation/satisfaction parameter for penalty update; values > 1 make the penalties decrease faster on violated constraints and vice versa. - :type delta: float - :param penalty_update: Penalty update strategy; must be one of `dimin`,`dimin_dual`,`dimin_adapt`,`const`. Defaults to`dimin_adapt`. - :type penalty_update: str - :param pbf: Penalty-Barrier Function to use. Must be one of `quadratic_logarithmic`,`quadratic_reciprocal` - :type pbf: str - :param init_duals: Initial values for the dual variables. Defaults to dual lower bound for all. - :type init_duals: float | Tensor - :param init_penalties: Initial values for the penalty variables. Defaults to the penalty upper bound for all. - :type init_penalties: float | Tensor - :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. - :type dual_range: Tuple[float, float] - """ + self.dual_range = dual_range self.penalty_range = penalty_range @@ -486,3 +465,39 @@ def _update_penalties_dimin_dual( "adapt": _update_penalties_adapt, "dimin_dual": _update_penalties_dimin_dual, } + + +PBM.__doc__ = ( + + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Penalty-Barrier Method rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2605.18618 + + .. note:: + + Natively, this method only supports inequality constraints (see reference). However, it is easy to transform one into the other: + + .. math:: + g(x) = |h(x)| \leq 0 + + We suggest using a small tolerance parameter on the right-hand side instead of 0. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param penalty_mult: Multiplier for penalty update (K1 or K2). For K2 (adaptive penalty update), values close to 1 correspond to a high "momentum". + :type penalty_mult: float + :param gamma: Multiplier for dual parameter update. Values close to 1 correspond to a high "momentum". + :type gamma: float + :param delta: Violation/satisfaction parameter for penalty update; values > 1 make the penalties decrease faster on violated constraints and vice versa. + :type delta: float + :param penalty_update: Penalty update strategy; must be one of `dimin`,`dimin_dual`,`dimin_adapt`,`const`. Defaults to`dimin_adapt`. + :type penalty_update: str + :param pbf: Penalty-Barrier Function to use. Must be one of `quadratic_logarithmic`,`quadratic_reciprocal` + :type pbf: str + :param init_duals: Initial values for the dual variables. Defaults to dual lower bound for all. + :type init_duals: float | Tensor + :param init_penalties: Initial values for the penalty variables. Defaults to the penalty upper bound for all. + :type init_penalties: float | Tensor + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + """ +) \ No newline at end of file From 83ac4df266cd20c016936866f1f27ceea51d7260 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Wed, 3 Jun 2026 17:18:35 +0200 Subject: [PATCH 17/30] add skeleton for nupi --- src/humancompatible/train/dual_optim/nupi.py | 357 +++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 src/humancompatible/train/dual_optim/nupi.py diff --git a/src/humancompatible/train/dual_optim/nupi.py b/src/humancompatible/train/dual_optim/nupi.py new file mode 100644 index 0000000..8eb3b28 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nupi.py @@ -0,0 +1,357 @@ +import torch +from torch.nn import Parameter +from torch.optim import Optimizer +from typing import Any, Tuple +from torch import clamp_, Tensor + +# cite: On PI Controllers for Updating Lagrange Multipliers in Constrained Optimization +# https://arxiv.org/pdf/2406.04558v1 + + +class nuPI(Optimizer): + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. Reference: https://doi.org/10.48550/arXiv.2504.07607 + + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ + def __init__( + self, + m: int = None, + lr: float = 0.01, + init_duals: float | Tensor = None, + penalty: float = 1.0, + *, + dual_range: Tuple[float, float] = (-100.0, 100.0), + momentum: float = 0.0, + dampening: float = 0.0, + is_ineq: bool = False, + ctol: float = 0., + device=None, + ) -> None: + + if momentum > 0 and dampening == 0: + dampening = momentum + + # self.dual_range = dual_range + # self.ctol = ctol + + self.penalty = penalty + duals, defaults = _init_constraint_group( + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + ) + + super().__init__(duals, defaults) + + @property + def duals(self) -> Tensor: + """ + :return: Dual variables, concatenated into a single tensor. + :rtype: Tensor + """ + return torch.cat([group["params"][0] for group in self.param_groups]) + + def add_constraint_group( + self, + m: int, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: Tensor = None, + dual_range: tuple[float, float] = None, + is_ineq: bool = False, + device = None + ) -> None: + """ + Allows to add a group of dual variables with separate initial values and learning rates. + + :param m: Size of group (number of dual variables to add) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param init_duals: Initial values for the new dual variables. Defaults to the value set when creating the optimizer. + :type init_duals: Tensor + :param dual_range: After each dual update, the dual variables will be clamped to this range. + :type dual_range: Tuple[float, float] + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + + .. note:: + Parameters here will default to values set when initializing the dual optimizer. + + """ + duals, settings_dict = _init_constraint_group( + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + ) + param_group_dict = {"params": duals, **settings_dict} + self.add_param_group(param_group_dict) + + def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: + """Add penalty term to lagrangian in-place.""" + if self.penalty == 0: + return + elif constraints.ndim > 0: + lagrangian.add_( + 0.5 + * self.penalty + * torch.dot(constraints, constraints) + ) + else: + lagrangian.add_( + 0.5 + * self.penalty + * torch.square(constraints) + ) + + + def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Calculates and returns the Augmented Lagrangian. + + Computes the augmented Lagrangian:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + where `loss` is the objective value, `duals_i` are the dual variables, `constraints_i` are constraint values, + `penalty` is the penalty parameter, and the sum is over all constraint groups. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + + for i in range(len(self.param_groups)): + duals, group_constraints = _process_constraint_group( + self.param_groups[i], i, constraints, update_duals=False + ) + lagrangian.add_(duals @ group_constraints) + + self._add_penalty_term(lagrangian, constraints) + return lagrangian + + + def update(self, constraints: Tensor) -> None: + """ + Updates the dual variables using constrained gradient ascent with optional momentum. + + For each constraint group, performs the dual variable update. + + First, update the momentum buffer (if momentum > 0):: + + if momentum > 0: + buffer_i = momentum * buffer_i + (1 - dampening) * constraints_i + else: + buffer_i = constraints_i + + Then, update the dual variables with clamping:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where `buffer_i` is the momentum buffer, `constraints_i` are constraint values, `duals_i` are dual variables, + and `clamp(x, lb, ub)` projects to the dual range. + + :param constraints: Tensor of constraint values + :type constraints: Tensor + """ + for i in range(len(self.param_groups)): + _process_constraint_group( + self.param_groups[i], i, constraints, update_duals=True + ) + + # evaluate the Lagrangian and update the dual variables + def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Combines `forward` and `update`; slightly faster than calling both separately. + + Computes the augmented Lagrangian and updates dual variables in one pass:: + + L = loss + sum(duals_i @ constraints_i for all groups) + 0.5 * penalty * ||constraints||^2 + + Then updates dual variables:: + + duals_i = clamp(duals_i + lr * buffer_i, lower_bound, upper_bound) + + where the momentum buffer is updated as in :meth:`update`. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + + for i in range(len(self.param_groups)): + duals, group_constraints = _process_constraint_group( + self.param_groups[i], i, constraints, update_duals=True + ) + lagrangian.add_(duals @ group_constraints) + + self._add_penalty_term(lagrangian, constraints) + return lagrangian + + def state_dict(self) -> dict[str, Any]: + """""" + state_dict = super().state_dict() + state_dict["state"]["penalty"] = self.penalty + # save params themselves in state_dict instead of param ID in default PyTorch + for id_pg, pg in enumerate(state_dict["param_groups"]): + pg["params"] = [ + self.param_groups[id_pg]["params"][param_id] + for param_id in pg["params"] + ] + return state_dict + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """""" + self.penalty = state_dict["state"]["penalty"] + # self.dual_range = state_dict["state"]["dual_range"] + params = state_dict["param_groups"] + self.param_groups = [] + for param in params: + self.param_groups.append(param) + + +def _process_constraint_group( + group: dict[str, Any], + group_idx: int, + constraints: Tensor, + update_duals: bool = False, +) -> Tuple[Tensor, Tensor]: + """ + Process a single constraint group: extract duals/constraints and optionally update duals. + + :param group: The constraint group dictionary + :param group_idx: Index of the constraint group + :param constraints: Full constraints tensor + :param update_duals: Whether to update dual variables + :return: Tuple of (duals, group_constraints) + """ + duals = group["params"][0] + if constraints.ndim > 0: + group_constraints = ( + constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] + ) + else: + group_constraints = constraints.unsqueeze(0) + + lr = group.get("lr") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group["momentum_buffer"] + dual_lb = group.get("lower_bound") + dual_ub = group.get("upper_bound") + is_ineq = group.get("is_ineq") + + with torch.no_grad(): + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + if update_duals: + _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr) + clamp_(duals, min=dual_lb, max=dual_ub) + + + return duals, group_constraints + + +def _init_constraint_group( + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + is_ineq: bool = None, + device = None, +): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of m, init_duals must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"momentum must be within [0,1]; got {momentum}") + + if not isinstance(is_ineq, bool): + raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + if dual_range is None and not is_ineq: + dual_range = (None, None) + elif dual_range is None and is_ineq: + dual_range = (0, None) + + settings_dict = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], + "upper_bound": dual_range[1], + "is_ineq": is_ineq + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + + +def _update_c_buffers( + constraints: Tensor, + momentum: float, + dampening: float, + buffer: Tensor, +) -> None: + """Update the constraint buffer with momentum.""" + raise NotImplementedError + if momentum == 0: + buffer = constraints + else: + buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + + +def _update_duals( + duals: Tensor, + buffer: Tensor, + lr: float, +) -> None: + raise NotImplementedError + """Update duals using the buffered constraint gradients.""" + duals.add_(buffer, alpha=lr) From ef05eb08e4665f634d5407d99e0f5b161599458d Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 4 Jun 2026 11:33:39 +0200 Subject: [PATCH 18/30] camassa update --- ...ePreservingPINN_2D_trainableACS_pytorch.py | 14 ++++++++++- ...ingPINN_CamassaHolm_trainableH1_pytorch.py | 25 +++++++++++-------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py index e2350db..8fc5ae0 100644 --- a/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_2D_trainableACS_pytorch.py @@ -123,7 +123,7 @@ def forward(self, x): x = self.ff(x) return self.net(x) -def custom_loss(inputs, model): +def custom_loss(inputs, model, dual_opt): x, y, t = inputs[:, 0:1], inputs[:, 1:2], inputs[:, 2:3] x.requires_grad_(True) y.requires_grad_(True) @@ -154,6 +154,18 @@ def custom_loss(inputs, model): H_loss = H(u_model.reshape(Nx, Ny, Nt), u_x.reshape(Nx, Ny, Nt), u_y.reshape(Nx, Ny, Nt)) + # constraint + H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) + + Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) + H_constraint = torch.abs(Hf - H0)/torch.abs(H0) + + eps = 5/(epoch+1) + H_constraint = torch.max(H_constraint - eps, torch.zeros_like(H_constraint)).unsqueeze(0) + + loss = dual_opt.forward_update(loss, H_constraint) + + return loss, loss_type, pde_loss, data_fitting_loss_0, data_fitting_loss_l_r, H_loss model = PINNModel().to(device) diff --git a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py index e366a3d..c10d84f 100644 --- a/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py +++ b/experiments/structure_preserving_pinns/structurePreservingPINN_CamassaHolm_trainableH1_pytorch.py @@ -152,15 +152,17 @@ def custom_loss(inputs, model, epoch): #### LOSS FUNCTION WITH H1 CONSTRAINT #### -def lagrangian_loss(inputs, model, dual_opt, epoch, H0): +def lagrangian_loss(inputs, model, dual_opt, epoch, H0=None): x, t = inputs[:, 0:1], inputs[:, 1:2] x.requires_grad_(True) t.requires_grad_(True) u_model = model(torch.cat([x, t], dim=1)) + u_model_0 = model(torch.cat([x, torch.zeros_like(t)], dim=1)) u_t = torch.autograd.grad(u_model.sum(), t, create_graph=True)[0] u_x = torch.autograd.grad(u_model.sum(), x, create_graph=True)[0] + u_x_0 = torch.autograd.grad(u_model_0.sum(), x, create_graph=True)[0] u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0] u_xxt = torch.autograd.grad(u_xx.sum(), t, create_graph=True)[0] @@ -193,14 +195,16 @@ def lagrangian_loss(inputs, model, dual_opt, epoch, H0): Hf = H(u_model.reshape(Nt, Nx), u_x.reshape(Nt, Nx), dx) - H_constraint = torch.abs(Hf - H0)/torch.abs(H0) + H0 = H(u_model_0.reshape(Nt, Nx), u_x_0.reshape(Nt, Nx), dx) - eps = 1/(epoch+1) - H_constraint = torch.max(H_constraint - eps, torch.zeros_like(H_constraint)).unsqueeze(0) + H_constraint = (torch.abs(Hf - H0)/torch.abs(H0)).unsqueeze(0) + + eps = 1/(epoch+1)**2 + H_constraint = H_constraint - eps loss = dual_opt.forward_update(loss, H_constraint) - return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, Hf + return loss, loss_type, pde_loss_L2, data_fitting_loss_0, data_fitting_loss_l_r, Hf, H0 ####### TRAINING LOOP ####### @@ -218,8 +222,8 @@ def lagrangian_loss(inputs, model, dual_opt, epoch, H0): t0 = time() -# dual_opt = ALM(m=1, lr=5e-5, dual_range=(0.,100.), device=device, ctol=1e-3, penalty=0.) -dual_opt = iALM(m=1, beta=0.01, sigma=1.0001, gamma=1., dual_range=(0.,10.), ctol=1e-3) +# dual_opt = ALM(m=1, lr=1e-3, device=device, penalty=0., is_ineq=True) +dual_opt = iALM(m=1, beta=0.001, sigma=1.001, gamma=1., dual_range=(-10.,10.), is_ineq=True) H0 = H(u_0(x_grid.flatten().reshape(-1, 1)), u_0_x(x_grid.flatten().reshape(-1, 1)), dx) @@ -227,13 +231,11 @@ def lagrangian_loss(inputs, model, dual_opt, epoch, H0): optimizer.zero_grad() # loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = custom_loss(inputs, model, epoch) - loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss = lagrangian_loss(inputs, model, dual_opt, epoch, H0) + loss, loss_type, pde_loss, data_loss_0, bc_loss, H_loss, H0 = lagrangian_loss(inputs, model, dual_opt, epoch, H0) loss.backward() optimizer.step() - - if epoch % 1 == 0: - lr_schedule.step() + lr_schedule.step() with torch.no_grad(): @@ -340,6 +342,7 @@ def lagrangian_loss(inputs, model, dual_opt, epoch, H0): H_losses_rel_error = np.array(H_losses_rel_error) plt.figure(figsize=(10, 6)) plt.plot(H_losses_rel_error) +# plt.plot() plt.xlabel('Epoch') plt.ylabel('Hamiltonian relative error') plt.title('Hamiltonian relative error over epochs') From 961f2235f26208fad69bc9ca520f15a8253eea9b Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 4 Jun 2026 15:23:49 +0200 Subject: [PATCH 19/30] add nupi algorithm --- src/humancompatible/train/dual_optim/nupi.py | 126 +++++++++++-------- 1 file changed, 75 insertions(+), 51 deletions(-) diff --git a/src/humancompatible/train/dual_optim/nupi.py b/src/humancompatible/train/dual_optim/nupi.py index 8eb3b28..2bc6c42 100644 --- a/src/humancompatible/train/dual_optim/nupi.py +++ b/src/humancompatible/train/dual_optim/nupi.py @@ -15,18 +15,18 @@ class nuPI(Optimizer): :param m: Number of constraints (determines the number of dual variables to create) :type m: int - :param lr: Dual variable update rate. - :type lr: float + :param nu: Momentum parameter. + :type nu: float :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. :type init_duals: float | Tensor :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` :type penalty: float :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. :type dual_range: Tuple[float, float] - :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. - :type momentum: float - :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. - :type dampening: float + :param ki: Momentum parameter. + :type ki: float + :param kp: Momentum parameter. + :type kp: float :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. :type is_ineq: bool :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. @@ -35,27 +35,25 @@ class nuPI(Optimizer): def __init__( self, m: int = None, - lr: float = 0.01, + nu: float = 0.01, init_duals: float | Tensor = None, penalty: float = 1.0, *, dual_range: Tuple[float, float] = (-100.0, 100.0), - momentum: float = 0.0, - dampening: float = 0.0, + ki: float = 0.0, + kp: float = 0.0, is_ineq: bool = False, ctol: float = 0., device=None, ) -> None: - if momentum > 0 and dampening == 0: - dampening = momentum - # self.dual_range = dual_range # self.ctol = ctol self.penalty = penalty + self._is_initialized = False duals, defaults = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + m, nu, ki, kp, init_duals, dual_range, is_ineq, device ) super().__init__(duals, defaults) @@ -71,9 +69,9 @@ def duals(self) -> Tensor: def add_constraint_group( self, m: int, - lr: float = None, - momentum: float = None, - dampening: float = None, + nu: float = None, + ki: float = None, + kp: float = None, init_duals: Tensor = None, dual_range: tuple[float, float] = None, is_ineq: bool = False, @@ -84,12 +82,12 @@ def add_constraint_group( :param m: Size of group (number of dual variables to add) :type m: int - :param lr: Dual variable update rate. - :type lr: float - :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. - :type momentum: float - :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. - :type dampening: float + :param nu: Momentum parameter. + :type nu: float + :param ki: Momentum parameter. + :type ki: float + :param kp: Momentum parameter. + :type kp: float :param init_duals: Initial values for the new dual variables. Defaults to the value set when creating the optimizer. :type init_duals: Tensor :param dual_range: After each dual update, the dual variables will be clamped to this range. @@ -102,7 +100,7 @@ def add_constraint_group( """ duals, settings_dict = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + m, nu, ki, kp, init_duals, dual_range, is_ineq, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -244,7 +242,7 @@ def _process_constraint_group( group: dict[str, Any], group_idx: int, constraints: Tensor, - update_duals: bool = False, + update_duals: bool = False ) -> Tuple[Tensor, Tensor]: """ Process a single constraint group: extract duals/constraints and optionally update duals. @@ -263,30 +261,31 @@ def _process_constraint_group( else: group_constraints = constraints.unsqueeze(0) - lr = group.get("lr") - momentum = group.get("momentum", 0.0) - dampening = group.get("dampening", 0.0) + nu = group.get("nu") + ki = group.get("ki", 0.0) + kp = group.get("kp", 0.0) momentum_buffer = group["momentum_buffer"] dual_lb = group.get("lower_bound") dual_ub = group.get("upper_bound") is_ineq = group.get("is_ineq") + _momentum_initialized = group.get("_momentum_initialized") with torch.no_grad(): - if momentum > 0: - _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + if update_duals: - _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr) + _update_duals(duals, momentum_buffer, group_constraints, nu, ki, kp) clamp_(duals, min=dual_lb, max=dual_ub) - + + _update_c_buffers(group_constraints, nu, momentum_buffer) return duals, group_constraints def _init_constraint_group( m: int = None, - lr: float = None, - momentum: float = None, - dampening: float = None, + nu: float = None, + ki: float = None, + kp: float = None, init_duals: float | Tensor = None, dual_range: Tuple[float, float] = None, is_ineq: bool = None, @@ -296,9 +295,6 @@ def _init_constraint_group( if init_duals is None and m is None: raise ValueError("At least one of m, init_duals must be set") - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"momentum must be within [0,1]; got {momentum}") - if not isinstance(is_ineq, bool): raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") @@ -317,15 +313,16 @@ def _init_constraint_group( dual_range = (0, None) settings_dict = { - "lr": lr, - "momentum": momentum, - "dampening": dampening, + "nu": nu, + "ki": ki, + "kp": kp, "momentum_buffer": torch.zeros_like( init_duals, requires_grad=False, device=device ), "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], "upper_bound": dual_range[1], - "is_ineq": is_ineq + "is_ineq": is_ineq, + "_momentum_initialized": False } settings_dict = {k: v for k, v in settings_dict.items() if v is not None} @@ -335,23 +332,50 @@ def _init_constraint_group( def _update_c_buffers( constraints: Tensor, - momentum: float, - dampening: float, + nu: float, buffer: Tensor, ) -> None: """Update the constraint buffer with momentum.""" - raise NotImplementedError - if momentum == 0: - buffer = constraints - else: - buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + buffer.mul_(nu).add_(constraints, alpha=1 - nu) def _update_duals( duals: Tensor, buffer: Tensor, - lr: float, + constraints: Tensor, + nu: float, + ki: float, + kp: float ) -> None: - raise NotImplementedError """Update duals using the buffered constraint gradients.""" - duals.add_(buffer, alpha=lr) + # duals.add_(buffer, alpha=lr).add_() + duals.add_( constraints, alpha=ki + kp * (1-nu) ).add_( buffer, alpha = -kp * (1-nu) ) + + +nuPI.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the nuPI Augmented Lagrangian rule, based on https://doi.org/10.48550/arXiv.2406.04558. Creates and updates dual variables. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ +) \ No newline at end of file From a0a1d9943ad8285ec3cfb090dbdb182647663a9c Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 4 Jun 2026 15:24:05 +0200 Subject: [PATCH 20/30] update docstring for ialm --- docs/source/api_reference/dual_opts/ialm.rst | 6 ++++ src/humancompatible/train/dual_optim/ialm.py | 35 ++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 docs/source/api_reference/dual_opts/ialm.rst diff --git a/docs/source/api_reference/dual_opts/ialm.rst b/docs/source/api_reference/dual_opts/ialm.rst new file mode 100644 index 0000000..eb26e99 --- /dev/null +++ b/docs/source/api_reference/dual_opts/ialm.rst @@ -0,0 +1,6 @@ +iALM +================= + + +.. autoclass:: humancompatible.train.dual_optim.iALM + :members: diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index a9ef9a2..ce87b3e 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -328,3 +328,38 @@ def _update_duals( update_mult = torch.min(beta, gamma / torch.linalg.norm(buffer)) duals.add_(buffer, alpha=update_mult) + + +iALM.__doc__ = ( + + # \textbf{input}: \gamma \text{ (lr) }, \pmb{\lambda}_t \text{ (dual variables, created by method) }, \\ + # \mathbf{c}(\theta) \text{ (constraints) }, f(\theta) \text{ (objective) }, \rho \text{ (penalty coefficient) } \\ + r""" + A Dual Optimizer that works on the dual maximization tasks according to the Augmented Lagrangian rule, with adaptive stepsize based on https://doi.org/10.1007/s10589-023-00521-z, Algorithm 1. Creates and updates dual variables. + + .. math:: + + \pmb{\lambda}_{t+1} & \leftarrow \pmb{\lambda}_t + \min\left\{ \beta_k, \frac{\gamma_k}{\|\mathbf{c}_t(\theta_t)\|} \right\} \mathbf{c}_t(\theta_{t}) + + \mathcal{L}_{t+1} & \leftarrow f_t(\theta_{t}) + \pmb{\lambda}_{t+1}^T \mathbf{c}_t(\theta_{t}) + \frac{\rho}{2} \| \mathbf{c}_t(\theta_{t}) \|^2_2 + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate. + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. + :type is_ineq: bool + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float + """ +) \ No newline at end of file From 8e4b658dad643ab2890ff72d0d40e1df29030242 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 4 Jun 2026 15:24:54 +0200 Subject: [PATCH 21/30] update init for nupi --- src/humancompatible/train/dual_optim/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/humancompatible/train/dual_optim/__init__.py b/src/humancompatible/train/dual_optim/__init__.py index 5dafcbd..53d5796 100644 --- a/src/humancompatible/train/dual_optim/__init__.py +++ b/src/humancompatible/train/dual_optim/__init__.py @@ -1,4 +1,5 @@ from .alm import ALM from .ialm import iALM from .pbm import PBM +from .nupi import nuPI from .moreau import MoreauEnvelope From f02ab8810f2608048623c2f314b647d67b0b286e Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 4 Jun 2026 15:59:26 +0200 Subject: [PATCH 22/30] fix ialm doc --- src/humancompatible/train/dual_optim/ialm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index ce87b3e..8396a61 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -345,8 +345,12 @@ def _update_duals( :param m: Number of constraints (determines the number of dual variables to create) :type m: int - :param lr: Dual variable update rate. - :type lr: float + :param beta: Dual variable update rate. + :type beta: float + :param sigma: Multiplier for increasing`beta`. + :type sigma: float + :param gamma: Penalty update parameter. + :type gamma: float :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. :type init_duals: float | Tensor :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` From 659170ecbc33093226104e3c1b78b200fbce9aa9 Mon Sep 17 00:00:00 2001 From: Andrii Kliachkin Date: Mon, 8 Jun 2026 10:42:04 +0200 Subject: [PATCH 23/30] Add Read the Docs badge to README Added documentation badge for Read the Docs. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 192ef3b..535c1ff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # humancompatible-train: a package for constrained machine learning -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/) The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API. From 37cdeb54da3c674cade45c0c45e0f84ac5d9e493 Mon Sep 17 00:00:00 2001 From: Andrii Kliachkin Date: Mon, 8 Jun 2026 10:42:31 +0200 Subject: [PATCH 24/30] Fix badge formatting in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 535c1ff..1423e47 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # humancompatible-train: a package for constrained machine learning -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/)] The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API. From 9ee4e578bf597b2bca715748cb996019273aec35 Mon Sep 17 00:00:00 2001 From: Andrii Kliachkin Date: Mon, 8 Jun 2026 10:42:51 +0200 Subject: [PATCH 25/30] Fix docs badge formatting --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1423e47..fbf1951 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # humancompatible-train: a package for constrained machine learning -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/)] +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) ![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/) The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API. From 2a59eb06195a6625de299fb96eb17a9e8c728643 Mon Sep 17 00:00:00 2001 From: Andrii Kliachkin Date: Mon, 8 Jun 2026 10:50:43 +0200 Subject: [PATCH 26/30] Update documentation badge in README.md Updated documentation badge to link to the latest version. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fbf1951..7d94a8e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # humancompatible-train: a package for constrained machine learning -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) ![Docs](https://app.readthedocs.org/projects/humancompatible-train/badge/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Setup](https://github.com/humancompatible/train/actions/workflows/setup.yml/badge.svg)](https://github.com/humancompatible/train/actions/workflows/setup.yml) [![docs](https://app.readthedocs.org/projects/humancompatible-train/badge/?version=latest)](https://humancompatible-train.readthedocs.io/en/latest/?badge=latest) The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API. From 00c6b092bd5103e369cdc67aeb3fd70fa4198b2e Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Mon, 8 Jun 2026 11:05:48 +0200 Subject: [PATCH 27/30] slight docs updates --- README.md | 2 +- docs/source/tutorials/copt_overview.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7d94a8e..6a48159 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ The only dependencies of this package are `numpy` and `torch`. ## Using the toolkit -The toolkit implements algorithms for constrained training of neural networks based on PyTorch. +The toolkit implements algorithms for constrained training of neural networks based on PyTorch. For the documentation, please visit [our Read the Docs page!](https://humancompatible-train.readthedocs.io?version=latest) The algorithms are intended for use in tandem with classic PyTorch optimizers, calculating the Lagrangian and keeping track of the dual variables. diff --git a/docs/source/tutorials/copt_overview.rst b/docs/source/tutorials/copt_overview.rst index dbc7f42..28b5d4d 100644 --- a/docs/source/tutorials/copt_overview.rst +++ b/docs/source/tutorials/copt_overview.rst @@ -21,7 +21,7 @@ We then introduce **constraints** -- they could express anything from some bound .. note:: - - As is standard in the field, we adopt the convention of writing the constraints as :math:`g(x) \leq 0`, and :math:`h(x) = 0`. This is just a notational choice, and does not affect the generality of the formulation. It is trivial to transform :math:`g(x) \geq 0` into :math:`-g(x) \leq 0`, or :math:`g(x) \leq \epsilon` into :math:`g(x) - \epsilon \leq 0` for some non-zero bound. + - As is standard in the field, we adopt the convention of writing the constraints as :math:`g(x) \leq 0`, and :math:`h(x) = 0`. This is just a notational choice, and does not affect the generality of the formulation. It is trivial to transform :math:`g(x) \geq 0` into :math:`-g(x) \leq 0`, or :math:`h(x) = \epsilon` into :math:`g(x) - \epsilon = 0` for some :math:`\epsilon`. We refer to this :math:`\epsilon` as the constraint **bound**, or **threshold**. - It is also easy to switch between equality and inequality constraints: to get :math:`g(x) = 0`, one can set :math:`-g(x) \leq 0` and :math:`g(x) \leq 0` simultaneously. In fact, different algorithms are designed to handle either equality or inequality constraints natively, but it is trivial to switch between the two. We shall see more concrete examples later on. From b3cc1ab921266d7fd0fe73d0f42372fc6b44e6e2 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 11 Jun 2026 16:17:43 +0200 Subject: [PATCH 28/30] docs update --- docs/source/api_reference/utils.rst | 8 +++++- docs/source/api_reference/utils/sampler.rst | 6 +++++ docs/source/conf.py | 2 +- docs/source/tutorials/tips.rst | 30 +++++++++++++++++++-- 4 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 docs/source/api_reference/utils/sampler.rst diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index 2085850..3f22410 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -1,2 +1,8 @@ Utils -===== \ No newline at end of file +===== + +.. toctree:: + :titlesonly: + :glob: + + utils/* \ No newline at end of file diff --git a/docs/source/api_reference/utils/sampler.rst b/docs/source/api_reference/utils/sampler.rst new file mode 100644 index 0000000..adb3d29 --- /dev/null +++ b/docs/source/api_reference/utils/sampler.rst @@ -0,0 +1,6 @@ +BalancedBatchSampler +==================== + + +.. autoclass:: humancompatible.train.fairness.utils.BalancedBatchSampler + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 5ae81c7..3ac9411 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ project = 'humancompatible-train' copyright = '2026, Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' author = 'Andrii Kliachkin, Gilles Bareillies, Jana Lepsova, Jakub Marecek' -release = '0.3.1' +# release = '0.3.1' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/tutorials/tips.rst b/docs/source/tutorials/tips.rst index 89bf410..ad56f73 100644 --- a/docs/source/tutorials/tips.rst +++ b/docs/source/tutorials/tips.rst @@ -8,5 +8,31 @@ Dealing with Noise In the stochastic case, the gradients are estimated using mini-batches of data, which introduces additional noise into the optimization process. This can make convergence more challenging, but this can be mitigated. -**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In `humancompatible-train`, the dual optimizers support momentum, which can be enabled by setting the `momentum` parameter to a non-zero value. \ -Some dual update strategies, such as nuPI, explicitly rely on momentum. +**Momentum**: Just like in standard optimization, using momentum can help smooth out the updates and mitigate the noise. In ``humancompatible-train``, the ``ALM``, ``iALM``, and ``nuPI`` dual optimizers support momentum, which can be enabled by setting the ``momentum`` parameter to a non-zero value. +Some dual update strategies, such as ``nuPI``, explicitly rely on momentum. + +Without momentum, the dual update at each step is a direct ascent step on the constraint values: + +.. math:: + + \pmb{\lambda}_{t+1} \leftarrow \text{clamp}\!\left(\pmb{\lambda}_t + \gamma\, \mathbf{c}_t(\theta_t),\; \lambda_{\min},\; \lambda_{\max}\right) + +With momentum enabled, a running buffer :math:`\mathbf{b}_t` accumulates a weighted history of past constraint values before being used for the dual update: + +.. math:: + + \mathbf{b}_{t+1} &\leftarrow \mu\, \mathbf{b}_t + (1 - \delta)\, \mathbf{c}_t(\theta_t) \\ + \pmb{\lambda}_{t+1} &\leftarrow \text{clamp}\!\left(\pmb{\lambda}_t + \gamma\, \mathbf{b}_{t+1},\; \lambda_{\min},\; \lambda_{\max}\right) + +where :math:`\mu` is the ``momentum`` coefficient, :math:`\delta` is the ``dampening`` coefficient, and :math:`\gamma` is the dual learning rate. + +.. note:: + + When ``momentum > 0`` and ``dampening`` is not explicitly provided, the library automatically sets ``dampening = momentum``. + This conservative choice prioritises stability: the buffer update becomes + + .. math:: + + \mathbf{b}_{t+1} \leftarrow \mu\, \mathbf{b}_t + (1 - \mu)\, \mathbf{c}_t(\theta_t) + + which is a standard exponential moving average of the constraint values with smoothing factor :math:`\mu`. \ No newline at end of file From dc00fa985a9493ce6c2e11e4c043319f4264d4c8 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 11 Jun 2026 16:18:06 +0200 Subject: [PATCH 29/30] initial ddp support --- src/humancompatible/train/dual_optim/alm.py | 97 +++++++++++++------- src/humancompatible/train/dual_optim/ialm.py | 48 +++++----- src/humancompatible/train/dual_optim/nupi.py | 44 ++++----- 3 files changed, 102 insertions(+), 87 deletions(-) diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index b291e88..c4717eb 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -1,7 +1,8 @@ import torch +import torch.distributed as dist from torch.nn import Parameter from torch.optim import Optimizer -from typing import Any, Tuple +from typing import Any, Optional, Tuple from torch import clamp_, Tensor # cite: Stochastic Smoothed Primal-Dual Algorithms for Nonconvex Optimization with Linear Inequality Constraints @@ -20,19 +21,19 @@ def __init__( momentum: float = 0.0, dampening: float = 0.0, is_ineq: bool = False, + restart: bool = False, ctol: float = 0., device=None, + process_group: Optional[dist.ProcessGroup] = None, ) -> None: if momentum > 0 and dampening == 0: dampening = momentum - # self.dual_range = dual_range - # self.ctol = ctol - self.penalty = penalty + self.process_group = process_group duals, defaults = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, restart, device ) super().__init__(duals, defaults) @@ -54,6 +55,7 @@ def add_constraint_group( init_duals: Tensor = None, dual_range: tuple[float, float] = None, is_ineq: bool = False, + restart: bool = False, device = None ) -> None: """ @@ -73,13 +75,15 @@ def add_constraint_group( :type dual_range: Tuple[float, float] :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be relaxed on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. :type is_ineq: bool + :param restart: Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints. + :type restart: bool .. note:: Parameters here will default to values set when initializing the dual optimizer. """ duals, settings_dict = _init_constraint_group( - m, lr, momentum, dampening, init_duals, dual_range, is_ineq, device + m, lr, momentum, dampening, init_duals, dual_range, is_ineq, restart, device ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) @@ -123,11 +127,11 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=False - ) + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=False) lagrangian.add_(duals @ group_constraints) + offset += len(duals) self._add_penalty_term(lagrangian, constraints) return lagrangian @@ -156,10 +160,16 @@ def update(self, constraints: Tensor) -> None: :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i in range(len(self.param_groups)): - _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=True - ) + if self.process_group is not None: + with torch.no_grad(): + constraints = constraints.detach().clone() + dist.all_reduce(constraints, op=dist.ReduceOp.AVG, group=self.process_group) + offset = 0 + for group in self.param_groups: + _process_constraint_group(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) + + step = update # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -184,14 +194,25 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: :return: Lagrangian :rtype: Tensor """ + if self.process_group is not None: + with torch.no_grad(): + constraints_for_update = constraints.detach().clone() + dist.all_reduce(constraints_for_update, op=dist.ReduceOp.AVG, group=self.process_group) + else: + constraints_for_update = constraints + lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=True - ) + offset = 0 + for group in self.param_groups: + duals, _ = _process_constraint_group(group, offset, constraints_for_update, update_duals=True) + # Always use the original (non-reduced) constraints for the Lagrangian term + # so that autograd can flow through ∂c/∂θ during backward(). + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) lagrangian.add_(duals @ group_constraints) + offset += n self._add_penalty_term(lagrangian, constraints) return lagrangian @@ -220,7 +241,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def _process_constraint_group( group: dict[str, Any], - group_idx: int, + offset: int, constraints: Tensor, update_duals: bool = False, ) -> Tuple[Tensor, Tensor]: @@ -228,35 +249,30 @@ def _process_constraint_group( Process a single constraint group: extract duals/constraints and optionally update duals. :param group: The constraint group dictionary - :param group_idx: Index of the constraint group + :param offset: Start index of this group's slice within the full constraints tensor :param constraints: Full constraints tensor :param update_duals: Whether to update dual variables :return: Tuple of (duals, group_constraints) """ duals = group["params"][0] - if constraints.ndim > 0: - group_constraints = ( - constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ) - else: - group_constraints = constraints.unsqueeze(0) - + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) + lr = group.get("lr") momentum = group.get("momentum", 0.0) dampening = group.get("dampening", 0.0) momentum_buffer = group["momentum_buffer"] dual_lb = group.get("lower_bound") dual_ub = group.get("upper_bound") - is_ineq = group.get("is_ineq") + restart = group.get("restart") with torch.no_grad(): - if momentum > 0: - _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) if update_duals: - _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr) + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + _update_duals(duals, momentum_buffer if momentum > 0 else group_constraints, lr, restart) clamp_(duals, min=dual_lb, max=dual_ub) - return duals, group_constraints @@ -268,6 +284,7 @@ def _init_constraint_group( init_duals: float | Tensor = None, dual_range: Tuple[float, float] = None, is_ineq: bool = None, + restart: bool = None, device = None, ): ## checks ## @@ -278,7 +295,10 @@ def _init_constraint_group( raise ValueError(f"momentum must be within [0,1]; got {momentum}") if not isinstance(is_ineq, bool): - raise ValueError(f"Expected a Boolean value for is_ineq, got {is_ineq}") + raise ValueError(f"Expected a Boolean value for is_ineq, got {type(is_ineq)}") + + if not isinstance(restart, bool): + raise ValueError(f"Expected a Boolean value for restart, got {type(restart)}") m = m if m is not None else len(init_duals) @@ -303,7 +323,8 @@ def _init_constraint_group( ), "lower_bound": max(dual_range[0], 0) if is_ineq else dual_range[0], "upper_bound": dual_range[1], - "is_ineq": is_ineq + "is_ineq": is_ineq, + "restart": restart } settings_dict = {k: v for k, v in settings_dict.items() if v is not None} @@ -328,9 +349,13 @@ def _update_duals( duals: Tensor, buffer: Tensor, lr: float, + restart: bool ) -> None: """Update duals using the buffered constraint gradients.""" duals.add_(buffer, alpha=lr) + # Set duals to 0 where buffer < 0 + if restart: + duals[buffer < 0] = 0 @@ -363,7 +388,11 @@ def _update_duals( :type dampening: float :param is_ineq: Whether to treat the constraints as equality or inequality. If`True`, dual variables will be decreased on strict satisfaction and lower-bounded by `max(dual_range[0], 0)`. :type is_ineq: bool + :param restart: Whether to set the dual variables to zero immediately on strict satisfaction of corresponding constraints. Not recommended for stochastic constraints. + :type restart: bool :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. :type ctol: float + :param process_group: Distributed process group for DDP. When set, constraint values are averaged across all workers via ``dist.all_reduce`` before each dual update, keeping dual variables consistent across replicas. Defaults to ``None`` (no synchronization). + :type process_group: dist.ProcessGroup, optional """ ) \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py index 8396a61..233ba9c 100644 --- a/src/humancompatible/train/dual_optim/ialm.py +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -130,11 +130,11 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, beta, group_constraints = _process_constraint_group_ialm( - self.param_groups, i, constraints, update_duals=False - ) + offset = 0 + for group in self.param_groups: + duals, beta, group_constraints = _process_constraint_group_ialm(group, offset, constraints, update_duals=False) lagrangian.add_(duals @ group_constraints) + offset += len(duals) # Use beta from first group for penalty term beta = self.param_groups[0]["beta"] @@ -149,10 +149,10 @@ def update(self, constraints: Tensor) -> None: :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i in range(len(self.param_groups)): - _process_constraint_group_ialm( - self.param_groups, i, constraints, update_duals=True - ) + offset = 0 + for group in self.param_groups: + _process_constraint_group_ialm(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) # Update beta by sigma for each group for group in self.param_groups: @@ -173,11 +173,11 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, beta, group_constraints = _process_constraint_group_ialm( - self.param_groups, i, constraints, update_duals=True - ) + offset = 0 + for group in self.param_groups: + duals, beta, group_constraints = _process_constraint_group_ialm(group, offset, constraints, update_duals=True) lagrangian.add_(duals @ group_constraints) + offset += len(duals) # Use beta from first group for penalty term beta = self.param_groups[0]["beta"] @@ -214,40 +214,36 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def _process_constraint_group_ialm( - param_groups: list, - group_idx: int, + group: dict[str, Any], + offset: int, constraints: Tensor, update_duals: bool = False, ) -> Tuple[Tensor, Tensor, Tensor]: """ Process a single constraint group: extract parameters and optionally update duals. - :param param_groups: List of parameter groups from optimizer - :param group_idx: Index of the constraint group + :param group: The constraint group dictionary + :param offset: Start index of this group's slice within the full constraints tensor :param constraints: Full constraints tensor - :param ctol: Constraint tolerance - :param dual_range: Safeguarding range for dual variables :param update_duals: Whether to update dual variables :return: Tuple of (duals, beta, group_constraints) """ - group = param_groups[group_idx] duals = group["params"][0] + n = len(duals) beta = group.get("beta") - sigma = group.get("sigma") gamma = group.get("gamma") momentum = group.get("momentum", 0.0) dampening = group.get("dampening", 0.0) momentum_buffer = group.get("momentum_buffer") dual_lb = group.get("lower_bound") dual_ub = group.get("upper_bound") - is_ineq = group.get("is_ineq") - group_constraints = constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) with torch.no_grad(): - if momentum > 0: - _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) if update_duals: + if momentum > 0: + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) _update_duals(duals, beta, gamma, momentum_buffer if momentum > 0 else group_constraints) clamp_(duals, min=dual_lb, max=dual_ub) @@ -276,9 +272,7 @@ def _init_constraint_group( m = m if m is not None else len(init_duals) if init_duals is None: # initialize duals if not set or set to scalar - init_duals = ( - torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - ) + init_duals = torch.zeros(m, requires_grad=False, device=device) elif isinstance(init_duals, float): init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals diff --git a/src/humancompatible/train/dual_optim/nupi.py b/src/humancompatible/train/dual_optim/nupi.py index 2bc6c42..dd013d5 100644 --- a/src/humancompatible/train/dual_optim/nupi.py +++ b/src/humancompatible/train/dual_optim/nupi.py @@ -144,11 +144,11 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=False - ) + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=False) lagrangian.add_(duals @ group_constraints) + offset += len(duals) self._add_penalty_term(lagrangian, constraints) return lagrangian @@ -177,10 +177,10 @@ def update(self, constraints: Tensor) -> None: :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i in range(len(self.param_groups)): - _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=True - ) + offset = 0 + for group in self.param_groups: + _process_constraint_group(group, offset, constraints, update_duals=True) + offset += len(group["params"][0]) # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -207,11 +207,11 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i in range(len(self.param_groups)): - duals, group_constraints = _process_constraint_group( - self.param_groups[i], i, constraints, update_duals=True - ) + offset = 0 + for group in self.param_groups: + duals, group_constraints = _process_constraint_group(group, offset, constraints, update_duals=True) lagrangian.add_(duals @ group_constraints) + offset += len(duals) self._add_penalty_term(lagrangian, constraints) return lagrangian @@ -240,7 +240,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def _process_constraint_group( group: dict[str, Any], - group_idx: int, + offset: int, constraints: Tensor, update_duals: bool = False ) -> Tuple[Tensor, Tensor]: @@ -248,35 +248,27 @@ def _process_constraint_group( Process a single constraint group: extract duals/constraints and optionally update duals. :param group: The constraint group dictionary - :param group_idx: Index of the constraint group + :param offset: Start index of this group's slice within the full constraints tensor :param constraints: Full constraints tensor :param update_duals: Whether to update dual variables :return: Tuple of (duals, group_constraints) """ duals = group["params"][0] - if constraints.ndim > 0: - group_constraints = ( - constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ) - else: - group_constraints = constraints.unsqueeze(0) - + n = len(duals) + group_constraints = constraints[offset : offset + n] if constraints.ndim > 0 else constraints.unsqueeze(0) + nu = group.get("nu") ki = group.get("ki", 0.0) kp = group.get("kp", 0.0) momentum_buffer = group["momentum_buffer"] dual_lb = group.get("lower_bound") dual_ub = group.get("upper_bound") - is_ineq = group.get("is_ineq") - _momentum_initialized = group.get("_momentum_initialized") with torch.no_grad(): - if update_duals: _update_duals(duals, momentum_buffer, group_constraints, nu, ki, kp) clamp_(duals, min=dual_lb, max=dual_ub) - - _update_c_buffers(group_constraints, nu, momentum_buffer) + _update_c_buffers(group_constraints, nu, momentum_buffer) return duals, group_constraints From b54b13db327489065b3ec3c95872c84428cb2883 Mon Sep 17 00:00:00 2001 From: andrewklayk Date: Thu, 11 Jun 2026 16:18:48 +0200 Subject: [PATCH 30/30] update tests --- tests/test_alm.py | 158 ++++++++++++++++++++++++++++++++++++++++++++- tests/test_ialm.py | 95 +++++++++++++++++++++++++++ tests/test_nupi.py | 104 +++++++++++++++++++++++++++++ 3 files changed, 355 insertions(+), 2 deletions(-) create mode 100644 tests/test_ialm.py create mode 100644 tests/test_nupi.py diff --git a/tests/test_alm.py b/tests/test_alm.py index 98cb0dd..4c3a22c 100644 --- a/tests/test_alm.py +++ b/tests/test_alm.py @@ -1,7 +1,9 @@ import unittest +from unittest.mock import patch from torch.optim import Optimizer from humancompatible.train.dual_optim import ALM import torch +import torch.distributed as dist # Unit tests class TestALM(unittest.TestCase): @@ -40,8 +42,6 @@ def test_alm_update(self): expected_duals = self.alm_default.duals + 0.1 * self.constraints # breakpoint() self.alm_default.update(self.constraints) - print(self.alm_default.duals) - print(expected_duals) self.assertTrue(torch.allclose(self.alm_default.duals, expected_duals)) def test_alm_momentum_update(self): @@ -81,10 +81,164 @@ def test_alm_dual_range_clamping(self): self.assertTrue(torch.all(self.alm_custom_range.duals <= 1.0) and torch.all(self.alm_custom_range.duals >= -1.0)) + def test_step_is_update_alias(self): + alm = ALM(m=3, lr=0.1, penalty=1.0) + duals_before = alm.duals.clone() + alm.step(self.constraints) + alm2 = ALM(m=3, lr=0.1, penalty=1.0) + alm2.update(self.constraints) + self.assertTrue(torch.allclose(alm.duals, alm2.duals)) + self.assertFalse(torch.allclose(alm.duals, duals_before)) + def test_alm_state_dict(self): alm = ALM(m=3, lr=0.1, penalty=2.0, dual_range=(-1.0, 1.0)) state_dict = alm.state_dict() self.assertEqual(state_dict["state"]["penalty"], 2.0) +class TestALMFixes(unittest.TestCase): + """Tests for fix 1 (momentum buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0, 10.0, 20.0, 30.0]) + + # --- Fix 1: forward() must not advance the momentum buffer --- + + def test_forward_does_not_corrupt_momentum_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_direct = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + alm_via_forward = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + + alm_direct.update(c) + + alm_via_forward.forward(self.loss, c) + alm_via_forward.update(c) + + self.assertTrue(torch.allclose(alm_direct.duals, alm_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + # forward_update() and forward() + update() must produce identical duals. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_combined = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + alm_separate = ALM(m=3, lr=0.1, penalty=1.0, momentum=0.9) + + alm_combined.forward_update(self.loss, c) + alm_separate.forward(self.loss, c) + alm_separate.update(c) + + self.assertTrue(torch.allclose(alm_combined.duals, alm_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + alm = ALM(m=2, lr=0.1, penalty=1.0) + alm.add_constraint_group(m=3, lr=0.2) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.update(c) + + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], 0.2 * c[2:])) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + alm = ALM(m=2, lr=0.1, penalty=1.0, init_duals=init0) + alm.add_constraint_group(m=3, lr=0.2, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = alm.forward(self.loss, c) + + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * alm.penalty * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_multi_group_forward_update_slices_correctly(self): + alm = ALM(m=2, lr=0.1, penalty=1.0) + alm.add_constraint_group(m=3, lr=0.2) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.forward_update(self.loss, c) + + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], 0.2 * c[2:])) + + +class TestALMDDP(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + self.pg = object() # sentinel; real value only matters to dist.all_reduce + + def test_no_process_group_skips_all_reduce(self): + alm = ALM(m=3, lr=0.1, penalty=1.0) + with patch('torch.distributed.all_reduce') as mock_ar: + alm.update(self.constraints) + alm.forward_update(self.loss, self.constraints) + mock_ar.assert_not_called() + + def test_update_calls_all_reduce_with_correct_args(self): + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce') as mock_ar: + alm.update(self.constraints) + mock_ar.assert_called_once() + _, kwargs = mock_ar.call_args + self.assertEqual(kwargs['op'], dist.ReduceOp.AVG) + self.assertEqual(kwargs['group'], self.pg) + + def test_update_uses_reduced_constraints(self): + # Simulate all_reduce replacing the tensor with worker-averaged values. + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + alm.update(self.constraints) + + self.assertTrue(torch.allclose(alm.duals, 0.1 * reduced)) + + def test_update_does_not_mutate_input(self): + # The all_reduce clone must be a detached copy; original tensor must be untouched. + original = self.constraints.clone() + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=lambda t, **kw: t.fill_(99.0)): + alm.update(self.constraints) + self.assertTrue(torch.allclose(self.constraints, original)) + + def test_forward_update_uses_reduced_constraints_for_dual(self): + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + alm.forward_update(self.loss, self.constraints) + + self.assertTrue(torch.allclose(alm.duals, 0.1 * reduced)) + + def test_forward_update_lagrangian_uses_original_constraints(self): + # Duals are updated with reduced constraints, but the Lagrangian must be + # computed with the original constraints so autograd flows through ∂c/∂θ. + reduced = torch.tensor([2.0, 4.0, 6.0]) + def fake_all_reduce(tensor, **kwargs): + tensor.copy_(reduced) + + alm = ALM(m=3, lr=0.1, penalty=1.0, process_group=self.pg) + with patch('torch.distributed.all_reduce', side_effect=fake_all_reduce): + lagrangian = alm.forward_update(self.loss, self.constraints) + + updated_duals = 0.1 * reduced + expected = ( + self.loss + + updated_duals @ self.constraints + + 0.5 * alm.penalty * torch.dot(self.constraints, self.constraints) + ) + self.assertTrue(torch.allclose(lagrangian, expected)) + + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/test_ialm.py b/tests/test_ialm.py new file mode 100644 index 0000000..2a50d22 --- /dev/null +++ b/tests/test_ialm.py @@ -0,0 +1,95 @@ +import unittest +import torch +from humancompatible.train.dual_optim import iALM + + +class TestiALM(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + + def test_ialm_initialization(self): + alm = iALM(m=3, beta=1.0, penalty=1.0) + self.assertEqual(len(alm.duals), 3) + + def test_ialm_forward(self): + alm = iALM(m=3, beta=1.0, penalty=1.0) + lagrangian = alm.forward(self.loss, self.constraints) + beta = alm.param_groups[0]["beta"] + expected = (self.loss + + alm.duals @ self.constraints + + 0.5 * beta * torch.dot(self.constraints, self.constraints)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_ialm_update(self): + # init_duals=zeros so the baseline is 0; duals should increase toward constraints + alm = iALM(m=3, beta=1.0, gamma=1.0, sigma=1.0, penalty=1.0, init_duals=torch.zeros(3)) + duals_before = alm.duals.clone() + alm.update(self.constraints) + self.assertTrue(torch.all(alm.duals > duals_before)) + + +class TestiALMFixes(unittest.TestCase): + """Tests for fix 1 (momentum buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + + # --- Fix 1: forward() must not advance the momentum buffer --- + + def test_forward_does_not_corrupt_momentum_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + alm_direct = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + alm_via_forward = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + + alm_direct.update(c) + + alm_via_forward.forward(self.loss, c) + alm_via_forward.update(c) + + self.assertTrue(torch.allclose(alm_direct.duals, alm_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + c = torch.tensor([1.0, 2.0, 3.0]) + alm_combined = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + alm_separate = iALM(m=3, beta=1.0, gamma=1e6, sigma=1.0, momentum=0.9) + + alm_combined.forward_update(self.loss, c) + alm_separate.forward(self.loss, c) + alm_separate.update(c) + + self.assertTrue(torch.allclose(alm_combined.duals, alm_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + alm = iALM(m=2, beta=1.0, gamma=1e6, sigma=1.0, init_duals=torch.zeros(2)) + alm.add_constraint_group(m=3, beta=1.0, gamma=1e6, sigma=1.0, init_duals=torch.zeros(3)) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + alm.update(c) + + # With large gamma step ≈ 1.0 for both groups, so duals ≈ c_slice + self.assertTrue(torch.allclose(alm.param_groups[0]["params"][0], c[:2], atol=1e-4)) + self.assertTrue(torch.allclose(alm.param_groups[1]["params"][0], c[2:], atol=1e-4)) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + alm = iALM(m=2, beta=1.0, gamma=1.0, sigma=1.0, init_duals=init0) + alm.add_constraint_group(m=3, beta=1.0, gamma=1.0, sigma=1.0, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = alm.forward(self.loss, c) + + beta = alm.param_groups[0]["beta"] + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * beta * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nupi.py b/tests/test_nupi.py new file mode 100644 index 0000000..c70a17b --- /dev/null +++ b/tests/test_nupi.py @@ -0,0 +1,104 @@ +import unittest +import torch +from humancompatible.train.dual_optim import nuPI + + +class TestnuPI(unittest.TestCase): + def setUp(self): + self.loss = torch.tensor(5.0) + self.constraints = torch.tensor([1.0, 2.0, 3.0]) + + def test_nupi_initialization(self): + opt = nuPI(m=3, nu=0.9, ki=0.01, kp=0.01, penalty=1.0) + self.assertEqual(len(opt.duals), 3) + + def test_nupi_forward(self): + opt = nuPI(m=3, nu=0.9, ki=0.01, kp=0.01, penalty=1.0) + lagrangian = opt.forward(self.loss, self.constraints) + expected = (self.loss + + opt.duals @ self.constraints + + 0.5 * opt.penalty * torch.dot(self.constraints, self.constraints)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_nupi_update(self): + # With zero buffer (initial state) and kp=0, update is purely integral: λ += ki * c + opt = nuPI(m=3, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.update(self.constraints) + self.assertTrue(torch.allclose(opt.duals, 0.1 * self.constraints)) + + +class TestnuPIFixes(unittest.TestCase): + """Tests for fix 1 (buffer in forward) and fix 2 (multi-group slicing).""" + + def setUp(self): + self.loss = torch.tensor(5.0) + + # --- Fix 1: forward() must not advance the EMA buffer --- + # nuPI's buffer is unconditionally updated in the original code, making this + # the most severe instance of the bug: it fires even without any momentum setting. + + def test_forward_does_not_corrupt_ema_buffer(self): + # Calling forward() then update() must give the same duals as update() alone. + c = torch.tensor([1.0, 2.0, 3.0]) + opt_direct = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + opt_via_forward = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + + opt_direct.update(c) + + opt_via_forward.forward(self.loss, c) + opt_via_forward.update(c) + + self.assertTrue(torch.allclose(opt_direct.duals, opt_via_forward.duals)) + + def test_forward_update_and_separate_forward_update_agree(self): + c = torch.tensor([1.0, 2.0, 3.0]) + opt_combined = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + opt_separate = nuPI(m=3, nu=0.9, ki=0.01, kp=0.05, penalty=1.0) + + opt_combined.forward_update(self.loss, c) + opt_separate.forward(self.loss, c) + opt_separate.update(c) + + self.assertTrue(torch.allclose(opt_combined.duals, opt_separate.duals)) + + # --- Fix 2: multi-group constraint slicing --- + + def test_multi_group_update_slices_correctly(self): + # kp=0 so update is purely λ += ki * c, easy to verify + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + opt.update(c) + + self.assertTrue(torch.allclose(opt.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(opt.param_groups[1]["params"][0], 0.2 * c[2:])) + + def test_multi_group_forward_lagrangian_correct(self): + init0 = torch.tensor([1.0, 1.0]) + init1 = torch.tensor([1.0, 1.0, 1.0]) + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0, init_duals=init0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0, init_duals=init1) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + lagrangian = opt.forward(self.loss, c) + + expected = (self.loss + + init0 @ c[:2] + + init1 @ c[2:] + + 0.5 * opt.penalty * torch.dot(c, c)) + self.assertTrue(torch.allclose(lagrangian, expected)) + + def test_multi_group_forward_update_slices_correctly(self): + opt = nuPI(m=2, nu=0.9, ki=0.1, kp=0.0, penalty=1.0) + opt.add_constraint_group(m=3, nu=0.9, ki=0.2, kp=0.0) + + c = torch.tensor([1.0, 2.0, 10.0, 20.0, 30.0]) + opt.forward_update(self.loss, c) + + self.assertTrue(torch.allclose(opt.param_groups[0]["params"][0], 0.1 * c[:2])) + self.assertTrue(torch.allclose(opt.param_groups[1]["params"][0], 0.2 * c[2:])) + + +if __name__ == "__main__": + unittest.main()