From bb7ce715ad11ac02da16f02a50483d49a6e463c6 Mon Sep 17 00:00:00 2001 From: Laura Gao Date: Fri, 21 Jun 2024 01:20:05 +0200 Subject: [PATCH 1/6] creating fourier series basic --- {tests => src/simpy/debug}/test_utils.py | 13 +- src/simpy/expr.py | 162 ++++++++++++++++++++++- src/simpy/integration.py | 27 +++- tests/test_derivatives.py | 3 +- tests/test_expand_power.py | 3 +- tests/test_exprs.py | 2 +- tests/test_fourier.py | 103 ++++++++++++++ tests/test_integrals.py | 3 +- tests/test_khan_academy_integrals.py | 3 +- tests/test_latex.py | 3 +- tests/test_nums.py | 14 +- tests/test_pts.py | 3 +- tests/test_regex.py | 2 +- tests/test_transforms.py | 2 +- tests/test_trig_simplify.py | 3 +- 15 files changed, 308 insertions(+), 38 deletions(-) rename {tests => src/simpy/debug}/test_utils.py (86%) create mode 100644 tests/test_fourier.py diff --git a/tests/test_utils.py b/src/simpy/debug/test_utils.py similarity index 86% rename from tests/test_utils.py rename to src/simpy/debug/test_utils.py index f984e67..693f976 100644 --- a/tests/test_utils.py +++ b/src/simpy/debug/test_utils.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple, Union from simpy.debug.utils import debug_repr -from simpy.expr import Expr, Symbol, TrigFunctionNotInverse, cast, log, symbols +from simpy.expr import Expr, Float, Symbol, TrigFunctionNotInverse, cast, log, symbols from simpy.integration import integrate from simpy.simplify import expand_logs, trig_simplify @@ -83,3 +83,14 @@ def unhashable_set_eq(a: list, b: list) -> bool: if el not in a: return False return True + + +def eq_float(e1: Expr, e2: Expr, atol=1e-6): + if type(e1) != type(e2): + return False + if isinstance(e1, Float) and isinstance(e2, Float): + return abs(e1.value - e2.value) < atol + + if not len(e1.children()) == len(e2.children()): + return False + return all(eq_float(c1, c2) for c1, c2 in zip(e1.children(), e2.children())) diff --git a/src/simpy/expr.py b/src/simpy/expr.py index 0f8311c..88d4e85 100644 --- a/src/simpy/expr.py +++ b/src/simpy/expr.py @@ -20,7 +20,7 @@ from dataclasses import dataclass, fields from fractions import Fraction from functools import cmp_to_key, reduce -from typing import Callable, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, Type, Union from .combinatorics import generate_permutations, multinomial_coefficient @@ -124,6 +124,7 @@ def __post_init__(self): # Experimental & not guaranteed to be robust so it's private API for now. self._strictly_positive = False + self._is_int = False def simplify(self) -> "Expr": from .simplify import simplify @@ -273,6 +274,10 @@ def is_int(self) -> bool: """Returns True if the expression is an integer.""" return False + @property + def symbolless(self) -> bool: + return len(self.symbols()) == 0 + @dataclass class Associative: @@ -794,7 +799,7 @@ def subs(self, subs: Dict[str, "Rat"]): return subs.get(self.name, self) def _evalf(self, subs): - return super()._evalf() + return subs.get(self.name, self) def diff(self, var) -> Rat: return Rat(1) if self == var else Rat(0) @@ -1152,6 +1157,11 @@ def __new__(cls, terms: List[Expr], *, skip_checks: bool = False) -> "Expr": return terms[0] return super().__new__(cls) + if any(isinstance(t, Piecewise) for t in terms): + piecewise = [t for t in terms if isinstance(t, Piecewise)][0] + other_terms = [t for t in terms if not isinstance(t, Piecewise)] + return piecewise * Prod(other_terms) + # We need to flatten BEFORE we accumulate like terms # ex: Prod(x, Prod(Power(x, -1), y)) terms = _cast(terms) @@ -1748,6 +1758,9 @@ def __new__(cls, inner: Expr, *, skip_checks: bool = False) -> "Expr": if str(pi_coeff.value) in cls._SPECIAL_KEYS: return cls.special_values[str(pi_coeff.value)] + if (pi_coeff / 2)._is_int: + return cls.special_values["0"] + # check if inner has ... a 2pi term in a sum if isinstance(inner, Sum): for t in inner.terms: @@ -1854,9 +1867,17 @@ def diff(self, var) -> Expr: @cast def __new__(cls, inner: Expr) -> Expr: - if inner.is_subtraction: - return -sin(-inner) - return super().__new__(cls, inner) + new = super().__new__(cls, inner) + if not isinstance(new, sin): + return new + if new.inner.is_subtraction: + return -sin(-new.inner) + if isinstance(new.inner, Sum): + for t in new.inner.terms: + if t == pi or t == -pi: + return sin(new.inner - t) + + return new class cos(TrigFunctionNotInverse): @@ -1898,8 +1919,12 @@ def __new__(cls, inner: Expr) -> "Expr": new = super().__new__(cls, inner) if not isinstance(new, cos): return new - if inner.is_subtraction: + if new.inner.is_subtraction: new.inner = -inner + if isinstance(new.inner, Sum): + for t in new.inner.terms: + if t == pi or t == -pi: + return -cos(new.inner - t) return new @@ -2107,3 +2132,128 @@ def remove_const_factor(expr: Expr, include_factor=False) -> Expr: def latex(expr: Expr) -> str: return expr.latex() + + +class Bound(NamedTuple): + value: Expr + inclusive: bool + + +@dataclass +class Piece: + expr: Expr + lower_bound: Bound + upper_bound: Bound + + def __post_init__(self): + if self.expr.has(Piecewise): + raise ValueError("Piecewise functions cannot be nested.") + + def __repr__(self) -> str: + return f"{self.lower_bound.value} <= x < {self.upper_bound.value}: {self.expr}" + + +class Piecewise(Expr): + pieces: List[Piece] + + def __init__(self, *args: List[Tuple[Expr, Expr, Expr]], var: Symbol = None): + pieces = [] + for arg in args: + if isinstance(arg, Piece): + pieces.append(arg) + continue + + pieces.append(Piece(_cast(arg[0]), Bound(_cast(arg[1]), True), Bound(_cast(arg[2]), False))) + self.pieces = pieces + self.var = var + + def __repr__(self) -> str: + return ( + "Piecewise(" + + ", ".join([f"{f.lower_bound.value} <= x < {f.upper_bound.value}: {f.expr}" for f in self.pieces]) + + ")" + ) + # return "Piecewise(...)" + + def latex(self) -> str: + return ( + "\\begin{cases} " + + " \\\\ ".join([f"{f.lower_bound.value} \\leq x < {f.upper_bound.value}: {f.expr}" for f in self.pieces]) + + " \\end{cases}" + ) + + def _evalf(self, subs) -> "Piecewise": + return Piecewise(*[Piece(p.expr._evalf(subs), p.lower_bound, p.upper_bound) for p in self.pieces]) + + def children(self) -> List[Expr]: + return [p.expr for p in self.pieces] + + def diff(self, var) -> "Piecewise": + return Piecewise(*[Piece(p.expr.diff(var), p.lower_bound, p.upper_bound) for p in self.pieces]) + + def subs(self, subs) -> "Piecewise": + return Piecewise(*[Piece(p.expr.subs(subs), p.lower_bound, p.upper_bound) for p in self.pieces]) + + def __add__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: x + y) + + def _operate(self, other: Expr, fn) -> "Piecewise": + if not isinstance(other, Piecewise): + return Piecewise(*[Piece(fn(p.expr, other), p.lower_bound, p.upper_bound) for p in self.pieces]) + + if self.var != other.var: + raise NotImplementedError + if len(self.pieces) != len(other.pieces): + raise NotImplementedError + + pieces = [] + for p1, p2 in zip(self.pieces, other.pieces): + if p1.lower_bound != p2.lower_bound or p1.upper_bound != p2.upper_bound: + raise NotImplementedError + pieces.append(Piece(fn(p1.expr, p2.expr), p1.lower_bound, p1.upper_bound)) + + return Piecewise(*pieces, var=self.var) + + def __radd__(self, other: Expr) -> "Piecewise": + return self + other + + def __sub__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: x - y) + + def __rsub__(self, other: Expr) -> "Piecewise": + return -self + other + + def __mul__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: x * y) + + def __rmul__(self, other: Expr) -> "Piecewise": + return self * other + + def __truediv__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: x / y) + + def __rtruediv__(self, other: Expr) -> "Piecewise": + return self**-1 * other + + def __neg__(self) -> "Piecewise": + # neg each expr + return Piecewise(*[Piece(-p.expr, p.lower_bound, p.upper_bound) for p in self.pieces]) + + def __pow__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: x**y) + + def __rpow__(self, other: Expr) -> "Piecewise": + return self._operate(other, fn=lambda x, y: y**x) + + def __eq__(self, other: Expr) -> bool: + if not isinstance(other, Piecewise): + return False + if len(self.pieces) != len(other.pieces): + return False + for p1, p2 in zip(self.pieces, other.pieces): + if p1.lower_bound != p2.lower_bound or p1.upper_bound != p2.upper_bound or p1.expr != p2.expr: + return False + return True + + def __ne__(self, other: Expr) -> bool: + return not self == other diff --git a/src/simpy/integration.py b/src/simpy/integration.py index a63999d..14e4f2d 100644 --- a/src/simpy/integration.py +++ b/src/simpy/integration.py @@ -6,7 +6,7 @@ from typing import Callable, List, Literal, Tuple, Union from .debug.tree import print_solution_tree, print_tree -from .expr import Expr, Optional, Symbol, cast, nesting +from .expr import Expr, Optional, Piece, Piecewise, Sum, Symbol, cast, nesting from .integral_table import check_integral_table from .transforms import HEURISTICS, SAFE_TRANSFORMS, Node @@ -123,6 +123,25 @@ def __init__(self, *, debug: bool = False, debug_hardcore: bool = False, breadth def integrate_bounds(self, expr: Expr, bounds: Tuple[Symbol, Expr, Expr]) -> Optional[Expr]: """Performs definite integral.""" x, a, b = bounds + + if isinstance(expr, Piecewise): + assert a.symbolless + assert b.symbolless + assert all(p.lower_bound.value.symbolless and p.upper_bound.value.symbolless for p in expr.pieces) + + total = [] + for piece in expr.pieces: + if piece.lower_bound.value.evalf() >= b.evalf() or piece.upper_bound.value.evalf() <= a.evalf(): + continue + ans = self.integrate_bounds( + piece.expr, (x, max(a, piece.lower_bound.value), min(b, piece.upper_bound.value)) + ) + if ans is None: + return None + total.append(ans) + + return Sum(total) + integral = self.integrate(expr, bounds[0], final=False) if integral is None: return None @@ -233,6 +252,12 @@ def _check_if_depth_first_bad(self, ans: Node) -> bool: def integrate(self, integrand: Expr, var: Symbol, final=True) -> Optional[Expr]: """Performs indefinite integral.""" + if isinstance(integrand, Piecewise): + return Piecewise( + *[Piece(self.integrate(p.expr, var), p.lower_bound, p.upper_bound) for p in integrand.pieces], + var=integrand.var, + ) + root = Node(integrand.simplify(), var) _integrate_safely(root) curr_node = root diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index e97a326..86aabcc 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -1,5 +1,4 @@ -from test_utils import assert_eq_strict, assert_eq_value, x, y - +from simpy.debug.test_utils import assert_eq_strict, assert_eq_value, x, y from simpy.expr import * diff --git a/tests/test_expand_power.py b/tests/test_expand_power.py index 4062f61..a339087 100644 --- a/tests/test_expand_power.py +++ b/tests/test_expand_power.py @@ -1,6 +1,5 @@ -from test_utils import assert_eq_strict, unhashable_set_eq - from simpy.combinatorics import generate_permutations +from simpy.debug.test_utils import assert_eq_strict, unhashable_set_eq from simpy.expr import Prod, Sum, symbols diff --git a/tests/test_exprs.py b/tests/test_exprs.py index 74fd9c7..5758d6d 100644 --- a/tests/test_exprs.py +++ b/tests/test_exprs.py @@ -1,8 +1,8 @@ from fractions import Fraction import pytest -from test_utils import assert_eq_strict, unhashable_set_eq, x, y +from simpy.debug.test_utils import assert_eq_strict, unhashable_set_eq, x, y from simpy.debug.utils import debug_repr from simpy.expr import * from simpy.integration import * diff --git a/tests/test_fourier.py b/tests/test_fourier.py new file mode 100644 index 0000000..6de9db7 --- /dev/null +++ b/tests/test_fourier.py @@ -0,0 +1,103 @@ +# finding the fourier series of various fucking functions +from typing import Literal + +import simpy as sp +from simpy.debug.test_utils import eq_float + + +def get_fourier_series( + f: sp.expr.Expr, T: sp.expr.Expr, x: sp.expr.Symbol, *, settings: Literal["center", "right"] = "right" +): + """gets the fourier series of a 1d function + + args: + f: the function to get the fourier series of + T: period of the function + x: variable of the function + + returns a tuple of: + a_0: coeff of constant + a_n: coeff of cos terms + b_n: coeff of sin terms + + a_0 = 1 \cdot f + a_n = 2 * cos(2nx*pi/T) \cdot f + b_n = 2 * sin(2nx*pi/T) \cdot f + + where \cdot is the dot product defined as + f \cdot g = 1/T * integrate over one period(f*g) + """ + c_n = sp.cos(2 * sp.pi * n * x / T) + s_n = sp.sin(2 * sp.pi * n * x / T) + + bounds = (x, 0, T) if settings == "right" else (x, -T / 2, T / 2) + + a_0 = sp.integrate(f, bounds) / T + a_n = 2 * sp.integrate(c_n * f, bounds) / T + b_n = 2 * sp.integrate(s_n * f, bounds) / T + + # make the summation + summation = a_0 + for i in range(1, 5): + subs = {"n": i} + summation += a_n.subs(subs) * c_n.subs(subs) + summation += b_n.subs(subs) * s_n.subs(subs) + + return a_0, a_n, b_n, summation + + +a, n, x = sp.symbols("a n x") +n._is_int = True + + +def test_q3(): + """MATH 2410 Recitation #5 question 3 + Piecewise fn can't be integrated directly bc it has a variable bound. + """ + + f = sp.sin(sp.pi * x / a) + c_n = sp.cos(2 * n * x) + s_n = sp.sin(2 * n * x) + + a_0 = 1 / sp.pi * sp.integrate(f, (x, 0, a)) + a_n = 2 / sp.pi * sp.integrate(c_n * f, (x, 0, a)) + b_n = 2 / sp.pi * sp.integrate(s_n * f, (x, 0, a)) + + # for some reason a_n is negative the one from class notes idk why + assert a_0 == 2 * a / sp.pi**2 + + expected_an = 2 * a * (1 + sp.cos(2 * a * n)) / (sp.pi**2 - (2 * a * n) ** 2) + expected_bn = 2 * a * sp.sin(2 * a * n) / (sp.pi**2 - (2 * a * n) ** 2) + + for i in range(1, 5): + assert eq_float(a_n.evalf({"n": i}), expected_an.evalf({"n": i})) + assert eq_float(b_n.evalf({"n": i}), expected_bn.evalf({"n": i})) + + # T = sp.pi + # fp = sp.expr.Piecewise((f, 0, a)) + # a0, an, bn, summation = get_fourier_series(fp, T, x) + + +def test_odd_function(): + """MATH 2410 Recitation #5 question 2""" + + f = (sp.pi - x) / 2 + T = 2 * sp.pi + + a_0, a_n, b_n, summation = get_fourier_series(f, T, x) + assert a_0 == 0 == a_n + assert b_n == 1 / n + + +def test_even_function(): + """MATH 2410 Recitation #5 question 1 + even piecewise function with period 4 + """ + f = sp.expr.Piecewise((0, -2, -1), (1 + x, -1, 0), (1 - x, 0, 1), var=x) + T = 4 + a0, an, bn, summation = get_fourier_series(f, T, x, settings="center") + # for an even function, all the b_n terms are zero + assert a0 == 1 / 4 + expected_an = 8 * sp.sin(n * sp.pi / 4) ** 2 / (n * sp.pi) ** 2 + for i in range(1, 5): + assert eq_float(an.evalf({"n": i}), expected_an.evalf({"n": i})) diff --git a/tests/test_integrals.py b/tests/test_integrals.py index 3481dcc..f830b03 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1,4 +1,4 @@ -from test_utils import ( +from simpy.debug.test_utils import ( assert_definite_integral, assert_eq_plusc, assert_eq_strict, @@ -7,7 +7,6 @@ x, y, ) - from simpy.expr import * from simpy.integration import * diff --git a/tests/test_khan_academy_integrals.py b/tests/test_khan_academy_integrals.py index 201af52..61509a2 100644 --- a/tests/test_khan_academy_integrals.py +++ b/tests/test_khan_academy_integrals.py @@ -3,8 +3,7 @@ and make sure simpy can do them """ -from test_utils import assert_definite_integral, assert_eq_plusc, assert_eq_value, assert_integral, x, y - +from simpy.debug.test_utils import assert_definite_integral, assert_eq_plusc, assert_eq_value, assert_integral, x, y from simpy.expr import * from simpy.integration import * diff --git a/tests/test_latex.py b/tests/test_latex.py index cd2f82e..18eaf37 100644 --- a/tests/test_latex.py +++ b/tests/test_latex.py @@ -1,8 +1,7 @@ """latex is deployed to production now so I'll test it somewhat more rigorously.""" -from test_utils import * - from simpy import * +from simpy.debug.test_utils import * def assert_latex(expr: Expr, expected_latex: str): diff --git a/tests/test_nums.py b/tests/test_nums.py index a3d9380..5c9edfc 100644 --- a/tests/test_nums.py +++ b/tests/test_nums.py @@ -1,19 +1,7 @@ -from test_utils import * - +from simpy.debug.test_utils import * from simpy.expr import * -def eq_float(e1: Expr, e2: Expr, atol=1e-6): - if type(e1) != type(e2): - return False - if isinstance(e1, Float) and isinstance(e2, Float): - return abs(e1.value - e2.value) < atol - - if not len(e1.children()) == len(e2.children()): - return False - return all(eq_float(c1, c2) for c1, c2 in zip(e1.children(), e2.children())) - - def test_infinity_basic_ops(): assert 0 < inf assert Rat(1) == 1 diff --git a/tests/test_pts.py b/tests/test_pts.py index 0aee80a..8494727 100644 --- a/tests/test_pts.py +++ b/tests/test_pts.py @@ -1,6 +1,5 @@ -from test_utils import * - from simpy import * +from simpy.debug.test_utils import * from simpy.simplify.product_to_sum import * diff --git a/tests/test_regex.py b/tests/test_regex.py index 8a745cd..ed3483f 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -1,6 +1,6 @@ import pytest -from test_utils import x, y +from simpy.debug.test_utils import x, y from simpy.expr import * from simpy.regex import Any_, any_, eq diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 60b23b8..a672fdd 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,8 +6,8 @@ """ import numpy as np -from test_utils import assert_eq_strict, x +from simpy.debug.test_utils import assert_eq_strict, x from simpy.expr import Rat from simpy.transforms import CompleteTheSquare, Node, PolynomialDivision, PullConstant, to_const_polynomial diff --git a/tests/test_trig_simplify.py b/tests/test_trig_simplify.py index 50257b5..39a67a5 100644 --- a/tests/test_trig_simplify.py +++ b/tests/test_trig_simplify.py @@ -1,6 +1,5 @@ -from test_utils import * - from simpy import * +from simpy.debug.test_utils import * def assert_simplified(e1: Expr, e2: Expr): From ee4042f8dddc301fb2a5942541bdc40827c90668 Mon Sep 17 00:00:00 2001 From: Laura Gao Date: Fri, 21 Jun 2024 01:28:57 +0200 Subject: [PATCH 2/6] fix WRONG VALUE sin(x+pi) ! --- src/simpy/expr.py | 2 +- tests/test_exprs.py | 7 +++++++ tests/test_fourier.py | 10 +++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/simpy/expr.py b/src/simpy/expr.py index 88d4e85..2fbc7e7 100644 --- a/src/simpy/expr.py +++ b/src/simpy/expr.py @@ -1875,7 +1875,7 @@ def __new__(cls, inner: Expr) -> Expr: if isinstance(new.inner, Sum): for t in new.inner.terms: if t == pi or t == -pi: - return sin(new.inner - t) + return -sin(new.inner - t) return new diff --git a/tests/test_exprs.py b/tests/test_exprs.py index 5758d6d..d331824 100644 --- a/tests/test_exprs.py +++ b/tests/test_exprs.py @@ -402,3 +402,10 @@ def test_power_abs(): assert sqrt(x) ** 2 == x assert (x**6) ** Rat(1, 6) == abs(x) assert (x**3) ** Rat(1, 3) == x + + +def test_trigfunctions_plusminuspi(): + assert sin(x + pi) == -sin(x) + assert cos(x + pi) == -cos(x) + assert sin(x - pi) == -sin(x) + assert cos(x - pi) == -cos(x) diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 6de9db7..1aae96d 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -1,4 +1,6 @@ -# finding the fourier series of various fucking functions +"""finding the (symbolic, non-discrete) fourier series of various functions""" + +import random from typing import Literal import simpy as sp @@ -70,8 +72,9 @@ def test_q3(): expected_bn = 2 * a * sp.sin(2 * a * n) / (sp.pi**2 - (2 * a * n) ** 2) for i in range(1, 5): - assert eq_float(a_n.evalf({"n": i}), expected_an.evalf({"n": i})) - assert eq_float(b_n.evalf({"n": i}), expected_bn.evalf({"n": i})) + subs = {"n": i, "a": random.random()} + assert eq_float(a_n.evalf(subs), expected_an.evalf(subs)) + assert eq_float(b_n.evalf(subs), expected_bn.evalf(subs)) # T = sp.pi # fp = sp.expr.Piecewise((f, 0, a)) @@ -101,3 +104,4 @@ def test_even_function(): expected_an = 8 * sp.sin(n * sp.pi / 4) ** 2 / (n * sp.pi) ** 2 for i in range(1, 5): assert eq_float(an.evalf({"n": i}), expected_an.evalf({"n": i})) + assert eq_float(bn.evalf({"n": i}), sp.expr.Float(0.0)) From 21c4758de6d5285f4b3ee95157188f56a4329b18 Mon Sep 17 00:00:00 2001 From: Laura Gao Date: Fri, 28 Jun 2024 18:43:04 +0200 Subject: [PATCH 3/6] add kh piecewise fns to tests --- src/simpy/equation.py | 37 ++++++++++++++++++++++++++++ src/simpy/expr.py | 6 +++++ src/simpy/integration.py | 21 +++++++++++++++- tests/test_equation.py | 9 +++++++ tests/test_khan_academy_integrals.py | 16 ++++++++++++ 5 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 src/simpy/equation.py create mode 100644 tests/test_equation.py diff --git a/src/simpy/equation.py b/src/simpy/equation.py new file mode 100644 index 0000000..f0d69bd --- /dev/null +++ b/src/simpy/equation.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass + +from .expr import Expr, Symbol + + +@dataclass +class Equation: + lhs: Expr + rhs: Expr + + +def solve(equation: Equation, x: Symbol): + # Ensure the equation is set to 0 (ax + b = 0) + expr = equation.lhs - equation.rhs + + # Initialize coefficients + coeff_a = 0 + coeff_b = 0 + + # Iterate over the terms of the expression + for term in expr.as_terms(): + if (term / x).symbolless: + coeff_a += term / x + elif not term.contains(x): + coeff_b += term + else: + raise NotImplementedError("can only solve linear eqs") + + # Calculate the solution + if coeff_a == 0: + if coeff_b == 0: + return "Infinite solutions (identity equation)." + else: + return "No solution (contradictory equation)." + + x_solution = -coeff_b / coeff_a + return x_solution diff --git a/src/simpy/expr.py b/src/simpy/expr.py index 2fbc7e7..81492fd 100644 --- a/src/simpy/expr.py +++ b/src/simpy/expr.py @@ -277,6 +277,11 @@ def is_int(self) -> bool: @property def symbolless(self) -> bool: return len(self.symbols()) == 0 + + def as_terms(self): + if isinstance(self, Sum): + return self.terms + return [self] @dataclass @@ -484,6 +489,7 @@ def __new__(cls, value): def __init__(self, value): self.value = value + super().__post_init__() def latex(self): return repr(self) diff --git a/src/simpy/integration.py b/src/simpy/integration.py index 14e4f2d..0ccab67 100644 --- a/src/simpy/integration.py +++ b/src/simpy/integration.py @@ -6,7 +6,8 @@ from typing import Callable, List, Literal, Tuple, Union from .debug.tree import print_solution_tree, print_tree -from .expr import Expr, Optional, Piece, Piecewise, Sum, Symbol, cast, nesting +from .equation import Equation, solve +from .expr import Abs, Expr, Optional, Piece, Piecewise, Sum, Symbol, cast, nesting from .integral_table import check_integral_table from .transforms import HEURISTICS, SAFE_TRANSFORMS, Node @@ -142,6 +143,24 @@ def integrate_bounds(self, expr: Expr, bounds: Tuple[Symbol, Expr, Expr]) -> Opt return Sum(total) + if isinstance(expr, Abs): + # it's tricky because you have to identify the x-value where expr.inner = 0 + # and then split the integral at that point. + critical_x_value = solve(Equation(expr.inner, 0), x) + assert a.symbolless + assert b.symbolless + assert critical_x_value.symbolless + + # ok this assumes a < b + if critical_x_value <= a: + return self.integrate_bounds(expr.inner, (x, a, b)) + if critical_x_value >= b: + return self.integrate_bounds(-expr.inner, (x, a, b)) + + return self.integrate_bounds(expr.inner, (x, a, critical_x_value)) + self.integrate_bounds( + -expr.inner, (x, critical_x_value, b) + ) + integral = self.integrate(expr, bounds[0], final=False) if integral is None: return None diff --git a/tests/test_equation.py b/tests/test_equation.py new file mode 100644 index 0000000..e806405 --- /dev/null +++ b/tests/test_equation.py @@ -0,0 +1,9 @@ +import simpy as sp +from simpy.equation import Equation, solve + + +def test_simple_linear(): + x = sp.symbols("x") + equation = Equation(-2 * x + 4, 0) + solution = solve(equation, x) + assert solution == 2 diff --git a/tests/test_khan_academy_integrals.py b/tests/test_khan_academy_integrals.py index 61509a2..d8cd44a 100644 --- a/tests/test_khan_academy_integrals.py +++ b/tests/test_khan_academy_integrals.py @@ -201,3 +201,19 @@ def test_more_complicated_trig(): expr = tan(x) ** 5 * sec(x) ** 4 expected_ans = tan(x) ** 6 / 6 + tan(x) ** 8 / 8 assert_integral(expr, expected_ans) + + +def test_abs(): + expr = abs(-2 * x + 4) + assert_definite_integral(expr, bounds=(-2, 4), expected=20) + + +def test_piecewise(): + expr = Piecewise((9 * sqrt(x), 0, inf), (-2 * x, -inf, 0), var=x) + assert_definite_integral(expr, bounds=(-3, 1), expected=15) + + expr = Piecewise((3 * x**2 - 1, 0, inf), (6 * x - 1, -inf, 0), var=x) + assert_definite_integral(expr, bounds=(-1, 1), expected=-4) + + expr = Piecewise((1 / x, 1, inf), (x, -inf, 1), var=x) + assert_definite_integral(expr, bounds=(0, 3), expected=log(3) + Rat(1, 2)) From 5b6f150cb9c0052ed41a88240254e8884c9b0d07 Mon Sep 17 00:00:00 2001 From: Laura Gao Date: Fri, 28 Jun 2024 21:09:31 +0200 Subject: [PATCH 4/6] add a few more integral solves with exponential usub --- src/simpy/expr.py | 43 +++++++++++++++++++++------- src/simpy/transforms.py | 41 ++++++++++++++++++++++---- tests/test_khan_academy_integrals.py | 12 ++++++-- 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/src/simpy/expr.py b/src/simpy/expr.py index 81492fd..945e834 100644 --- a/src/simpy/expr.py +++ b/src/simpy/expr.py @@ -277,7 +277,7 @@ def is_int(self) -> bool: @property def symbolless(self) -> bool: return len(self.symbols()) == 0 - + def as_terms(self): if isinstance(self, Sum): return self.terms @@ -1359,6 +1359,10 @@ def _multiply_exponents(b: Expr, x1: Expr, x2: Expr) -> Expr: return b ** (x1 * x2) +def exp(x: Expr) -> Expr: + return e**x + + @dataclass class Power(Expr): base: Expr @@ -1673,6 +1677,8 @@ def __get__(self, instance, cls): class TrigFunction(SingleFunc, ABC): is_inverse: bool # class property _fields_already_casted = True + _odd = False + _even = False _SPECIAL_KEYS = [ "0", @@ -1777,6 +1783,12 @@ def __new__(cls, inner: Expr, *, skip_checks: bool = False) -> "Expr": instance.inner = inner - t return instance + # Odd and even shit + if cls._odd and inner.is_subtraction: + return -cls(-inner) + if cls._even and inner.is_subtraction: + return cls(-inner) + # 2. Check if inner is trigfunction # things like sin(cos(x)) cannot be more simplified. if isinstance(inner, TrigFunction) and inner.is_inverse != cls.is_inverse: @@ -1842,6 +1854,7 @@ class TrigFunctionNotInverse(TrigFunction, ABC): class sin(TrigFunctionNotInverse): func = "sin" _func = math.sin + _odd = True @classproperty def reciprocal_class(cls): @@ -1876,19 +1889,22 @@ def __new__(cls, inner: Expr) -> Expr: new = super().__new__(cls, inner) if not isinstance(new, sin): return new - if new.inner.is_subtraction: - return -sin(-new.inner) if isinstance(new.inner, Sum): for t in new.inner.terms: if t == pi or t == -pi: return -sin(new.inner - t) + # sin(n * pi) = 0 + if new.inner.has(Pi) and (new.inner / pi)._is_int: + return Rat(0) + return new class cos(TrigFunctionNotInverse): func = "cos" _func = math.cos + _even = True @classproperty def reciprocal_class(cls): @@ -1925,8 +1941,6 @@ def __new__(cls, inner: Expr) -> "Expr": new = super().__new__(cls, inner) if not isinstance(new, cos): return new - if new.inner.is_subtraction: - new.inner = -inner if isinstance(new.inner, Sum): for t in new.inner.terms: if t == pi or t == -pi: @@ -1938,17 +1952,12 @@ class tan(TrigFunctionNotInverse): func = "tan" _func = math.tan _period = 1 + _odd = True @classproperty def reciprocal_class(cls): return cot - @cast - def __new__(cls, inner: Expr) -> Expr: - if inner.is_subtraction: - return -tan(-inner) - return super().__new__(cls, inner) - @classproperty def _special_values(cls): return {k: sin.special_values[k] / cos.special_values[k] for k in cls._SPECIAL_KEYS} @@ -1961,6 +1970,7 @@ class csc(TrigFunctionNotInverse): func = "csc" _func = lambda x: 1 / math.sin(x) reciprocal_class = sin + _odd = True @classproperty def _special_values(cls): @@ -1974,6 +1984,7 @@ class sec(TrigFunctionNotInverse): func = "sec" _func = lambda x: 1 / math.cos(x) reciprocal_class = cos + _even = True @classproperty def _special_values(cls): @@ -1988,6 +1999,7 @@ class cot(TrigFunctionNotInverse): func = "cot" _func = lambda x: 1 / math.tan(x) _period = 1 + _odd = True @classproperty def _special_values(cls): @@ -2001,6 +2013,7 @@ class asin(TrigFunction): func = "sin" is_inverse = True _func = math.asin + _odd = True def diff(self, var): return 1 / sqrt(1 - self.inner**2) * self.inner.diff(var) @@ -2027,6 +2040,14 @@ class atan(TrigFunction): func = "tan" is_inverse = True _func = math.atan + _odd = True + + def __new__(cls, inner): + # TODO: standardize special value for inverse trig functions + if inner == 1: + return pi / 4 + + return super().__new__(cls, inner) def diff(self, var): return 1 / (1 + self.inner**2) * self.inner.diff(var) diff --git a/src/simpy/transforms.py b/src/simpy/transforms.py index d470eb2..994daf8 100644 --- a/src/simpy/transforms.py +++ b/src/simpy/transforms.py @@ -22,6 +22,8 @@ cos, cot, csc, + e, + exp, log, remove_const_factor, sec, @@ -33,7 +35,7 @@ from .integral_table import check_integral_table from .linalg import invert from .polynomial import Polynomial, is_polynomial, polynomial_to_expr, rid_ending_zeros, to_const_polynomial -from .regex import count, general_contains, replace, replace_class, replace_factory +from .regex import count, general_contains, general_count, replace, replace_class, replace_factory from .simplify import pythagorean_simplification from .simplify.product_to_sum import product_to_sum_unit from .utils import ExprFn, eq_with_var, random_id @@ -935,8 +937,6 @@ def backward(self, node: Node) -> None: class GenericUSub(USub): """Generic u-substitution""" - _u: Expr = None - def check(self, node: Node) -> bool: if super().check(node) is False: return False @@ -950,7 +950,7 @@ def check(self, node: Node) -> bool: integral = remove_const_factor(integral) # assume term appears only once in integrand # because node.expr is simplified - rest = Prod(node.expr.terms[:i] + node.expr.terms[i + 1 :]) + rest = Prod(node.expr.terms[:i] + node.expr.terms[i + 1 :], skip_checks=True) if count(rest, integral) == count(rest, node.var) / count(integral, node.var): self._u = integral return True @@ -962,7 +962,37 @@ def forward(self, node: Node) -> None: du_dx = self._u.diff(node.var) new_integrand = replace((node.expr / du_dx), self._u, intermediate) node.add_child(Node(new_integrand, intermediate, self, node)) - self._u = self._u + + +class ExponentialUSub(USub): + def check(self, node: Node) -> bool: + if super().check(node) is False: + return False + + if not isinstance(node.expr, Prod): + return False + + for i, term in enumerate(node.expr.terms): + if term == exp(node.var): + # now make sure all other occurances of x are in the exponent of an exp + # like it can be exp(2x) or smtn where it can be rewritten as exp(x)^2 + + rest = Prod(node.expr.terms[:i] + node.expr.terms[i + 1 :], skip_checks=True) + if count(rest, node.var) == general_count( + rest, + lambda x: isinstance(x, Power) and x.base == e and not (x.exponent / node.var).contains(node.var), + ): + self._u = exp(node.var) + self._rest = rest + return True + + return False + + def forward(self, node: Node) -> None: + intermediate = generate_intermediate_var() + x_in_terms_of_u = log(intermediate) + new_integrand = replace(self._rest, node.var, x_in_terms_of_u) + node.add_child(Node(new_integrand, intermediate, self, node)) class CompleteTheSquare(Transform): @@ -1120,6 +1150,7 @@ def backward(self, node: Node) -> None: RewritePythagorean, InverseTrigUSub, CompleteTheSquare, + ExponentialUSub, GenericUSub, ] SAFE_TRANSFORMS: List[Type[Transform]] = [ diff --git a/tests/test_khan_academy_integrals.py b/tests/test_khan_academy_integrals.py index d8cd44a..0b39019 100644 --- a/tests/test_khan_academy_integrals.py +++ b/tests/test_khan_academy_integrals.py @@ -94,11 +94,17 @@ def test_misc(): assert_integral(3 * x**5 - x**3 + 6, 6 * x - x**4 / 4 + x**6 / 2) assert_integral(x**3 * e ** (x**4), (e ** (x**4) / 4)) + assert_definite_integral(8 * x / sqrt(1 - 4 * x**2), (0, Fraction(1, 4)), 2 - sqrt(3)) + assert_definite_integral(sin(4 * x), (0, pi / 4), Fraction(1, 2)) + assert_integral(exp(x) / (1 + exp(2 * x)), atan(exp(x))) + + +def test_usub(): # Uses generic u-sub assert_definite_integral(e**x / (1 + e**x), (log(2), log(8)), log(9) - log(3)) - assert_definite_integral(8 * x / sqrt(1 - 4 * x**2), (0, Fraction(1, 4)), 2 - sqrt(3)) - assert_definite_integral(sin(4 * x), (0, pi / 4), Fraction(1, 2)) + # ideally have this work with just log(x) + assert_definite_integral(log(abs(x)) ** 2 / x, bounds=(1, e), expected=Rat(1, 3)) def test_csc_x_squared(): @@ -141,6 +147,8 @@ def test_complete_the_square_integrals(): assert_integral(1 / (x**2 - 8 * x + 65), atan((-4 + x) / 7) / 7) assert_integral(1 / sqrt(-(x**2) - 6 * x + 40), asin((3 + x) / 7)) + assert_definite_integral(1 / (1 + 9 * x**2), (-Rat(1, 3), Rat(1, 3)), expected=pi / 6) + def test_neg_inf(): assert integrate(-(e**x), (-inf, 1)) == -e From 5c77b844ad904e4ab6f5d07a3a4f1631994964f8 Mon Sep 17 00:00:00 2001 From: Laura Gao Date: Fri, 28 Jun 2024 21:16:38 +0200 Subject: [PATCH 5/6] add more fourier examples from exam --- tests/test_fourier.py | 52 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 1aae96d..bcd6a6a 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -1,14 +1,19 @@ """finding the (symbolic, non-discrete) fourier series of various functions""" import random -from typing import Literal +from typing import Literal, Union import simpy as sp from simpy.debug.test_utils import eq_float +@sp.expr.cast def get_fourier_series( - f: sp.expr.Expr, T: sp.expr.Expr, x: sp.expr.Symbol, *, settings: Literal["center", "right"] = "right" + f: sp.expr.Expr, + T: sp.expr.Expr, + x: sp.expr.Symbol, + *, + settings: Union[Literal["center", "right"], sp.expr.Expr] = "right" ): """gets the fourier series of a 1d function @@ -32,7 +37,12 @@ def get_fourier_series( c_n = sp.cos(2 * sp.pi * n * x / T) s_n = sp.sin(2 * sp.pi * n * x / T) - bounds = (x, 0, T) if settings == "right" else (x, -T / 2, T / 2) + if settings == "center": + bounds = (x, -T / 2, T / 2) + elif settings == "right": + bounds = (x, 0, T) + else: + bounds = (x, settings, T + settings) a_0 = sp.integrate(f, bounds) / T a_n = 2 * sp.integrate(c_n * f, bounds) / T @@ -65,7 +75,6 @@ def test_q3(): a_n = 2 / sp.pi * sp.integrate(c_n * f, (x, 0, a)) b_n = 2 / sp.pi * sp.integrate(s_n * f, (x, 0, a)) - # for some reason a_n is negative the one from class notes idk why assert a_0 == 2 * a / sp.pi**2 expected_an = 2 * a * (1 + sp.cos(2 * a * n)) / (sp.pi**2 - (2 * a * n) ** 2) @@ -105,3 +114,38 @@ def test_even_function(): for i in range(1, 5): assert eq_float(an.evalf({"n": i}), expected_an.evalf({"n": i})) assert eq_float(bn.evalf({"n": i}), sp.expr.Float(0.0)) + + +def test_x_plus_x_squared(): + """Exam 2 q3""" + f = x + x**2 + T = 1 + a0, an, bn, summation = get_fourier_series(f, T, x) + assert a0 == sp.expr.Rat(5, 6) + assert an == 1 / (n**2 * sp.pi**2) + assert bn == -2 / (n * sp.pi) + + +def test_x_plus_x_squared_even(): + """Exam 2 q3""" + f = x + x**2 + f_neg = (-x) + (-x) ** 2 + f_tilde = sp.expr.Piecewise((f, 0, 1), (f_neg, -1, 0), var=x) + T = 2 + a0, an, bn, summation = get_fourier_series(f_tilde, T, x, settings="center") + assert a0 == sp.expr.Rat(5, 6) + assert bn == 0 + assert an == ((6 * sp.cos(n * sp.pi) - 2) / (n**2 * sp.pi**2)).expand() + + +def test_parseval(): + """idk why but the dot product with itself is 1 or smtn ?? idk i forget + + the dot product of a function with itself is 1 + why do we * 2 then? + + anyways i forget exactly what this was, but if you want to check, it's from q4 on exam 2. + """ + expr = sp.cos(3 * x / 2) ** 2 + ans = sp.integrate(expr, (x, 0, sp.pi)) * 2 / sp.pi + assert ans == 1 From 1f0f509717f8ba1d6b25faffceb28207a9eaecc1 Mon Sep 17 00:00:00 2001 From: laura gao Date: Wed, 11 Jun 2025 01:53:18 -0400 Subject: [PATCH 6/6] replacing assertion checks for symbolless integration bounds with explicit error handling Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/simpy/integration.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/simpy/integration.py b/src/simpy/integration.py index 0ccab67..26f9604 100644 --- a/src/simpy/integration.py +++ b/src/simpy/integration.py @@ -126,9 +126,12 @@ def integrate_bounds(self, expr: Expr, bounds: Tuple[Symbol, Expr, Expr]) -> Opt x, a, b = bounds if isinstance(expr, Piecewise): - assert a.symbolless - assert b.symbolless - assert all(p.lower_bound.value.symbolless and p.upper_bound.value.symbolless for p in expr.pieces) + if not a.symbolless: + raise ValueError("Lower bound 'a' must be symbolless for integration.") + if not b.symbolless: + raise ValueError("Upper bound 'b' must be symbolless for integration.") + if not all(p.lower_bound.value.symbolless and p.upper_bound.value.symbolless for p in expr.pieces): + raise ValueError("All piecewise bounds must be symbolless for integration.") total = [] for piece in expr.pieces: