-
Notifications
You must be signed in to change notification settings - Fork 3
Improve shape validation and broadcasting in atoms and canonicalization #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+843
−112
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
2c61fda
Removes sparse stack dimension widening and replaces “return zeros on…
haozhu10015 0d29cfc
Add axis index check for sum_axis
haozhu10015 cf41c4a
Add shape compatibility check for reshape
haozhu10015 a5cf1b5
Add column/row matching check for vstack/hstack
haozhu10015 bb2e3ef
Update outdated docs
haozhu10015 eae8e9c
Add number of expressions and dimension compatibility check and broad…
haozhu10015 92b8321
Update maximum/minimum atom canonicalization
haozhu10015 fc9daf6
Move broadcasting related functions to broadcast.rs
haozhu10015 311ec4e
Update canonicalizer and tests
haozhu10015 dd83800
Tighten LinExpr add shape validation
haozhu10015 be945ac
Validate expression shape fallbacks
haozhu10015 1aac153
Add regression tests for matrix variables
haozhu10015 ec6602b
Add regression tests for column major vstack
haozhu10015 1b14393
Replace assert_eq! with debug_assert_eq!
haozhu10015 098cd92
Fix matrix-shaped norm/quadratic canonicalization
haozhu10015 6c60242
Fix canonicalize vstack_lin
haozhu10015 703c845
Unify binary and n-ary elementwise broadcasting
haozhu10015 3377fc1
Optimize canonicalize_minimum/maximum performance
haozhu10015 e78578b
Revert assert_eq! checks
haozhu10015 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| //! Internal broadcasting helpers for atom construction and canonicalization. | ||
|
|
||
| use crate::atoms::affine::{matmul, promote}; | ||
| use crate::expr::{Expr, Shape, ones}; | ||
|
|
||
| pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) { | ||
| let mut exprs = broadcast_elementwise_exprs([lhs, rhs]).1; | ||
| let rhs = exprs.pop().expect("binary broadcast should return rhs"); | ||
| let lhs = exprs.pop().expect("binary broadcast should return lhs"); | ||
| (lhs, rhs) | ||
| } | ||
|
|
||
| pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> { | ||
| if expr_shape == target_shape | ||
| || (expr_shape.rows() == target_shape.rows() && expr_shape.cols() == target_shape.cols()) | ||
| { | ||
| return Some(expr); | ||
| } | ||
|
|
||
| if expr_shape.is_scalar_like() { | ||
| return Some(promote(&expr, target_shape.clone())); | ||
| } | ||
|
|
||
| broadcast_2d_to(expr, expr_shape, target_shape) | ||
| } | ||
|
|
||
| pub(crate) fn broadcast_elementwise_exprs( | ||
| exprs: impl IntoIterator<Item = Expr>, | ||
| ) -> (Shape, Vec<Expr>) { | ||
| let exprs: Vec<Expr> = exprs.into_iter().collect(); | ||
| let target_shape = exprs | ||
| .iter() | ||
| .map(|expr| expr.shape()) | ||
| .reduce(|acc, shape| { | ||
| acc.broadcast(&shape) | ||
| .unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape)) | ||
| }) | ||
| .expect("elementwise atom requires at least one expression"); | ||
|
|
||
| let exprs = exprs | ||
| .into_iter() | ||
| .map(|expr| { | ||
| let expr_shape = expr.shape(); | ||
| broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| { | ||
| panic!("cannot broadcast shape {} to {}", expr_shape, target_shape) | ||
| }) | ||
| }) | ||
| .collect(); | ||
|
|
||
| (target_shape, exprs) | ||
| } | ||
|
|
||
| fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> { | ||
| if !expr_shape.is_matrix() || !target_shape.is_matrix() { | ||
| return None; | ||
| } | ||
|
|
||
| if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols() | ||
| { | ||
| let left = ones((target_shape.rows(), 1)); | ||
| return Some(matmul(&left, &expr)); | ||
| } | ||
|
|
||
| if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows() | ||
| { | ||
| let right = ones((1, target_shape.cols())); | ||
| return Some(matmul(&expr, &right)); | ||
| } | ||
|
|
||
| None | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast_elementwise_exprspanics on shapes that the binarybroadcast_exprs(and therefore+,-, and constraints) accepts: mixing(n,)with(n,1).Shape::broadcastmaps the pair to(n,), butbroadcast_tothen fails becausebroadcast_2d_torequires the target to be 2-D — the binary version tolerates this via its rows/cols-equal early return above, which this function omits.Repro (works on
main, panics here withcannot broadcast shape (3, 1) to (3,)):Suggestion: move the rows/cols-equivalence tolerance into
broadcast_to's "already matches target" check, and makebroadcast_exprsa thin wrapper over this n-ary function — that fixes the regression and keeps the two entry points from drifting apart.