Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 189 additions & 23 deletions src/atoms/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use std::ops::{Add, Div, Mul, Neg, Sub};
use std::sync::Arc;

use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant};
use crate::expr::{AxisIndex, Expr, IndexSpec, Shape, constant, ones};

// ============================================================================
// Operator overloading for Expr
Expand All @@ -35,98 +35,95 @@ impl Add for Expr {
type Output = Expr;

fn add(self, rhs: Expr) -> Expr {
Expr::Add(Arc::new(self), Arc::new(rhs))
add_exprs(self, rhs)
}
}

impl Add for &Expr {
type Output = Expr;

fn add(self, rhs: &Expr) -> Expr {
Expr::Add(Arc::new(self.clone()), Arc::new(rhs.clone()))
add_exprs(self.clone(), rhs.clone())
}
}

impl Add<&Expr> for Expr {
type Output = Expr;

fn add(self, rhs: &Expr) -> Expr {
Expr::Add(Arc::new(self), Arc::new(rhs.clone()))
add_exprs(self, rhs.clone())
}
}

impl Add<Expr> for &Expr {
type Output = Expr;

fn add(self, rhs: Expr) -> Expr {
Expr::Add(Arc::new(self.clone()), Arc::new(rhs))
add_exprs(self.clone(), rhs)
}
}

impl Sub for Expr {
type Output = Expr;

fn sub(self, rhs: Expr) -> Expr {
Expr::Add(Arc::new(self), Arc::new(Expr::Neg(Arc::new(rhs))))
sub_exprs(self, rhs)
}
}

impl Sub for &Expr {
type Output = Expr;

fn sub(self, rhs: &Expr) -> Expr {
Expr::Add(
Arc::new(self.clone()),
Arc::new(Expr::Neg(Arc::new(rhs.clone()))),
)
sub_exprs(self.clone(), rhs.clone())
}
}

impl Sub<&Expr> for Expr {
type Output = Expr;

fn sub(self, rhs: &Expr) -> Expr {
Expr::Add(Arc::new(self), Arc::new(Expr::Neg(Arc::new(rhs.clone()))))
sub_exprs(self, rhs.clone())
}
}

impl Sub<Expr> for &Expr {
type Output = Expr;

fn sub(self, rhs: Expr) -> Expr {
Expr::Add(Arc::new(self.clone()), Arc::new(Expr::Neg(Arc::new(rhs))))
sub_exprs(self.clone(), rhs)
}
}

impl Mul for Expr {
type Output = Expr;

fn mul(self, rhs: Expr) -> Expr {
Expr::Mul(Arc::new(self), Arc::new(rhs))
mul_exprs(self, rhs)
}
}

impl Mul for &Expr {
type Output = Expr;

fn mul(self, rhs: &Expr) -> Expr {
Expr::Mul(Arc::new(self.clone()), Arc::new(rhs.clone()))
mul_exprs(self.clone(), rhs.clone())
}
}

impl Mul<&Expr> for Expr {
type Output = Expr;

fn mul(self, rhs: &Expr) -> Expr {
Expr::Mul(Arc::new(self), Arc::new(rhs.clone()))
mul_exprs(self, rhs.clone())
}
}

impl Mul<Expr> for &Expr {
type Output = Expr;

fn mul(self, rhs: Expr) -> Expr {
Expr::Mul(Arc::new(self.clone()), Arc::new(rhs))
mul_exprs(self.clone(), rhs)
}
}

Expand All @@ -135,31 +132,31 @@ impl Mul<f64> for Expr {
type Output = Expr;

fn mul(self, rhs: f64) -> Expr {
Expr::Mul(Arc::new(constant(rhs)), Arc::new(self))
mul_exprs(constant(rhs), self)
}
}

impl Mul<f64> for &Expr {
type Output = Expr;

fn mul(self, rhs: f64) -> Expr {
Expr::Mul(Arc::new(constant(rhs)), Arc::new(self.clone()))
mul_exprs(constant(rhs), self.clone())
}
}

impl Mul<Expr> for f64 {
type Output = Expr;

fn mul(self, rhs: Expr) -> Expr {
Expr::Mul(Arc::new(constant(self)), Arc::new(rhs))
mul_exprs(constant(self), rhs)
}
}

impl Mul<&Expr> for f64 {
type Output = Expr;

fn mul(self, rhs: &Expr) -> Expr {
Expr::Mul(Arc::new(constant(self)), Arc::new(rhs.clone()))
mul_exprs(constant(self), rhs.clone())
}
}

Expand All @@ -168,18 +165,87 @@ impl Div<f64> for Expr {
type Output = Expr;

fn div(self, rhs: f64) -> Expr {
Expr::Mul(Arc::new(constant(1.0 / rhs)), Arc::new(self))
mul_exprs(constant(1.0 / rhs), self)
}
}

impl Div<f64> for &Expr {
type Output = Expr;

fn div(self, rhs: f64) -> Expr {
Expr::Mul(Arc::new(constant(1.0 / rhs)), Arc::new(self.clone()))
mul_exprs(constant(1.0 / rhs), self.clone())
}
}

fn add_exprs(lhs: Expr, rhs: Expr) -> Expr {
let (lhs, rhs) = broadcast_exprs(lhs, rhs);
Expr::Add(Arc::new(lhs), Arc::new(rhs))
}

fn sub_exprs(lhs: Expr, rhs: Expr) -> Expr {
let (lhs, rhs) = broadcast_exprs(lhs, rhs);
Expr::Add(Arc::new(lhs), Arc::new(Expr::Neg(Arc::new(rhs))))
}

fn mul_exprs(lhs: Expr, rhs: Expr) -> Expr {
let (lhs, rhs) = broadcast_exprs(lhs, rhs);
Expr::Mul(Arc::new(lhs), Arc::new(rhs))
}

pub(crate) fn broadcast_exprs(lhs: Expr, rhs: Expr) -> (Expr, Expr) {
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
let target_shape = lhs_shape
.broadcast(&rhs_shape)
.unwrap_or_else(|| panic!("cannot broadcast shapes {} and {}", lhs_shape, rhs_shape));

if lhs_shape == target_shape && rhs_shape == target_shape {
return (lhs, rhs);
}
if lhs_shape.rows() == rhs_shape.rows() && lhs_shape.cols() == rhs_shape.cols() {
return (lhs, rhs);
}

let lhs = broadcast_to(lhs, &lhs_shape, &target_shape)
.unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", lhs_shape, target_shape));
let rhs = broadcast_to(rhs, &rhs_shape, &target_shape)
.unwrap_or_else(|| panic!("cannot broadcast shape {} to {}", rhs_shape, target_shape));

(lhs, rhs)
}

fn broadcast_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if expr_shape == target_shape {
return Some(expr);
}

if expr_shape.is_scalar_like() {
return Some(promote(&expr, target_shape.clone()));
}

broadcast_2d_to(expr, expr_shape, target_shape)
}

fn broadcast_2d_to(expr: Expr, expr_shape: &Shape, target_shape: &Shape) -> Option<Expr> {
if !expr_shape.is_matrix() || !target_shape.is_matrix() {
return None;
}

if expr_shape.rows() == 1 && target_shape.rows() > 1 && expr_shape.cols() == target_shape.cols()
{
let left = ones((target_shape.rows(), 1));
return Some(matmul(&left, &expr));
}

if expr_shape.cols() == 1 && target_shape.cols() > 1 && expr_shape.rows() == target_shape.rows()
{
let right = ones((1, target_shape.cols()));
return Some(matmul(&expr, &right));
}

None
}

// ============================================================================
// Affine atom functions
// ============================================================================
Expand All @@ -194,6 +260,23 @@ pub fn sum_axis(expr: &Expr, axis: usize) -> Expr {
Expr::Sum(Arc::new(expr.clone()), Some(axis))
}

/// Promote a scalar-like expression to a target shape.
pub fn promote(expr: &Expr, shape: impl Into<Shape>) -> Expr {
let target_shape = shape.into();
let expr_shape = expr.shape();

if expr_shape == target_shape {
return expr.clone();
}

assert!(
expr_shape.is_scalar_like(),
"only scalar expressions can be promoted"
);

Expr::Promote(Arc::new(expr.clone()), target_shape)
}

/// Reshape an expression to a new shape.
pub fn reshape(expr: &Expr, shape: impl Into<Shape>) -> Expr {
Expr::Reshape(Arc::new(expr.clone()), shape.into())
Expand Down Expand Up @@ -347,7 +430,7 @@ pub fn diag(expr: &Expr) -> Expr {
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, variable};
use crate::expr::{constant, constant_matrix, constant_vec, nonneg_variable, variable};

#[test]
fn test_add() {
Expand Down Expand Up @@ -389,6 +472,81 @@ mod tests {
assert_eq!(s.shape(), Shape::scalar());
}

#[test]
fn test_promote_shape_and_metadata() {
let x = nonneg_variable(());
let p = promote(&x, (2, 3));

assert_eq!(p.shape(), Shape::matrix(2, 3));
assert_eq!(p.variables(), x.variables());
assert!(p.curvature().is_affine());
assert!(p.sign().is_nonneg());
}

#[test]
fn test_promote_accepts_scalar_like_shapes() {
let row_scalar = constant_vec(vec![2.0]);
let matrix_scalar = constant_matrix(vec![3.0], 1, 1);

assert_eq!(promote(&row_scalar, (2, 2)).shape(), Shape::matrix(2, 2));
assert_eq!(promote(&matrix_scalar, (2, 2)).shape(), Shape::matrix(2, 2));
}

#[test]
fn test_promote_same_shape_is_noop() {
let x = variable(3);
assert_eq!(promote(&x, 3).shape(), Shape::vector(3));
}

#[test]
#[should_panic(expected = "only scalar expressions can be promoted")]
fn test_promote_rejects_non_scalar_like_shape() {
let x = variable(2);
let _ = promote(&x, (2, 2));
}

#[test]
fn test_scalar_like_broadcast_shapes() {
let x = variable((2, 2));
let scalar = variable(());
let row_scalar = constant_vec(vec![1.0]);
let matrix_scalar = constant_matrix(vec![1.0], 1, 1);

assert_eq!((&scalar + &x).shape(), Shape::matrix(2, 2));
assert_eq!((&x - &row_scalar).shape(), Shape::matrix(2, 2));
assert_eq!((&matrix_scalar * &x).shape(), Shape::matrix(2, 2));
}

#[test]
fn test_row_and_column_broadcast_shapes() {
let m = variable((2, 3));
let row = variable((1, 3));
let col = variable((2, 1));

assert_eq!((&row + &m).shape(), Shape::matrix(2, 3));
assert_eq!((&m - &row).shape(), Shape::matrix(2, 3));
assert_eq!((&col * &m).shape(), Shape::matrix(2, 3));
assert_eq!((&m + &col).shape(), Shape::matrix(2, 3));
}

#[test]
fn test_mutual_row_and_column_broadcast_shape() {
let row = variable((1, 3));
let col = variable((2, 1));

assert_eq!((&row + &col).shape(), Shape::matrix(2, 3));
assert_eq!((&col * &row).shape(), Shape::matrix(2, 3));
}

#[test]
#[should_panic(expected = "cannot broadcast shapes (2, 2) and (3, 3)")]
fn test_incompatible_broadcast_panics_at_construction() {
let c = constant_matrix(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let x = variable((3, 3));

let _ = c * x;
}

#[test]
fn test_transpose() {
let x = variable((3, 4));
Expand All @@ -404,6 +562,14 @@ mod tests {
assert_eq!(b.shape(), Shape::vector(3));
}

#[test]
#[should_panic(expected = "cannot matrix-multiply shapes (3, 4) and (3,)")]
fn test_invalid_matmul_shape_panics() {
let a = variable((3, 4));
let x = variable(3);
let _ = matmul(&a, &x).shape();
}

#[test]
fn test_index_and_slice_shapes() {
let x = variable(10);
Expand Down
4 changes: 2 additions & 2 deletions src/atoms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub mod nonlinear;

// Re-export affine operations
pub use affine::{
cumsum, diag, dot, flatten, hstack, index, indexc, matmul, reshape, select, slice, slicec, sum,
sum_axis, trace, transpose, vstack,
cumsum, diag, dot, flatten, hstack, index, indexc, matmul, promote, reshape, select, slice,
slicec, sum, sum_axis, trace, transpose, vstack,
};

// Re-export nonlinear atoms
Expand Down
Loading
Loading