Skip to content

Commit 89082a8

Browse files
committed
update
1 parent 97fa3f8 commit 89082a8

3 files changed

Lines changed: 417 additions & 0 deletions

File tree

minitorch/scalar.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Iterable, Optional, Sequence, Tuple, Type, Union
5+
6+
import numpy as np
7+
8+
from dataclasses import field
9+
from .autodiff import Context, Variable, backpropagate, central_difference
10+
from .scalar_functions import (
11+
EQ,
12+
LT,
13+
Add,
14+
Exp,
15+
Inv,
16+
Log,
17+
Mul,
18+
Neg,
19+
ReLU,
20+
ScalarFunction,
21+
Sigmoid,
22+
)
23+
24+
ScalarLike = Union[float, int, "Scalar"]
25+
26+
27+
@dataclass
28+
class ScalarHistory:
29+
"""`ScalarHistory` stores the history of `Function` operations that was
30+
used to construct the current Variable.
31+
32+
Attributes
33+
----------
34+
last_fn : The last Function that was called.
35+
ctx : The context for that Function.
36+
inputs : The inputs that were given when `last_fn.forward` was called.
37+
38+
"""
39+
40+
last_fn: Optional[Type[ScalarFunction]] = None
41+
ctx: Optional[Context] = None
42+
inputs: Sequence[Scalar] = ()
43+
44+
45+
# ## Task 1.2 and 1.4
46+
# Scalar Forward and Backward
47+
48+
_var_count = 0
49+
50+
51+
@dataclass
52+
class Scalar:
53+
"""A reimplementation of scalar values for autodifferentiation
54+
tracking. Scalar Variables behave as close as possible to standard
55+
Python numbers while also tracking the operations that led to the
56+
number's creation. They can only be manipulated by
57+
`ScalarFunction`.
58+
"""
59+
60+
data: float
61+
history: Optional[ScalarHistory] = field(default_factory=ScalarHistory)
62+
derivative: Optional[float] = None
63+
name: str = field(default="")
64+
unique_id: int = field(default=0)
65+
66+
def __post_init__(self):
67+
global _var_count
68+
_var_count += 1
69+
object.__setattr__(self, "unique_id", _var_count)
70+
object.__setattr__(self, "name", str(self.unique_id))
71+
object.__setattr__(self, "data", float(self.data))
72+
73+
def __repr__(self) -> str:
74+
return f"Scalar({self.data})"
75+
76+
def __mul__(self, b: ScalarLike) -> Scalar:
77+
return Mul.apply(self, b)
78+
79+
def __truediv__(self, b: ScalarLike) -> Scalar:
80+
return Mul.apply(self, Inv.apply(b))
81+
82+
def __rtruediv__(self, b: ScalarLike) -> Scalar:
83+
return Mul.apply(b, Inv.apply(self))
84+
85+
def __bool__(self) -> bool:
86+
return bool(self.data)
87+
88+
def __radd__(self, b: ScalarLike) -> Scalar:
89+
return self + b
90+
91+
def __rmul__(self, b: ScalarLike) -> Scalar:
92+
return self * b
93+
94+
# Variable elements for backprop
95+
96+
def accumulate_derivative(self, x: Any) -> None:
97+
"""Add `val` to the the derivative accumulated on this variable.
98+
Should only be called during autodifferentiation on leaf variables.
99+
100+
Args:
101+
----
102+
x: value to be accumulated
103+
104+
"""
105+
assert self.is_leaf(), "Only leaf variables can have derivatives."
106+
if self.derivative is None:
107+
self.__setattr__("derivative", 0.0)
108+
self.__setattr__("derivative", self.derivative + x)
109+
110+
def is_leaf(self) -> bool:
111+
"""True if this variable created by the user (no `last_fn`)"""
112+
return self.history is not None and self.history.last_fn is None
113+
114+
def is_constant(self) -> bool:
115+
return self.history is None
116+
117+
@property
118+
def parents(self) -> Iterable[Variable]:
119+
"""Get the variables used to create this one."""
120+
assert self.history is not None
121+
return self.history.inputs
122+
123+
def chain_rule(self, d_output: Any) -> Iterable[Tuple[Variable, Any]]:
124+
h = self.history
125+
assert h is not None
126+
assert h.last_fn is not None
127+
assert h.ctx is not None
128+
129+
raise NotImplementedError("Need to include this file from past assignment.")
130+
131+
def backward(self, d_output: Optional[float] = None) -> None:
132+
"""Calls autodiff to fill in the derivatives for the history of this object.
133+
134+
Args:
135+
----
136+
d_output (number, opt): starting derivative to backpropagate through the model
137+
(typically left out, and assumed to be 1.0).
138+
139+
"""
140+
if d_output is None:
141+
d_output = 1.0
142+
backpropagate(self, d_output)
143+
144+
raise NotImplementedError("Need to include this file from past assignment.")
145+
146+
147+
def derivative_check(f: Any, *scalars: Scalar) -> None:
148+
"""Checks that autodiff works on a python function.
149+
Asserts False if derivative is incorrect.
150+
151+
Parameters
152+
----------
153+
f : function from n-scalars to 1-scalar.
154+
*scalars : n input scalar values.
155+
156+
"""
157+
out = f(*scalars)
158+
out.backward()
159+
160+
err_msg = """
161+
Derivative check at arguments f(%s) and received derivative f'=%f for argument %d,
162+
but was expecting derivative f'=%f from central difference."""
163+
for i, x in enumerate(scalars):
164+
check = central_difference(f, *scalars, arg=i)
165+
print(str([x.data for x in scalars]), x.derivative, i, check)
166+
assert x.derivative is not None
167+
np.testing.assert_allclose(
168+
x.derivative,
169+
check.data,
170+
1e-2,
171+
1e-2,
172+
err_msg=err_msg
173+
% (str([x.data for x in scalars]), x.derivative, i, check.data),
174+
)

minitorch/scalar_functions.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import minitorch
6+
7+
from . import operators
8+
from .autodiff import Context
9+
10+
if TYPE_CHECKING:
11+
from typing import Tuple
12+
13+
from .scalar import Scalar, ScalarLike
14+
15+
16+
def wrap_tuple(x: float | Tuple[float, ...]) -> Tuple[float, ...]:
17+
"""Turn a possible value into a tuple"""
18+
if isinstance(x, tuple):
19+
return x
20+
return (x,)
21+
22+
23+
class ScalarFunction:
24+
"""A wrapper for a mathematical function that processes and produces
25+
Scalar variables.
26+
27+
This is a static class and is never instantiated. We use `class`
28+
here to group together the `forward` and `backward` code.
29+
"""
30+
31+
@classmethod
32+
def _backward(cls, ctx: Context, d_out: float) -> Tuple[float, ...]:
33+
return wrap_tuple(cls.backward(ctx, d_out)) # type: ignore
34+
35+
@classmethod
36+
def _forward(cls, ctx: Context, *inps: float) -> float:
37+
return cls.forward(ctx, *inps) # type: ignore
38+
39+
@classmethod
40+
def apply(cls, *vals: ScalarLike) -> Scalar:
41+
raw_vals = []
42+
scalars = []
43+
for v in vals:
44+
if isinstance(v, minitorch.scalar.Scalar):
45+
scalars.append(v)
46+
raw_vals.append(v.data)
47+
else:
48+
scalars.append(minitorch.scalar.Scalar(v))
49+
raw_vals.append(v)
50+
51+
# Create the context.
52+
ctx = Context(False)
53+
54+
# Call forward with the variables.
55+
c = cls._forward(ctx, *raw_vals)
56+
assert isinstance(c, float), "Expected return type float got %s" % (type(c))
57+
58+
# Create a new variable from the result with a new history.
59+
back = minitorch.scalar.ScalarHistory(cls, ctx, scalars)
60+
return minitorch.scalar.Scalar(c, back)
61+
62+
63+
# Examples
64+
class Add(ScalarFunction):
65+
"""Addition function $f(x, y) = x + y$"""
66+
67+
@staticmethod
68+
def forward(ctx: Context, a: float, b: float) -> float:
69+
return a + b
70+
71+
@staticmethod
72+
def backward(ctx: Context, d_output: float) -> Tuple[float, ...]:
73+
return d_output, d_output
74+
75+
76+
class Log(ScalarFunction):
77+
"""Log function $f(x) = log(x)$"""
78+
79+
@staticmethod
80+
def forward(ctx: Context, a: float) -> float:
81+
ctx.save_for_backward(a)
82+
return operators.log(a)
83+
84+
@staticmethod
85+
def backward(ctx: Context, d_output: float) -> float:
86+
(a,) = ctx.saved_values
87+
return operators.log_back(a, d_output)
88+
89+
90+
# To implement.
91+
92+

0 commit comments

Comments
 (0)