Improve shape validation and broadcasting in atoms and canonicalization#13
Conversation
… impossible CSC failure” fallbacks with explicit expectations.
…casting for minimum and maximum atoms
SteveDiamond
left a comment
There was a problem hiding this comment.
Thanks — failing loudly instead of silently returning wrong results is the right direction. However, three of the new checks fire on inputs that solve correctly on main, so this PR currently breaks norm1/norm_inf and sum_squares/quad_over_lin on matrix variables, and maximum/minimum mixing (n,) with (n,1). I reproduced all three on this branch vs. the merge base (details in the inline comments). The test suite passes because it never exercises these atoms with matrix-shaped arguments — worth adding matrix-variable tests for norm1, norm_inf, sum_squares, quad_over_lin, and mixed-shape maximum along with the fixes.
One smaller maintainability note: the same vstack/hstack check now exists at four layers (atom constructor, Expr::shape(), vstack_lin/stack_vertical, csc_vstack), and the stack_vertical/stack_horizontal copies are unreachable behind the vstack_lin asserts. Long-term it may be easier to keep one authoritative validation point and debug_assert! in the inner layers.
|
|
||
| /// Add two linear expressions. | ||
| pub fn add(&self, other: &LinExpr) -> LinExpr { | ||
| let new_shape = self.shape.broadcast(&other.shape).unwrap_or_else(|| { |
There was a problem hiding this comment.
This check breaks norm1 and norm_inf on matrix expressions. canonicalize_norm1/canonicalize_norm_inf build the aux variable t with flat shape (size,) and then call t.add(&cx.neg()) where cx keeps its matrix shape — Shape::broadcast((6,), (2,3)) returns None, so this panics. Both sides store coefficients/constants column-flat, so the addition itself was always well-defined.
Repro (solves to 6.0 on main, panics here with cannot add linear expressions with shapes (6,) and (2, 3)):
let x = variable((2, 3));
let prob = Problem::minimize(norm1(&x))
.subject_to(vec![constraint!(x >= 1.0)])
.build();
prob.solve()?;Separately, Shape::broadcast is also wider than what this function implements: it accepts e.g. (1,4)+(3,4) and (3,1)+(3,3), but the body only handles equal dims or 1×1-scalar promotion — those pairs pass this guard and then panic later in the constants branch (with the misleading "cannot add linear expression constants" message) or inside csc_add. Since LinExpr::add never actually broadcasts coefficients, the predicate here should be exactly what the body supports: equal dims, scalar-like, or equal flattened size (which also fixes the norm1/norm_inf case).
| assert_eq!( | ||
| a.shape.cols(), | ||
| b.shape.cols(), | ||
| "vstack requires matching column counts" |
There was a problem hiding this comment.
This assert breaks sum_squares and quad_over_lin on matrix arguments. The SOC constructions in canonicalize_sum_squares and canonicalize_quad_over_lin vstack a scalar LinExpr (t - 1, cols=1) with the matrix-shaped argument (cols=n). The coefficient stacking operates on flattened (size × var_size) blocks and was always consistent — only the constant DMatrix stacking relied on the old max() widening.
Repro (solves to √6 ≈ 2.449 on main, panics here with vstack requires matching column counts: left: 1, right: 3):
let x = variable((2, 3));
let prob = Problem::maximize(sum(&x))
.subject_to(vec![constraint!((sum_squares(&x)) <= 1.0)])
.build();
prob.solve()?;One option is to have the SOC paths flatten their arguments before stacking, so the assert can stay strict for genuine user-facing vstacks.
Also a heads-up while hardening this function: for inputs with more than one column, vstack_lin's plain row-concatenation of coefficient blocks is incorrect under the column-major flattening convention (rows need per-column interleaving — hstack_lin's comment notes the convention). I verified with vstack([x_{2×2}, y_{1×2}]) == known 3×2 constant: the solver returns x = [[1,3],[2,4]] instead of [[1,4],[2,5]]. That bug predates this PR, but the new assert makes these inputs look validated.
| broadcast_2d_to(expr, expr_shape, target_shape) | ||
| } | ||
|
|
||
| pub(crate) fn broadcast_elementwise_exprs( |
There was a problem hiding this comment.
broadcast_elementwise_exprs panics on shapes that the binary broadcast_exprs (and therefore +, -, and constraints) accepts: mixing (n,) with (n,1). Shape::broadcast maps the pair to (n,), but broadcast_to then fails because broadcast_2d_to requires the target to be 2-D — the binary version tolerates this via its rows/cols-equal early return above, which this function omits.
Repro (works on main, panics here with cannot broadcast shape (3, 1) to (3,)):
let x = variable(3);
let y = variable((3, 1));
let m = maximum(vec![x, y]); // also minimum, max2, min2Suggestion: move the rows/cols-equivalence tolerance into broadcast_to's "already matches target" check, and make broadcast_exprs a thin wrapper over this n-ary function — that fixes the regression and keeps the two entry points from drifting apart.
|
|
||
| let shape = exprs[0].shape(); | ||
| let (shape, exprs) = | ||
| broadcast_elementwise_exprs(exprs.iter().map(|expr| expr.as_ref().clone())); |
There was a problem hiding this comment.
Minor: this deep-clones every argument tree out of its Arc (Constant variants copy their dense arrays) and re-runs broadcasting on every canonicalization, but the maximum()/minimum() constructors already broadcast at build time — so for anything built through the public API this is a no-op that still pays O(total tree size) clones per solve. Consider checking whether all exprs[i].shape() already match first and iterating the Arc slice directly in that (common) case, only cloning when a shape actually needs rewriting.
Route binary broadcasting through the n-ary helper and treat shapes with matching effective rows/cols as already compatible. This lets maximum and minimum accept mixed (n,) and (n,1) arguments consistently with binary ops and constraints.
|
The first push consists of regression tests which should fail now. Fixes will follow. |
|
@SteveDiamond Thanks for the review! The mentioned issues are fixed. Here is the summary of changes, it would be great if you could take a look.
One note on the stack validation cleanup suggestion: I tried moving the inner stack checks to |
|
Nice job! |
This PR tightens shape checks for affine atoms, sparse stacking, linear expression addition, and canonicalizer fallback paths.
Changes
sum_axis,reshape,vstack, andhstackbroadcast.rsmaximumandminimumand canonicalizationLinExpraddition shape checksP.S. The current fix is rather simple via
panicwith the purpose of reporting the error instead of silently return a wrong result. Again this might not be ideal in the long term to panic on some user inputs. Improve on the current error handling policy is expected to be rather large and I might need to think about it carefully.