Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 76 additions & 57 deletions src/atoms/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::sync::Arc;

use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant, ones};
use super::broadcast::broadcast_exprs;
use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant};

// ============================================================================
// Operator overloading for Expr
Expand Down Expand Up @@ -192,60 +193,6 @@ fn mul_exprs(lhs: Expr, rhs: Expr) -> Expr {
Expr::Mul(Arc::new(lhs), Arc::new(rhs))
}

pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) {
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
let target_shape = lhs_shape
.broadcast(&rhs_shape)
.unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape));

if lhs_shape == target_shape && rhs_shape == target_shape {
return (lhs, rhs);
}
if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() {
return (lhs, rhs);
}

let lhs = broadcast_to(lhs, &lhs_shape, &target_shape)
.unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape));
let rhs = broadcast_to(rhs, &rhs_shape, &target_shape)
.unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape));

(lhs, rhs)
}

fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if expr_shape == target_shape {
return Some(expr);
}

if expr_shape.is_scalar_like() {
return Some(promote(&expr, target_shape.clone()));
}

broadcast_2d_to(expr, expr_shape, target_shape)
}

fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if !expr_shape.is_matrix() || !target_shape.is_matrix() {
return None;
}

if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols()
{
let left = ones((target_shape.rows(), 1));
return Some(matmul(&left, &expr));
}

if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows()
{
let right = ones((1, target_shape.cols()));
return Some(matmul(&expr, &right));
}

None
}

// ============================================================================
// Affine atom functions
// ============================================================================
Expand All @@ -257,6 +204,13 @@ pub fn sum(expr: &Expr) -> Expr {

/// Sum along a specific axis.
pub fn sum_axis(expr: &Expr, axis: usize) -> Expr {
let shape = expr.shape();
assert!(
axis < shape.ndim().max(1),
"axis {} out of bounds for shape {}",
axis,
shape
);
Expr::Sum(Arc::new(expr.clone()), Some(axis))
}

Expand All @@ -279,7 +233,15 @@ pub fn promote(expr: &Expr, shape: impl Into<Shape>) -> Expr {

/// Reshape an expression to a new shape.
pub fn reshape(expr: &Expr, shape: impl Into<Shape>) -> Expr {
Expr::Reshape(Arc::new(expr.clone()), shape.into())
let shape = shape.into();
assert_eq!(
expr.shape().size(),
shape.size(),
"cannot reshape size {} into shape {}",
expr.shape().size(),
shape
);
Expr::Reshape(Arc::new(expr.clone()), shape)
}

/// Flatten an expression to a vector.
Expand All @@ -300,11 +262,31 @@ pub fn trace(expr: &Expr) -> Expr {

/// Vertical stack (row-wise concatenation).
pub fn vstack(exprs: Vec<Expr>) -> Expr {
if let Some(first) = exprs.first() {
let cols = first.shape().cols();
for expr in &exprs[1..] {
assert_eq!(
expr.shape().cols(),
cols,
"vstack requires matching column counts"
);
}
}
Expr::VStack(exprs.into_iter().map(Arc::new).collect())
}

/// Horizontal stack (column-wise concatenation).
pub fn hstack(exprs: Vec<Expr>) -> Expr {
if let Some(first) = exprs.first() {
let rows = first.shape().rows();
for expr in &exprs[1..] {
assert_eq!(
expr.shape().rows(),
rows,
"hstack requires matching row counts"
);
}
}
Expr::HStack(exprs.into_iter().map(Arc::new).collect())
}

Expand Down Expand Up @@ -422,7 +404,7 @@ pub fn cumsum(expr: &Expr) -> Expr {
/// Diagonal matrix from vector, or diagonal of matrix.
///
/// - Vector input: Creates diagonal matrix with vector on diagonal
/// - Matrix input: Extracts diagonal as vector (v1.0: returns input as fallback)
/// - Matrix input: Extracts diagonal as vector
pub fn diag(expr: &Expr) -> Expr {
Expr::Diag(Arc::new(expr.clone()))
}
Expand Down Expand Up @@ -472,6 +454,20 @@ mod tests {
assert_eq!(s.shape(), Shape::scalar());
}

#[test]
fn test_sum_axis_vector_shape() {
let x = variable(3);
let s = sum_axis(&x, 0);
assert_eq!(s.shape(), Shape::scalar());
}

#[test]
#[should_panic(expected = "axis 1 out of bounds for shape (3,)")]
fn test_sum_axis_invalid_axis_panics() {
let x = variable(3);
let _ = sum_axis(&x, 1);
}

#[test]
fn test_promote_shape_and_metadata() {
let x = nonneg_variable(());
Expand All @@ -498,6 +494,13 @@ mod tests {
assert_eq!(promote(&x, 3).shape(), Shape::vector(3));
}

#[test]
#[should_panic(expected = "cannot reshape size 3 into shape (2, 2)")]
fn test_reshape_size_mismatch_panics() {
let x = variable(3);
let _ = reshape(&x, (2, 2));
}

#[test]
#[should_panic(expected = "only scalar expressions can be promoted")]
fn test_promote_rejects_non_scalar_like_shape() {
Expand Down Expand Up @@ -657,6 +660,22 @@ mod tests {
assert_eq!(z.shape(), Shape::matrix(5, 3));
}

#[test]
#[should_panic(expected = "vstack requires matching column counts")]
fn test_vstack_mismatched_columns_panics() {
let x = variable((2, 3));
let y = variable((3, 2));
let _ = vstack(vec![x, y]);
}

#[test]
#[should_panic(expected = "hstack requires matching row counts")]
fn test_hstack_mismatched_rows_panics() {
let x = variable((2, 3));
let y = variable((3, 3));
let _ = hstack(vec![x, y]);
}

#[test]
fn test_affine_is_affine() {
let x = variable(5);
Expand Down
71 changes: 71 additions & 0 deletions src/atoms/broadcast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//! Internal broadcasting helpers for atom construction and canonicalization.

use crate::atoms::affine::{matmul, promote};
use crate::expr::{Expr, Shape, ones};

pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) {
let mut exprs = broadcast_elementwise_exprs([lhs, rhs]).1;
let rhs = exprs.pop().expect("binary broadcast should return rhs");
let lhs = exprs.pop().expect("binary broadcast should return lhs");
(lhs, rhs)
}

pub(crate) fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if expr_shape == target_shape
|| (expr_shape.rows() == target_shape.rows() && expr_shape.cols() == target_shape.cols())
{
return Some(expr);
}

if expr_shape.is_scalar_like() {
return Some(promote(&expr, target_shape.clone()));
}

broadcast_2d_to(expr, expr_shape, target_shape)
}

pub(crate) fn broadcast_elementwise_exprs(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast_elementwise_exprs panics on shapes that the binary broadcast_exprs (and therefore +, -, and constraints) accepts: mixing (n,) with (n,1). Shape::broadcast maps the pair to (n,), but broadcast_to then fails because broadcast_2d_to requires the target to be 2-D — the binary version tolerates this via its rows/cols-equal early return above, which this function omits.

Repro (works on main, panics here with cannot broadcast shape (3, 1) to (3,)):

let x = variable(3);
let y = variable((3, 1));
let m = maximum(vec![x, y]); // also minimum, max2, min2

Suggestion: move the rows/cols-equivalence tolerance into broadcast_to's "already matches target" check, and make broadcast_exprs a thin wrapper over this n-ary function — that fixes the regression and keeps the two entry points from drifting apart.

exprs: impl IntoIterator<Item = Expr>,
) -> (Shape, Vec<Expr>) {
let exprs: Vec<Expr> = exprs.into_iter().collect();
let target_shape = exprs
.iter()
.map(|expr| expr.shape())
.reduce(|acc, shape| {
acc.broadcast(&shape)
.unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", acc, shape))
})
.expect("elementwise atom requires at least one expression");

let exprs = exprs
.into_iter()
.map(|expr| {
let expr_shape = expr.shape();
broadcast_to(expr, &expr_shape, &target_shape).unwrap_or_else(|| {
panic!("cannot broadcast shape {} to {}", expr_shape, target_shape)
})
})
.collect();

(target_shape, exprs)
}

fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if !expr_shape.is_matrix() || !target_shape.is_matrix() {
return None;
}

if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols()
{
let left = ones((target_shape.rows(), 1));
return Some(matmul(&left, &expr));
}

if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows()
{
let right = ones((1, target_shape.cols()));
return Some(matmul(&expr, &right));
}

None
}
1 change: 1 addition & 0 deletions src/atoms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! - **Nonlinear atoms**: Operations with specific curvature (norms, quadratic forms, etc.)

pub mod affine;
pub(crate) mod broadcast;
pub mod nonlinear;

// Re-export affine operations
Expand Down
Loading
Loading