diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index ee54e32..3d1e821 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -9,7 +9,8 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use std::sync::Arc; -use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant, ones}; +use super::broadcast::broadcast_exprs; +use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant}; // ============================================================================ // Operator overloading for Expr @@ -192,60 +193,6 @@ fn mul_exprs(lhs: Expr, rhs: Expr) -> Expr { Expr::Mul(Arc::new(lhs), Arc::new(rhs)) } -pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - let lhs_shape = lhs.shape(); - let rhs_shape = rhs.shape(); - let target_shape = lhs_shape - .broadcast(&rhs_shape) - .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape)); - - if lhs_shape == target_shape && rhs_shape == target_shape { - return (lhs, rhs); - } - if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() { - return (lhs, rhs); - } - - let lhs = broadcast_to(lhs, &lhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape)); - let rhs = broadcast_to(rhs, &rhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape)); - - (lhs, rhs) -} - -fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { - if expr_shape == target_shape { - return Some(expr); - } - - if expr_shape.is_scalar_like() { - return Some(promote(&expr, target_shape.clone())); - } - - broadcast_2d_to(expr, expr_shape, target_shape) -} - -fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { - if !expr_shape.is_matrix() || !target_shape.is_matrix() { - return None; - } - - if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() - { - let left = ones((target_shape.rows(), 1)); - return Some(matmul(&left, &expr)); - } - - if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() - { - let right = ones((1, target_shape.cols())); - return Some(matmul(&expr, &right)); - } - - None -} - // ============================================================================ // Affine atom functions // ============================================================================ @@ -257,6 +204,13 @@ pub fn sum(expr: &Expr) -> Expr { /// Sum along a specific axis. pub fn sum_axis(expr: &Expr, axis: usize) -> Expr { + let shape = expr.shape(); + assert!( + axis < shape.ndim().max(1), + "axis {} out of bounds for shape {}", + axis, + shape + ); Expr::Sum(Arc::new(expr.clone()), Some(axis)) } @@ -279,7 +233,15 @@ pub fn promote(expr: &Expr, shape: impl Into) -> Expr { /// Reshape an expression to a new shape. pub fn reshape(expr: &Expr, shape: impl Into) -> Expr { - Expr::Reshape(Arc::new(expr.clone()), shape.into()) + let shape = shape.into(); + assert_eq!( + expr.shape().size(), + shape.size(), + "cannot reshape size {} into shape {}", + expr.shape().size(), + shape + ); + Expr::Reshape(Arc::new(expr.clone()), shape) } /// Flatten an expression to a vector. @@ -300,11 +262,31 @@ pub fn trace(expr: &Expr) -> Expr { /// Vertical stack (row-wise concatenation). pub fn vstack(exprs: Vec) -> Expr { + if let Some(first) = exprs.first() { + let cols = first.shape().cols(); + for expr in &exprs[1..] { + assert_eq!( + expr.shape().cols(), + cols, + "vstack requires matching column counts" + ); + } + } Expr::VStack(exprs.into_iter().map(Arc::new).collect()) } /// Horizontal stack (column-wise concatenation). pub fn hstack(exprs: Vec) -> Expr { + if let Some(first) = exprs.first() { + let rows = first.shape().rows(); + for expr in &exprs[1..] { + assert_eq!( + expr.shape().rows(), + rows, + "hstack requires matching row counts" + ); + } + } Expr::HStack(exprs.into_iter().map(Arc::new).collect()) } @@ -422,7 +404,7 @@ pub fn cumsum(expr: &Expr) -> Expr { /// Diagonal matrix from vector, or diagonal of matrix. /// /// - Vector input: Creates diagonal matrix with vector on diagonal -/// - Matrix input: Extracts diagonal as vector (v1.0: returns input as fallback) +/// - Matrix input: Extracts diagonal as vector pub fn diag(expr: &Expr) -> Expr { Expr::Diag(Arc::new(expr.clone())) } @@ -472,6 +454,20 @@ mod tests { assert_eq!(s.shape(), Shape::scalar()); } + #[test] + fn test_sum_axis_vector_shape() { + let x = variable(3); + let s = sum_axis(&x, 0); + assert_eq!(s.shape(), Shape::scalar()); + } + + #[test] + #[should_panic(expected = "axis 1 out of bounds for shape (3,)")] + fn test_sum_axis_invalid_axis_panics() { + let x = variable(3); + let _ = sum_axis(&x, 1); + } + #[test] fn test_promote_shape_and_metadata() { let x = nonneg_variable(()); @@ -498,6 +494,13 @@ mod tests { assert_eq!(promote(&x, 3).shape(), Shape::vector(3)); } + #[test] + #[should_panic(expected = "cannot reshape size 3 into shape (2, 2)")] + fn test_reshape_size_mismatch_panics() { + let x = variable(3); + let _ = reshape(&x, (2, 2)); + } + #[test] #[should_panic(expected = "only scalar expressions can be promoted")] fn test_promote_rejects_non_scalar_like_shape() { @@ -657,6 +660,22 @@ mod tests { assert_eq!(z.shape(), Shape::matrix(5, 3)); } + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_vstack_mismatched_columns_panics() { + let x = variable((2, 3)); + let y = variable((3, 2)); + let _ = vstack(vec![x, y]); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_hstack_mismatched_rows_panics() { + let x = variable((2, 3)); + let y = variable((3, 3)); + let _ = hstack(vec![x, y]); + } + #[test] fn test_affine_is_affine() { let x = variable(5); diff --git a/src/atoms/broadcast.rs b/src/atoms/broadcast.rs new file mode 100644 index 0000000..bbbf5f2 --- /dev/null +++ b/src/atoms/broadcast.rs @@ -0,0 +1,71 @@ +//! Internal broadcasting helpers for atom construction and canonicalization. + +use crate::atoms::affine::{matmul, promote}; +use crate::expr::{Expr, Shape, ones}; + +pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { + let mut exprs = broadcast_elementwise_exprs([lhs, rhs]).1; + let rhs = exprs.pop().expect("binary broadcast should return rhs"); + let lhs = exprs.pop().expect("binary broadcast should return lhs"); + (lhs, rhs) +} + +pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if expr_shape == target_shape + || (expr_shape.rows() == target_shape.rows() && expr_shape.cols() == target_shape.cols()) + { + return Some(expr); + } + + if expr_shape.is_scalar_like() { + return Some(promote(&expr, target_shape.clone())); + } + + broadcast_2d_to(expr, expr_shape, target_shape) +} + +pub(crate) fn broadcast_elementwise_exprs( + exprs: impl IntoIterator, +) -> (Shape, Vec) { + let exprs: Vec = exprs.into_iter().collect(); + let target_shape = exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("elementwise atom requires at least one expression"); + + let exprs = exprs + .into_iter() + .map(|expr| { + let expr_shape = expr.shape(); + broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { + panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) + }) + }) + .collect(); + + (target_shape, exprs) +} + +fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if !expr_shape.is_matrix() || !target_shape.is_matrix() { + return None; + } + + if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() + { + let left = ones((target_shape.rows(), 1)); + return Some(matmul(&left, &expr)); + } + + if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() + { + let right = ones((1, target_shape.cols())); + return Some(matmul(&expr, &right)); + } + + None +} diff --git a/src/atoms/mod.rs b/src/atoms/mod.rs index 4f401a1..f5c7c91 100644 --- a/src/atoms/mod.rs +++ b/src/atoms/mod.rs @@ -6,6 +6,7 @@ //! - **Nonlinear atoms**: Operations with specific curvature (norms, quadratic forms, etc.) pub mod affine; +pub(crate) mod broadcast; pub mod nonlinear; // Re-export affine operations diff --git a/src/atoms/nonlinear.rs b/src/atoms/nonlinear.rs index 6d8affa..a3e206e 100644 --- a/src/atoms/nonlinear.rs +++ b/src/atoms/nonlinear.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use crate::atoms::broadcast::broadcast_elementwise_exprs; use crate::expr::Expr; // ============================================================================ @@ -136,9 +137,14 @@ pub fn neg_part(x: &Expr) -> Expr { /// - Sign: Depends on arguments /// - Monotonicity: Increasing in all arguments pub fn maximum(exprs: Vec) -> Expr { + if exprs.is_empty() { + panic!("maximum requires at least one expression"); + } if exprs.len() == 1 { return exprs.into_iter().next().unwrap(); } + + let (_, exprs) = broadcast_elementwise_exprs(exprs); Expr::Maximum(exprs.into_iter().map(Arc::new).collect()) } @@ -154,9 +160,14 @@ pub fn max2(a: &Expr, b: &Expr) -> Expr { /// - Sign: Depends on arguments /// - Monotonicity: Increasing in all arguments pub fn minimum(exprs: Vec) -> Expr { + if exprs.is_empty() { + panic!("minimum requires at least one expression"); + } if exprs.len() == 1 { return exprs.into_iter().next().unwrap(); } + + let (_, exprs) = broadcast_elementwise_exprs(exprs); Expr::Minimum(exprs.into_iter().map(Arc::new).collect()) } @@ -287,6 +298,56 @@ mod tests { assert_eq!(m.curvature(), Curvature::Convex); } + #[test] + #[should_panic(expected = "maximum requires at least one expression")] + fn test_maximum_empty_panics() { + let _ = maximum(Vec::new()); + } + + #[test] + fn test_maximum_single_expression_is_identity() { + let x = variable(5); + let m = maximum(vec![x.clone()]); + assert_eq!(m.shape(), x.shape()); + assert_eq!(m.variables(), x.variables()); + } + + #[test] + fn test_maximum_broadcasts_scalar_to_vector() { + let x = variable(5); + let y = variable(()); + let m = maximum(vec![x, y]); + assert_eq!(m.shape(), crate::expr::Shape::vector(5)); + } + + #[test] + fn test_maximum_three_args_stays_flat() { + let x = variable(5); + let y = variable(5); + let z = variable(()); + let m = maximum(vec![x, y, z]); + + match m { + Expr::Maximum(args) => { + assert_eq!(args.len(), 3); + assert!( + args.iter() + .all(|arg| arg.shape() == crate::expr::Shape::vector(5)) + ); + assert!(!matches!(&*args[0], Expr::Maximum(_))); + } + _ => panic!("expected maximum expression"), + } + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_maximum_incompatible_shapes_panic() { + let x = variable(2); + let y = variable(3); + let _ = maximum(vec![x, y]); + } + #[test] fn test_minimum_concave() { let x = variable(5); @@ -295,6 +356,56 @@ mod tests { assert_eq!(m.curvature(), Curvature::Concave); } + #[test] + #[should_panic(expected = "minimum requires at least one expression")] + fn test_minimum_empty_panics() { + let _ = minimum(Vec::new()); + } + + #[test] + fn test_minimum_single_expression_is_identity() { + let x = variable(5); + let m = minimum(vec![x.clone()]); + assert_eq!(m.shape(), x.shape()); + assert_eq!(m.variables(), x.variables()); + } + + #[test] + fn test_minimum_broadcasts_scalar_to_vector() { + let x = variable(5); + let y = variable(()); + let m = minimum(vec![x, y]); + assert_eq!(m.shape(), crate::expr::Shape::vector(5)); + } + + #[test] + fn test_minimum_three_args_stays_flat() { + let x = variable(5); + let y = variable(5); + let z = variable(()); + let m = minimum(vec![x, y, z]); + + match m { + Expr::Minimum(args) => { + assert_eq!(args.len(), 3); + assert!( + args.iter() + .all(|arg| arg.shape() == crate::expr::Shape::vector(5)) + ); + assert!(!matches!(&*args[0], Expr::Minimum(_))); + } + _ => panic!("expected minimum expression"), + } + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_minimum_incompatible_shapes_panic() { + let x = variable(2); + let y = variable(3); + let _ = minimum(vec![x, y]); + } + #[test] fn test_sum_squares_convex() { let x = variable(5); diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index f8a6c50..645e3f4 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -11,6 +11,7 @@ use nalgebra::DMatrix; use nalgebra_sparse::CscMatrix; use super::lin_expr::{LinExpr, QuadExpr}; +use crate::atoms::broadcast::broadcast_elementwise_exprs; use crate::expr::{Array, Expr, ExprId, IndexSpec, Shape, VariableBuilder}; use crate::sparse::{csc_add, csc_repeat_rows, csc_to_dense, csc_vstack, dense_to_csc}; @@ -512,9 +513,10 @@ impl CanonContext { }) } - fn canonicalize_sum_axis_lin(&self, x: &LinExpr, axis: usize) -> CanonExpr { + fn canonicalize_sum_axis_lin(&mut self, x: &LinExpr, axis: usize) -> CanonExpr { if x.shape.ndim() <= 1 { - return CanonExpr::Linear(x.clone()); + assert_eq!(axis, 0, "axis {} out of bounds for shape {}", axis, x.shape); + return self.canonicalize_sum_lin(x); } let rows = x.shape.rows(); @@ -522,7 +524,7 @@ impl CanonContext { let (out_size, mut s_rows, mut s_cols, mut s_vals) = match axis { 0 => (cols, Vec::new(), Vec::new(), Vec::new()), 1 => (rows, Vec::new(), Vec::new(), Vec::new()), - _ => return CanonExpr::Linear(x.clone()), + _ => panic!("axis {} out of bounds for shape {}", axis, x.shape), }; for col in 0..cols { @@ -558,6 +560,13 @@ impl CanonContext { fn canonicalize_reshape(&mut self, a: &Expr, shape: &Shape) -> CanonExpr { let ca = self.canonicalize_expr(a, false).as_linear().clone(); + assert_eq!( + ca.shape.size(), + shape.size(), + "cannot reshape size {} into shape {}", + ca.shape.size(), + shape + ); // Reshape doesn't change the linear structure, just the shape interpretation CanonExpr::Linear(LinExpr { coeffs: ca.coeffs, @@ -650,6 +659,11 @@ impl CanonContext { } fn vstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { + assert_eq!( + a.shape.cols(), + b.shape.cols(), + "vstack requires matching column counts" + ); // Stack constants vertically let new_const = stack_vertical(&a.constant, &b.constant); let new_shape = Shape::matrix(new_const.nrows(), new_const.ncols()); @@ -663,14 +677,14 @@ impl CanonContext { let ca = a.coeffs.get(&var_id); let cb = b.coeffs.get(&var_id); let stacked = match (ca, cb) { - (Some(ma), Some(mb)) => stack_csc_vertical(ma, mb), + (Some(ma), Some(mb)) => stack_csc_vertical_for_shapes(ma, mb, &a.shape, &b.shape), (Some(ma), None) => { let zeros = CscMatrix::zeros(b.size(), ma.ncols()); - stack_csc_vertical(ma, &zeros) + stack_csc_vertical_for_shapes(ma, &zeros, &a.shape, &b.shape) } (None, Some(mb)) => { let zeros = CscMatrix::zeros(a.size(), mb.ncols()); - stack_csc_vertical(&zeros, mb) + stack_csc_vertical_for_shapes(&zeros, mb, &a.shape, &b.shape) } (None, None) => continue, }; @@ -698,6 +712,11 @@ impl CanonContext { } fn hstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { + assert_eq!( + a.shape.rows(), + b.shape.rows(), + "hstack requires matching row counts" + ); // Stack constants horizontally let new_const = stack_horizontal(&a.constant, &b.constant); let new_shape = Shape::matrix(new_const.nrows(), new_const.ncols()); @@ -828,6 +847,7 @@ impl CanonContext { // Introduce t_i >= 0, -t_i <= x_i <= t_i // Then ||x||_1 = sum(t_i) let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let size = cx.size(); let (_, t) = self.new_nonneg_aux_var(Shape::vector(size)); @@ -861,6 +881,7 @@ impl CanonContext { // ||x||_inf = max(|x_i|) // Introduce t >= 0, -t <= x_i <= t for all i let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let size = cx.size(); let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); @@ -918,13 +939,41 @@ impl CanonContext { fn canonicalize_maximum(&mut self, exprs: &[Arc]) -> CanonExpr { // max(x1, ..., xn): Introduce t, t >= x_i for all i if exprs.is_empty() { - return CanonExpr::Linear(LinExpr::zeros(Shape::scalar())); + panic!("maximum requires at least one expression"); } - - let shape = exprs[0].shape(); + let first_shape = exprs[0].shape(); + let all_same_shape = exprs.iter().all(|expr| expr.shape() == first_shape); + let shape = if all_same_shape { + first_shape + } else { + exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("maximum requires at least one expression") + }; let (_, t) = self.new_aux_var(shape); - for e in exprs { + if all_same_shape { + for e in exprs { + let ce = self + .canonicalize_expr(e.as_ref(), false) + .as_linear() + .clone(); + // t >= x_i, i.e., t - x_i >= 0 + self.constraints.push(ConeConstraint::NonNeg { + a: t.add(&ce.neg()), + }); + } + return CanonExpr::Linear(t); + } + + let (_, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); + for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t >= x_i, i.e., t - x_i >= 0 self.constraints.push(ConeConstraint::NonNeg { @@ -938,13 +987,41 @@ impl CanonContext { fn canonicalize_minimum(&mut self, exprs: &[Arc]) -> CanonExpr { // min(x1, ..., xn): Introduce t, t <= x_i for all i if exprs.is_empty() { - return CanonExpr::Linear(LinExpr::zeros(Shape::scalar())); + panic!("minimum requires at least one expression"); } - - let shape = exprs[0].shape(); + let first_shape = exprs[0].shape(); + let all_same_shape = exprs.iter().all(|expr| expr.shape() == first_shape); + let shape = if all_same_shape { + first_shape + } else { + exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("minimum requires at least one expression") + }; let (_, t) = self.new_aux_var(shape); - for e in exprs { + if all_same_shape { + for e in exprs { + let ce = self + .canonicalize_expr(e.as_ref(), false) + .as_linear() + .clone(); + // t <= x_i, i.e., x_i - t >= 0 + self.constraints.push(ConeConstraint::NonNeg { + a: ce.add(&t.neg()), + }); + } + return CanonExpr::Linear(t); + } + + let (_, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); + for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t <= x_i, i.e., x_i - t >= 0 self.constraints.push(ConeConstraint::NonNeg { @@ -983,6 +1060,7 @@ impl CanonContext { fn canonicalize_sum_squares(&mut self, x: &Expr, for_objective: bool) -> CanonExpr { // ||x||_2^2 = x' x let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); self.canonicalize_sum_squares_lin(&cx, for_objective) } @@ -1040,6 +1118,7 @@ impl CanonContext { fn canonicalize_quad_over_lin(&mut self, x: &Expr, y: &Expr) -> CanonExpr { // ||x||_2^2 / y: introduce t with ||[2x; t-y]||_2 <= t+y. let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let cy = self.canonicalize_expr(y, false).as_linear().clone(); let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); @@ -1079,6 +1158,18 @@ impl CanonContext { }) } + fn flatten_lin(&self, x: &LinExpr) -> LinExpr { + let size = x.size(); + LinExpr { + coeffs: x.coeffs.clone(), + constant: x + .constant + .clone() + .reshape_generic(nalgebra::Dyn(size), nalgebra::Dyn(1)), + shape: Shape::vector(size), + } + } + fn expand_scalar(&self, scalar: &LinExpr, size: usize) -> LinExpr { // Expand a scalar to a vector by repeating let ones = DMatrix::from_element(size, 1, 1.0); @@ -1428,7 +1519,12 @@ fn dense_sparse_matmul(dense: &DMatrix, sparse: &CscMatrix) -> CscMatr } fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { - let mut result = DMatrix::zeros(a.nrows() + b.nrows(), a.ncols().max(b.ncols())); + assert_eq!( + a.ncols(), + b.ncols(), + "vstack requires matching column counts" + ); + let mut result = DMatrix::zeros(a.nrows() + b.nrows(), a.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result .view_mut((a.nrows(), 0), (b.nrows(), b.ncols())) @@ -1437,7 +1533,8 @@ fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { } fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { - let mut result = DMatrix::zeros(a.nrows().max(b.nrows()), a.ncols() + b.ncols()); + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + let mut result = DMatrix::zeros(a.nrows(), a.ncols() + b.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result .view_mut((0, a.ncols()), (b.nrows(), b.ncols())) @@ -1445,8 +1542,52 @@ fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { result } -fn stack_csc_vertical(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { - csc_vstack(a, b) +fn stack_csc_vertical_for_shapes( + a: &CscMatrix, + b: &CscMatrix, + a_shape: &Shape, + b_shape: &Shape, +) -> CscMatrix { + debug_assert_eq!( + a_shape.cols(), + b_shape.cols(), + "vstack requires matching column counts" + ); + debug_assert_eq!(a.ncols(), b.ncols()); + + let a_rows = a_shape.rows(); + let b_rows = b_shape.rows(); + let cols = a_shape.cols(); + let out_rows = a_rows + b_rows; + + let mut rows = Vec::new(); + let mut col_indices = Vec::new(); + let mut vals = Vec::new(); + + for (r, c, v) in a.triplet_iter() { + let matrix_col = r / a_rows; + let matrix_row = r % a_rows; + debug_assert!(matrix_col < cols); + rows.push(matrix_row + matrix_col * out_rows); + col_indices.push(c); + vals.push(*v); + } + for (r, c, v) in b.triplet_iter() { + let matrix_col = r / b_rows; + let matrix_row = r % b_rows; + debug_assert!(matrix_col < cols); + rows.push(a_rows + matrix_row + matrix_col * out_rows); + col_indices.push(c); + vals.push(*v); + } + + crate::sparse::triplets_to_csc( + a_shape.size() + b_shape.size(), + a.ncols(), + &rows, + &col_indices, + &vals, + ) } fn repeat_rows_csc(m: &CscMatrix, times: usize) -> CscMatrix { @@ -1464,7 +1605,7 @@ fn scalar_constant_value(expr: &Expr) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::atoms::{matmul, promote}; + use crate::atoms::{matmul, promote, sum_axis}; use crate::expr::{constant, constant_matrix, variable}; #[test] @@ -1494,6 +1635,97 @@ mod tests { assert!(matches!(result.expr, CanonExpr::Quadratic(_)) || !result.constraints.is_empty()); } + #[test] + fn test_canonicalize_sum_axis_vector_returns_scalar() { + let x = variable(3); + let result = canonicalize(&sum_axis(&x, 0), false); + + assert_eq!(result.expr.as_linear().shape, Shape::scalar()); + } + + #[test] + #[should_panic(expected = "axis 2 out of bounds for shape (2, 3)")] + fn test_canonicalize_sum_axis_invalid_direct_expr_panics() { + let x = variable((2, 3)); + let _ = canonicalize(&Expr::Sum(Arc::new(x), Some(2)), false); + } + + #[test] + #[should_panic(expected = "cannot reshape size 3 into shape (2, 2)")] + fn test_canonicalize_reshape_size_mismatch_direct_expr_panics() { + let x = variable(3); + let expr = Expr::Reshape(Arc::new(x), Shape::matrix(2, 2)); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_canonicalize_vstack_mismatched_columns_direct_expr_panics() { + let x = variable((2, 3)); + let y = variable((3, 2)); + let expr = Expr::VStack(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_canonicalize_hstack_mismatched_rows_direct_expr_panics() { + let x = variable((2, 3)); + let y = variable((3, 3)); + let expr = Expr::HStack(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "maximum requires at least one expression")] + fn test_canonicalize_maximum_empty_direct_expr_panics() { + let _ = canonicalize(&Expr::Maximum(Vec::new()), false); + } + + #[test] + fn test_canonicalize_maximum_direct_expr_broadcasts_scalar_to_vector() { + let x = variable(3); + let y = variable(()); + let expr = Expr::Maximum(vec![Arc::new(x), Arc::new(y)]); + let result = canonicalize(&expr, false); + + assert_eq!(result.expr.as_linear().shape, Shape::vector(3)); + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_canonicalize_maximum_direct_expr_incompatible_shapes_panics() { + let x = variable(2); + let y = variable(3); + let expr = Expr::Maximum(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "minimum requires at least one expression")] + fn test_canonicalize_minimum_empty_direct_expr_panics() { + let _ = canonicalize(&Expr::Minimum(Vec::new()), false); + } + + #[test] + fn test_canonicalize_minimum_direct_expr_broadcasts_scalar_to_vector() { + let x = variable(3); + let y = variable(()); + let expr = Expr::Minimum(vec![Arc::new(x), Arc::new(y)]); + let result = canonicalize(&expr, false); + + assert_eq!(result.expr.as_linear().shape, Shape::vector(3)); + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_canonicalize_minimum_direct_expr_incompatible_shapes_panics() { + let x = variable(2); + let y = variable(3); + let expr = Expr::Minimum(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + #[test] fn test_canonicalize_matmul_preserves_vector_result_shape() { let a = constant_matrix(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 2, 3); diff --git a/src/canon/lin_expr.rs b/src/canon/lin_expr.rs index a9c21de..089f5ef 100644 --- a/src/canon/lin_expr.rs +++ b/src/canon/lin_expr.rs @@ -84,20 +84,54 @@ impl LinExpr { /// Add two linear expressions. pub fn add(&self, other: &LinExpr) -> LinExpr { + let new_shape = + if self.shape.rows() == other.shape.rows() && self.shape.cols() == other.shape.cols() { + if self.shape.is_vector() { + self.shape.clone() + } else if other.shape.is_vector() { + other.shape.clone() + } else { + self.shape.clone() + } + } else if self.shape.is_scalar_like() && !other.shape.is_scalar_like() { + other.shape.clone() + } else if other.shape.is_scalar_like() && !self.shape.is_scalar_like() { + self.shape.clone() + } else if self.shape.size() == other.shape.size() { + Shape::vector(self.shape.size()) + } else { + panic!( + "cannot add linear expressions with shapes {} and {}", + self.shape, other.shape + ) + }; + let new_size = new_shape.size(); + // Optimization: if self has no coefficients, just clone other's let coeffs = if self.coeffs.is_empty() { - other.coeffs.clone() + other + .coeffs + .iter() + .map(|(var_id, coeff)| (*var_id, coeff_for_size(coeff, other.size(), new_size))) + .collect() } else if other.coeffs.is_empty() { - self.coeffs.clone() + self.coeffs + .iter() + .map(|(var_id, coeff)| (*var_id, coeff_for_size(coeff, self.size(), new_size))) + .collect() } else { // Both have coefficients - clone the larger one and merge smaller into it - let mut coeffs = self.coeffs.clone(); + let mut coeffs = HashMap::new(); + for (var_id, coeff) in &self.coeffs { + coeffs.insert(*var_id, coeff_for_size(coeff, self.size(), new_size)); + } coeffs.reserve(other.coeffs.len()); for (var_id, coeff) in &other.coeffs { + let coeff = coeff_for_size(coeff, other.size(), new_size); coeffs .entry(*var_id) - .and_modify(|c| *c = csc_add(c, coeff)) - .or_insert_with(|| coeff.clone()); + .and_modify(|c| *c = csc_add(c, &coeff)) + .or_insert(coeff); } coeffs }; @@ -115,15 +149,21 @@ impl LinExpr { // Broadcast scalar self to match other's shape let scalar = self.constant[(0, 0)]; other.constant.map(|v| v + scalar) + } else if self.shape.size() == other.shape.size() { + let self_flat = self + .constant + .clone() + .reshape_generic(nalgebra::Dyn(new_size), nalgebra::Dyn(1)); + let other_flat = other + .constant + .clone() + .reshape_generic(nalgebra::Dyn(new_size), nalgebra::Dyn(1)); + self_flat + other_flat } else { - // Incompatible shapes, just use self (will likely error later) - self.constant.clone() - }; - - let new_shape = if self.shape.size() >= other.shape.size() { - self.shape.clone() - } else { - other.shape.clone() + panic!( + "cannot add linear expression constants with shapes {} and {}", + self.shape, other.shape + ); }; LinExpr { @@ -165,6 +205,35 @@ impl LinExpr { } } +fn coeff_for_size( + coeff: &CscMatrix, + source_size: usize, + target_size: usize, +) -> CscMatrix { + if coeff.nrows() == target_size { + return coeff.clone(); + } + if source_size == 1 { + let mut rows = Vec::new(); + let mut cols = Vec::new(); + let mut vals = Vec::new(); + for (row, col, val) in coeff.triplet_iter() { + debug_assert_eq!(row, 0); + for target_row in 0..target_size { + rows.push(target_row); + cols.push(col); + vals.push(*val); + } + } + return crate::sparse::triplets_to_csc(target_size, coeff.ncols(), &rows, &cols, &vals); + } + panic!( + "cannot resize coefficient rows from {} to {}", + coeff.nrows(), + target_size + ); +} + /// A quadratic expression: (1/2) x' P x + q' x + r /// /// Used for quadratic objectives in QP problems. @@ -286,6 +355,14 @@ mod tests { assert_eq!(sum.variables().len(), 2); } + #[test] + #[should_panic(expected = "cannot add linear expressions with shapes (2,) and (3,)")] + fn test_lin_expr_add_incompatible_shapes_panics() { + let e1 = LinExpr::variable(ExprId::new(), Shape::vector(2)); + let e2 = LinExpr::variable(ExprId::new(), Shape::vector(3)); + let _ = e1.add(&e2); + } + #[test] fn test_quad_expr_from_linear() { let var_id = ExprId::new(); diff --git a/src/constraints/constraint.rs b/src/constraints/constraint.rs index 1731a88..9b67220 100644 --- a/src/constraints/constraint.rs +++ b/src/constraints/constraint.rs @@ -62,7 +62,7 @@ impl Constraint { /// Broadcast scalars to match shapes if needed. fn broadcast_if_needed(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - crate::atoms::affine::broadcast_exprs(lhs, rhs) + crate::atoms::broadcast::broadcast_exprs(lhs, rhs) } /// Create a SOC constraint: ||x||_2 <= t. diff --git a/src/expr/expression.rs b/src/expr/expression.rs index 32ae206..6c2dfa9 100644 --- a/src/expr/expression.rs +++ b/src/expr/expression.rs @@ -356,12 +356,11 @@ impl Expr { if axis.is_some() { // Sum along axis reduces that dimension let dims = a.shape(); - if dims.ndim() <= 1 { - Shape::scalar() - } else if *axis == Some(0) { - Shape::vector(dims.cols()) - } else { - Shape::vector(dims.rows()) + match (dims.ndim(), axis.unwrap()) { + (0 | 1, 0) => Shape::scalar(), + (2, 0) => Shape::vector(dims.cols()), + (2, 1) => Shape::vector(dims.rows()), + (_, axis) => panic!("axis {} out of bounds for shape {}", axis, dims), } } else { Shape::scalar() @@ -374,6 +373,13 @@ impl Expr { return Shape::scalar(); } let first = exprs[0].shape(); + for e in &exprs[1..] { + assert_eq!( + e.shape().cols(), + first.cols(), + "vstack requires matching column counts" + ); + } let total_rows: usize = exprs.iter().map(|e| e.shape().rows()).sum(); Shape::matrix(total_rows, first.cols()) } @@ -382,6 +388,13 @@ impl Expr { return Shape::scalar(); } let first = exprs[0].shape(); + for e in &exprs[1..] { + assert_eq!( + e.shape().rows(), + first.rows(), + "hstack requires matching row counts" + ); + } let total_cols: usize = exprs.iter().map(|e| e.shape().cols()).sum(); Shape::matrix(first.rows(), total_cols) } @@ -616,4 +629,60 @@ mod tests { ); assert_eq!(Expr::Sum(Arc::new(x), Some(1)).shape(), Shape::vector(2)); } + + #[test] + #[should_panic(expected = "axis 2 out of bounds for shape (2, 3)")] + fn test_sum_axis_invalid_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::Sum(Arc::new(x), Some(2)).shape(); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_vstack_mismatched_columns_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + let y = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(3, 2), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::VStack(vec![Arc::new(x), Arc::new(y)]).shape(); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_hstack_mismatched_rows_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + let y = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(3, 3), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::HStack(vec![Arc::new(x), Arc::new(y)]).shape(); + } } diff --git a/src/sparse.rs b/src/sparse.rs index 6a24c96..402af77 100644 --- a/src/sparse.rs +++ b/src/sparse.rs @@ -67,6 +67,11 @@ pub fn csc_to_dense(sparse: &CscMatrix) -> DMatrix { /// Stack two CSC matrices vertically. pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { + assert_eq!( + a.ncols(), + b.ncols(), + "vstack requires matching column counts" + ); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new(); @@ -82,13 +87,7 @@ pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { vals.push(*v); } - csc_from_triplets( - a.nrows() + b.nrows(), - a.ncols().max(b.ncols()), - rows, - cols, - vals, - ) + csc_from_triplets(a.nrows() + b.nrows(), a.ncols(), rows, cols, vals) } /// Add two CSC matrices. @@ -103,7 +102,7 @@ pub fn csc_neg(a: &CscMatrix) -> CscMatrix { let col_offsets: Vec = a.col_offsets().to_vec(); let row_indices: Vec = a.row_indices().to_vec(); CscMatrix::try_from_csc_data(a.nrows(), a.ncols(), col_offsets, row_indices, values) - .unwrap_or_else(|_| CscMatrix::zeros(a.nrows(), a.ncols())) + .expect("valid CSC structure should remain valid when negating values") } /// Scale a CSC matrix. @@ -112,7 +111,7 @@ pub fn csc_scale(a: &CscMatrix, scalar: f64) -> CscMatrix { let col_offsets: Vec = a.col_offsets().to_vec(); let row_indices: Vec = a.row_indices().to_vec(); CscMatrix::try_from_csc_data(a.nrows(), a.ncols(), col_offsets, row_indices, values) - .unwrap_or_else(|_| CscMatrix::zeros(a.nrows(), a.ncols())) + .expect("valid CSC structure should remain valid when scaling values") } /// Multiply sparse matrix by dense matrix on the right: A_sparse @ B_dense @@ -126,6 +125,7 @@ pub fn sparse_dense_matmul(a: &CscMatrix, b: &DMatrix) -> CscMatrix, b: &CscMatrix) -> CscMatrix { + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new(); @@ -141,13 +141,7 @@ pub fn csc_hstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { vals.push(*v); } - csc_from_triplets( - a.nrows().max(b.nrows()), - a.ncols() + b.ncols(), - rows, - cols, - vals, - ) + csc_from_triplets(a.nrows(), a.ncols() + b.ncols(), rows, cols, vals) } /// Multiply two CSC matrices: A @ B @@ -220,4 +214,64 @@ mod tests { assert!((v - 2.0).abs() < 1e-10); } } + + #[test] + fn test_csc_neg_preserves_structure_and_negates_values() { + let a = csc_from_triplets(2, 2, vec![0, 1], vec![0, 1], vec![2.0, -3.0]); + let neg = csc_neg(&a); + let dense = csc_to_dense(&neg); + + assert_eq!(neg.nrows(), 2); + assert_eq!(neg.ncols(), 2); + assert_eq!(dense[(0, 0)], -2.0); + assert_eq!(dense[(1, 1)], 3.0); + } + + #[test] + fn test_csc_scale_preserves_structure_and_scales_values() { + let a = csc_from_triplets(2, 2, vec![0, 1], vec![0, 1], vec![2.0, -3.0]); + let scaled = csc_scale(&a, 4.0); + let dense = csc_to_dense(&scaled); + + assert_eq!(scaled.nrows(), 2); + assert_eq!(scaled.ncols(), 2); + assert_eq!(dense[(0, 0)], 8.0); + assert_eq!(dense[(1, 1)], -12.0); + } + + #[test] + fn test_csc_vstack_matching_columns() { + let a = CscMatrix::::identity(2); + let b = CscMatrix::::zeros(3, 2); + let stacked = csc_vstack(&a, &b); + + assert_eq!(stacked.nrows(), 5); + assert_eq!(stacked.ncols(), 2); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_csc_vstack_mismatched_columns_panics() { + let a = CscMatrix::::zeros(2, 2); + let b = CscMatrix::::zeros(3, 3); + let _ = csc_vstack(&a, &b); + } + + #[test] + fn test_csc_hstack_matching_rows() { + let a = CscMatrix::::identity(2); + let b = CscMatrix::::identity(2); + let stacked = csc_hstack(&a, &b); + + assert_eq!(stacked.nrows(), 2); + assert_eq!(stacked.ncols(), 4); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_csc_hstack_mismatched_rows_panics() { + let a = CscMatrix::::zeros(2, 2); + let b = CscMatrix::::zeros(3, 2); + let _ = csc_hstack(&a, &b); + } } diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs index 4f11aa7..cfd4a89 100644 --- a/tests/canon_tests.rs +++ b/tests/canon_tests.rs @@ -314,6 +314,103 @@ fn test_constant_vector_times_scalar_affine_broadcasts() { } } +#[test] +fn test_norm1_matrix_variable_canonicalizes_flat_aux_constraints() { + let x = variable((2, 3)); + + let sol = Problem::minimize(norm1(&x)) + .subject_to([x.ge(1.0)]) + .solve() + .expect("matrix norm1 problem should solve"); + + assert!((sol.value.unwrap() - 6.0).abs() < TOL); +} + +#[test] +fn test_norm_inf_matrix_variable_canonicalizes_flat_aux_constraints() { + let x = variable((2, 3)); + + let sol = Problem::minimize(norm_inf(&x)) + .subject_to([sum(&x).eq(6.0)]) + .solve() + .expect("matrix norm_inf problem should solve"); + + assert!((sol.value.unwrap() - 1.0).abs() < TOL); +} + +#[test] +fn test_sum_squares_matrix_argument_flattens_soc_stack() { + let x = variable((2, 3)); + + let sol = Problem::maximize(sum(&x)) + .subject_to([sum_squares(&x).le(1.0)]) + .solve() + .expect("matrix sum_squares constraint should solve"); + + assert!((sol.value.unwrap() - 6.0_f64.sqrt()).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_matrix_argument_flattens_soc_stack() { + let x = variable((2, 3)); + + let sol = Problem::maximize(sum(&x)) + .subject_to([quad_over_lin(&x, &constant(1.0)).le(1.0)]) + .solve() + .expect("matrix quad_over_lin constraint should solve"); + + assert!((sol.value.unwrap() - 6.0_f64.sqrt()).abs() < TOL); +} + +#[test] +fn test_maximum_minimum_accept_vector_and_column_shapes() { + let x = variable(3); + let y = variable((3, 1)); + + assert_eq!( + maximum(vec![x.clone(), y.clone()]).shape(), + Shape::vector(3) + ); + assert_eq!( + minimum(vec![x.clone(), y.clone()]).shape(), + Shape::vector(3) + ); + assert_eq!(max2(&x, &y).shape(), Shape::vector(3)); + assert_eq!(min2(&x, &y).shape(), Shape::vector(3)); +} + +#[test] +fn test_vstack_matrix_constraint_uses_column_major_interleaving() { + let x = variable((2, 2)); + let y = variable((1, 2)); + let target = constant_dmatrix(DMatrix::from_row_slice( + 3, + 2, + &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], + )); + + let sol = Problem::minimize(sum(&x)) + .subject_to([vstack(vec![x.clone(), y.clone()]).eq(target)]) + .solve() + .expect("matrix vstack equality should solve"); + + if let Array::Dense(x_vals) = x.value(&sol) { + assert!((x_vals[(0, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(1, 0)] - 2.0).abs() < TOL); + assert!((x_vals[(0, 1)] - 4.0).abs() < TOL); + assert!((x_vals[(1, 1)] - 5.0).abs() < TOL); + } else { + panic!("expected dense matrix solution for x"); + } + + if let Array::Dense(y_vals) = y.value(&sol) { + assert!((y_vals[(0, 0)] - 3.0).abs() < TOL); + assert!((y_vals[(0, 1)] - 6.0).abs() < TOL); + } else { + panic!("expected dense matrix solution for y"); + } +} + fn solution_value(sol: &Solution, expr: &Expr) -> f64 { expr.value(sol).as_scalar().expect("expected scalar") }