Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
cx = jaxtorch.Context(params, key)
cx = jaxtorch.Context(px=params, key=key)
x = jnp.array([1.0,2.0,3.0])
y = jnp.array([4.0,5.0,6.0])
return jnp.mean((model(cx, x) - y)**2)
Expand Down
6 changes: 3 additions & 3 deletions jaxtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jaxtorch.core import *
import jaxtorch.nn
import jaxtorch.cbor
import jaxtorch.image
import jaxtorch.init
import jaxtorch.nn
import jaxtorch.pt
from jaxtorch.core import *
50 changes: 0 additions & 50 deletions jaxtorch/cbor.py

This file was deleted.

136 changes: 96 additions & 40 deletions jaxtorch/core.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,92 @@
from abc import abstractmethod
from typing import Callable, Any
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import functools
import jaxtorch.monkeypatches
import sys
import jmp


def fn_wrap_policy_cast(f, cast_output=True):
def wrapped(cx, *args, **kwargs):
args = cx.policy.cast_to_compute(args)
kwargs = cx.policy.cast_to_compute(kwargs)
out = f(cx, *args, **kwargs)
if cast_output:
out = cx.policy.cast_to_output(out)
return out
return wrapped

def method_wrap_policy_cast(method, cast_output=True):
def wrapped(self, cx, *args, **kwargs):
if hasattr(self, "policy"):
cx.push_policy(self.policy)

args = cx.policy.cast_to_compute(args)
kwargs = cx.policy.cast_to_compute(kwargs)
out = method(self, cx, *args, **kwargs)

if cast_output:
out = cx.policy.cast_to_output(out)

if hasattr(self, "policy"):
cx.pop_policy()

return out
return wrapped


def _addindent(s_, numSpaces):
s = s_.split('\n')
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s


class Param(object):
"""Represents a parameter of a Module, and specifies its shape and initialization."""

def __init__(self, shape, initializer):
self.shape = shape
self.initializer = initializer
self.name = None

def __repr__(self):
if self.name is not None:
return f'<Param at {self.name}>'
return f"<Param at {self.name}>"
else:
return super().__repr__()


class PRNG(object):
"""Just a stateful wrapper for a jax.random.PRNGKey."""

def __init__(self, key):
self.key = key

def split(self):
(self.key, subkey) = jax.random.split(self.key)
return subkey

class ContextRandom(object):
"""Lives inside a Context and provides convenience functions for
random number generation that use the Context's stateful PRNG.

class ContextRandom(object):
"""
Lives inside a Context and provides convenience functions for
random number generation that use the Context's stateful PRNG.
"""

def __init__(self, rng):
self.rng = rng

def _wrap(f):
def _wrap(f: Callable) -> Callable: # type: ignore
return lambda self, *args, **kwargs: f(self.rng.split(), *args, **kwargs)

bernoulli = _wrap(jax.random.bernoulli)
Expand Down Expand Up @@ -75,38 +115,52 @@ def _wrap(f):
uniform = _wrap(jax.random.uniform)
weibull_min = _wrap(jax.random.weibull_min)


@jax.tree_util.register_pytree_node_class
class Context(object):
"""Wraps a parameter dictionary and a PRNG."""
def __init__(self, px, key, mode='train'):

def __init__(self, *, px, key, mode="train", policy=jmp.get_policy("float32")):
self.px = px
self.rng = PRNG(key)
self.random = ContextRandom(self.rng)
self.mode = mode
self.policy = policy
self.pstack = []

def push_policy(self, p):
self.pstack.append(self.policy)
self.policy = p

def pop_policy(self):
p = self.policy
self.policy = self.pstack.pop()
return p

def train_mode_(self):
self.mode = 'train'
self.mode = "train"
return self

def eval_mode_(self):
self.mode = 'eval'
self.mode = "eval"
return self

def __getitem__(self, par):
if isinstance(par, Param):
return self.px[par.name]
return self.policy.cast_to_compute(self.px[par.name])
elif isinstance(par, str):
return self.px[par]
return self.policy.cast_to_compute(self.px[par])
else:
raise TypeError('Expected a Param for indexing into Context')
raise TypeError("Expected a Param for indexing into Context")

def __setitem__(self, par, value):
value = self.policy.cast_to_param(value)
if isinstance(par, Param):
self.px[par.name] = value
elif isinstance(par, str):
self.px[par] = value
else:
raise TypeError('Expected a Param for indexing into Context')
raise TypeError("Expected a Param for indexing into Context")

# TODO: having this might be a bad idea if it breaks future
# features, might need a dedicated wrapper for transforming cx
Expand All @@ -118,29 +172,31 @@ def tree_flatten(self):
def tree_unflatten(aux, values):
(px, key) = values
(mode,) = aux
return Context(px, key, mode=mode)
return Context(px=px, key=key, mode=mode)


class Module(object):
@method_wrap_policy_cast
def __call__(self, cx: Context, *args, **kwargs):
return self.forward(cx, *args, **kwargs)
out = self.forward(cx, *args, **kwargs) # type: ignore
return out

def forward(self, cx: Context, *args, **kwargs):
"""Implements the forward pass. Must take Context as the first argument."""
raise NotImplementedError
# @abstractmethod
# def forward(self, cx: Context, *args, **kwargs): # type: ignore
# """Implements the forward pass. Must take Context as the first argument."""
# raise NotImplementedError

def self_named_modules(self):
"""Yields a sequence of (str, Module) for direct children of this
module. May be overridden.

module. May be overridden.
"""
for (name, val) in self.__dict__.items():
if isinstance(val, Module):
yield (name, val)

def self_named_parameters(self):
"""Yields a sequence of (str, Param) for direct children of this
module. May be overridden.
module. May be overridden.

"""
for (name, val) in self.__dict__.items():
Expand All @@ -149,8 +205,8 @@ def self_named_parameters(self):

def self_init_weights(self, cx):
"""Initializes weights for this network's parameters. May be overriden
for custom initialization. Child modules are initialized
before parents.
for custom initialization. Child modules are initialized
before parents.

"""
for (name, par) in self.self_named_parameters():
Expand All @@ -159,11 +215,11 @@ def self_init_weights(self, cx):

def init_weights(self, key):
"""Attaches names to parameters and returns initialized dict of
parameters by name.
parameters by name.

"""
self.labeled_parameters_()
cx = Context({}, key)
cx = Context(px={}, key=key)
for module in self.gen_postorder_modules():
module.self_init_weights(cx)
self.self_init_weights(cx)
Expand All @@ -179,7 +235,7 @@ def gen_named_modules(self):
for (name, val) in self.self_named_modules():
yield (name, val)
for (k, v) in val.gen_named_modules():
yield (name+'.'+k, v)
yield (name + "." + k, v)

def gen_postorder_modules(self):
"Yields Module for all descendants of this module (postorder traversal)."
Expand All @@ -195,7 +251,7 @@ def gen_named_parameters(self):

for (name, mod) in self.self_named_modules():
for (k, v) in mod.gen_named_parameters():
yield (name+'.'+k, v)
yield (name + "." + k, v)

def named_parameters(self):
return list(self.gen_named_parameters())
Expand All @@ -207,20 +263,20 @@ def parameters(self):
return [p for (k, p) in self.gen_named_parameters()]

def state_dict(self, px):
return {name:px[par.name] for (name, par) in self.gen_named_parameters()}
return {name: px[par.name] for (name, par) in self.gen_named_parameters()}

def load_state_dict(self, px, state, strict=True):
"""Load a previously saved state_dict into px. Returns px."""
for (k, p) in self.gen_named_parameters():
if k not in state:
if strict:
raise ValueError(f'Not loading missing parameter: {k}')
raise ValueError(f"Not loading missing parameter: {k}")
else:
print(f'Not loading missing parameter: {k}', file=sys.stderr)
print(f"Not loading missing parameter: {k}", file=sys.stderr)
continue

if px[p.name].shape != state[k].shape:
msg = f'Not loading parameter from incompatible shape: {k} ({px[p.name].shape} vs {state[k].shape})'
msg = f"Not loading parameter from incompatible shape: {k} ({px[p.name].shape} vs {state[k].shape})"
if strict:
raise ValueError(msg)
else:
Expand All @@ -240,30 +296,30 @@ def extra_repr(self) -> str:
this method in your own modules. Both single-line and multi-line
strings are acceptable.
"""
return ''
return ""

def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self.__dict__.items():
if isinstance(module, Module):
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines

main_str = self._get_name() + '('
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += "\n " + "\n ".join(lines) + "\n"

main_str += ')'
main_str += ")"
return main_str
Loading