diff --git a/src/canon/canonicalizer.rs b/src/canon/canonicalizer.rs index 44f4563..eea81b0 100644 --- a/src/canon/canonicalizer.rs +++ b/src/canon/canonicalizer.rs @@ -1161,6 +1161,15 @@ impl CanonContext { return CanonExpr::Linear(cx); } + if p.abs() < 1e-10 { + // x^0 = 1 elementwise. + return CanonExpr::Linear(LinExpr::constant(DMatrix::from_element( + cx.shape.rows(), + cx.shape.cols(), + 1.0, + ))); + } + // Create auxiliary variable t for the result let (t_var_id, t) = self.new_nonneg_aux_var(cx.shape.clone()); let _ = t_var_id; diff --git a/tests/canon_tests.rs b/tests/canon_tests.rs index 0ddc596..c64baaa 100644 --- a/tests/canon_tests.rs +++ b/tests/canon_tests.rs @@ -173,6 +173,30 @@ fn test_quad_over_lin_objective_variable_denominator() { assert!((sol.value.unwrap() - 2.0).abs() < TOL); } +#[test] +fn test_power_zero_is_constant_one() { + let x = variable(()); + + let sol = Problem::minimize(power(&x, 0.0)) + .subject_to([x.ge(1.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 1.0).abs() < TOL); +} + +#[test] +fn test_power_zero_vector_is_constant_one_elementwise() { + let x = variable(3); + + let sol = Problem::minimize(sum(&power(&x, 0.0))) + .subject_to([x.ge(1.0)]) + .solve() + .expect("problem should solve"); + + assert!((sol.value.unwrap() - 3.0).abs() < TOL); +} + #[test] fn test_power_two_is_elementwise() { let x = variable(2);