diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index a6439a0..ee54e32 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -9,7 +9,7 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use std::sync::Arc; -use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant}; +use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant, ones}; // ============================================================================ // Operator overloading for Expr @@ -35,7 +35,7 @@ impl Add for Expr { type Output = Expr; fn add(self, rhs: Expr) -> Expr { - Expr::Add(Arc::new(self), Arc::new(rhs)) + add_exprs(self, rhs) } } @@ -43,7 +43,7 @@ impl Add for &Expr { type Output = Expr; fn add(self, rhs: &Expr) -> Expr { - Expr::Add(Arc::new(self.clone()), Arc::new(rhs.clone())) + add_exprs(self.clone(), rhs.clone()) } } @@ -51,7 +51,7 @@ impl Add<&Expr> for Expr { type Output = Expr; fn add(self, rhs: &Expr) -> Expr { - Expr::Add(Arc::new(self), Arc::new(rhs.clone())) + add_exprs(self, rhs.clone()) } } @@ -59,7 +59,7 @@ impl Add for &Expr { type Output = Expr; fn add(self, rhs: Expr) -> Expr { - Expr::Add(Arc::new(self.clone()), Arc::new(rhs)) + add_exprs(self.clone(), rhs) } } @@ -67,7 +67,7 @@ impl Sub for Expr { type Output = Expr; fn sub(self, rhs: Expr) -> Expr { - Expr::Add(Arc::new(self), Arc::new(Expr::Neg(Arc::new(rhs)))) + sub_exprs(self, rhs) } } @@ -75,10 +75,7 @@ impl Sub for &Expr { type Output = Expr; fn sub(self, rhs: &Expr) -> Expr { - Expr::Add( - Arc::new(self.clone()), - Arc::new(Expr::Neg(Arc::new(rhs.clone()))), - ) + sub_exprs(self.clone(), rhs.clone()) } } @@ -86,7 +83,7 @@ impl Sub<&Expr> for Expr { type Output = Expr; fn sub(self, rhs: &Expr) -> Expr { - Expr::Add(Arc::new(self), Arc::new(Expr::Neg(Arc::new(rhs.clone())))) + sub_exprs(self, rhs.clone()) } } @@ -94,7 +91,7 @@ impl Sub for &Expr { type Output = Expr; fn sub(self, rhs: Expr) -> Expr { - Expr::Add(Arc::new(self.clone()), Arc::new(Expr::Neg(Arc::new(rhs)))) + sub_exprs(self.clone(), rhs) } } @@ -102,7 +99,7 @@ impl Mul for Expr { type Output = Expr; fn mul(self, rhs: Expr) -> Expr { - Expr::Mul(Arc::new(self), Arc::new(rhs)) + mul_exprs(self, rhs) } } @@ -110,7 +107,7 @@ impl Mul for &Expr { type Output = Expr; fn mul(self, rhs: &Expr) -> Expr { - Expr::Mul(Arc::new(self.clone()), Arc::new(rhs.clone())) + mul_exprs(self.clone(), rhs.clone()) } } @@ -118,7 +115,7 @@ impl Mul<&Expr> for Expr { type Output = Expr; fn mul(self, rhs: &Expr) -> Expr { - Expr::Mul(Arc::new(self), Arc::new(rhs.clone())) + mul_exprs(self, rhs.clone()) } } @@ -126,7 +123,7 @@ impl Mul for &Expr { type Output = Expr; fn mul(self, rhs: Expr) -> Expr { - Expr::Mul(Arc::new(self.clone()), Arc::new(rhs)) + mul_exprs(self.clone(), rhs) } } @@ -135,7 +132,7 @@ impl Mul for Expr { type Output = Expr; fn mul(self, rhs: f64) -> Expr { - Expr::Mul(Arc::new(constant(rhs)), Arc::new(self)) + mul_exprs(constant(rhs), self) } } @@ -143,7 +140,7 @@ impl Mul for &Expr { type Output = Expr; fn mul(self, rhs: f64) -> Expr { - Expr::Mul(Arc::new(constant(rhs)), Arc::new(self.clone())) + mul_exprs(constant(rhs), self.clone()) } } @@ -151,7 +148,7 @@ impl Mul for f64 { type Output = Expr; fn mul(self, rhs: Expr) -> Expr { - Expr::Mul(Arc::new(constant(self)), Arc::new(rhs)) + mul_exprs(constant(self), rhs) } } @@ -159,7 +156,7 @@ impl Mul<&Expr> for f64 { type Output = Expr; fn mul(self, rhs: &Expr) -> Expr { - Expr::Mul(Arc::new(constant(self)), Arc::new(rhs.clone())) + mul_exprs(constant(self), rhs.clone()) } } @@ -168,7 +165,7 @@ impl Div for Expr { type Output = Expr; fn div(self, rhs: f64) -> Expr { - Expr::Mul(Arc::new(constant(1.0 / rhs)), Arc::new(self)) + mul_exprs(constant(1.0 / rhs), self) } } @@ -176,10 +173,79 @@ impl Div for &Expr { type Output = Expr; fn div(self, rhs: f64) -> Expr { - Expr::Mul(Arc::new(constant(1.0 / rhs)), Arc::new(self.clone())) + mul_exprs(constant(1.0 / rhs), self.clone()) } } +fn add_exprs(lhs: Expr, rhs: Expr) -> Expr { + let (lhs, rhs) = broadcast_exprs(lhs, rhs); + Expr::Add(Arc::new(lhs), Arc::new(rhs)) +} + +fn sub_exprs(lhs: Expr, rhs: Expr) -> Expr { + let (lhs, rhs) = broadcast_exprs(lhs, rhs); + Expr::Add(Arc::new(lhs), Arc::new(Expr::Neg(Arc::new(rhs)))) +} + +fn mul_exprs(lhs: Expr, rhs: Expr) -> Expr { + let (lhs, rhs) = broadcast_exprs(lhs, rhs); + Expr::Mul(Arc::new(lhs), Arc::new(rhs)) +} + +pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { + let lhs_shape = lhs.shape(); + let rhs_shape = rhs.shape(); + let target_shape = lhs_shape + .broadcast(&rhs_shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape)); + + if lhs_shape == target_shape && rhs_shape == target_shape { + return (lhs, rhs); + } + if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() { + return (lhs, rhs); + } + + let lhs = broadcast_to(lhs, &lhs_shape, &target_shape) + .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape)); + let rhs = broadcast_to(rhs, &rhs_shape, &target_shape) + .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape)); + + (lhs, rhs) +} + +fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if expr_shape == target_shape { + return Some(expr); + } + + if expr_shape.is_scalar_like() { + return Some(promote(&expr, target_shape.clone())); + } + + broadcast_2d_to(expr, expr_shape, target_shape) +} + +fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if !expr_shape.is_matrix() || !target_shape.is_matrix() { + return None; + } + + if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() + { + let left = ones((target_shape.rows(), 1)); + return Some(matmul(&left, &expr)); + } + + if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() + { + let right = ones((1, target_shape.cols())); + return Some(matmul(&expr, &right)); + } + + None +} + // ============================================================================ // Affine atom functions // ============================================================================ @@ -194,6 +260,23 @@ pub fn sum_axis(expr: &Expr, axis: usize) -> Expr { Expr::Sum(Arc::new(expr.clone()), Some(axis)) } +/// Promote a scalar-like expression to a target shape. +pub fn promote(expr: &Expr, shape: impl Into) -> Expr { + let target_shape = shape.into(); + let expr_shape = expr.shape(); + + if expr_shape == target_shape { + return expr.clone(); + } + + assert!( + expr_shape.is_scalar_like(), + "only scalar expressions can be promoted" + ); + + Expr::Promote(Arc::new(expr.clone()), target_shape) +} + /// Reshape an expression to a new shape. pub fn reshape(expr: &Expr, shape: impl Into) -> Expr { Expr::Reshape(Arc::new(expr.clone()), shape.into()) @@ -347,7 +430,7 @@ pub fn diag(expr: &Expr) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::expr::{constant, variable}; + use crate::expr::{constant, constant_matrix, constant_vec, nonneg_variable, variable}; #[test] fn test_add() { @@ -389,6 +472,81 @@ mod tests { assert_eq!(s.shape(), Shape::scalar()); } + #[test] + fn test_promote_shape_and_metadata() { + let x = nonneg_variable(()); + let p = promote(&x, (2, 3)); + + assert_eq!(p.shape(), Shape::matrix(2, 3)); + assert_eq!(p.variables(), x.variables()); + assert!(p.curvature().is_affine()); + assert!(p.sign().is_nonneg()); + } + + #[test] + fn test_promote_accepts_scalar_like_shapes() { + let row_scalar = constant_vec(vec![2.0]); + let matrix_scalar = constant_matrix(vec![3.0], 1, 1); + + assert_eq!(promote(&row_scalar, (2, 2)).shape(), Shape::matrix(2, 2)); + assert_eq!(promote(&matrix_scalar, (2, 2)).shape(), Shape::matrix(2, 2)); + } + + #[test] + fn test_promote_same_shape_is_noop() { + let x = variable(3); + assert_eq!(promote(&x, 3).shape(), Shape::vector(3)); + } + + #[test] + #[should_panic(expected = "only scalar expressions can be promoted")] + fn test_promote_rejects_non_scalar_like_shape() { + let x = variable(2); + let _ = promote(&x, (2, 2)); + } + + #[test] + fn test_scalar_like_broadcast_shapes() { + let x = variable((2, 2)); + let scalar = variable(()); + let row_scalar = constant_vec(vec![1.0]); + let matrix_scalar = constant_matrix(vec![1.0], 1, 1); + + assert_eq!((&scalar + &x).shape(), Shape::matrix(2, 2)); + assert_eq!((&x - &row_scalar).shape(), Shape::matrix(2, 2)); + assert_eq!((&matrix_scalar * &x).shape(), Shape::matrix(2, 2)); + } + + #[test] + fn test_row_and_column_broadcast_shapes() { + let m = variable((2, 3)); + let row = variable((1, 3)); + let col = variable((2, 1)); + + assert_eq!((&row + &m).shape(), Shape::matrix(2, 3)); + assert_eq!((&m - &row).shape(), Shape::matrix(2, 3)); + assert_eq!((&col * &m).shape(), Shape::matrix(2, 3)); + assert_eq!((&m + &col).shape(), Shape::matrix(2, 3)); + } + + #[test] + fn test_mutual_row_and_column_broadcast_shape() { + let row = variable((1, 3)); + let col = variable((2, 1)); + + assert_eq!((&row + &col).shape(), Shape::matrix(2, 3)); + assert_eq!((&col * &row).shape(), Shape::matrix(2, 3)); + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2, 2) and (3, 3)")] + fn test_incompatible_broadcast_panics_at_construction() { + let c = constant_matrix(vec![1.0, 2.0, 3.0, 4.0], 2, 2); + let x = variable((3, 3)); + + let _ = c * x; + } + #[test] fn test_transpose() { let x = variable((3, 4)); @@ -404,6 +562,14 @@ mod tests { assert_eq!(b.shape(), Shape::vector(3)); } + #[test] + #[should_panic(expected = "cannot matrix-multiply shapes (3, 4) and (3,)")] + fn test_invalid_matmul_shape_panics() { + let a = variable((3, 4)); + let x = variable(3); + let _ = matmul(&a, &x).shape(); + } + #[test] fn test_index_and_slice_shapes() { let x = variable(10); diff --git a/src/atoms/mod.rs b/src/atoms/mod.rs index 4dac29c..4f401a1 100644 --- a/src/atoms/mod.rs +++ b/src/atoms/mod.rs @@ -10,8 +10,8 @@ pub mod nonlinear; // Re-export affine operations pub use affine::{ - cumsum, diag, dot, flatten, hstack, index, indexc, matmul, reshape, select, slice, slicec, sum, - sum_axis, trace, transpose, vstack, + cumsum, diag, dot, flatten, hstack, index, indexc, matmul, promote, reshape, select, slice, + slicec, sum, sum_axis, trace, transpose, vstack, }; // Re-export nonlinear atoms diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index eea81b0..f8a6c50 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -175,6 +175,7 @@ impl CanonContext { } } Expr::Mul(a, b) => self.canonicalize_mul(a, b, for_objective), + Expr::Promote(a, shape) => self.canonicalize_promote(a, shape), Expr::MatMul(a, b) => self.canonicalize_matmul(a, b), Expr::Sum(a, axis) => self.canonicalize_sum(a, *axis), Expr::Reshape(a, shape) => self.canonicalize_reshape(a, shape), @@ -227,23 +228,19 @@ impl CanonContext { let b_is_const = b.variables().is_empty(); // Handle scalar multiplication first (most common case) - if let Some(arr) = a.constant_value() { - if let Some(scalar) = arr.as_scalar() { - let cb = self.canonicalize_expr(b, for_objective); - return match cb { - CanonExpr::Linear(l) => CanonExpr::Linear(l.scale(scalar)), - CanonExpr::Quadratic(q) => CanonExpr::Quadratic(q.scale(scalar)), - }; - } + if let Some(scalar) = scalar_constant_value(a) { + let cb = self.canonicalize_expr(b, for_objective); + return match cb { + CanonExpr::Linear(l) => CanonExpr::Linear(l.scale(scalar)), + CanonExpr::Quadratic(q) => CanonExpr::Quadratic(q.scale(scalar)), + }; } - if let Some(arr) = b.constant_value() { - if let Some(scalar) = arr.as_scalar() { - let ca = self.canonicalize_expr(a, for_objective); - return match ca { - CanonExpr::Linear(l) => CanonExpr::Linear(l.scale(scalar)), - CanonExpr::Quadratic(q) => CanonExpr::Quadratic(q.scale(scalar)), - }; - } + if let Some(scalar) = scalar_constant_value(b) { + let ca = self.canonicalize_expr(a, for_objective); + return match ca { + CanonExpr::Linear(l) => CanonExpr::Linear(l.scale(scalar)), + CanonExpr::Quadratic(q) => CanonExpr::Quadratic(q.scale(scalar)), + }; } // Handle constant expression that evaluates to scalar @@ -283,21 +280,30 @@ impl CanonContext { return CanonExpr::Linear(LinExpr::constant(result)); } - // Both have variables - not DCP - self.canonicalize_expr(a, false) + panic!("cannot canonicalize product of two non-constant expressions") } fn elementwise_mul_const_lin(&self, c: &DMatrix, lin: &LinExpr) -> LinExpr { // Element-wise multiplication: diag(c) @ lin // For flat representation, this scales each row of coefficients by corresponding c value let c_flat: Vec = c.iter().copied().collect(); - let size = c_flat.len(); + let size = lin.shape.size(); + assert_eq!( + c_flat.len(), + size, + "elementwise multiplication requires matching sizes after broadcasting" + ); let mut new_coeffs = std::collections::HashMap::new(); for (var_id, coeff) in &lin.coeffs { let coeff_dense = csc_to_dense(coeff); + assert_eq!( + coeff_dense.nrows(), + size, + "linear coefficient rows must match expression size" + ); let mut new_coeff = DMatrix::zeros(size, coeff_dense.ncols()); - for i in 0..size.min(coeff_dense.nrows()) { + for i in 0..size { for j in 0..coeff_dense.ncols() { new_coeff[(i, j)] = c_flat[i] * coeff_dense[(i, j)]; } @@ -305,7 +311,8 @@ impl CanonContext { new_coeffs.insert(*var_id, dense_to_csc(&new_coeff)); } - let new_const = c.component_mul(&lin.constant); + let c_shaped = DMatrix::from_vec(lin.shape.rows(), lin.shape.cols(), c_flat); + let new_const = c_shaped.component_mul(&lin.constant); LinExpr { coeffs: new_coeffs, @@ -314,37 +321,71 @@ impl CanonContext { } } + fn canonicalize_promote(&mut self, expr: &Expr, target_shape: &Shape) -> CanonExpr { + let lin = self.canonicalize_expr(expr, false).as_linear().clone(); + if lin.shape.size() != 1 { + return CanonExpr::Linear(lin); + } + + let coeffs = lin + .coeffs + .iter() + .map(|(var_id, coeff)| (*var_id, csc_repeat_rows(coeff, target_shape.size()))) + .collect(); + let constant = DMatrix::from_element( + target_shape.rows(), + target_shape.cols(), + lin.constant[(0, 0)], + ); + + CanonExpr::Linear(LinExpr { + coeffs, + constant, + shape: target_shape.clone(), + }) + } + fn canonicalize_matmul(&mut self, a: &Expr, b: &Expr) -> CanonExpr { // Check if expressions are constant (no variables, not just Constant variant) let a_is_const = a.variables().is_empty(); let b_is_const = b.variables().is_empty(); + let result_shape = a.shape().matmul(&b.shape()).unwrap_or_else(|| { + panic!( + "cannot matrix-multiply shapes {} and {}", + a.shape(), + b.shape() + ) + }); if a_is_const && !b_is_const { // A is constant expression, B has variables: A @ B is affine in B let ca = self.canonicalize_expr(a, false).as_linear().clone(); let cb = self.canonicalize_expr(b, false).as_linear().clone(); let a_arr = Array::Dense(ca.constant); - return CanonExpr::Linear(self.matmul_const_lin(&a_arr, &cb)); + return CanonExpr::Linear(self.matmul_const_lin(&a_arr, &cb, result_shape)); } if b_is_const && !a_is_const { // B is constant expression, A has variables: A @ B is affine in A let ca = self.canonicalize_expr(a, false).as_linear().clone(); let cb = self.canonicalize_expr(b, false).as_linear().clone(); let b_arr = Array::Dense(cb.constant); - return CanonExpr::Linear(self.lin_matmul_const(&ca, &b_arr)); + return CanonExpr::Linear(self.lin_matmul_const(&ca, &b_arr, result_shape)); } if a_is_const && b_is_const { // Both constant - evaluate and return constant let ca = self.canonicalize_expr(a, false).as_linear().clone(); let cb = self.canonicalize_expr(b, false).as_linear().clone(); let result = &ca.constant * &cb.constant; - return CanonExpr::Linear(LinExpr::constant(result)); + return CanonExpr::Linear(LinExpr { + coeffs: std::collections::HashMap::new(), + constant: result, + shape: result_shape, + }); } - // Both have variables - not DCP, return simplified - self.canonicalize_expr(a, false) + panic!("cannot canonicalize matrix product of two non-constant expressions") } - fn matmul_const_lin(&self, a: &Array, b: &LinExpr) -> LinExpr { + fn matmul_const_lin(&self, a: &Array, b: &LinExpr, shape: Shape) -> LinExpr { // For matrix expression A @ E where E has shape (m, n): // vec(A @ E) = (I_n ⊗ A) @ vec(E) // So for coefficient C: new_C = (I_n ⊗ A) @ C @@ -385,7 +426,6 @@ impl CanonContext { } let new_const = &a_mat * &b.constant; - let shape = Shape::matrix(new_const.nrows(), new_const.ncols()); LinExpr { coeffs: new_coeffs, @@ -394,7 +434,7 @@ impl CanonContext { } } - fn lin_matmul_const(&self, a: &LinExpr, b: &Array) -> LinExpr { + fn lin_matmul_const(&self, a: &LinExpr, b: &Array, shape: Shape) -> LinExpr { // For matrix expression E @ B where E has shape (m, n): // vec(E @ B) = (B' ⊗ I_m) @ vec(E) // So for coefficient C: new_C = (B' ⊗ I_m) @ C @@ -435,7 +475,6 @@ impl CanonContext { } let new_const = &a.constant * &b_mat; - let shape = Shape::matrix(new_const.nrows(), new_const.ncols()); LinExpr { coeffs: new_coeffs, @@ -1414,10 +1453,19 @@ fn repeat_rows_csc(m: &CscMatrix, times: usize) -> CscMatrix { csc_repeat_rows(m, times) } +fn scalar_constant_value(expr: &Expr) -> Option { + match expr { + Expr::Constant(c) => c.value.as_scalar(), + Expr::Promote(a, _) => scalar_constant_value(a), + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; - use crate::expr::variable; + use crate::atoms::{matmul, promote}; + use crate::expr::{constant, constant_matrix, variable}; #[test] fn test_canonicalize_variable() { @@ -1445,4 +1493,62 @@ mod tests { // For objective, should produce quadratic or SOC assert!(matches!(result.expr, CanonExpr::Quadratic(_)) || !result.constraints.is_empty()); } + + #[test] + fn test_canonicalize_matmul_preserves_vector_result_shape() { + let a = constant_matrix(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 2, 3); + let x = variable(3); + let result = canonicalize(&matmul(&a, &x), false); + + assert_eq!(result.expr.as_linear().shape, Shape::vector(2)); + } + + #[test] + fn test_canonicalize_matmul_preserves_column_matrix_result_shape() { + let a = constant_matrix(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 2, 3); + let x = variable((3, 1)); + let result = canonicalize(&matmul(&a, &x), false); + + assert_eq!(result.expr.as_linear().shape, Shape::matrix(2, 1)); + } + + #[test] + fn test_canonicalize_mul_sees_promoted_scalar_constant_privately() { + let x = variable((2, 2)); + let x_id = x.variable_id().unwrap(); + let promoted = promote(&constant(3.0), (2, 2)); + assert!(promoted.constant_value().is_none()); + + let result = canonicalize(&(promoted * x), false); + let lin = result.expr.as_linear(); + let coeff = csc_to_dense(&lin.coeffs[&x_id]); + + assert_eq!(lin.shape, Shape::matrix(2, 2)); + assert_eq!(coeff, DMatrix::identity(4, 4) * 3.0); + assert!(lin.constant.iter().all(|v| v.abs() < 1e-10)); + } + + #[test] + #[should_panic(expected = "cannot canonicalize product of two non-constant expressions")] + fn test_canonicalize_variable_product_panics() { + let x = variable(2); + let y = variable(2); + let _ = canonicalize(&(&x * &y), false); + } + + #[test] + #[should_panic(expected = "cannot matrix-multiply shapes (2, 3) and (2,)")] + fn test_canonicalize_invalid_matmul_shape_panics() { + let a = constant_matrix(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 2, 3); + let x = variable(2); + let _ = canonicalize(&Expr::MatMul(Arc::new(a), Arc::new(x)), false); + } + + #[test] + #[should_panic(expected = "cannot canonicalize matrix product of two non-constant expressions")] + fn test_canonicalize_variable_matmul_panics() { + let x = variable((2, 2)); + let y = variable((2, 2)); + let _ = canonicalize(&matmul(&x, &y), false); + } } diff --git a/src/constraints/constraint.rs b/src/constraints/constraint.rs index 9cac849..1731a88 100644 --- a/src/constraints/constraint.rs +++ b/src/constraints/constraint.rs @@ -31,24 +31,6 @@ pub enum Constraint { } impl Constraint { - /// Broadcast scalar to match the shape of target expression. - fn broadcast_scalar(scalar: &Expr, target_shape: &crate::expr::Shape) -> Expr { - use crate::expr::{constant, ones}; - - // Extract scalar value if it's a constant - if let Expr::Constant(data) = scalar { - if let Some(val) = data.value.as_scalar() { - if target_shape.is_scalar() { - return scalar.clone(); - } - // Broadcast: scalar * ones(shape) - return constant(val) * ones(target_shape.clone()); - } - } - // Not a scalar constant, return as-is - scalar.clone() - } - /// Create an equality constraint: lhs == rhs (with broadcasting). pub fn eq(lhs: Expr, rhs: Expr) -> Self { let (lhs, rhs) = Self::broadcast_if_needed(lhs, rhs); @@ -80,23 +62,7 @@ impl Constraint { /// Broadcast scalars to match shapes if needed. fn broadcast_if_needed(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - let lhs_shape = lhs.shape(); - let rhs_shape = rhs.shape(); - - // If shapes match, no broadcasting needed - if lhs_shape == rhs_shape { - return (lhs, rhs); - } - - // Broadcast scalar to match non-scalar - if lhs_shape.is_scalar() && !rhs_shape.is_scalar() { - (Self::broadcast_scalar(&lhs, &rhs_shape), rhs) - } else if rhs_shape.is_scalar() && !lhs_shape.is_scalar() { - (lhs, Self::broadcast_scalar(&rhs, &lhs_shape)) - } else { - // Shapes don't match and neither is scalar - return as-is, will error later - (lhs, rhs) - } + crate::atoms::affine::broadcast_exprs(lhs, rhs) } /// Create a SOC constraint: ||x||_2 <= t. diff --git a/src/dcp/curvature.rs b/src/dcp/curvature.rs index ef0268d..2965b74 100644 --- a/src/dcp/curvature.rs +++ b/src/dcp/curvature.rs @@ -148,6 +148,7 @@ impl Expr { Expr::Neg(a) => a.curvature().negate(), Expr::Mul(a, b) => mul_curvature(a, b), Expr::MatMul(a, b) => matmul_curvature(a, b), + Expr::Promote(a, _) => a.curvature(), Expr::Sum(a, _) => a.curvature(), Expr::Reshape(a, _) => a.curvature(), Expr::Index(a, _) => a.curvature(), diff --git a/src/dcp/sign.rs b/src/dcp/sign.rs index ea6797f..c92e102 100644 --- a/src/dcp/sign.rs +++ b/src/dcp/sign.rs @@ -122,6 +122,7 @@ impl Expr { } } Expr::Sum(a, _) => a.sign(), + Expr::Promote(a, _) => a.sign(), Expr::Reshape(a, _) => a.sign(), Expr::Index(a, _) => a.sign(), Expr::VStack(exprs) => combine_signs(exprs), diff --git a/src/expr/eval.rs b/src/expr/eval.rs index dcf6f68..2ce6d57 100644 --- a/src/expr/eval.rs +++ b/src/expr/eval.rs @@ -101,6 +101,7 @@ impl Expr { let bv = b.eval(ctx)?; eval_mul(av, bv) } + Expr::Promote(a, shape) => eval_promote(a.eval(ctx)?, shape), Expr::Sum(a, axis) => eval_sum(a.eval(ctx)?, *axis), Expr::Reshape(a, shape) => eval_reshape(a.eval(ctx)?, shape), Expr::Index(a, spec) => eval_index(a.eval(ctx)?, spec), @@ -223,6 +224,17 @@ fn eval_mul(a: Array, b: Array) -> crate::Result { } } +fn eval_promote(a: Array, shape: &Shape) -> crate::Result { + let scalar = a.as_scalar().ok_or_else(|| { + crate::CvxError::InvalidProblem("Only scalar expressions can be promoted".into()) + })?; + Ok(Array::Dense(DMatrix::from_element( + shape.rows(), + shape.cols(), + scalar, + ))) +} + fn eval_matmul(a: Array, b: Array) -> crate::Result { let am = arr_to_dense(a); let bm = arr_to_dense(b); @@ -616,6 +628,114 @@ mod tests { assert!((v - 10.0).abs() < 1e-10); } + #[test] + fn test_eval_promote_scalar_to_matrix() { + let (x, ctx) = make_var_scalar(2.0); + let value = promote(&x, (2, 3)).value(&ctx); + + if let Array::Dense(m) = value { + assert_eq!(m.nrows(), 2); + assert_eq!(m.ncols(), 3); + assert!(m.iter().all(|v| (*v - 2.0).abs() < 1e-10)); + } else { + panic!("expected dense promoted value"); + } + } + + #[test] + fn test_eval_scalar_like_broadcast_add_sub_mul() { + let matrix = constant_dmatrix(DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0])); + let row_scalar = constant_vec(vec![10.0]); + let matrix_scalar = constant_dmatrix(DMatrix::from_row_slice(1, 1, &[2.0])); + let (_, ctx) = make_var_scalar(0.0); + + let add_value = (&row_scalar + &matrix).value(&ctx); + let sub_value = (&matrix - &row_scalar).value(&ctx); + let mul_value = (&matrix_scalar * &matrix).value(&ctx); + + if let (Array::Dense(add), Array::Dense(sub), Array::Dense(mul)) = + (add_value, sub_value, mul_value) + { + assert_eq!(add[(0, 0)], 11.0); + assert_eq!(add[(1, 1)], 14.0); + assert_eq!(sub[(0, 0)], -9.0); + assert_eq!(sub[(1, 1)], -6.0); + assert_eq!(mul[(0, 0)], 2.0); + assert_eq!(mul[(1, 1)], 8.0); + } else { + panic!("expected dense broadcast values"); + } + } + + #[test] + fn test_eval_row_broadcast_add_sub_mul() { + let matrix = constant_dmatrix(DMatrix::from_row_slice( + 2, + 3, + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + )); + let row = constant_dmatrix(DMatrix::from_row_slice(1, 3, &[10.0, 20.0, 30.0])); + let (_, ctx) = make_var_scalar(0.0); + + let add_value = (&row + &matrix).value(&ctx); + let sub_value = (&matrix - &row).value(&ctx); + let mul_value = (&row * &matrix).value(&ctx); + + if let (Array::Dense(add), Array::Dense(sub), Array::Dense(mul)) = + (add_value, sub_value, mul_value) + { + assert_eq!(add[(0, 0)], 11.0); + assert_eq!(add[(1, 0)], 14.0); + assert_eq!(add[(0, 2)], 33.0); + assert_eq!(sub[(0, 1)], -18.0); + assert_eq!(sub[(1, 2)], -24.0); + assert_eq!(mul[(0, 0)], 10.0); + assert_eq!(mul[(1, 1)], 100.0); + assert_eq!(mul[(0, 2)], 90.0); + } else { + panic!("expected dense row broadcast values"); + } + } + + #[test] + fn test_eval_column_broadcast_add_sub_mul() { + let matrix = constant_dmatrix(DMatrix::from_row_slice( + 2, + 3, + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + )); + let col = constant_dmatrix(DMatrix::from_row_slice(2, 1, &[10.0, 20.0])); + let (_, ctx) = make_var_scalar(0.0); + + let add_value = (&matrix + &col).value(&ctx); + let sub_value = (&matrix - &col).value(&ctx); + let mul_value = (&col * &matrix).value(&ctx); + + if let (Array::Dense(add), Array::Dense(sub), Array::Dense(mul)) = + (add_value, sub_value, mul_value) + { + assert_eq!(add[(0, 0)], 11.0); + assert_eq!(add[(1, 0)], 24.0); + assert_eq!(add[(0, 2)], 13.0); + assert_eq!(sub[(0, 1)], -8.0); + assert_eq!(sub[(1, 2)], -14.0); + assert_eq!(mul[(0, 0)], 10.0); + assert_eq!(mul[(1, 1)], 100.0); + assert_eq!(mul[(1, 2)], 120.0); + } else { + panic!("expected dense column broadcast values"); + } + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2, 3) and (3, 2)")] + fn test_eval_incompatible_broadcast_panics_at_construction() { + let a = constant_dmatrix(DMatrix::zeros(2, 3)); + let b = constant_dmatrix(DMatrix::zeros(3, 2)); + + let _ = &a + &b; + } + #[test] fn test_eval_norm2() { let (x, ctx) = make_var_vec(vec![3.0, 4.0]); diff --git a/src/expr/expression.rs b/src/expr/expression.rs index 8559796..32ae206 100644 --- a/src/expr/expression.rs +++ b/src/expr/expression.rs @@ -278,6 +278,8 @@ pub enum Expr { Neg(Arc), /// Multiplication: a * b (scalar or matrix) Mul(Arc, Arc), + /// Promote a scalar-like expression to a larger shape. + Promote(Arc, Shape), /// Summation with optional axis. Sum(Arc, Option), /// Reshape to new shape. @@ -342,15 +344,14 @@ impl Expr { Expr::Constant(c) => c.shape(), // Affine - Expr::Add(a, b) => a - .shape() - .broadcast(&b.shape()) - .unwrap_or_else(Shape::scalar), + Expr::Add(a, b) => a.shape().broadcast(&b.shape()).unwrap_or_else(|| { + panic!("cannot broadcast shapes {} and {}", a.shape(), b.shape()) + }), Expr::Neg(a) => a.shape(), - Expr::Mul(a, b) => a - .shape() - .broadcast(&b.shape()) - .unwrap_or_else(Shape::scalar), + Expr::Mul(a, b) => a.shape().broadcast(&b.shape()).unwrap_or_else(|| { + panic!("cannot broadcast shapes {} and {}", a.shape(), b.shape()) + }), + Expr::Promote(_, shape) => shape.clone(), Expr::Sum(a, axis) => { if axis.is_some() { // Sum along axis reduces that dimension @@ -386,7 +387,13 @@ impl Expr { } Expr::Transpose(a) => a.shape().transpose(), Expr::Trace(_) => Shape::scalar(), - Expr::MatMul(a, b) => a.shape().matmul(&b.shape()).unwrap_or_else(Shape::scalar), + Expr::MatMul(a, b) => a.shape().matmul(&b.shape()).unwrap_or_else(|| { + panic!( + "cannot matrix-multiply shapes {} and {}", + a.shape(), + b.shape() + ) + }), // Nonlinear - norms return scalars Expr::Norm1(_) | Expr::Norm2(_) | Expr::NormInf(_) => Shape::scalar(), @@ -465,6 +472,7 @@ impl Expr { b.collect_variables(vars); } Expr::Neg(a) + | Expr::Promote(a, _) | Expr::Sum(a, _) | Expr::Reshape(a, _) | Expr::Index(a, _) diff --git a/src/expr/shape.rs b/src/expr/shape.rs index 7a0282c..57efd14 100644 --- a/src/expr/shape.rs +++ b/src/expr/shape.rs @@ -52,6 +52,14 @@ impl Shape { self.0.is_empty() } + /// Check if this shape has exactly one element. + /// + /// This is used for CVXPY-style scalar promotion in broadcasting while + /// preserving `is_scalar()` as the strict `()` shape check. + pub fn is_scalar_like(&self) -> bool { + self.size() == 1 + } + /// Check if this is a vector. pub fn is_vector(&self) -> bool { self.0.len() == 1 @@ -94,38 +102,59 @@ impl Shape { } } - /// Check if shapes are broadcastable and return the result shape. + /// Return the result shape for cvxrust element-wise affine broadcasting. pub fn broadcast(&self, other: &Shape) -> Option { - let max_ndim = self.ndim().max(other.ndim()); - let mut result = Vec::with_capacity(max_ndim); - - // Pad shapes with 1s on the left - let self_padded: Vec = std::iter::repeat_n(1, max_ndim - self.ndim()) - .chain(self.0.iter().copied()) - .collect(); - let other_padded: Vec = std::iter::repeat_n(1, max_ndim - other.ndim()) - .chain(other.0.iter().copied()) - .collect(); - - for (a, b) in self_padded.iter().zip(other_padded.iter()) { - if *a == *b { - result.push(*a); - } else if *a == 1 { - result.push(*b); - } else if *b == 1 { - result.push(*a); - } else { - return None; // Not broadcastable + if self == other { + return Some(self.clone()); + } + + if self.is_scalar_like() && !other.is_scalar_like() { + return Some(other.clone()); + } + if other.is_scalar_like() && !self.is_scalar_like() { + return Some(self.clone()); + } + + if self.rows() == other.rows() && self.cols() == other.cols() { + if self.is_vector() { + return Some(self.clone()); + } + if other.is_vector() { + return Some(other.clone()); } + return Some(Shape::matrix(self.rows(), self.cols())); } - Some(Shape(result)) + if self.is_matrix() && other.is_matrix() { + if self.rows() == 1 && self.cols() == other.cols() { + return Some(other.clone()); + } + if other.rows() == 1 && other.cols() == self.cols() { + return Some(self.clone()); + } + if self.cols() == 1 && self.rows() == other.rows() { + return Some(other.clone()); + } + if other.cols() == 1 && other.rows() == self.rows() { + return Some(self.clone()); + } + if self.rows() == 1 && other.cols() == 1 { + return Some(Shape::matrix(other.rows(), self.cols())); + } + if other.rows() == 1 && self.cols() == 1 { + return Some(Shape::matrix(self.rows(), other.cols())); + } + } + + None } /// Check if matrix multiplication is valid and return result shape. pub fn matmul(&self, other: &Shape) -> Option { // Handle various cases match (self.ndim(), other.ndim()) { + // matrix @ scalar, treating scalar as a 1x1 column + (2, 0) if self.cols() == 1 => Some(Shape::vector(self.rows())), // matrix @ matrix (2, 2) if self.cols() == other.rows() => Some(Shape::matrix(self.rows(), other.cols())), // matrix @ vector @@ -249,9 +278,19 @@ mod tests { Some(Shape::matrix(3, 4)) ); - // Vector broadcasts with matrix + // Row and column matrices broadcast over a matching matrix. + assert_eq!( + Shape::matrix(1, 4).broadcast(&Shape::matrix(3, 4)), + Some(Shape::matrix(3, 4)) + ); + assert_eq!( + Shape::matrix(3, 1).broadcast(&Shape::matrix(3, 4)), + Some(Shape::matrix(3, 4)) + ); + + // Mutual row/column broadcast. assert_eq!( - Shape::vector(4).broadcast(&Shape::matrix(3, 4)), + Shape::matrix(1, 4).broadcast(&Shape::matrix(3, 1)), Some(Shape::matrix(3, 4)) ); diff --git a/src/lib.rs b/src/lib.rs index 7ffc100..e231282 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,7 @@ //! - Aggregation: `sum`, `trace` //! - Structural: `reshape`, `transpose`, `vstack`, `hstack` //! - Linear algebra: `matmul`, `dot` +//! - Broadcasting: `promote` //! //! ### Convex //! - Norms: `norm1`, `norm2`, `norm_inf` @@ -87,7 +88,7 @@ pub mod prelude { // Atoms pub use crate::atoms::{ abs, cumsum, diag, dot, entropy, exp, flatten, hstack, indexc, log, matmul, max2, maximum, - min2, minimum, neg_part, norm, norm_inf, norm1, norm2, pos, power, quad_form, + min2, minimum, neg_part, norm, norm_inf, norm1, norm2, pos, power, promote, quad_form, quad_over_lin, reshape, select, slicec, sqrt, sum, sum_axis, sum_squares, trace, transpose, try_norm, vstack, }; diff --git a/src/problem.rs b/src/problem.rs index 22e4b7f..dbfb48f 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -126,6 +126,7 @@ impl Problem { Self::collect_variable_shapes(b, shapes); } Expr::Neg(a) + | Expr::Promote(a, _) | Expr::Sum(a, _) | Expr::Reshape(a, _) | Expr::Index(a, _) diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs index c64baaa..4f11aa7 100644 --- a/tests/canon_tests.rs +++ b/tests/canon_tests.rs @@ -273,6 +273,47 @@ fn test_sum_axis_one_constraints_are_not_total_sum() { } } +#[test] +fn test_scalar_affine_broadcasts_in_vector_constraint() { + let x = variable(2); + let y = variable(()); + + let sol = Problem::maximize(sum(&x)) + .subject_to([x.le(y.clone()), y.le(1.0), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 2.0).abs() < TOL); + assert!((solution_value(&sol, &y) - 1.0).abs() < TOL); + + if let Array::Dense(x_vals) = x.value(&sol) { + assert!((x_vals[(0, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(1, 0)] - 1.0).abs() < TOL); + } else { + panic!("expected dense vector solution"); + } +} + +#[test] +fn test_constant_vector_times_scalar_affine_broadcasts() { + let y = variable(()); + let weighted = constant_vec(vec![2.0, 3.0]) * y.clone(); + + let sol = Problem::maximize(sum(&weighted)) + .subject_to([y.eq(1.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 5.0).abs() < TOL); + + if let Array::Dense(weighted_vals) = weighted.value(&sol) { + assert!((weighted_vals[(0, 0)] - 2.0).abs() < TOL); + assert!((weighted_vals[(1, 0)] - 3.0).abs() < TOL); + } else { + panic!("expected dense vector value"); + } +} + fn solution_value(sol: &Solution, expr: &Expr) -> f64 { expr.value(sol).as_scalar().expect("expected scalar") } diff --git a/tests/eval_tests.rs b/tests/eval_tests.rs index a6b25a3..b0183a3 100644 --- a/tests/eval_tests.rs +++ b/tests/eval_tests.rs @@ -62,6 +62,46 @@ fn test_value_affine_scale_and_shift() { assert!(approx(v, 7.0), "expected 7.0, got {v}"); } +#[test] +fn test_value_promote_public_atom() { + let x = variable(()); + let promoted = promote(&x, (2, 2)); + let sol = Problem::minimize(x.clone()) + .subject_to([x.eq(3.0)]) + .solve() + .unwrap(); + + let vals = promoted.value(&sol); + assert!(approx(vals[(0, 0)], 3.0)); + assert!(approx(vals[(1, 0)], 3.0)); + assert!(approx(vals[(0, 1)], 3.0)); + assert!(approx(vals[(1, 1)], 3.0)); +} + +#[test] +fn test_value_broadcasted_row_and_column_affine_expressions() { + let matrix = constant_matrix(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 2, 3); + let row = constant_matrix(vec![10.0, 20.0, 30.0], 1, 3); + let col = constant_matrix(vec![100.0, 200.0], 2, 1); + let x = variable(()); + let sol = Problem::minimize(x.clone()) + .subject_to([x.eq(0.0)]) + .solve() + .unwrap(); + + let row_sum = (&matrix + &row).value(&sol); + assert!(approx(row_sum[(0, 0)], 11.0)); + assert!(approx(row_sum[(1, 0)], 14.0)); + assert!(approx(row_sum[(0, 2)], 33.0)); + assert!(approx(row_sum[(1, 2)], 36.0)); + + let col_weighted = (&col * &matrix).value(&sol); + assert!(approx(col_weighted[(0, 0)], 100.0)); + assert!(approx(col_weighted[(1, 0)], 800.0)); + assert!(approx(col_weighted[(0, 2)], 300.0)); + assert!(approx(col_weighted[(1, 2)], 1200.0)); +} + #[test] fn test_value_matmul_residual() { // minimize ||Ax - b||^2 s.t. x >= 0 diff --git a/tests/solve_tests.rs b/tests/solve_tests.rs index f8f0846..7695397 100644 --- a/tests/solve_tests.rs +++ b/tests/solve_tests.rs @@ -844,6 +844,114 @@ fn test_tight_constraints() { ); } +#[test] +fn test_broadcasted_affine_constraints_public_workflow() { + let x = variable((2, 3)); + let row_cap = variable((1, 3)); + let col_floor = variable((2, 1)); + + let prob = Problem::maximize(sum(&x)) + .subject_to([ + row_cap.eq(constant_matrix(vec![2.0, 3.0, 4.0], 1, 3)), + col_floor.eq(constant_matrix(vec![1.0, 2.0], 2, 1)), + x.le(row_cap.clone()), + x.ge(col_floor.clone()), + ]) + .build(); + + assert!(prob.is_dcp()); + + let solution = prob.solve().expect("broadcasted problem should solve"); + let value = solution.value.expect("should have objective"); + assert!( + (value - 18.0).abs() < TOL, + "broadcasted constraints: expected 18, got {}", + value + ); + + let x_vals = &solution[&x]; + for i in 0..2 { + assert!((x_vals[(i, 0)] - 2.0).abs() < TOL); + assert!((x_vals[(i, 1)] - 3.0).abs() < TOL); + assert!((x_vals[(i, 2)] - 4.0).abs() < TOL); + } +} + +#[test] +fn test_row_affine_broadcasts_in_matrix_constraint() { + let x = variable((2, 3)); + let row = variable((1, 3)); + let row_values = constant_matrix(vec![1.0, 2.0, 3.0], 1, 3); + + let sol = Problem::maximize(sum(&x)) + .subject_to([x.le(row.clone()), row.eq(row_values), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 12.0).abs() < TOL); + + if let Array::Dense(x_vals) = x.value(&sol) { + for i in 0..2 { + assert!((x_vals[(i, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(i, 1)] - 2.0).abs() < TOL); + assert!((x_vals[(i, 2)] - 3.0).abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_column_affine_broadcasts_in_matrix_constraint() { + let x = variable((2, 3)); + let col = variable((2, 1)); + let col_values = constant_matrix(vec![1.0, 2.0], 2, 1); + + let sol = Problem::maximize(sum(&x)) + .subject_to([x.le(col.clone()), col.eq(col_values), x.ge(0.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 9.0).abs() < TOL); + + if let Array::Dense(x_vals) = x.value(&sol) { + for j in 0..3 { + assert!((x_vals[(0, j)] - 1.0).abs() < TOL); + assert!((x_vals[(1, j)] - 2.0).abs() < TOL); + } + } else { + panic!("expected dense matrix solution"); + } +} + +#[test] +fn test_row_broadcast_constant_times_affine_matrix() { + let x = variable((2, 3)); + let weights = constant_matrix(vec![10.0, 20.0, 30.0], 1, 3); + let weighted = weights * x.clone(); + + let sol = Problem::minimize(sum(&weighted)) + .subject_to([x.eq(ones((2, 3)))]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 120.0).abs() < TOL); +} + +#[test] +fn test_column_broadcast_constant_times_affine_matrix() { + let x = variable((2, 3)); + let weights = constant_matrix(vec![10.0, 20.0], 2, 1); + let weighted = weights * x.clone(); + + let sol = Problem::minimize(sum(&weighted)) + .subject_to([x.eq(ones((2, 3)))]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 90.0).abs() < TOL); +} + // ============================================================================ // Scale tests (medium-sized problems) // ============================================================================