diff --git a/.gitignore b/.gitignore index 8844369..374a379 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ venv*/ coverage.xml sandbox/ +# working benchmark directory +b/ # from pip install -e . *.egg-info/ diff --git a/benchmark/benchmark-profiler.py b/benchmark/benchmark-profiler.py index 1bfabe3..5b341b5 100644 --- a/benchmark/benchmark-profiler.py +++ b/benchmark/benchmark-profiler.py @@ -14,7 +14,7 @@ BENCHMARKING_SUITE = [ -5 * x**4 / (1 - x**2) ** Rat(5, 2), x**2 / sqrt(1 - x**3), - (Rat(1, 15) - Rat(1, 360) * (x - 6)) * (1 - (40 - x) ** 2 / 875), + # (Rat(1, 15) - Rat(1, 360) * (x - 6)) * (1 - (40 - x) ** 2 / 875), # looks too ugly on the plot cos(w * x - phi) * cos(w * x), sin(2 * x) / cos(2 * x), cos(x) ** 2, diff --git a/benchmark/integration_log.png b/benchmark/integration_log.png new file mode 100644 index 0000000..11841cb Binary files /dev/null and b/benchmark/integration_log.png differ diff --git a/benchmark/integration_log.txt b/benchmark/integration_log.txt new file mode 100644 index 0000000..05e1751 --- /dev/null +++ b/benchmark/integration_log.txt @@ -0,0 +1,93 @@ +Integrand: time taken (s) + +sec(x)^4*tan(x)^5: 0.07898354530334473 +-5*x^4/(-x^2 + 1)^(5/2): 0.06334471702575684 +5*csc(x)^2: 0.0379023551940918 +cos(2*x)*sin(2*x)*sin(x): 0.03209376335144043 +sec(2*x)*tan(2*x): 0.03049325942993164 +2*cot(x)*csc(x): 0.02516341209411621 +1/sqrt(-x^2 + 10*x + 11): 0.02431011199951172 +sin(pi*x)*x^2: 0.02408909797668457 +tan(x)^4: 0.02292919158935547 +sin(x)^2*cos(x)^3: 0.022572040557861328 +cos(w*x)*cos(w*x - phi): 0.02143096923828125 +sin(x)^5: 0.019292116165161133 +x^2/sqrt(-x^3 + 1): 0.01819753646850586 +asin(x): 0.016515731811523438 +e^x/(e^x + 1): 0.01630997657775879 +sin(2*x)/cos(2*x): 0.013149738311767578 +ln(x + 6)/x^2: 0.012115240097045898 +x*e^(-x): 0.011050939559936523 +cos(w*x)*sin(w*x): 0.010923147201538086 +cos(x)^2: 0.010917901992797852 +sin(x)^2: 0.008718252182006836 +x*cos(x): 0.00684356689453125 +(2*x - 5)^10: 0.00171661376953125 +(x - 5)/(-2*x + 2): 0.00013136863708496094 +(x + 8)/(x*(x + 6)): 0.00010347366333007812 +6*e^x: 8.034706115722656e-05 + + + +Solution tree of the integral with most time spent: +[0] {4} sec(x)^4*tan(x)^5 1/(4*cos(x)^4) - 1/(3*cos(x)^6) + 1/(8*cos(x)^8) (NoneType) (solved) +[1] {4} sin(x)^5/cos(x)^9 1/(4*cos(x)^4) - 1/(3*cos(x)^6) + 1/(8*cos(x)^8) (RewriteTrig) (solved) +[2] {6} sin(x)*(-cos(x)^2 + 1)^2/cos(x)^9 1/(4*cos(x)^4) - 1/(3*cos(x)^6) + 1/(8*cos(x)^8) (RewritePythagorean) (solved) +[3] {5} sin(x)/cos(x)^9 + sin(x)/cos(x)^5 - 2*sin(x)/co... 1/(4*cos(x)^4) - 1/(3*cos(x)^6) + 1/(8*cos(x)^8) (Expand) (solved) +[4] {4} sin(x)/cos(x)^9 1/(8*cos(x)^8) (Additivity) (solved) +[5] {4} filler 1/(8*cos(x)^8) (ByParts) (SOLUTION) + +[5] {2} -1/(u_2)^9 1/(8*(u_2)^8) (GenericUSub) (solved) +[6] {2} filler 1/(8*(u_2)^8) (ByParts) (SOLUTION) + +[4] {4} sin(x)/cos(x)^5 1/(4*cos(x)^4) (Additivity) (solved) +[5] {4} filler 1/(4*cos(x)^4) (ByParts) (SOLUTION) + +[5] {2} -1/(u_3)^5 1/(4*(u_3)^4) (GenericUSub) (solved) +[6] {2} filler 1/(4*(u_3)^4) (ByParts) (SOLUTION) + +[4] {4} -2*sin(x)/cos(x)^7 -1/(3*cos(x)^6) (Additivity) (solved) +[5] {4} sin(x)/cos(x)^7 1/(6*cos(x)^6) (PullConstant) (solved) +[6] {4} filler 1/(6*cos(x)^6) (ByParts) (SOLUTION) + + + + +[0] {4} sec(x)^4*tan(x)^5 (NoneType) (solved) +[1] {4} sin(x)^5/cos(x)^9 (RewriteTrig) (solved) +[2] {6} sin(x)*(-cos(x)^2 + 1)^2/cos(x)^9 (RewritePythagorean) (solved) +[3] {5} sin(x)/cos(x)^9 + sin(x)/cos(x)^5 - 2*sin(x)/co... (Expand) (solved) +[4] {4} sin(x)/cos(x)^9 (Additivity) (solved) +[5] {4} filler (ByParts) (SOLUTION) + +[5] {4} csc(x)^8*tan(x)^9 (RewriteTrig) (stale) (UNSET) + +[5] {4} sec(x)^8/cot(x) (RewriteTrig) (stale) (UNSET) + +[5] {2} -1/(u_2)^9 (GenericUSub) (solved) +[6] {2} filler (ByParts) (SOLUTION) + +[4] {4} sin(x)/cos(x)^5 (Additivity) (solved) +[5] {4} filler (ByParts) (SOLUTION) + +[5] {4} csc(x)^4*tan(x)^5 (RewriteTrig) (stale) (UNSET) + +[5] {4} sec(x)^4/cot(x) (RewriteTrig) (stale) (UNSET) + +[5] {2} -1/(u_3)^5 (GenericUSub) (solved) +[6] {2} filler (ByParts) (SOLUTION) + +[4] {4} -2*sin(x)/cos(x)^7 (Additivity) (solved) +[5] {4} sin(x)/cos(x)^7 (PullConstant) (solved) +[6] {4} filler (ByParts) (SOLUTION) + +[6] {4} csc(x)^6*tan(x)^7 (RewriteTrig) (stale) (UNSET) + +[6] {4} sec(x)^6/cot(x) (RewriteTrig) (stale) (UNSET) + +[6] {2} -1/(u_4)^7 (GenericUSub) (stale) (UNSET) + +[1] {4} csc(x)^4*tan(x)^9 (RewriteTrig) (stale) (UNSET) + +[1] {4} sec(x)^4/cot(x)^5 (RewriteTrig) (stale) (UNSET) + diff --git a/benchmark/readme.md b/benchmark/readme.md index 857d57a..e7d4808 100644 --- a/benchmark/readme.md +++ b/benchmark/readme.md @@ -20,3 +20,7 @@ The output would look something like this: 142 0.000 0.000 0.374 0.003 integration.py:147(_cycle) ``` + +## Logging + +`benchmark-profiler.py` also keeps track of the time it takes for each integral to run and creates a plot of it. See `integration_log.png` for the bar chart; see `integration_log.txt` for some more detailed numerical breakdowns. diff --git a/src/simpy/debug/logger.py b/src/simpy/debug/logger.py index 064c0a8..2d28140 100644 --- a/src/simpy/debug/logger.py +++ b/src/simpy/debug/logger.py @@ -1,3 +1,28 @@ +"""See benchmark/benchmark-profiler.py for an example usage of the logger. + +``` +import simpy as sp +from simpy.debug.logger import Logger +from simpy.integration import Integration + +logger = Logger() +Integration.logger = logger + +# do some integration as normal... +x = sp.symbols('x') +sp.integrate(x ** 2) + +logger.dump() # dumps information into integration_log.txt +logger.plot() # creates a bar chart of integrals speeds to integration_log.png + +``` + +Yeah, this is kind of jank, mutating the whole integration classs... What can I say, this whole debug module is jank. + +TODO: make some consistent organizational decisions. Maybe move benchmark to this folder or move logger to +benchmark. Probably the former. Maybe BENCHMARK_SUITE can be exported from the debug module. +""" + import time from typing import Dict, NamedTuple @@ -25,6 +50,12 @@ def __init__(self): self._data = {} def log(self, expr: Expr, time_spent: float, root: Node): + """Log an integration entry. + + expr: integrand + time_spent: time taken to integrate expr, in seconds + root: the root node of the integration tree + """ self._data[str(expr)] = Datum(expr, time_spent, root) @property @@ -39,12 +70,14 @@ def dump(self): self.sort() with open("integration_log.txt", "w") as f: + f.write("Integrand: time taken (s)") + f.write("\n\n") for k, v in self._data.items(): f.write(f"{k}: {v.time_spent}\n") # For the one with the most time spent, print the tree. f.write("\n\n\n") - f.write("Most time spent: \n") + f.write("Solution tree of the integral with most time spent: \n") print_solution_tree(self._data[list(self._data.keys())[0]].root, func=lambda x: f.write(f"{x}\n")) f.write("\n\n\n") print_tree(self._data[list(self._data.keys())[0]].root, func=lambda x: f.write(f"{x}\n")) @@ -56,6 +89,9 @@ def plot(self): x = list(self._data.keys()) y = [v.time_spent for v in self._data.values()] plt.bar(x, y) + plt.ylabel("Time taken to integrate (s)") + plt.xticks(rotation=90) # rotate labels vertically + plt.tight_layout() # automatically adjust spacing (needed to show the entirety of the vertical labels) plt.savefig("integration_log.png") diff --git a/src/simpy/expr.py b/src/simpy/expr.py index 0f8311c..b8e5e27 100644 --- a/src/simpy/expr.py +++ b/src/simpy/expr.py @@ -228,6 +228,11 @@ def evalf(self, subs: Optional[Dict[str, "Expr"]] = None) -> "Expr": def children(self) -> List["Expr"]: raise NotImplementedError(f"Cannot get children of {self.__class__.__name__}") + @property + def childless(self) -> bool: + """Returns True if it's a basic element like a symbol or a number that doesn't have any subcomponents at all.""" + return len(self.children()) == 0 + def contains(self: "Expr", var: "Symbol") -> bool: is_var = isinstance(self, Symbol) and self.name == var.name return is_var or any(e.contains(var) for e in self.children()) @@ -326,9 +331,11 @@ def _nesting_without_factor(expr: "Expr") -> int: else: expr2 = remove_const_factor(expr) - if isinstance(expr2, (Symbol, Any_)): + if isinstance(expr2, Symbol) or isinstance(expr2, Any_) and not expr2.is_constant: ans = 1 - elif len(expr2.symbols()) == 0 and len(get_anys(expr2)) == 0: + elif len(expr2.symbols()) == 0 and ( + len(get_anys(expr2)) == 0 or all([a.is_constant for a in get_anys(expr2)]) + ): ans = 0 else: ans = 1 + max(_nesting_without_factor(sub_expr) for sub_expr in expr2.children()) @@ -2065,7 +2072,7 @@ def symbols(symbols: str) -> Union[Symbol, List[Symbol]]: @cast -def diff(expr: Expr, var: Optional[Symbol]) -> Expr: +def diff(expr: Expr, var: Optional[Symbol] = None) -> Expr: """Takes the derivative of expr relative to var. If expr has only one symbol in it, var doesn't need to be specified.""" if not hasattr(expr, "diff"): raise NotImplementedError(f"Differentiation of {expr} not implemented") diff --git a/src/simpy/regex.py b/src/simpy/regex.py index e7480c0..c59b686 100644 --- a/src/simpy/regex.py +++ b/src/simpy/regex.py @@ -1,24 +1,41 @@ -"""Custom library for checking what shit exists. Replaces searching the repr with regex. +"""Custom library for checking what exists within exprs. Replaces searching the repr with regex. This module is currently still developmental. It does the job often but is not promised to be robust outside of the cases it is currently used for. Use with caution. + +TODO: write a regex quickstart guide. Maybe spin this off into a subfolder and make the outward facing API more +intuitive. Perhaps make it similar to the regex library. ++ Unit testing with comprehensively thought-out cases. """ from collections import defaultdict from dataclasses import dataclass, fields from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Type -from .expr import Expr, Power, Prod, Rat, SingleFunc, Sum, Symbol, cast, log +from .expr import Expr, Num, Power, Prod, Rat, SingleFunc, Sum, Symbol, cast, log from .utils import ExprCondition, ExprFn, OptionalExprFn, random_id class Any_(Expr): _fields_already_casted = True - def __init__(self, key=None, *, is_multiple_terms=False): + def __init__( + self, key: str = None, condition: Callable[[Expr], bool] = None, *, is_constant=False, is_multiple_terms=False + ): + """ + key: unique identifier of each Any_ instance. + this is used for if we want each Any_ object match to be the same thing. + condition: a condition for if sth matches an Any_? by default it's nothing. + is_constant: a decorator actually only for sorting purposes. if True, this Any_ object will be considered a + constant for sorting purposes. for matching purposes, pls specify in the condition callable. + """ if not key: key = random_id(10) + if condition is None: + condition = lambda e: True # default condition is always true self._key = key + self._condition = condition + self._is_constant = is_constant self._is_multiple_terms = is_multiple_terms super().__post_init__() @@ -26,6 +43,14 @@ def __init__(self, key=None, *, is_multiple_terms=False): def key(self) -> str: return self._key + @property + def condition(self) -> Callable[[Expr], bool]: + return self._condition + + @property + def is_constant(self) -> bool: + return self._is_constant + @property def is_multiple_terms(self) -> bool: return self._is_multiple_terms @@ -33,9 +58,6 @@ def is_multiple_terms(self) -> bool: def __eq__(self, other): if isinstance(other, Any_): return self.key == other.key - # if isinstance(other, Expr): - # return True - # return NotImplemented return False def __repr__(self) -> str: @@ -59,6 +81,9 @@ def latex(self) -> str: any_ = Any_() +any_constant = Any_( + key="constant", condition=lambda e: isinstance(e, Num), is_constant=True +) # matches any numerical constant. # smallTODO: make this a namedtuple EqResult = Dict[Literal["success", "factor", "rest", "matches"], Any] @@ -253,6 +278,11 @@ def _is_sum_eq(expr: Sum, query: Sum) -> EqResult: return self._result def _eq(self, expr: Any, query: Any) -> bool: + """Recursively checks if `expr` matches `query` one-for-one, no up to factors/sums. + This method is the recursed component. + """ + + ## base cases ## if isinstance(expr, list): if not (isinstance(query, list) and len(expr) == len(query)): return False @@ -262,11 +292,15 @@ def _eq(self, expr: Any, query: Any) -> bool: if not isinstance(expr, Expr) or not isinstance(query, Expr): return False if isinstance(query, Any_): + if not query.condition(expr): + return False self._matches[query.key].append(expr) return True if not query.has(Any_): return False + # ok bro idk why i did this i kinda hate it. + # like what if we just check if both query and expr are products and have special handling for that? if not self._is_divide: # You don't get to divide if we already is --- prevents inf recursion. one, quotient_matches = divide_anys(query, expr) @@ -292,6 +326,8 @@ def _eq(self, expr: Any, query: Any) -> bool: if not expr.__class__ == query.__class__: return False + + ## and here we recurse. ## return all(self._eq(getattr(expr, field.name), getattr(query, field.name)) for field in fields(expr)) @@ -315,8 +351,8 @@ def _make_factors_list(expr: Expr) -> List[Expr]: anys.append(t) else: terms.append(t) - if len([t for t in expr.terms if isinstance(t, Any_)]) > 1: - raise NotImplementedError(f"{expr} is ambiguous") + # if len([t for t in expr.terms if isinstance(t, Any_)]) > 1: + # raise NotImplementedError(f"{expr} is ambiguous") if len(anys) > 0: terms.extend(anys) if len(any_factors) > 0: @@ -326,6 +362,9 @@ def _make_factors_list(expr: Expr) -> List[Expr]: numfactors = _make_factors_list(num) denfactors = _make_factors_list(denom) matches = defaultdict(list) + + # For every factor in the numerator, try to find a matching factor in the denominator. + # If a match is found, remove both factors from their respective lists. And add the match to the matches dict. for i in range(len(numfactors)): f = numfactors[i] for j in range(len(denfactors)): @@ -365,6 +404,28 @@ def count(expr: Expr, query: Expr) -> int: return sum(count(e, query) for e in expr.children()) +def contains(expr: Expr, query: Expr) -> EqResult: + """Checks if `query` appears in `expr`. Assumes query contains Any_ objects. + Exact any-matches only, no up to factor or sum. + Returns a results dictionary like eq() does. with `success` and `matches` keys. + """ + ## base cases ## + eq_output = eq(expr, query) + if eq_output["success"]: + return eq_output + if expr.childless: + return {"success": False} + + ## recursive cases ## + # this isn't super sophisticated for dupes but it's fine for now. + for e in expr.children(): + eq_output_ = contains(e, query) + if eq_output_["success"]: + return eq_output_ + + return {"success": False} + + def contains_cls(expr: Expr, cls: Type[Expr]) -> bool: if isinstance(expr, cls): return True @@ -382,6 +443,10 @@ def general_count(expr: Expr, condition: ExprCondition) -> int: def general_contains(expr: Expr, condition: ExprCondition) -> bool: + """contains with a condition function instead of any. + + this is sorta-legacy --- i think we should use contains instead; it's cuter. any-matches are cute. + """ if condition(expr): return True return any(general_contains(e, condition) for e in expr.children()) diff --git a/src/simpy/simplify/product_to_sum.py b/src/simpy/simplify/product_to_sum.py index 26d81e0..984a2bb 100644 --- a/src/simpy/simplify/product_to_sum.py +++ b/src/simpy/simplify/product_to_sum.py @@ -1,7 +1,9 @@ from typing import List, Optional, Union -from ..expr import Expr, Power, Prod, Rat, Sum, cos, remove_const_factor, sin +from ..expr import Expr, Power, Prod, Rat, Sum, TrigFunctionNotInverse, cos, nesting, remove_const_factor, sin +from ..regex import Any_, any_, eq from ..utils import count_symbols +from .utils import is_simpler def _perform_on_terms( @@ -137,3 +139,52 @@ def product_to_sum(expr: Expr) -> Optional[Expr]: if len(final.terms) == len(expr.terms) and count_symbols(final) < count_symbols(expr): # This ensures that e.g. 2*cos(x)*sin(2*x)/3 - cos(2*x)*sin(x)/3 simplifies to -2*sin(x)**3/3 + sin(x) return final + + +def double_angle(expr: Expr) -> Optional[Expr]: + """Applies double angle + Used in simplify + + Assumes that expr.has(TrigFunctionNotInverse) == True + """ + + if not isinstance(expr, (sin, cos)): + return + + any_even_number = Any_( + "even_number", lambda expr: isinstance(expr, Rat) and expr.denominator == 1 and expr % 2 == 0, is_constant=True + ) + query = any_even_number * any_ + out = eq(expr.inner, query) + + if not out["success"]: + return + + x = out["matches"][any_.key] + num = out["matches"]["even_number"] + + if isinstance(expr, sin): + # TODO: make this robust through iteration or recursion + if num == 2: + final = 2 * sin(x) * cos(x) + elif num == 4: + final = 4 * sin(x) * cos(x) - 8 * sin(x) ** 3 * cos(x) + elif num == 6: + final = 6 * sin(x) * cos(x) - 32 * sin(x) ** 3 * cos(x) + 32 * sin(x) ** 5 * cos(x) + else: + return + # raise NotImplementedError("Double angle for sin with num > 4 is not implemented") + else: + return + # raise NotImplementedError("Double angle for cos is not implemented") + + if not final.has(TrigFunctionNotInverse) or is_simpler(final, expr): + # If final doesn't have any trig functions, it's definitely simpler. + # this can def be ... improved lol. + # currently im basing it off of the + # sin(4*asin(x/2)) -> 2*x*sqrt(-x^2/4 + 1) - x^3*sqrt(-x^2/4 + 1) + # case. + # like that shit is not simpler by any other metric other than it doesn't have the sin(asin) nesting yk. + # nesting of 2 trig funcs is always ugllyyyyyy. maybe the most robust metric should just rid those ugly + # nests. + return final diff --git a/src/simpy/simplify/simplify.py b/src/simpy/simplify/simplify.py index 2960f29..7db2208 100644 --- a/src/simpy/simplify/simplify.py +++ b/src/simpy/simplify/simplify.py @@ -15,13 +15,15 @@ cot, csc, log, + nesting, sec, sin, tan, ) from ..regex import any_, eq, general_contains, kinder_replace, kinder_replace_many, replace_class, replace_factory -from ..utils import ExprFn, count_symbols -from .product_to_sum import product_to_sum +from ..utils import ExprFn +from .product_to_sum import double_angle, product_to_sum +from .utils import is_simpler def expand_logs(expr: Expr, **kwargs) -> Expr: @@ -128,6 +130,10 @@ def sin_cos_condition(expr: Sum): rest = result["rest"] return factor * perform(inner) + rest + # right now we are doing the sec/tan simplification in a separate function because its implementation is a bit more + # complex? + # but i feel like maybe ideally it should be in this function along with cos^2(x) + sin^2(x) = 1 and we can have + # them both done together with more advanced regex. but for now, this is fine! # other_table = [ # (r"^sec\((.+)\)\^2$", r"^-tan\((.+)\)\^2$", Const(1)), # ] @@ -247,7 +253,7 @@ def reciprocate_trigs(expr: Expr, **kwargs) -> Expr: def simplify(expr: Expr) -> Expr: """Simplifies an expression. - This is the general one that does all heuristics & is for aesthetics (& comparisons). + This is the general simplification function that does all heuristics & is for aesthetics (& comparisons). Use more specific simplification functions in integration please. """ if expr.expandable(): @@ -270,7 +276,7 @@ def simplify(expr: Expr) -> Expr: return expr -def trig_simplify(expr): +def trig_simplify(expr: Expr) -> Tuple[Expr, bool]: # reciprocate and combine trigs is last because sometimes the pythag complex simplification will # generate new trigs in the num/denom that can be simplified down. expr, is_hit_1 = kinder_replace_many( @@ -281,10 +287,15 @@ def trig_simplify(expr): ) expr, is_hit_2 = kinder_replace_many( expr, - [_combine_trigs, product_to_sum, sectan], + [_combine_trigs, product_to_sum, double_angle, sectan], overarching_cond=lambda x: x.has(TrigFunctionNotInverse), verbose=True, ) + # sometimes, the double angle will change a non-sum to a sum, so we need to check again for possible expanding. + if is_hit_2 and expr.expandable(): + # q: should we check that the expanded stuff is simpler? for now im not doing that. + # because it is a bit expensive and we can do that later if we want. + expr = expr.expand() return expr, is_hit_1 or is_hit_2 @@ -342,7 +353,7 @@ def perform(e: Power): for s in secs: new = replace_factory(condition, perform)(s) - new = new.expand() if new.expandable() else s + new = new.expand() if new.expandable() else new assert isinstance(new, Sum) tans.extend(new.terms) @@ -351,17 +362,8 @@ def perform(e: Power): # This must mean that we simplified the sum into one term return new_sum - if len(new_sum.terms) < sum.terms: + if is_simpler(new_sum, sum): return new_sum - -def is_simpler(e1, e2) -> bool: - """returns whether e1 is simpler than e2""" - c1 = count_symbols(e1) - c2 = count_symbols(e2) - return c1 < c2 - # if c1 < c2: - # return True - # if c1 == c2: - # return len(repr(e1)) < len(repr(e2)) - # return False + # if we didn't simplify, return original + return sum diff --git a/src/simpy/simplify/utils.py b/src/simpy/simplify/utils.py new file mode 100644 index 0000000..73907ad --- /dev/null +++ b/src/simpy/simplify/utils.py @@ -0,0 +1,15 @@ +from ..expr import nesting +from ..utils import count_symbols + + +def is_simpler(e1, e2) -> bool: + """returns whether e1 is simpler than e2""" + c1 = count_symbols(e1) + c2 = count_symbols(e2) + if c1 < c2: + return True + + if c1 == c2: + return nesting(e1) < nesting(e2) + + return False diff --git a/src/simpy/transforms.py b/src/simpy/transforms.py index b1347b3..94b0762 100644 --- a/src/simpy/transforms.py +++ b/src/simpy/transforms.py @@ -17,6 +17,7 @@ Symbol, TrigFunction, TrigFunctionNotInverse, + acos, asin, atan, cos, @@ -33,7 +34,7 @@ from .integral_table import check_integral_table from .linalg import invert from .polynomial import Polynomial, is_polynomial, polynomial_to_expr, rid_ending_zeros, to_const_polynomial -from .regex import count, general_contains, replace, replace_class, replace_factory +from .regex import Any_, contains, count, general_contains, replace, replace_class, replace_factory from .simplify import pythagorean_simplification from .simplify.product_to_sum import product_to_sum_unit from .utils import ExprFn, eq_with_var, random_id @@ -211,6 +212,7 @@ def check(self, node: Node) -> bool: class USub(Transform, ABC): """Base class for u-substituion transforms.""" + # u (new var) written in terms of x (old var) _u: Expr = None def backward(self, node: Node) -> None: @@ -377,8 +379,6 @@ def _get_last_heuristic_transform(node: Node, tup=(PullConstant, Additivity)): return node.transform -# Let's just add all the transforms we've used for now. -# and we will make this shit good and generalized later. class TrigUSub2(USub): """ u-sub of a trig function @@ -410,8 +410,8 @@ def check(self, node: Node) -> bool: if super().check(node) is False: return False - # Since B and C essentially undo each other, we want to make sure that the last - # heuristic transform wasn't C. + # Since TrigUSub2 and InverseTrigUSub essentially undo each other, we want to make sure that the last + # heuristic transform wasn't InverseTrigUSub. t = _get_last_heuristic_transform(node) if isinstance(t, InverseTrigUSub): @@ -966,6 +966,105 @@ def forward(self, node: Node) -> None: self._u = self._u +asec = lambda x: acos(1 / x) + + +class TrigUSub(USub): + """This is the substitution of the form x = 4*cos(theta) + + If the integrand contains sqrt(a^2 - x^2), sqrt(a^2 + x^2), or sqrt(-a^2 + x^2), + you can do this trig sub. + 3 cases: + 1. sqrt(a^2 - x^2): x = a*sin(theta) + 2. sqrt(a^2 + x^2): x = a*tan(theta) + 3. sqrt(-a^2 + x^2): x = a*sec(theta) + """ + + _a: Rat = None # constant in the square root + _exponent: Rat = None # exponent of the square root (includes the square root, is a fraction w denom = 2) + _case: int = None # 0 for sqrt(a^2 - x^2), 1 for sqrt(a^2 + x^2), 2 for sqrt(-a^2 + x^2) + + def check(self, node: Node): + """Check if node.expr contains sqrt(a^2 - x^2) or sqrt(a^2 + x^2) or sqrt(-a^2 + x^2) where a is a constant.""" + if super().check(node) is False: + return False + + def squared_integer_condition(expr: Expr) -> bool: + return isinstance(expr, Rat) and isinstance(sqrt(expr), Rat) + + a_squared = Any_("squared_integer", squared_integer_condition, is_constant=True) + any_square_root_exponent = Any_( + "square_root_exponent", lambda expr: isinstance(expr, Rat) and expr.denominator == 2, is_constant=True + ) + queries = [ + (a_squared - node.var**2) ** any_square_root_exponent, + (node.var**2 + a_squared) ** any_square_root_exponent, + (node.var**2 - a_squared) ** any_square_root_exponent, + ] + for i, query in enumerate(queries): + out = contains(node.expr, query) + if out["success"]: + self._a = sqrt(out["matches"]["squared_integer"]) + self._exponent = out["matches"]["square_root_exponent"] + self._case = i + return True + + return False + + def forward(self, node: Node) -> None: + theta = generate_intermediate_var() + a = self._a + x = node.var + if self._case == 0: + # in the case of sqrt(a^2 - x^2): + # x = a * sin(theta) + # dx = a * cos(theta) d(theta) + dx_dtheta = a * cos(theta) + + # so we replace x = a * sin(theta), this effectively leads to + # replacing sqrt(a^2 - x^2) = a * cos^2(theta) + theta_expr = replace( + node.expr, + (a**2 - x**2) ** self._exponent, + (a * cos(theta)) ** self._exponent.numerator, + ) + if theta_expr.contains(x): + theta_expr = replace(theta_expr, x, a * sin(theta)) + self._u = asin(x / a) # theta in terms of x + + elif self._case == 1: + # in th cas of sqrt(a^2 + x^2) + # x = a * tan(theta) + # dx = a * sec(theta) ** 2 * d(theta) + dx_dtheta = self._a * sec(theta) ** 2 + + # sqrt(a^2 + x^2) = a * sec(theta) + theta_expr = replace( + node.expr, + (a**2 + x**2) ** self._exponent, + (a * sec(theta)) ** self._exponent.numerator, + ) + if theta_expr.contains(node.var): + theta_expr = replace(theta_expr, x, a * tan(theta)) + self._u = atan(x / a) # theta in terms of x + + else: + # x = a * sec(theta) + dx_dtheta = a * sec(theta) * tan(theta) + # sqrt(a^2 + x^2) = a * tan(theta) + theta_expr = replace( + node.expr, + (-(a**2) + x**2) ** self._exponent, + (a * tan(theta)) ** self._exponent.numerator, + ) + if theta_expr.contains(x): + theta_expr = replace(theta_expr, x, a * sec(theta)) + self._u = asec(x / a) # theta in terms of x + + new_integrand = theta_expr * dx_dtheta + node.add_child(Node(new_integrand, theta, self, node)) + + class CompleteTheSquare(Transform): """Integration via completing the square""" @@ -1120,6 +1219,7 @@ def backward(self, node: Node) -> None: RewriteTrig, RewritePythagorean, InverseTrigUSub, + TrigUSub, CompleteTheSquare, GenericUSub, ] diff --git a/tests/test_khan_academy_integrals.py b/tests/test_khan_academy_integrals.py index 201af52..c385a6e 100644 --- a/tests/test_khan_academy_integrals.py +++ b/tests/test_khan_academy_integrals.py @@ -202,3 +202,32 @@ def test_more_complicated_trig(): expr = tan(x) ** 5 * sec(x) ** 4 expected_ans = tan(x) ** 6 / 6 + tan(x) ** 8 / 8 assert_integral(expr, expected_ans) + + +def test_trigonometric_substitution(): + # that's this one: https://www.khanacademy.org/math/integral-calculus/ic-integration/ic-trig-substitution/e/integration-using-trigonometric-substitution + expr = (4 - x**2) ** Fraction(3, 2) + expected_ans = 5 * x * sqrt(1 - x**2 / 4) - x**3 * sqrt(1 - x**2 / 4) / 2 + 6 * asin(x / 2) + ans = integrate(expr) + assert_eq_plusc(expected_ans, ans) + + +def test_trigonometric_substitution_tan_sub(): + expr = 1 / (x**2 + 4) ** Fraction(3, 2) + ans = integrate(expr) + expected_ans = x / (8 * sqrt(x**2 / 4 + 1)) + assert_eq_plusc(expected_ans, ans) + + +def test_trig_sub_sec_sub(): + # this one is not from KH + # perhaps I should stop sorting tests using KH or not but through transform or smtn. + expr = 3 * (25 - x**2) ** Fraction(5 / 2) + ans = integrate(expr) + expected_ans = ( + 5 * x**5 * sqrt(-(x**2) / 25 + 1) / 2 + + 103125 * x * sqrt(-(x**2) / 25 + 1) / 16 + - 1625 * x**3 * sqrt(-(x**2) / 25 + 1) / 8 + + 234375 * asin(x / 5) / 16 + ) + assert_eq_plusc(ans, expected_ans) diff --git a/tests/test_regex.py b/tests/test_regex.py index 8a745cd..b6b9fd5 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -2,7 +2,7 @@ from test_utils import x, y from simpy.expr import * -from simpy.regex import Any_, any_, eq +from simpy.regex import Any_, any_, any_constant, contains, eq def test_any_basic(): @@ -18,6 +18,14 @@ def test_sort_anys(): assert eq(sin(x) * sec(x), sec(any_) * sin(any_)) +def test_eq_with_different_anys(): + any2 = Any_() + expr = sin(x) + cos(y) + query = sin(any_) + cos(any2) + out = eq(expr, query) + assert out["success"] + + @pytest.mark.parametrize( ["sum", "expected"], [ @@ -95,3 +103,46 @@ def test_cofounder(): query = -sin(any_) ** 2 + 1 out = eq(expr, query, up_to_factor=True, up_to_sum=True) assert out["success"] is False + + +def test_any_constant(): + # Tests that any_constant matches constants. + expr = 2 * x + 3 + query = 2 * x + any_constant + out = eq(expr, query) + assert out["success"] + assert out["matches"] == 3 + + expr = 2 * x + 5 * y + query = 2 * x + any_constant * y + out = eq(expr, query) + assert out["success"] + assert out["matches"] == 5 + + +def test_any_constant_fail(): + # Tests that any_constant does not match variables. + expr = 2 * x + y + query = 2 * x + any_constant + out = eq(expr, query) + assert not out["success"] + + +def test_any_constant_with_multiple_anys(): + expr = 2 * x + 3 + query = 2 * any_ + any_constant + out = eq(expr, query) + assert out["success"] + + +def test_contains(): + query = log(sin(any_) ** 2 + cos(any_) ** 2) + expr = (log(sin(x) ** 2 + cos(x) ** 2) + 3) ** 2 + assert contains(expr, query)["success"] + assert contains(expr, query)["matches"] == x + + +def test_contains_fail(): + query = log(sin(any_) ** 2 + cos(any_) ** 2) + expr = (log(sin(x) ** 2 + cos(x) ** 3) + 3) ** 2 + 1 + assert not contains(expr, query)["success"] diff --git a/tests/test_simplify.py b/tests/test_simplify.py new file mode 100644 index 0000000..8313884 --- /dev/null +++ b/tests/test_simplify.py @@ -0,0 +1,9 @@ +from test_utils import assert_eq_strict, x + +from simpy.expr import * +from simpy.simplify import simplify + + +def test_simplifies_when_expanding_is_simpler(): + expr = 2 * (2 * x + 3) + assert_eq_strict(simplify(expr), 4 * x + 6) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 60b23b8..e5ede10 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -45,6 +45,7 @@ def test_polynomial_division(): tr.forward(test_node) ans = test_node.children[0].expr + assert_eq_strict(ans, -x + x / (-(x**2) + 1)) def test_complete_the_square(): diff --git a/tests/test_trig_simplify.py b/tests/test_trig_simplify.py index 50257b5..ecfcdcf 100644 --- a/tests/test_trig_simplify.py +++ b/tests/test_trig_simplify.py @@ -68,6 +68,19 @@ def test_pts(): assert_simplified(e1, e2) +def test_sectan_simple(): + expr = sec(x) ** 2 - tan(x) ** 2 + assert_simplified(expr, 1) + + +def test_sectan_plus(): + # this one shouldn't be simplified to 1 + # but it can be rewritten to 1 + 2 * tan(x) ** 2 + # the fact that we replace sec^2(x) with tan^2(x) + 1 instead of the other way around is kinda arbitrary. + expr2 = sec(x) ** 2 + tan(x) ** 2 + assert_simplified(expr2, 1 + 2 * tan(x) ** 2) + + def test_sectan(): # need to replace sec^2(x) with tan^2(x) + 1 e1 = 1 / (4 * cos(x) ** 4) - 1 / (3 * cos(x) ** 6) + 1 / (8 * cos(x) ** 8) diff --git a/tests/test_utils.py b/tests/test_utils.py index f984e67..1201171 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -35,14 +35,14 @@ def _assert_eq_plusc(a, b, *vars) -> Tuple[bool, Expr]: diff = a - b if len(diff.symbols()) == 0 or vars and all(var not in diff.symbols() for var in vars): return True, None - diff = simplify_to_same_standard(diff) + diff = _simplify_to_same_standard(diff) if not vars: return len(diff.symbols()) == 0, diff else: return all(var not in diff.symbols() for var in vars), diff -def simplify_to_same_standard(expr: Expr) -> Expr: +def _simplify_to_same_standard(expr: Expr) -> Expr: if expr.expandable(): expr2 = expr.expand() else: @@ -62,8 +62,8 @@ def assert_eq_value(a: Expr, b: Expr): """Tests that the values of a & b are the same in spirit, regardless of how they are represented with the Expr data structures.""" if a == b: return - diff = simplify_to_same_standard(a - b) - assert diff == 0, f"a != b, {simplify_to_same_standard(a)} != {simplify_to_same_standard(b)}" + diff = _simplify_to_same_standard(a - b) + assert diff == 0, f"a != b, {_simplify_to_same_standard(a)} != {_simplify_to_same_standard(b)}" @cast