From 2c61fda64069db9c0e86c176fd137d276d24bebd Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 22:30:38 +0200 Subject: [PATCH 01/19] =?UTF-8?q?Removes=20sparse=20stack=20dimension=20wi?= =?UTF-8?q?dening=20and=20replaces=20=E2=80=9Creturn=20zeros=20on=20imposs?= =?UTF-8?q?ible=20CSC=20failure=E2=80=9D=20fallbacks=20with=20explicit=20e?= =?UTF-8?q?xpectations.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sparse.rs | 86 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 16 deletions(-) diff --git a/src/sparse.rs b/src/sparse.rs index 6a24c96..402af77 100644 --- a/src/sparse.rs +++ b/src/sparse.rs @@ -67,6 +67,11 @@ pub fn csc_to_dense(sparse: &CscMatrix) -> DMatrix { /// Stack two CSC matrices vertically. pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { + assert_eq!( + a.ncols(), + b.ncols(), + "vstack requires matching column counts" + ); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new(); @@ -82,13 +87,7 @@ pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { vals.push(*v); } - csc_from_triplets( - a.nrows() + b.nrows(), - a.ncols().max(b.ncols()), - rows, - cols, - vals, - ) + csc_from_triplets(a.nrows() + b.nrows(), a.ncols(), rows, cols, vals) } /// Add two CSC matrices. @@ -103,7 +102,7 @@ pub fn csc_neg(a: &CscMatrix) -> CscMatrix { let col_offsets: Vec = a.col_offsets().to_vec(); let row_indices: Vec = a.row_indices().to_vec(); CscMatrix::try_from_csc_data(a.nrows(), a.ncols(), col_offsets, row_indices, values) - .unwrap_or_else(|_| CscMatrix::zeros(a.nrows(), a.ncols())) + .expect("valid CSC structure should remain valid when negating values") } /// Scale a CSC matrix. @@ -112,7 +111,7 @@ pub fn csc_scale(a: &CscMatrix, scalar: f64) -> CscMatrix { let col_offsets: Vec = a.col_offsets().to_vec(); let row_indices: Vec = a.row_indices().to_vec(); CscMatrix::try_from_csc_data(a.nrows(), a.ncols(), col_offsets, row_indices, values) - .unwrap_or_else(|_| CscMatrix::zeros(a.nrows(), a.ncols())) + .expect("valid CSC structure should remain valid when scaling values") } /// Multiply sparse matrix by dense matrix on the right: A_sparse @ B_dense @@ -126,6 +125,7 @@ pub fn sparse_dense_matmul(a: &CscMatrix, b: &DMatrix) -> CscMatrix, b: &CscMatrix) -> CscMatrix { + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new(); @@ -141,13 +141,7 @@ pub fn csc_hstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { vals.push(*v); } - csc_from_triplets( - a.nrows().max(b.nrows()), - a.ncols() + b.ncols(), - rows, - cols, - vals, - ) + csc_from_triplets(a.nrows(), a.ncols() + b.ncols(), rows, cols, vals) } /// Multiply two CSC matrices: A @ B @@ -220,4 +214,64 @@ mod tests { assert!((v - 2.0).abs() < 1e-10); } } + + #[test] + fn test_csc_neg_preserves_structure_and_negates_values() { + let a = csc_from_triplets(2, 2, vec![0, 1], vec![0, 1], vec![2.0, -3.0]); + let neg = csc_neg(&a); + let dense = csc_to_dense(&neg); + + assert_eq!(neg.nrows(), 2); + assert_eq!(neg.ncols(), 2); + assert_eq!(dense[(0, 0)], -2.0); + assert_eq!(dense[(1, 1)], 3.0); + } + + #[test] + fn test_csc_scale_preserves_structure_and_scales_values() { + let a = csc_from_triplets(2, 2, vec![0, 1], vec![0, 1], vec![2.0, -3.0]); + let scaled = csc_scale(&a, 4.0); + let dense = csc_to_dense(&scaled); + + assert_eq!(scaled.nrows(), 2); + assert_eq!(scaled.ncols(), 2); + assert_eq!(dense[(0, 0)], 8.0); + assert_eq!(dense[(1, 1)], -12.0); + } + + #[test] + fn test_csc_vstack_matching_columns() { + let a = CscMatrix::::identity(2); + let b = CscMatrix::::zeros(3, 2); + let stacked = csc_vstack(&a, &b); + + assert_eq!(stacked.nrows(), 5); + assert_eq!(stacked.ncols(), 2); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_csc_vstack_mismatched_columns_panics() { + let a = CscMatrix::::zeros(2, 2); + let b = CscMatrix::::zeros(3, 3); + let _ = csc_vstack(&a, &b); + } + + #[test] + fn test_csc_hstack_matching_rows() { + let a = CscMatrix::::identity(2); + let b = CscMatrix::::identity(2); + let stacked = csc_hstack(&a, &b); + + assert_eq!(stacked.nrows(), 2); + assert_eq!(stacked.ncols(), 4); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_csc_hstack_mismatched_rows_panics() { + let a = CscMatrix::::zeros(2, 2); + let b = CscMatrix::::zeros(3, 2); + let _ = csc_hstack(&a, &b); + } } From 0d29cfc0dfb132260686d27e07d29d2fe6fdd716 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 22:38:26 +0200 Subject: [PATCH 02/19] Add axis index check for sum_axis --- src/atoms/affine.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index ee54e32..b018d95 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -257,6 +257,13 @@ pub fn sum(expr: &Expr) -> Expr { /// Sum along a specific axis. pub fn sum_axis(expr: &Expr, axis: usize) -> Expr { + let shape = expr.shape(); + assert!( + axis < shape.ndim().max(1), + "axis {} out of bounds for shape {}", + axis, + shape + ); Expr::Sum(Arc::new(expr.clone()), Some(axis)) } @@ -472,6 +479,20 @@ mod tests { assert_eq!(s.shape(), Shape::scalar()); } + #[test] + fn test_sum_axis_vector_shape() { + let x = variable(3); + let s = sum_axis(&x, 0); + assert_eq!(s.shape(), Shape::scalar()); + } + + #[test] + #[should_panic(expected = "axis 1 out of bounds for shape (3,)")] + fn test_sum_axis_invalid_axis_panics() { + let x = variable(3); + let _ = sum_axis(&x, 1); + } + #[test] fn test_promote_shape_and_metadata() { let x = nonneg_variable(()); From cf41c4aaec5e22b0a9243be35e63e700208a3160 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 22:39:29 +0200 Subject: [PATCH 03/19] Add shape compatibility check for reshape --- src/atoms/affine.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index b018d95..a8229c1 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -286,7 +286,15 @@ pub fn promote(expr: &Expr, shape: impl Into) -> Expr { /// Reshape an expression to a new shape. pub fn reshape(expr: &Expr, shape: impl Into) -> Expr { - Expr::Reshape(Arc::new(expr.clone()), shape.into()) + let shape = shape.into(); + assert_eq!( + expr.shape().size(), + shape.size(), + "cannot reshape size {} into shape {}", + expr.shape().size(), + shape + ); + Expr::Reshape(Arc::new(expr.clone()), shape) } /// Flatten an expression to a vector. @@ -519,6 +527,13 @@ mod tests { assert_eq!(promote(&x, 3).shape(), Shape::vector(3)); } + #[test] + #[should_panic(expected = "cannot reshape size 3 into shape (2, 2)")] + fn test_reshape_size_mismatch_panics() { + let x = variable(3); + let _ = reshape(&x, (2, 2)); + } + #[test] #[should_panic(expected = "only scalar expressions can be promoted")] fn test_promote_rejects_non_scalar_like_shape() { From a5cf1b51716325f108cc19506d5286b810e8f179 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 22:40:36 +0200 Subject: [PATCH 04/19] Add column/row matching check for vstack/hstack --- src/atoms/affine.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index a8229c1..80e3a56 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -315,11 +315,31 @@ pub fn trace(expr: &Expr) -> Expr { /// Vertical stack (row-wise concatenation). pub fn vstack(exprs: Vec) -> Expr { + if let Some(first) = exprs.first() { + let cols = first.shape().cols(); + for expr in &exprs[1..] { + assert_eq!( + expr.shape().cols(), + cols, + "vstack requires matching column counts" + ); + } + } Expr::VStack(exprs.into_iter().map(Arc::new).collect()) } /// Horizontal stack (column-wise concatenation). pub fn hstack(exprs: Vec) -> Expr { + if let Some(first) = exprs.first() { + let rows = first.shape().rows(); + for expr in &exprs[1..] { + assert_eq!( + expr.shape().rows(), + rows, + "hstack requires matching row counts" + ); + } + } Expr::HStack(exprs.into_iter().map(Arc::new).collect()) } @@ -693,6 +713,22 @@ mod tests { assert_eq!(z.shape(), Shape::matrix(5, 3)); } + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_vstack_mismatched_columns_panics() { + let x = variable((2, 3)); + let y = variable((3, 2)); + let _ = vstack(vec![x, y]); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_hstack_mismatched_rows_panics() { + let x = variable((2, 3)); + let y = variable((3, 3)); + let _ = hstack(vec![x, y]); + } + #[test] fn test_affine_is_affine() { let x = variable(5); From bb2e3efeebadad625e640552009e8e0f27bf53a9 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 22:41:52 +0200 Subject: [PATCH 05/19] Update outdated docs --- src/atoms/affine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index 80e3a56..d6a8b0c 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -457,7 +457,7 @@ pub fn cumsum(expr: &Expr) -> Expr { /// Diagonal matrix from vector, or diagonal of matrix. /// /// - Vector input: Creates diagonal matrix with vector on diagonal -/// - Matrix input: Extracts diagonal as vector (v1.0: returns input as fallback) +/// - Matrix input: Extracts diagonal as vector pub fn diag(expr: &Expr) -> Expr { Expr::Diag(Arc::new(expr.clone())) } From eae8e9c74d47a6d6e5fad804ea4fc817d194d917 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:11:25 +0200 Subject: [PATCH 06/19] Add number of expressions and dimension compatibility check and broadcasting for minimum and maximum atoms --- src/atoms/affine.rs | 2 +- src/atoms/nonlinear.rs | 135 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index d6a8b0c..2571bb0 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -214,7 +214,7 @@ pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { (lhs, rhs) } -fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { +pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { if expr_shape == target_shape { return Some(expr); } diff --git a/src/atoms/nonlinear.rs b/src/atoms/nonlinear.rs index 6d8affa..c373dfa 100644 --- a/src/atoms/nonlinear.rs +++ b/src/atoms/nonlinear.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use crate::atoms::affine::broadcast_to; use crate::expr::Expr; // ============================================================================ @@ -136,10 +137,14 @@ pub fn neg_part(x: &Expr) -> Expr { /// - Sign: Depends on arguments /// - Monotonicity: Increasing in all arguments pub fn maximum(exprs: Vec) -> Expr { + if exprs.is_empty() { + panic!("maximum requires at least one expression"); + } if exprs.len() == 1 { return exprs.into_iter().next().unwrap(); } - Expr::Maximum(exprs.into_iter().map(Arc::new).collect()) + + Expr::Maximum(broadcast_elementwise_args(exprs)) } /// Maximum of two expressions. @@ -154,10 +159,14 @@ pub fn max2(a: &Expr, b: &Expr) -> Expr { /// - Sign: Depends on arguments /// - Monotonicity: Increasing in all arguments pub fn minimum(exprs: Vec) -> Expr { + if exprs.is_empty() { + panic!("minimum requires at least one expression"); + } if exprs.len() == 1 { return exprs.into_iter().next().unwrap(); } - Expr::Minimum(exprs.into_iter().map(Arc::new).collect()) + + Expr::Minimum(broadcast_elementwise_args(exprs)) } /// Minimum of two expressions. @@ -165,6 +174,28 @@ pub fn min2(a: &Expr, b: &Expr) -> Expr { minimum(vec![a.clone(), b.clone()]) } +fn broadcast_elementwise_args(exprs: Vec) -> Vec> { + let target_shape = exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("elementwise atom requires at least one expression"); + + exprs + .into_iter() + .map(|expr| { + let expr_shape = expr.shape(); + let expr = broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { + panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) + }); + Arc::new(expr) + }) + .collect() +} + // ============================================================================ // Quadratic atoms // ============================================================================ @@ -287,6 +318,56 @@ mod tests { assert_eq!(m.curvature(), Curvature::Convex); } + #[test] + #[should_panic(expected = "maximum requires at least one expression")] + fn test_maximum_empty_panics() { + let _ = maximum(Vec::new()); + } + + #[test] + fn test_maximum_single_expression_is_identity() { + let x = variable(5); + let m = maximum(vec![x.clone()]); + assert_eq!(m.shape(), x.shape()); + assert_eq!(m.variables(), x.variables()); + } + + #[test] + fn test_maximum_broadcasts_scalar_to_vector() { + let x = variable(5); + let y = variable(()); + let m = maximum(vec![x, y]); + assert_eq!(m.shape(), crate::expr::Shape::vector(5)); + } + + #[test] + fn test_maximum_three_args_stays_flat() { + let x = variable(5); + let y = variable(5); + let z = variable(()); + let m = maximum(vec![x, y, z]); + + match m { + Expr::Maximum(args) => { + assert_eq!(args.len(), 3); + assert!( + args.iter() + .all(|arg| arg.shape() == crate::expr::Shape::vector(5)) + ); + assert!(!matches!(&*args[0], Expr::Maximum(_))); + } + _ => panic!("expected maximum expression"), + } + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_maximum_incompatible_shapes_panic() { + let x = variable(2); + let y = variable(3); + let _ = maximum(vec![x, y]); + } + #[test] fn test_minimum_concave() { let x = variable(5); @@ -295,6 +376,56 @@ mod tests { assert_eq!(m.curvature(), Curvature::Concave); } + #[test] + #[should_panic(expected = "minimum requires at least one expression")] + fn test_minimum_empty_panics() { + let _ = minimum(Vec::new()); + } + + #[test] + fn test_minimum_single_expression_is_identity() { + let x = variable(5); + let m = minimum(vec![x.clone()]); + assert_eq!(m.shape(), x.shape()); + assert_eq!(m.variables(), x.variables()); + } + + #[test] + fn test_minimum_broadcasts_scalar_to_vector() { + let x = variable(5); + let y = variable(()); + let m = minimum(vec![x, y]); + assert_eq!(m.shape(), crate::expr::Shape::vector(5)); + } + + #[test] + fn test_minimum_three_args_stays_flat() { + let x = variable(5); + let y = variable(5); + let z = variable(()); + let m = minimum(vec![x, y, z]); + + match m { + Expr::Minimum(args) => { + assert_eq!(args.len(), 3); + assert!( + args.iter() + .all(|arg| arg.shape() == crate::expr::Shape::vector(5)) + ); + assert!(!matches!(&*args[0], Expr::Minimum(_))); + } + _ => panic!("expected minimum expression"), + } + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_minimum_incompatible_shapes_panic() { + let x = variable(2); + let y = variable(3); + let _ = minimum(vec![x, y]); + } + #[test] fn test_sum_squares_convex() { let x = variable(5); From 92b832122702c2ee42f5999fa90cbf1195202f39 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:24:25 +0200 Subject: [PATCH 07/19] Update maximum/minimum atom canonicalization --- src/atoms/affine.rs | 26 ++++++++++++++++++++++++++ src/atoms/nonlinear.rs | 30 +++++------------------------- src/canon/canonicalizer.rs | 17 +++++++++-------- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index 2571bb0..819fae6 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -226,6 +226,32 @@ pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) broadcast_2d_to(expr, expr_shape, target_shape) } +pub(crate) fn broadcast_elementwise_exprs( + exprs: impl IntoIterator, +) -> (Shape, Vec) { + let exprs: Vec = exprs.into_iter().collect(); + let target_shape = exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("elementwise atom requires at least one expression"); + + let exprs = exprs + .into_iter() + .map(|expr| { + let expr_shape = expr.shape(); + broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { + panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) + }) + }) + .collect(); + + (target_shape, exprs) +} + fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { if !expr_shape.is_matrix() || !target_shape.is_matrix() { return None; diff --git a/src/atoms/nonlinear.rs b/src/atoms/nonlinear.rs index c373dfa..7d387a1 100644 --- a/src/atoms/nonlinear.rs +++ b/src/atoms/nonlinear.rs @@ -5,7 +5,7 @@ use std::sync::Arc; -use crate::atoms::affine::broadcast_to; +use crate::atoms::affine::broadcast_elementwise_exprs; use crate::expr::Expr; // ============================================================================ @@ -144,7 +144,8 @@ pub fn maximum(exprs: Vec) -> Expr { return exprs.into_iter().next().unwrap(); } - Expr::Maximum(broadcast_elementwise_args(exprs)) + let (_, exprs) = broadcast_elementwise_exprs(exprs); + Expr::Maximum(exprs.into_iter().map(Arc::new).collect()) } /// Maximum of two expressions. @@ -166,7 +167,8 @@ pub fn minimum(exprs: Vec) -> Expr { return exprs.into_iter().next().unwrap(); } - Expr::Minimum(broadcast_elementwise_args(exprs)) + let (_, exprs) = broadcast_elementwise_exprs(exprs); + Expr::Minimum(exprs.into_iter().map(Arc::new).collect()) } /// Minimum of two expressions. @@ -174,28 +176,6 @@ pub fn min2(a: &Expr, b: &Expr) -> Expr { minimum(vec![a.clone(), b.clone()]) } -fn broadcast_elementwise_args(exprs: Vec) -> Vec> { - let target_shape = exprs - .iter() - .map(|expr| expr.shape()) - .reduce(|acc, shape| { - acc.broadcast(&shape) - .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) - }) - .expect("elementwise atom requires at least one expression"); - - exprs - .into_iter() - .map(|expr| { - let expr_shape = expr.shape(); - let expr = broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { - panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) - }); - Arc::new(expr) - }) - .collect() -} - // ============================================================================ // Quadratic atoms // ============================================================================ diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index f8a6c50..4168ccb 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -11,6 +11,7 @@ use nalgebra::DMatrix; use nalgebra_sparse::CscMatrix; use super::lin_expr::{LinExpr, QuadExpr}; +use crate::atoms::affine::broadcast_elementwise_exprs; use crate::expr::{Array, Expr, ExprId, IndexSpec, Shape, VariableBuilder}; use crate::sparse::{csc_add, csc_repeat_rows, csc_to_dense, csc_vstack, dense_to_csc}; @@ -918,13 +919,13 @@ impl CanonContext { fn canonicalize_maximum(&mut self, exprs: &[Arc]) -> CanonExpr { // max(x1, ..., xn): Introduce t, t >= x_i for all i if exprs.is_empty() { - return CanonExpr::Linear(LinExpr::zeros(Shape::scalar())); + panic!("maximum requires at least one expression"); } - - let shape = exprs[0].shape(); + let (shape, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); let (_, t) = self.new_aux_var(shape); - for e in exprs { + for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t >= x_i, i.e., t - x_i >= 0 self.constraints.push(ConeConstraint::NonNeg { @@ -938,13 +939,13 @@ impl CanonContext { fn canonicalize_minimum(&mut self, exprs: &[Arc]) -> CanonExpr { // min(x1, ..., xn): Introduce t, t <= x_i for all i if exprs.is_empty() { - return CanonExpr::Linear(LinExpr::zeros(Shape::scalar())); + panic!("minimum requires at least one expression"); } - - let shape = exprs[0].shape(); + let (shape, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); let (_, t) = self.new_aux_var(shape); - for e in exprs { + for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t <= x_i, i.e., x_i - t >= 0 self.constraints.push(ConeConstraint::NonNeg { From fc9daf63e97b568bafcefbd7abce7c6513ba8052 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:29:35 +0200 Subject: [PATCH 08/19] Move broadcasting related functions to broadcast.rs --- src/atoms/affine.rs | 83 +--------------------------------- src/atoms/broadcast.rs | 84 +++++++++++++++++++++++++++++++++++ src/atoms/mod.rs | 1 + src/atoms/nonlinear.rs | 2 +- src/canon/canonicalizer.rs | 2 +- src/constraints/constraint.rs | 2 +- 6 files changed, 90 insertions(+), 84 deletions(-) create mode 100644 src/atoms/broadcast.rs diff --git a/src/atoms/affine.rs b/src/atoms/affine.rs index 819fae6..3d1e821 100644 --- a/src/atoms/affine.rs +++ b/src/atoms/affine.rs @@ -9,7 +9,8 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use std::sync::Arc; -use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant, ones}; +use super::broadcast::broadcast_exprs; +use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant}; // ============================================================================ // Operator overloading for Expr @@ -192,86 +193,6 @@ fn mul_exprs(lhs: Expr, rhs: Expr) -> Expr { Expr::Mul(Arc::new(lhs), Arc::new(rhs)) } -pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - let lhs_shape = lhs.shape(); - let rhs_shape = rhs.shape(); - let target_shape = lhs_shape - .broadcast(&rhs_shape) - .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape)); - - if lhs_shape == target_shape && rhs_shape == target_shape { - return (lhs, rhs); - } - if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() { - return (lhs, rhs); - } - - let lhs = broadcast_to(lhs, &lhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape)); - let rhs = broadcast_to(rhs, &rhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape)); - - (lhs, rhs) -} - -pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { - if expr_shape == target_shape { - return Some(expr); - } - - if expr_shape.is_scalar_like() { - return Some(promote(&expr, target_shape.clone())); - } - - broadcast_2d_to(expr, expr_shape, target_shape) -} - -pub(crate) fn broadcast_elementwise_exprs( - exprs: impl IntoIterator, -) -> (Shape, Vec) { - let exprs: Vec = exprs.into_iter().collect(); - let target_shape = exprs - .iter() - .map(|expr| expr.shape()) - .reduce(|acc, shape| { - acc.broadcast(&shape) - .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) - }) - .expect("elementwise atom requires at least one expression"); - - let exprs = exprs - .into_iter() - .map(|expr| { - let expr_shape = expr.shape(); - broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { - panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) - }) - }) - .collect(); - - (target_shape, exprs) -} - -fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { - if !expr_shape.is_matrix() || !target_shape.is_matrix() { - return None; - } - - if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() - { - let left = ones((target_shape.rows(), 1)); - return Some(matmul(&left, &expr)); - } - - if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() - { - let right = ones((1, target_shape.cols())); - return Some(matmul(&expr, &right)); - } - - None -} - // ============================================================================ // Affine atom functions // ============================================================================ diff --git a/src/atoms/broadcast.rs b/src/atoms/broadcast.rs new file mode 100644 index 0000000..cec4d95 --- /dev/null +++ b/src/atoms/broadcast.rs @@ -0,0 +1,84 @@ +//! Internal broadcasting helpers for atom construction and canonicalization. + +use crate::atoms::affine::{matmul, promote}; +use crate::expr::{Expr, Shape, ones}; + +pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { + let lhs_shape = lhs.shape(); + let rhs_shape = rhs.shape(); + let target_shape = lhs_shape + .broadcast(&rhs_shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape)); + + if lhs_shape == target_shape && rhs_shape == target_shape { + return (lhs, rhs); + } + if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() { + return (lhs, rhs); + } + + let lhs = broadcast_to(lhs, &lhs_shape, &target_shape) + .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape)); + let rhs = broadcast_to(rhs, &rhs_shape, &target_shape) + .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape)); + + (lhs, rhs) +} + +pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if expr_shape == target_shape { + return Some(expr); + } + + if expr_shape.is_scalar_like() { + return Some(promote(&expr, target_shape.clone())); + } + + broadcast_2d_to(expr, expr_shape, target_shape) +} + +pub(crate) fn broadcast_elementwise_exprs( + exprs: impl IntoIterator, +) -> (Shape, Vec) { + let exprs: Vec = exprs.into_iter().collect(); + let target_shape = exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("elementwise atom requires at least one expression"); + + let exprs = exprs + .into_iter() + .map(|expr| { + let expr_shape = expr.shape(); + broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { + panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) + }) + }) + .collect(); + + (target_shape, exprs) +} + +fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { + if !expr_shape.is_matrix() || !target_shape.is_matrix() { + return None; + } + + if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() + { + let left = ones((target_shape.rows(), 1)); + return Some(matmul(&left, &expr)); + } + + if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() + { + let right = ones((1, target_shape.cols())); + return Some(matmul(&expr, &right)); + } + + None +} diff --git a/src/atoms/mod.rs b/src/atoms/mod.rs index 4f401a1..f5c7c91 100644 --- a/src/atoms/mod.rs +++ b/src/atoms/mod.rs @@ -6,6 +6,7 @@ //! - **Nonlinear atoms**: Operations with specific curvature (norms, quadratic forms, etc.) pub mod affine; +pub(crate) mod broadcast; pub mod nonlinear; // Re-export affine operations diff --git a/src/atoms/nonlinear.rs b/src/atoms/nonlinear.rs index 7d387a1..a3e206e 100644 --- a/src/atoms/nonlinear.rs +++ b/src/atoms/nonlinear.rs @@ -5,7 +5,7 @@ use std::sync::Arc; -use crate::atoms::affine::broadcast_elementwise_exprs; +use crate::atoms::broadcast::broadcast_elementwise_exprs; use crate::expr::Expr; // ============================================================================ diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 4168ccb..436ecc4 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -11,7 +11,7 @@ use nalgebra::DMatrix; use nalgebra_sparse::CscMatrix; use super::lin_expr::{LinExpr, QuadExpr}; -use crate::atoms::affine::broadcast_elementwise_exprs; +use crate::atoms::broadcast::broadcast_elementwise_exprs; use crate::expr::{Array, Expr, ExprId, IndexSpec, Shape, VariableBuilder}; use crate::sparse::{csc_add, csc_repeat_rows, csc_to_dense, csc_vstack, dense_to_csc}; diff --git a/src/constraints/constraint.rs b/src/constraints/constraint.rs index 1731a88..9b67220 100644 --- a/src/constraints/constraint.rs +++ b/src/constraints/constraint.rs @@ -62,7 +62,7 @@ impl Constraint { /// Broadcast scalars to match shapes if needed. fn broadcast_if_needed(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - crate::atoms::affine::broadcast_exprs(lhs, rhs) + crate::atoms::broadcast::broadcast_exprs(lhs, rhs) } /// Create a SOC constraint: ||x||_2 <= t. From 311ec4ed797c9d61f09276ee8fc9ec9feec66b89 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:39:05 +0200 Subject: [PATCH 09/19] Update canonicalizer and tests --- src/canon/canonicalizer.rs | 127 +++++++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 6 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 436ecc4..b083e35 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -513,9 +513,10 @@ impl CanonContext { }) } - fn canonicalize_sum_axis_lin(&self, x: &LinExpr, axis: usize) -> CanonExpr { + fn canonicalize_sum_axis_lin(&mut self, x: &LinExpr, axis: usize) -> CanonExpr { if x.shape.ndim() <= 1 { - return CanonExpr::Linear(x.clone()); + assert_eq!(axis, 0, "axis {} out of bounds for shape {}", axis, x.shape); + return self.canonicalize_sum_lin(x); } let rows = x.shape.rows(); @@ -523,7 +524,7 @@ impl CanonContext { let (out_size, mut s_rows, mut s_cols, mut s_vals) = match axis { 0 => (cols, Vec::new(), Vec::new(), Vec::new()), 1 => (rows, Vec::new(), Vec::new(), Vec::new()), - _ => return CanonExpr::Linear(x.clone()), + _ => panic!("axis {} out of bounds for shape {}", axis, x.shape), }; for col in 0..cols { @@ -559,6 +560,13 @@ impl CanonContext { fn canonicalize_reshape(&mut self, a: &Expr, shape: &Shape) -> CanonExpr { let ca = self.canonicalize_expr(a, false).as_linear().clone(); + assert_eq!( + ca.shape.size(), + shape.size(), + "cannot reshape size {} into shape {}", + ca.shape.size(), + shape + ); // Reshape doesn't change the linear structure, just the shape interpretation CanonExpr::Linear(LinExpr { coeffs: ca.coeffs, @@ -651,6 +659,11 @@ impl CanonContext { } fn vstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { + assert_eq!( + a.shape.cols(), + b.shape.cols(), + "vstack requires matching column counts" + ); // Stack constants vertically let new_const = stack_vertical(&a.constant, &b.constant); let new_shape = Shape::matrix(new_const.nrows(), new_const.ncols()); @@ -699,6 +712,11 @@ impl CanonContext { } fn hstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { + assert_eq!( + a.shape.rows(), + b.shape.rows(), + "hstack requires matching row counts" + ); // Stack constants horizontally let new_const = stack_horizontal(&a.constant, &b.constant); let new_shape = Shape::matrix(new_const.nrows(), new_const.ncols()); @@ -1429,7 +1447,12 @@ fn dense_sparse_matmul(dense: &DMatrix, sparse: &CscMatrix) -> CscMatr } fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { - let mut result = DMatrix::zeros(a.nrows() + b.nrows(), a.ncols().max(b.ncols())); + assert_eq!( + a.ncols(), + b.ncols(), + "vstack requires matching column counts" + ); + let mut result = DMatrix::zeros(a.nrows() + b.nrows(), a.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result .view_mut((a.nrows(), 0), (b.nrows(), b.ncols())) @@ -1438,7 +1461,8 @@ fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { } fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { - let mut result = DMatrix::zeros(a.nrows().max(b.nrows()), a.ncols() + b.ncols()); + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + let mut result = DMatrix::zeros(a.nrows(), a.ncols() + b.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result .view_mut((0, a.ncols()), (b.nrows(), b.ncols())) @@ -1465,7 +1489,7 @@ fn scalar_constant_value(expr: &Expr) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::atoms::{matmul, promote}; + use crate::atoms::{matmul, promote, sum_axis}; use crate::expr::{constant, constant_matrix, variable}; #[test] @@ -1495,6 +1519,97 @@ mod tests { assert!(matches!(result.expr, CanonExpr::Quadratic(_)) || !result.constraints.is_empty()); } + #[test] + fn test_canonicalize_sum_axis_vector_returns_scalar() { + let x = variable(3); + let result = canonicalize(&sum_axis(&x, 0), false); + + assert_eq!(result.expr.as_linear().shape, Shape::scalar()); + } + + #[test] + #[should_panic(expected = "axis 2 out of bounds for shape (2, 3)")] + fn test_canonicalize_sum_axis_invalid_direct_expr_panics() { + let x = variable((2, 3)); + let _ = canonicalize(&Expr::Sum(Arc::new(x), Some(2)), false); + } + + #[test] + #[should_panic(expected = "cannot reshape size 3 into shape (2, 2)")] + fn test_canonicalize_reshape_size_mismatch_direct_expr_panics() { + let x = variable(3); + let expr = Expr::Reshape(Arc::new(x), Shape::matrix(2, 2)); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_canonicalize_vstack_mismatched_columns_direct_expr_panics() { + let x = variable((2, 3)); + let y = variable((3, 2)); + let expr = Expr::VStack(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_canonicalize_hstack_mismatched_rows_direct_expr_panics() { + let x = variable((2, 3)); + let y = variable((3, 3)); + let expr = Expr::HStack(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "maximum requires at least one expression")] + fn test_canonicalize_maximum_empty_direct_expr_panics() { + let _ = canonicalize(&Expr::Maximum(Vec::new()), false); + } + + #[test] + fn test_canonicalize_maximum_direct_expr_broadcasts_scalar_to_vector() { + let x = variable(3); + let y = variable(()); + let expr = Expr::Maximum(vec![Arc::new(x), Arc::new(y)]); + let result = canonicalize(&expr, false); + + assert_eq!(result.expr.as_linear().shape, Shape::vector(3)); + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_canonicalize_maximum_direct_expr_incompatible_shapes_panics() { + let x = variable(2); + let y = variable(3); + let expr = Expr::Maximum(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + + #[test] + #[should_panic(expected = "minimum requires at least one expression")] + fn test_canonicalize_minimum_empty_direct_expr_panics() { + let _ = canonicalize(&Expr::Minimum(Vec::new()), false); + } + + #[test] + fn test_canonicalize_minimum_direct_expr_broadcasts_scalar_to_vector() { + let x = variable(3); + let y = variable(()); + let expr = Expr::Minimum(vec![Arc::new(x), Arc::new(y)]); + let result = canonicalize(&expr, false); + + assert_eq!(result.expr.as_linear().shape, Shape::vector(3)); + } + + #[test] + #[should_panic(expected = "cannot broadcast shapes (2,) and (3,)")] + fn test_canonicalize_minimum_direct_expr_incompatible_shapes_panics() { + let x = variable(2); + let y = variable(3); + let expr = Expr::Minimum(vec![Arc::new(x), Arc::new(y)]); + let _ = canonicalize(&expr, false); + } + #[test] fn test_canonicalize_matmul_preserves_vector_result_shape() { let a = constant_matrix(vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0], 2, 3); From dd8380002e78a36d7f05e513b8285d3d1ceddb65 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:43:31 +0200 Subject: [PATCH 10/19] Tighten LinExpr add shape validation --- src/canon/lin_expr.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/canon/lin_expr.rs b/src/canon/lin_expr.rs index a9c21de..e705dd4 100644 --- a/src/canon/lin_expr.rs +++ b/src/canon/lin_expr.rs @@ -84,6 +84,13 @@ impl LinExpr { /// Add two linear expressions. pub fn add(&self, other: &LinExpr) -> LinExpr { + let new_shape = self.shape.broadcast(&other.shape).unwrap_or_else(|| { + panic!( + "cannot add linear expressions with shapes {} and {}", + self.shape, other.shape + ) + }); + // Optimization: if self has no coefficients, just clone other's let coeffs = if self.coeffs.is_empty() { other.coeffs.clone() @@ -116,14 +123,10 @@ impl LinExpr { let scalar = self.constant[(0, 0)]; other.constant.map(|v| v + scalar) } else { - // Incompatible shapes, just use self (will likely error later) - self.constant.clone() - }; - - let new_shape = if self.shape.size() >= other.shape.size() { - self.shape.clone() - } else { - other.shape.clone() + panic!( + "cannot add linear expression constants with shapes {} and {}", + self.shape, other.shape + ); }; LinExpr { @@ -286,6 +289,14 @@ mod tests { assert_eq!(sum.variables().len(), 2); } + #[test] + #[should_panic(expected = "cannot add linear expressions with shapes (2,) and (3,)")] + fn test_lin_expr_add_incompatible_shapes_panics() { + let e1 = LinExpr::variable(ExprId::new(), Shape::vector(2)); + let e2 = LinExpr::variable(ExprId::new(), Shape::vector(3)); + let _ = e1.add(&e2); + } + #[test] fn test_quad_expr_from_linear() { let var_id = ExprId::new(); From be945ac35fbc115e9833d20851fc98b1f3d666eb Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Tue, 9 Jun 2026 23:46:35 +0200 Subject: [PATCH 11/19] Validate expression shape fallbacks --- src/expr/expression.rs | 81 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 6 deletions(-) diff --git a/src/expr/expression.rs b/src/expr/expression.rs index 32ae206..6c2dfa9 100644 --- a/src/expr/expression.rs +++ b/src/expr/expression.rs @@ -356,12 +356,11 @@ impl Expr { if axis.is_some() { // Sum along axis reduces that dimension let dims = a.shape(); - if dims.ndim() <= 1 { - Shape::scalar() - } else if *axis == Some(0) { - Shape::vector(dims.cols()) - } else { - Shape::vector(dims.rows()) + match (dims.ndim(), axis.unwrap()) { + (0 | 1, 0) => Shape::scalar(), + (2, 0) => Shape::vector(dims.cols()), + (2, 1) => Shape::vector(dims.rows()), + (_, axis) => panic!("axis {} out of bounds for shape {}", axis, dims), } } else { Shape::scalar() @@ -374,6 +373,13 @@ impl Expr { return Shape::scalar(); } let first = exprs[0].shape(); + for e in &exprs[1..] { + assert_eq!( + e.shape().cols(), + first.cols(), + "vstack requires matching column counts" + ); + } let total_rows: usize = exprs.iter().map(|e| e.shape().rows()).sum(); Shape::matrix(total_rows, first.cols()) } @@ -382,6 +388,13 @@ impl Expr { return Shape::scalar(); } let first = exprs[0].shape(); + for e in &exprs[1..] { + assert_eq!( + e.shape().rows(), + first.rows(), + "hstack requires matching row counts" + ); + } let total_cols: usize = exprs.iter().map(|e| e.shape().cols()).sum(); Shape::matrix(first.rows(), total_cols) } @@ -616,4 +629,60 @@ mod tests { ); assert_eq!(Expr::Sum(Arc::new(x), Some(1)).shape(), Shape::vector(2)); } + + #[test] + #[should_panic(expected = "axis 2 out of bounds for shape (2, 3)")] + fn test_sum_axis_invalid_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::Sum(Arc::new(x), Some(2)).shape(); + } + + #[test] + #[should_panic(expected = "vstack requires matching column counts")] + fn test_vstack_mismatched_columns_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + let y = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(3, 2), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::VStack(vec![Arc::new(x), Arc::new(y)]).shape(); + } + + #[test] + #[should_panic(expected = "hstack requires matching row counts")] + fn test_hstack_mismatched_rows_direct_expr_shape_panics() { + let x = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(2, 3), + name: None, + nonneg: false, + nonpos: false, + }); + let y = Expr::Variable(VariableData { + id: ExprId::new(), + shape: Shape::matrix(3, 3), + name: None, + nonneg: false, + nonpos: false, + }); + + let _ = Expr::HStack(vec![Arc::new(x), Arc::new(y)]).shape(); + } } From 1aac15328dfa7f72c86efb6950cb5ddc0dbb2244 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 19:22:27 +0200 Subject: [PATCH 12/19] Add regression tests for matrix variables --- tests/canon_tests.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs index 4f11aa7..4a0335e 100644 --- a/tests/canon_tests.rs +++ b/tests/canon_tests.rs @@ -314,6 +314,71 @@ fn test_constant_vector_times_scalar_affine_broadcasts() { } } +#[test] +fn test_norm1_matrix_variable_canonicalizes_flat_aux_constraints() { + let x = variable((2, 3)); + + let sol = Problem::minimize(norm1(&x)) + .subject_to([x.ge(1.0)]) + .solve() + .expect("matrix norm1 problem should solve"); + + assert!((sol.value.unwrap() - 6.0).abs() < TOL); +} + +#[test] +fn test_norm_inf_matrix_variable_canonicalizes_flat_aux_constraints() { + let x = variable((2, 3)); + + let sol = Problem::minimize(norm_inf(&x)) + .subject_to([sum(&x).eq(6.0)]) + .solve() + .expect("matrix norm_inf problem should solve"); + + assert!((sol.value.unwrap() - 1.0).abs() < TOL); +} + +#[test] +fn test_sum_squares_matrix_argument_flattens_soc_stack() { + let x = variable((2, 3)); + + let sol = Problem::maximize(sum(&x)) + .subject_to([sum_squares(&x).le(1.0)]) + .solve() + .expect("matrix sum_squares constraint should solve"); + + assert!((sol.value.unwrap() - 6.0_f64.sqrt()).abs() < TOL); +} + +#[test] +fn test_quad_over_lin_matrix_argument_flattens_soc_stack() { + let x = variable((2, 3)); + + let sol = Problem::maximize(sum(&x)) + .subject_to([quad_over_lin(&x, &constant(1.0)).le(1.0)]) + .solve() + .expect("matrix quad_over_lin constraint should solve"); + + assert!((sol.value.unwrap() - 6.0_f64.sqrt()).abs() < TOL); +} + +#[test] +fn test_maximum_minimum_accept_vector_and_column_shapes() { + let x = variable(3); + let y = variable((3, 1)); + + assert_eq!( + maximum(vec![x.clone(), y.clone()]).shape(), + Shape::vector(3) + ); + assert_eq!( + minimum(vec![x.clone(), y.clone()]).shape(), + Shape::vector(3) + ); + assert_eq!(max2(&x, &y).shape(), Shape::vector(3)); + assert_eq!(min2(&x, &y).shape(), Shape::vector(3)); +} + fn solution_value(sol: &Solution, expr: &Expr) -> f64 { expr.value(sol).as_scalar().expect("expected scalar") } From ec6602b5cc017c92458e8ad60e0473017e66539a Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 19:34:03 +0200 Subject: [PATCH 13/19] Add regression tests for column major vstack --- tests/canon_tests.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs index 4a0335e..cfd4a89 100644 --- a/tests/canon_tests.rs +++ b/tests/canon_tests.rs @@ -379,6 +379,38 @@ fn test_maximum_minimum_accept_vector_and_column_shapes() { assert_eq!(min2(&x, &y).shape(), Shape::vector(3)); } +#[test] +fn test_vstack_matrix_constraint_uses_column_major_interleaving() { + let x = variable((2, 2)); + let y = variable((1, 2)); + let target = constant_dmatrix(DMatrix::from_row_slice( + 3, + 2, + &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], + )); + + let sol = Problem::minimize(sum(&x)) + .subject_to([vstack(vec![x.clone(), y.clone()]).eq(target)]) + .solve() + .expect("matrix vstack equality should solve"); + + if let Array::Dense(x_vals) = x.value(&sol) { + assert!((x_vals[(0, 0)] - 1.0).abs() < TOL); + assert!((x_vals[(1, 0)] - 2.0).abs() < TOL); + assert!((x_vals[(0, 1)] - 4.0).abs() < TOL); + assert!((x_vals[(1, 1)] - 5.0).abs() < TOL); + } else { + panic!("expected dense matrix solution for x"); + } + + if let Array::Dense(y_vals) = y.value(&sol) { + assert!((y_vals[(0, 0)] - 3.0).abs() < TOL); + assert!((y_vals[(0, 1)] - 6.0).abs() < TOL); + } else { + panic!("expected dense matrix solution for y"); + } +} + fn solution_value(sol: &Solution, expr: &Expr) -> f64 { expr.value(sol).as_scalar().expect("expected scalar") } From 1b143937479644c7423a416b0ea4c332e6a745aa Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 20:05:44 +0200 Subject: [PATCH 14/19] Replace assert_eq! with debug_assert_eq! --- src/canon/canonicalizer.rs | 8 ++++---- src/sparse.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index b083e35..4d54e06 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -659,7 +659,7 @@ impl CanonContext { } fn vstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { - assert_eq!( + debug_assert_eq!( a.shape.cols(), b.shape.cols(), "vstack requires matching column counts" @@ -712,7 +712,7 @@ impl CanonContext { } fn hstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { - assert_eq!( + debug_assert_eq!( a.shape.rows(), b.shape.rows(), "hstack requires matching row counts" @@ -1447,7 +1447,7 @@ fn dense_sparse_matmul(dense: &DMatrix, sparse: &CscMatrix) -> CscMatr } fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { - assert_eq!( + debug_assert_eq!( a.ncols(), b.ncols(), "vstack requires matching column counts" @@ -1461,7 +1461,7 @@ fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { } fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { - assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + debug_assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut result = DMatrix::zeros(a.nrows(), a.ncols() + b.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result diff --git a/src/sparse.rs b/src/sparse.rs index 402af77..da4cfca 100644 --- a/src/sparse.rs +++ b/src/sparse.rs @@ -67,7 +67,7 @@ pub fn csc_to_dense(sparse: &CscMatrix) -> DMatrix { /// Stack two CSC matrices vertically. pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { - assert_eq!( + debug_assert_eq!( a.ncols(), b.ncols(), "vstack requires matching column counts" @@ -125,7 +125,7 @@ pub fn sparse_dense_matmul(a: &CscMatrix, b: &DMatrix) -> CscMatrix, b: &CscMatrix) -> CscMatrix { - assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + debug_assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new(); From 098cd925f89219e9aa528ed0e7bc769ed5ad01f1 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 20:14:57 +0200 Subject: [PATCH 15/19] Fix matrix-shaped norm/quadratic canonicalization --- src/canon/canonicalizer.rs | 16 +++++++ src/canon/lin_expr.rs | 88 +++++++++++++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 4d54e06..1c1387a 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -847,6 +847,7 @@ impl CanonContext { // Introduce t_i >= 0, -t_i <= x_i <= t_i // Then ||x||_1 = sum(t_i) let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let size = cx.size(); let (_, t) = self.new_nonneg_aux_var(Shape::vector(size)); @@ -880,6 +881,7 @@ impl CanonContext { // ||x||_inf = max(|x_i|) // Introduce t >= 0, -t <= x_i <= t for all i let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let size = cx.size(); let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); @@ -1002,6 +1004,7 @@ impl CanonContext { fn canonicalize_sum_squares(&mut self, x: &Expr, for_objective: bool) -> CanonExpr { // ||x||_2^2 = x' x let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); self.canonicalize_sum_squares_lin(&cx, for_objective) } @@ -1059,6 +1062,7 @@ impl CanonContext { fn canonicalize_quad_over_lin(&mut self, x: &Expr, y: &Expr) -> CanonExpr { // ||x||_2^2 / y: introduce t with ||[2x; t-y]||_2 <= t+y. let cx = self.canonicalize_expr(x, false).as_linear().clone(); + let cx = self.flatten_lin(&cx); let cy = self.canonicalize_expr(y, false).as_linear().clone(); let (_, t) = self.new_nonneg_aux_var(Shape::scalar()); @@ -1098,6 +1102,18 @@ impl CanonContext { }) } + fn flatten_lin(&self, x: &LinExpr) -> LinExpr { + let size = x.size(); + LinExpr { + coeffs: x.coeffs.clone(), + constant: x + .constant + .clone() + .reshape_generic(nalgebra::Dyn(size), nalgebra::Dyn(1)), + shape: Shape::vector(size), + } + } + fn expand_scalar(&self, scalar: &LinExpr, size: usize) -> LinExpr { // Expand a scalar to a vector by repeating let ones = DMatrix::from_element(size, 1, 1.0); diff --git a/src/canon/lin_expr.rs b/src/canon/lin_expr.rs index e705dd4..089f5ef 100644 --- a/src/canon/lin_expr.rs +++ b/src/canon/lin_expr.rs @@ -84,27 +84,54 @@ impl LinExpr { /// Add two linear expressions. pub fn add(&self, other: &LinExpr) -> LinExpr { - let new_shape = self.shape.broadcast(&other.shape).unwrap_or_else(|| { - panic!( - "cannot add linear expressions with shapes {} and {}", - self.shape, other.shape - ) - }); + let new_shape = + if self.shape.rows() == other.shape.rows() && self.shape.cols() == other.shape.cols() { + if self.shape.is_vector() { + self.shape.clone() + } else if other.shape.is_vector() { + other.shape.clone() + } else { + self.shape.clone() + } + } else if self.shape.is_scalar_like() && !other.shape.is_scalar_like() { + other.shape.clone() + } else if other.shape.is_scalar_like() && !self.shape.is_scalar_like() { + self.shape.clone() + } else if self.shape.size() == other.shape.size() { + Shape::vector(self.shape.size()) + } else { + panic!( + "cannot add linear expressions with shapes {} and {}", + self.shape, other.shape + ) + }; + let new_size = new_shape.size(); // Optimization: if self has no coefficients, just clone other's let coeffs = if self.coeffs.is_empty() { - other.coeffs.clone() + other + .coeffs + .iter() + .map(|(var_id, coeff)| (*var_id, coeff_for_size(coeff, other.size(), new_size))) + .collect() } else if other.coeffs.is_empty() { - self.coeffs.clone() + self.coeffs + .iter() + .map(|(var_id, coeff)| (*var_id, coeff_for_size(coeff, self.size(), new_size))) + .collect() } else { // Both have coefficients - clone the larger one and merge smaller into it - let mut coeffs = self.coeffs.clone(); + let mut coeffs = HashMap::new(); + for (var_id, coeff) in &self.coeffs { + coeffs.insert(*var_id, coeff_for_size(coeff, self.size(), new_size)); + } coeffs.reserve(other.coeffs.len()); for (var_id, coeff) in &other.coeffs { + let coeff = coeff_for_size(coeff, other.size(), new_size); coeffs .entry(*var_id) - .and_modify(|c| *c = csc_add(c, coeff)) - .or_insert_with(|| coeff.clone()); + .and_modify(|c| *c = csc_add(c, &coeff)) + .or_insert(coeff); } coeffs }; @@ -122,6 +149,16 @@ impl LinExpr { // Broadcast scalar self to match other's shape let scalar = self.constant[(0, 0)]; other.constant.map(|v| v + scalar) + } else if self.shape.size() == other.shape.size() { + let self_flat = self + .constant + .clone() + .reshape_generic(nalgebra::Dyn(new_size), nalgebra::Dyn(1)); + let other_flat = other + .constant + .clone() + .reshape_generic(nalgebra::Dyn(new_size), nalgebra::Dyn(1)); + self_flat + other_flat } else { panic!( "cannot add linear expression constants with shapes {} and {}", @@ -168,6 +205,35 @@ impl LinExpr { } } +fn coeff_for_size( + coeff: &CscMatrix, + source_size: usize, + target_size: usize, +) -> CscMatrix { + if coeff.nrows() == target_size { + return coeff.clone(); + } + if source_size == 1 { + let mut rows = Vec::new(); + let mut cols = Vec::new(); + let mut vals = Vec::new(); + for (row, col, val) in coeff.triplet_iter() { + debug_assert_eq!(row, 0); + for target_row in 0..target_size { + rows.push(target_row); + cols.push(col); + vals.push(*val); + } + } + return crate::sparse::triplets_to_csc(target_size, coeff.ncols(), &rows, &cols, &vals); + } + panic!( + "cannot resize coefficient rows from {} to {}", + coeff.nrows(), + target_size + ); +} + /// A quadratic expression: (1/2) x' P x + q' x + r /// /// Used for quadratic objectives in QP problems. From 6c602425766128310f0f6eee254417b0e49fe2bb Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 20:15:49 +0200 Subject: [PATCH 16/19] Fix canonicalize vstack_lin --- src/canon/canonicalizer.rs | 54 ++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 1c1387a..36aee20 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -677,14 +677,14 @@ impl CanonContext { let ca = a.coeffs.get(&var_id); let cb = b.coeffs.get(&var_id); let stacked = match (ca, cb) { - (Some(ma), Some(mb)) => stack_csc_vertical(ma, mb), + (Some(ma), Some(mb)) => stack_csc_vertical_for_shapes(ma, mb, &a.shape, &b.shape), (Some(ma), None) => { let zeros = CscMatrix::zeros(b.size(), ma.ncols()); - stack_csc_vertical(ma, &zeros) + stack_csc_vertical_for_shapes(ma, &zeros, &a.shape, &b.shape) } (None, Some(mb)) => { let zeros = CscMatrix::zeros(a.size(), mb.ncols()); - stack_csc_vertical(&zeros, mb) + stack_csc_vertical_for_shapes(&zeros, mb, &a.shape, &b.shape) } (None, None) => continue, }; @@ -1486,8 +1486,52 @@ fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { result } -fn stack_csc_vertical(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { - csc_vstack(a, b) +fn stack_csc_vertical_for_shapes( + a: &CscMatrix, + b: &CscMatrix, + a_shape: &Shape, + b_shape: &Shape, +) -> CscMatrix { + debug_assert_eq!( + a_shape.cols(), + b_shape.cols(), + "vstack requires matching column counts" + ); + debug_assert_eq!(a.ncols(), b.ncols()); + + let a_rows = a_shape.rows(); + let b_rows = b_shape.rows(); + let cols = a_shape.cols(); + let out_rows = a_rows + b_rows; + + let mut rows = Vec::new(); + let mut col_indices = Vec::new(); + let mut vals = Vec::new(); + + for (r, c, v) in a.triplet_iter() { + let matrix_col = r / a_rows; + let matrix_row = r % a_rows; + debug_assert!(matrix_col < cols); + rows.push(matrix_row + matrix_col * out_rows); + col_indices.push(c); + vals.push(*v); + } + for (r, c, v) in b.triplet_iter() { + let matrix_col = r / b_rows; + let matrix_row = r % b_rows; + debug_assert!(matrix_col < cols); + rows.push(a_rows + matrix_row + matrix_col * out_rows); + col_indices.push(c); + vals.push(*v); + } + + crate::sparse::triplets_to_csc( + a_shape.size() + b_shape.size(), + a.ncols(), + &rows, + &col_indices, + &vals, + ) } fn repeat_rows_csc(m: &CscMatrix, times: usize) -> CscMatrix { From 703c8452748da0b3f6385bf2051681ec37ca6ce6 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 20:27:47 +0200 Subject: [PATCH 17/19] Unify binary and n-ary elementwise broadcasting Route binary broadcasting through the n-ary helper and treat shapes with matching effective rows/cols as already compatible. This lets maximum and minimum accept mixed (n,) and (n,1) arguments consistently with binary ops and constraints. --- src/atoms/broadcast.rs | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/atoms/broadcast.rs b/src/atoms/broadcast.rs index cec4d95..bbbf5f2 100644 --- a/src/atoms/broadcast.rs +++ b/src/atoms/broadcast.rs @@ -4,29 +4,16 @@ use crate::atoms::affine::{matmul, promote}; use crate::expr::{Expr, Shape, ones}; pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { - let lhs_shape = lhs.shape(); - let rhs_shape = rhs.shape(); - let target_shape = lhs_shape - .broadcast(&rhs_shape) - .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape)); - - if lhs_shape == target_shape && rhs_shape == target_shape { - return (lhs, rhs); - } - if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() { - return (lhs, rhs); - } - - let lhs = broadcast_to(lhs, &lhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape)); - let rhs = broadcast_to(rhs, &rhs_shape, &target_shape) - .unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape)); - + let mut exprs = broadcast_elementwise_exprs([lhs, rhs]).1; + let rhs = exprs.pop().expect("binary broadcast should return rhs"); + let lhs = exprs.pop().expect("binary broadcast should return lhs"); (lhs, rhs) } pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option { - if expr_shape == target_shape { + if expr_shape == target_shape + || (expr_shape.rows() == target_shape.rows() && expr_shape.cols() == target_shape.cols()) + { return Some(expr); } From 3377fc1749cbb111b408e9fcef30937f68a71573 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 20:38:43 +0200 Subject: [PATCH 18/19] Optimize canonicalize_minimum/maximum performance --- src/canon/canonicalizer.rs | 64 +++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 36aee20..fd48fae 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -941,10 +941,38 @@ impl CanonContext { if exprs.is_empty() { panic!("maximum requires at least one expression"); } - let (shape, exprs) = - broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); + let first_shape = exprs[0].shape(); + let all_same_shape = exprs.iter().all(|expr| expr.shape() == first_shape); + let shape = if all_same_shape { + first_shape + } else { + exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("maximum requires at least one expression") + }; let (_, t) = self.new_aux_var(shape); + if all_same_shape { + for e in exprs { + let ce = self + .canonicalize_expr(e.as_ref(), false) + .as_linear() + .clone(); + // t >= x_i, i.e., t - x_i >= 0 + self.constraints.push(ConeConstraint::NonNeg { + a: t.add(&ce.neg()), + }); + } + return CanonExpr::Linear(t); + } + + let (_, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t >= x_i, i.e., t - x_i >= 0 @@ -961,10 +989,38 @@ impl CanonContext { if exprs.is_empty() { panic!("minimum requires at least one expression"); } - let (shape, exprs) = - broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); + let first_shape = exprs[0].shape(); + let all_same_shape = exprs.iter().all(|expr| expr.shape() == first_shape); + let shape = if all_same_shape { + first_shape + } else { + exprs + .iter() + .map(|expr| expr.shape()) + .reduce(|acc, shape| { + acc.broadcast(&shape) + .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) + }) + .expect("minimum requires at least one expression") + }; let (_, t) = self.new_aux_var(shape); + if all_same_shape { + for e in exprs { + let ce = self + .canonicalize_expr(e.as_ref(), false) + .as_linear() + .clone(); + // t <= x_i, i.e., x_i - t >= 0 + self.constraints.push(ConeConstraint::NonNeg { + a: ce.add(&t.neg()), + }); + } + return CanonExpr::Linear(t); + } + + let (_, exprs) = + broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); for e in &exprs { let ce = self.canonicalize_expr(e, false).as_linear().clone(); // t <= x_i, i.e., x_i - t >= 0 From e78578bc9127259a2ebe388faa06900e8149c0b9 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 Jun 2026 23:21:37 +0200 Subject: [PATCH 19/19] Revert assert_eq! checks --- src/canon/canonicalizer.rs | 8 ++++---- src/sparse.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index fd48fae..645e3f4 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -659,7 +659,7 @@ impl CanonContext { } fn vstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { - debug_assert_eq!( + assert_eq!( a.shape.cols(), b.shape.cols(), "vstack requires matching column counts" @@ -712,7 +712,7 @@ impl CanonContext { } fn hstack_lin(&self, a: &LinExpr, b: &LinExpr) -> LinExpr { - debug_assert_eq!( + assert_eq!( a.shape.rows(), b.shape.rows(), "hstack requires matching row counts" @@ -1519,7 +1519,7 @@ fn dense_sparse_matmul(dense: &DMatrix, sparse: &CscMatrix) -> CscMatr } fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { - debug_assert_eq!( + assert_eq!( a.ncols(), b.ncols(), "vstack requires matching column counts" @@ -1533,7 +1533,7 @@ fn stack_vertical(a: &DMatrix, b: &DMatrix) -> DMatrix { } fn stack_horizontal(a: &DMatrix, b: &DMatrix) -> DMatrix { - debug_assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut result = DMatrix::zeros(a.nrows(), a.ncols() + b.ncols()); result.view_mut((0, 0), (a.nrows(), a.ncols())).copy_from(a); result diff --git a/src/sparse.rs b/src/sparse.rs index da4cfca..402af77 100644 --- a/src/sparse.rs +++ b/src/sparse.rs @@ -67,7 +67,7 @@ pub fn csc_to_dense(sparse: &CscMatrix) -> DMatrix { /// Stack two CSC matrices vertically. pub fn csc_vstack(a: &CscMatrix, b: &CscMatrix) -> CscMatrix { - debug_assert_eq!( + assert_eq!( a.ncols(), b.ncols(), "vstack requires matching column counts" @@ -125,7 +125,7 @@ pub fn sparse_dense_matmul(a: &CscMatrix, b: &DMatrix) -> CscMatrix, b: &CscMatrix) -> CscMatrix { - debug_assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); + assert_eq!(a.nrows(), b.nrows(), "hstack requires matching row counts"); let mut rows = Vec::new(); let mut cols = Vec::new(); let mut vals = Vec::new();