diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..bee8a64
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+__pycache__
diff --git a/README.md b/README.md
index a6e8264..6598254 100644
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'
def loss(params, key):
- cx = jaxtorch.Context(params, key)
+ cx = jaxtorch.Context(px=params, key=key)
x = jnp.array([1.0,2.0,3.0])
y = jnp.array([4.0,5.0,6.0])
return jnp.mean((model(cx, x) - y)**2)
diff --git a/jaxtorch/__init__.py b/jaxtorch/__init__.py
index 17b4b23..e9dd482 100644
--- a/jaxtorch/__init__.py
+++ b/jaxtorch/__init__.py
@@ -1,5 +1,5 @@
-from jaxtorch.core import *
-import jaxtorch.nn
-import jaxtorch.cbor
import jaxtorch.image
+import jaxtorch.init
+import jaxtorch.nn
import jaxtorch.pt
+from jaxtorch.core import *
diff --git a/jaxtorch/cbor.py b/jaxtorch/cbor.py
deleted file mode 100644
index 48126de..0000000
--- a/jaxtorch/cbor.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Wraps cbor2 with hooks for encoding and decoding tensors."""
-import jax
-import cbor2
-import numpy as np
-import functools
-
-from cbor2 import CBORTag
-
-# Standard tags for multidimensional arrays from RFC8746
-# (little-endian, row-major).
-TAG_FLOAT32 = 85
-TAG_FLOAT64 = 86
-TAG_INT32 = 78
-TAG_INT64 = 79
-TAG_ARRAY = 40
-
-def encode_flat(arr):
- if arr.dtype == np.float32:
- return CBORTag(TAG_FLOAT32, arr.tobytes())
- if arr.dtype == np.int32:
- return CBORTag(TAG_INT32, arr.tobytes())
- else:
- raise NotImplemented
-
-def default_encoder(encoder, value):
- if isinstance(value, jax.numpy.DeviceArray):
- encoder.encode(np.array(value))
- elif isinstance(value, np.ndarray):
- encoder.encode(CBORTag(TAG_ARRAY, [list(value.shape), encode_flat(value)]))
- else:
- raise NotImplemented
-
-def tag_hook(decoder, tag, shareable_index=None):
- if tag.tag == TAG_ARRAY:
- [shape, value] = tag.value
- return value.reshape(shape)
- elif tag.tag == TAG_FLOAT32:
- return np.frombuffer(tag.value, dtype=np.float32)
- elif tag.tag == TAG_INT32:
- return np.frombuffer(tag.value, dtype=np.int32)
- elif tag.tag == TAG_INT64:
- return np.frombuffer(tag.value, dtype=np.int64)
- else:
- return tag
-
-dumps = functools.partial(cbor2.dumps, default=default_encoder)
-dump = functools.partial(cbor2.dump, default=default_encoder)
-
-loads = functools.partial(cbor2.loads, tag_hook=tag_hook)
-load = functools.partial(cbor2.load, tag_hook=tag_hook)
diff --git a/jaxtorch/core.py b/jaxtorch/core.py
index 0c6fe39..1c84a89 100644
--- a/jaxtorch/core.py
+++ b/jaxtorch/core.py
@@ -1,3 +1,5 @@
+from abc import abstractmethod
+from typing import Callable, Any
import jax
import jax.numpy as jnp
import jaxlib
@@ -5,20 +7,53 @@
import functools
import jaxtorch.monkeypatches
import sys
+import jmp
+
+
+def fn_wrap_policy_cast(f, cast_output=True):
+ def wrapped(cx, *args, **kwargs):
+ args = cx.policy.cast_to_compute(args)
+ kwargs = cx.policy.cast_to_compute(kwargs)
+ out = f(cx, *args, **kwargs)
+ if cast_output:
+ out = cx.policy.cast_to_output(out)
+ return out
+ return wrapped
+
+def method_wrap_policy_cast(method, cast_output=True):
+ def wrapped(self, cx, *args, **kwargs):
+ if hasattr(self, "policy"):
+ cx.push_policy(self.policy)
+
+ args = cx.policy.cast_to_compute(args)
+ kwargs = cx.policy.cast_to_compute(kwargs)
+ out = method(self, cx, *args, **kwargs)
+
+ if cast_output:
+ out = cx.policy.cast_to_output(out)
+
+ if hasattr(self, "policy"):
+ cx.pop_policy()
+
+ return out
+ return wrapped
+
def _addindent(s_, numSpaces):
- s = s_.split('\n')
+ s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
- s = [(numSpaces * ' ') + line for line in s]
- s = '\n'.join(s)
- s = first + '\n' + s
+ s = [(numSpaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
return s
+
class Param(object):
"""Represents a parameter of a Module, and specifies its shape and initialization."""
+
def __init__(self, shape, initializer):
self.shape = shape
self.initializer = initializer
@@ -26,27 +61,32 @@ def __init__(self, shape, initializer):
def __repr__(self):
if self.name is not None:
- return f''
+ return f""
else:
return super().__repr__()
+
class PRNG(object):
"""Just a stateful wrapper for a jax.random.PRNGKey."""
+
def __init__(self, key):
self.key = key
+
def split(self):
(self.key, subkey) = jax.random.split(self.key)
return subkey
-class ContextRandom(object):
- """Lives inside a Context and provides convenience functions for
-random number generation that use the Context's stateful PRNG.
+class ContextRandom(object):
+ """
+ Lives inside a Context and provides convenience functions for
+ random number generation that use the Context's stateful PRNG.
"""
+
def __init__(self, rng):
self.rng = rng
- def _wrap(f):
+ def _wrap(f: Callable) -> Callable: # type: ignore
return lambda self, *args, **kwargs: f(self.rng.split(), *args, **kwargs)
bernoulli = _wrap(jax.random.bernoulli)
@@ -75,38 +115,52 @@ def _wrap(f):
uniform = _wrap(jax.random.uniform)
weibull_min = _wrap(jax.random.weibull_min)
+
@jax.tree_util.register_pytree_node_class
class Context(object):
"""Wraps a parameter dictionary and a PRNG."""
- def __init__(self, px, key, mode='train'):
+
+ def __init__(self, *, px, key, mode="train", policy=jmp.get_policy("float32")):
self.px = px
self.rng = PRNG(key)
self.random = ContextRandom(self.rng)
self.mode = mode
+ self.policy = policy
+ self.pstack = []
+
+ def push_policy(self, p):
+ self.pstack.append(self.policy)
+ self.policy = p
+
+ def pop_policy(self):
+ p = self.policy
+ self.policy = self.pstack.pop()
+ return p
def train_mode_(self):
- self.mode = 'train'
+ self.mode = "train"
return self
def eval_mode_(self):
- self.mode = 'eval'
+ self.mode = "eval"
return self
def __getitem__(self, par):
if isinstance(par, Param):
- return self.px[par.name]
+ return self.policy.cast_to_compute(self.px[par.name])
elif isinstance(par, str):
- return self.px[par]
+ return self.policy.cast_to_compute(self.px[par])
else:
- raise TypeError('Expected a Param for indexing into Context')
+ raise TypeError("Expected a Param for indexing into Context")
def __setitem__(self, par, value):
+ value = self.policy.cast_to_param(value)
if isinstance(par, Param):
self.px[par.name] = value
elif isinstance(par, str):
self.px[par] = value
else:
- raise TypeError('Expected a Param for indexing into Context')
+ raise TypeError("Expected a Param for indexing into Context")
# TODO: having this might be a bad idea if it breaks future
# features, might need a dedicated wrapper for transforming cx
@@ -118,21 +172,23 @@ def tree_flatten(self):
def tree_unflatten(aux, values):
(px, key) = values
(mode,) = aux
- return Context(px, key, mode=mode)
+ return Context(px=px, key=key, mode=mode)
class Module(object):
+ @method_wrap_policy_cast
def __call__(self, cx: Context, *args, **kwargs):
- return self.forward(cx, *args, **kwargs)
+ out = self.forward(cx, *args, **kwargs) # type: ignore
+ return out
- def forward(self, cx: Context, *args, **kwargs):
- """Implements the forward pass. Must take Context as the first argument."""
- raise NotImplementedError
+ # @abstractmethod
+ # def forward(self, cx: Context, *args, **kwargs): # type: ignore
+ # """Implements the forward pass. Must take Context as the first argument."""
+ # raise NotImplementedError
def self_named_modules(self):
"""Yields a sequence of (str, Module) for direct children of this
- module. May be overridden.
-
+ module. May be overridden.
"""
for (name, val) in self.__dict__.items():
if isinstance(val, Module):
@@ -140,7 +196,7 @@ def self_named_modules(self):
def self_named_parameters(self):
"""Yields a sequence of (str, Param) for direct children of this
- module. May be overridden.
+ module. May be overridden.
"""
for (name, val) in self.__dict__.items():
@@ -149,8 +205,8 @@ def self_named_parameters(self):
def self_init_weights(self, cx):
"""Initializes weights for this network's parameters. May be overriden
- for custom initialization. Child modules are initialized
- before parents.
+ for custom initialization. Child modules are initialized
+ before parents.
"""
for (name, par) in self.self_named_parameters():
@@ -159,11 +215,11 @@ def self_init_weights(self, cx):
def init_weights(self, key):
"""Attaches names to parameters and returns initialized dict of
- parameters by name.
+ parameters by name.
"""
self.labeled_parameters_()
- cx = Context({}, key)
+ cx = Context(px={}, key=key)
for module in self.gen_postorder_modules():
module.self_init_weights(cx)
self.self_init_weights(cx)
@@ -179,7 +235,7 @@ def gen_named_modules(self):
for (name, val) in self.self_named_modules():
yield (name, val)
for (k, v) in val.gen_named_modules():
- yield (name+'.'+k, v)
+ yield (name + "." + k, v)
def gen_postorder_modules(self):
"Yields Module for all descendants of this module (postorder traversal)."
@@ -195,7 +251,7 @@ def gen_named_parameters(self):
for (name, mod) in self.self_named_modules():
for (k, v) in mod.gen_named_parameters():
- yield (name+'.'+k, v)
+ yield (name + "." + k, v)
def named_parameters(self):
return list(self.gen_named_parameters())
@@ -207,20 +263,20 @@ def parameters(self):
return [p for (k, p) in self.gen_named_parameters()]
def state_dict(self, px):
- return {name:px[par.name] for (name, par) in self.gen_named_parameters()}
+ return {name: px[par.name] for (name, par) in self.gen_named_parameters()}
def load_state_dict(self, px, state, strict=True):
"""Load a previously saved state_dict into px. Returns px."""
for (k, p) in self.gen_named_parameters():
if k not in state:
if strict:
- raise ValueError(f'Not loading missing parameter: {k}')
+ raise ValueError(f"Not loading missing parameter: {k}")
else:
- print(f'Not loading missing parameter: {k}', file=sys.stderr)
+ print(f"Not loading missing parameter: {k}", file=sys.stderr)
continue
if px[p.name].shape != state[k].shape:
- msg = f'Not loading parameter from incompatible shape: {k} ({px[p.name].shape} vs {state[k].shape})'
+ msg = f"Not loading parameter from incompatible shape: {k} ({px[p.name].shape} vs {state[k].shape})"
if strict:
raise ValueError(msg)
else:
@@ -240,7 +296,7 @@ def extra_repr(self) -> str:
this method in your own modules. Both single-line and multi-line
strings are acceptable.
"""
- return ''
+ return ""
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
@@ -248,22 +304,22 @@ def __repr__(self):
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
- extra_lines = extra_repr.split('\n')
+ extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self.__dict__.items():
if isinstance(module, Module):
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
- child_lines.append('(' + key + '): ' + mod_str)
+ child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
- main_str = self._get_name() + '('
+ main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
- main_str += '\n ' + '\n '.join(lines) + '\n'
+ main_str += "\n " + "\n ".join(lines) + "\n"
- main_str += ')'
+ main_str += ")"
return main_str
diff --git a/jaxtorch/image.py b/jaxtorch/image.py
index 342620e..dda4aff 100644
--- a/jaxtorch/image.py
+++ b/jaxtorch/image.py
@@ -1,80 +1,215 @@
+import math
+from typing import Tuple
+
import jax
import jax.numpy as jnp
+import numpy as np
+from einops import repeat
+
+
+def factor_int(n: int) -> Tuple[int, int]:
+ f1 = int(math.ceil(math.sqrt(n)))
+ while n % f1:
+ f1 -= 1
+ f2 = n // f1
+ return min(f1, f2), max(f1, f2)
+
+
+def compute_channel_change_mat(c_in: int, c_out: int) -> np.ndarray:
+ assert max(c_in, c_out) % min(c_in, c_out) == 0
+ io_ratio = max(c_in, c_out) // min(c_in, c_out)
+ base = np.eye(min(c_in, c_out))
+ if c_in < c_out:
+ return repeat(base, "d1 d2 -> (d1 r) d2", r=io_ratio)
+ elif c_out < c_in:
+ # decreasing channel count, average nearby channels
+ return repeat(base, "d1 d2 -> d1 (d2 r)", r=io_ratio) / io_ratio
+ else:
+ return base
+
+
+upsample_arrays = dict(
+ lanczos3=np.array(
+ [
+ 0.0073782638646662235,
+ 0.030112292617559433,
+ -0.06799723953008652,
+ -0.13327467441558838,
+ 0.2710106074810028,
+ 0.8927707076072693,
+ 0.8927707672119141,
+ 0.2710106074810028,
+ -0.13327467441558838,
+ -0.06799724698066711,
+ 0.03011229634284973,
+ 0.007378263399004936,
+ ],
+ ),
+ cubic=np.array(
+ [
+ -0.0234375,
+ -0.0703125,
+ 0.2265625,
+ 0.8671875,
+ 0.8671875,
+ 0.2265625,
+ -0.0703125,
+ -0.0234375,
+ ],
+ ),
+ linear=np.array([0.25, 0.75, 0.75, 0.25]),
+)
+
+
+downsample_arrays = dict(
+ lanczos3=np.array(
+ [
+ 0.003689131001010537,
+ 0.015056144446134567,
+ -0.03399861603975296,
+ -0.066637322306633,
+ 0.13550527393817902,
+ 0.44638532400131226,
+ 0.44638532400131226,
+ 0.13550527393817902,
+ -0.066637322306633,
+ -0.03399861603975296,
+ 0.015056144446134567,
+ 0.003689131001010537,
+ ]
+ ),
+ cubic=np.array(
+ [
+ -0.01171875,
+ -0.03515625,
+ 0.11328125,
+ 0.43359375,
+ 0.43359375,
+ 0.11328125,
+ -0.03515625,
+ -0.01171875,
+ ]
+ ),
+ linear=np.array([0.125, 0.375, 0.375, 0.125]),
+)
+
+
+def upsample_kernel(
+ c_in: int,
+ c_out: int,
+ method: str = "linear",
+) -> np.ndarray:
+ cmat = compute_channel_change_mat(c_in, c_out)
+ kernel = upsample_arrays[method]
+ weight = np.einsum("oi,h,w->oihw", cmat, kernel, kernel)
+ return weight
+
+
+def downsample_kernel(
+ c_in: int,
+ c_out: int,
+ method="linear",
+) -> np.ndarray:
+ cmat = compute_channel_change_mat(c_in, c_out)
+ kernel = downsample_arrays[method]
+ weight = np.einsum("oi,h,w->oihw", cmat, kernel, kernel)
+ return weight
+
+
+def upsample2x_base(
+ img: jnp.ndarray,
+ kern: jnp.ndarray,
+ format: str = "NCHW",
+ norm:bool=True,
+):
+ ksize = kern.shape[-1]
+ kern = jax.lax.convert_element_type(kern, img.dtype)
+ out = jax.lax.conv_general_dilated(
+ img,
+ kern,
+ window_strides=[1, 1],
+ padding=[(ksize // 2, ksize // 2), (ksize // 2, ksize // 2)],
+ lhs_dilation=[2, 2],
+ rhs_dilation=None,
+ dimension_numbers=(format, "OIHW", format),
+ )
+
+ if norm:
+ # normalization for parts that touch the zero-padding
+ norm = jax.lax.conv_general_dilated(
+ jnp.ones([1, *img.shape[-3:]], dtype=img.dtype),
+ kern,
+ window_strides=[1, 1],
+ padding=[(ksize // 2, ksize // 2), (ksize // 2, ksize // 2)],
+ lhs_dilation=[2, 2],
+ rhs_dilation=None,
+ dimension_numbers=(format, "OIHW", format),
+ )
+ out = out / norm
+
+ return out
+
+
+def downsample2x_base(
+ x: jnp.ndarray,
+ kern: jnp.ndarray,
+ format: str = "NCHW",
+ norm:bool=True,
+):
+ ksize = kern.shape[-1]
+ kern = jax.lax.convert_element_type(kern, x.dtype)
+ out = jax.lax.conv_general_dilated(
+ x,
+ kern,
+ window_strides=[2, 2],
+ padding=[(ksize // 2 - 1, ksize // 2 - 1), (ksize // 2 - 1, ksize // 2 - 1)],
+ lhs_dilation=[1, 1],
+ rhs_dilation=None,
+ dimension_numbers=(format, "OIHW", format),
+ )
+
+ if norm:
+ # normalization for parts that touch the zero-padding
+ norm = jax.lax.conv_general_dilated(
+ jnp.ones([1, *x.shape[-3:]], dtype=x.dtype),
+ kern,
+ window_strides=[2, 2],
+ padding=[
+ (ksize // 2 - 1, ksize // 2 - 1),
+ (ksize // 2 - 1, ksize // 2 - 1),
+ ],
+ lhs_dilation=[1, 1],
+ rhs_dilation=None,
+ dimension_numbers=(format, "OIHW", format),
+ )
+ out = out / norm
+
+ return out
+
+
+def upsample2x(
+ img: jnp.ndarray,
+ c_out: int = None,
+ method: str = "linear",
+ format: str = "NCHW",
+) -> jnp.ndarray:
+ c_in = img.shape[-3]
+ if c_out is None:
+ c_out = c_in
+ kern = upsample_kernel(c_in, c_out, method=method)
+ kern = jnp.array(kern, dtype=img.dtype)
+ return upsample2x_base(img, kern, format)
+
-def upsample2x_base(image, kernel):
- ksize = kernel.shape[0]
- (n, c, h, w) = image.shape
- out = jax.lax.conv_general_dilated(image.reshape(n*c,1,h,w),
- kernel.reshape(1,1,ksize,ksize),
- window_strides=[1,1],
- padding=[(ksize//2,ksize//2),(ksize//2,ksize//2)],
- lhs_dilation=[2,2],
- rhs_dilation=None,
- dimension_numbers=('NCHW',
- 'IOHW', 'NCHW'))
-
- # normalization for parts that touch the zero-padding
- norm = jax.lax.conv_general_dilated(jnp.ones((1,1,h,w)),
- kernel.reshape(1,1,ksize,ksize),
- window_strides=[1,1],
- padding=[(ksize//2,ksize//2),(ksize//2,ksize//2)],
- lhs_dilation=[2,2],
- rhs_dilation=None,
- dimension_numbers=('NCHW',
- 'IOHW', 'NCHW'))
- return (out / norm).reshape(n, c, 2*h,2*w)
-
-def upsample2x(image, method='linear'):
- if method == 'lanczos3':
- # extracted from the gradients of jax.image.resize(method='lanczos3')
- kernel = jnp.array([0.0073782638646662235, 0.030112292617559433,
- -0.06799723953008652, -0.13327467441558838,
- 0.2710106074810028, 0.8927707076072693,
- 0.8927707672119141, 0.2710106074810028,
- -0.13327467441558838, -0.06799724698066711,
- 0.03011229634284973, 0.007378263399004936])
- elif method == 'cubic':
- # extracted from the gradients of jax.image.resize(method='cubic')
- kernel = jnp.array([-0.0234375, -0.0703125, 0.2265625, 0.8671875, 0.8671875, 0.2265625, -0.0703125, -0.0234375])
- elif method == 'linear':
- # extracted from the gradients of jax.image.resize(method='linear')
- kernel = jnp.array([0.25, 0.75, 0.75, 0.25])
- kernel = kernel.reshape(-1,1) * kernel.reshape(1,-1)
-
- return upsample2x_base(image, kernel)
-
-def downsample2x_base(image, kernel):
- ksize = kernel.shape[0]
- (n, c, h, w) = image.shape
- out = jax.lax.conv_general_dilated(image.reshape(n*c,1,h,w),
- kernel.reshape(1,1,ksize,ksize),
- window_strides=[2,2],
- padding=[(ksize//2-1,ksize//2-1),(ksize//2-1,ksize//2-1)],
- lhs_dilation=[1,1],
- rhs_dilation=None,
- dimension_numbers=('NCHW',
- 'IOHW', 'NCHW'))
-
- # normalization for parts that touch the zero-padding
- norm = jax.lax.conv_general_dilated(jnp.ones((1,1,h,w)),
- kernel.reshape(1,1,ksize,ksize),
- window_strides=[2,2],
- padding=[(ksize//2-1,ksize//2-1),(ksize//2-1,ksize//2-1)],
- lhs_dilation=[1,1],
- rhs_dilation=None,
- dimension_numbers=('NCHW',
- 'IOHW', 'NCHW'))
- return (out / norm).reshape(n, c, h//2,w//2)
-
-def downsample2x(image, method='linear'):
- if method == 'linear':
- kernel = jnp.array([0.125, 0.375, 0.375, 0.125])
- elif method == 'cubic':
- kernel = jnp.array([-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875])
- elif method == 'lanczos3':
- kernel = jnp.array([0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
- -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
- 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
- -0.03399861603975296, 0.015056144446134567, 0.003689131001010537])
- kernel = kernel.reshape(-1,1) * kernel.reshape(1,-1)
- return downsample2x_base(image, kernel)
+def downsample2x(
+ img: jnp.ndarray,
+ c_out: int = None,
+ method: str = "linear",
+ format: str = "NCHW",
+) -> jnp.ndarray:
+ c_in = img.shape[-3]
+ if c_out is None:
+ c_out = c_in
+ kern = downsample_kernel(c_in, c_out, method=method)
+ kern = jax.lax.convert_element_type(kern, img.dtype)
+ return downsample2x_base(img, kern, format)
diff --git a/jaxtorch/init.py b/jaxtorch/init.py
index c06bccc..4872df8 100644
--- a/jaxtorch/init.py
+++ b/jaxtorch/init.py
@@ -4,36 +4,102 @@
import numpy as np
from jaxtorch import core
+
def zeros(*shape):
shape = jax.core.canonicalize_shape(shape)
return core.Param(shape, lambda key: jnp.zeros(shape))
+
def ones(*shape):
shape = jax.core.canonicalize_shape(shape)
return core.Param(shape, lambda key: jnp.ones(shape))
-def normal(*shape, stddev=1.0):
+
+def normal(*shape, mean=0.0, stddev=1.0):
shape = jax.core.canonicalize_shape(shape)
- return core.Param(shape, lambda key: stddev * jax.random.normal(key, shape))
+ return core.Param(shape, lambda key: mean + stddev * jax.random.normal(key, shape))
+
def const(tensor):
shape = jax.core.canonicalize_shape(tensor.shape)
return core.Param(shape, lambda key: tensor)
-def glorot_normal(*shape):
+
+def full(*shape, value=1.0):
+ shape = jax.core.canonicalize_shape(shape)
+ return core.Param(shape, lambda key: jnp.full(shape, value))
+
+
+def glorot_normal_t(key, *shape, gain=1.0):
+ shape = jax.core.canonicalize_shape(shape)
+ fan_out = shape[0] * np.prod(shape[2:])
+ fan_in = shape[1] * np.prod(shape[2:])
+ stddev = gain * np.sqrt(2.0 / (fan_in + fan_out))
+ return stddev * jax.random.normal(key, shape)
+
+
+def glorot_normal(*shape, gain=1.0):
+ return core.Param(shape, lambda key: glorot_normal_t(key, *shape, gain=gain))
+
+
+def glorot_uniform(*shape, gain=1.0):
shape = jax.core.canonicalize_shape(shape)
fan_out = shape[0] * np.prod(shape[2:])
fan_in = shape[1] * np.prod(shape[2:])
- stddev = np.sqrt(2.0 / (fan_in + fan_out))
- return core.Param(shape, lambda key: stddev * jax.random.normal(key, shape))
+ stddev = gain * np.sqrt(6.0 / (fan_in + fan_out))
+ return normal(*shape, stddev=stddev)
+
def uniform(*shape, min=-1.0, max=1.0):
shape = jax.core.canonicalize_shape(shape)
- return core.Param(shape, lambda key: jax.random.uniform(key, shape, minval=min, maxval=max))
+ return core.Param(
+ shape, lambda key: jax.random.uniform(key, shape, minval=min, maxval=max)
+ )
-def kaiming_uniform(*shape, a=0):
+
+def kaiming_uniform(*shape, a=0, scale=1.0):
shape = jax.core.canonicalize_shape(shape)
fan_in = np.prod(shape[1:])
- gain = math.sqrt(2.0 / (1 + a ** 2))
- bound = gain * math.sqrt(3.0 / fan_in)
+ gain = math.sqrt(2.0 / (1 + a**2))
+ bound = scale * gain * math.sqrt(3.0 / fan_in)
return uniform(*shape, min=-bound, max=bound)
+
+
+def mup_input_init(*shape, mean=0.0, std=1.0):
+ shape = jax.core.canonicalize_shape(shape)
+ fan_in = np.prod(shape[1:])
+ stddev = std / fan_in
+ return normal(*shape, mean=mean, stddev=stddev)
+
+
+def mup_output_init(*shape, mean=0.0, std=1.0):
+ shape = jax.core.canonicalize_shape(shape)
+ fan_in = np.prod(shape[1:])
+ stddev = std / fan_in**2
+ return normal(*shape, mean=mean, stddev=stddev)
+
+
+def mup_hidden_init(*shape, mean=0.0, std=1.0):
+ shape = jax.core.canonicalize_shape(shape)
+ fan_in = np.prod(shape[1:])
+ stddev = std / fan_in
+ return normal(*shape, mean=mean, stddev=stddev)
+
+
+def sum_init(*inits):
+ def init(*shape):
+ ps = [i(*shape).initializer for i in inits]
+
+ def _init(key):
+ ks = jax.random.split(key, len(ps))
+ vs = [p(k) for p, k in zip(ps, ks)]
+ return sum(vs)
+
+ return core.Param(shape, _init)
+
+ return init
+
+
+def scale_init(scale, init, *shape):
+ base = init(*shape).initializer
+ return core.Param(shape, lambda key: scale * base(key))
diff --git a/jaxtorch/monkeypatches.py b/jaxtorch/monkeypatches.py
index d1cec72..fcb39c2 100644
--- a/jaxtorch/monkeypatches.py
+++ b/jaxtorch/monkeypatches.py
@@ -15,8 +15,11 @@ def register(**kwargs):
if hasattr(jnp.zeros([]), attr):
print(f'Not monkeypatching DeviceArray and Tracer with `{attr}`, because that method is already implemented.', file=sys.stderr)
continue
- setattr(jaxlib.xla_extension.DeviceArrayBase, attr, fun)
- setattr(jax.interpreters.xla.DeviceArray, attr, fun)
+ if hasattr(jaxlib.xla_extension, "ArrayImpl"):
+ setattr(jaxlib.xla_extension.ArrayImpl, attr, fun)
+ if hasattr(jaxlib.xla_extension, "DeviceArrayBase"):
+ setattr(jaxlib.xla_extension.DeviceArrayBase, attr, fun)
+ setattr(jax.interpreters.xla.DeviceArray, attr, fun)
setattr(jax.core.Tracer, attr, fun)
def broadcast_to(arr, shape):
diff --git a/jaxtorch/nn/modules.py b/jaxtorch/nn/modules.py
index 5ac92f9..7302046 100644
--- a/jaxtorch/nn/modules.py
+++ b/jaxtorch/nn/modules.py
@@ -1,4 +1,5 @@
import math
+from typing import Callable, OrderedDict
import jax
import jax.numpy as jnp
import jaxtorch
@@ -6,11 +7,43 @@
from jaxtorch.core import Module, PRNG, Context
from jaxtorch import init
+
class Identity(Module):
def forward(self, cx, x):
return x
+class Lambda(Module):
+ def __init__(self, f: Callable, use_cx=False):
+ super().__init__()
+ self.f = f
+ self.use_cx = use_cx
+
+ def forward(self, cx, *args, **kwargs):
+ if self.use_cx:
+ return self.f(cx, *args, **kwargs)
+ else:
+ return self.f(*args, **kwargs)
+
+
+class SequentialDict(Module):
+ def __init__(self, modules: OrderedDict[str, Module]):
+ super().__init__()
+ self.mods = modules
+
+ def self_named_modules(self):
+ for k, m in self.mods.items():
+ yield k, m
+
+ def forward(self, cx, x, *args, **kwargs):
+ for k, m in self.mods.items():
+ x = m(cx, x, *args, **kwargs)
+ return x
+
+ def __getitem__(self, key):
+ return self.mods[key]
+
+
class ModuleList(Module):
def __init__(self, *modules):
self.modules = []
@@ -33,35 +66,71 @@ def forward(self, cx, x):
def self_named_modules(self):
for (i, m) in enumerate(self.modules):
- yield (f'{i}', m)
+ yield (f"{i}", m)
+
+ def __getitem__(self, key):
+ return self.modules[key]
class Sequential(ModuleList):
- def forward(self, cx, x):
+ def forward(self, cx, x, *args, **kwargs):
for module in self.modules:
- x = module(cx, x)
+ x = module(cx, x, *args, **kwargs)
return x
+
+
+
+def ignore_kwargs(mod):
+ class IgnoreKwargs(mod):
+ def forward(self, cx, *args, **kwargs):
+ return super().forward(cx, *args)
+
+ return IgnoreKwargs
+
+
+def ignore_non_kwargs(mod):
+ class IgnoreNonKwargs(mod):
+ def forward(self, cx, *args, **kwargs):
+ return super().forward(cx, **kwargs)
+
+ return IgnoreNonKwargs
class Linear(Module):
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ weight_init=None,
+ bias_init=None,
+ ):
super().__init__()
self.in_features = in_features
self.out_features = out_features
- self.weight = init.glorot_normal(out_features, in_features)
+ k = math.sqrt(1 / in_features)
+ if weight_init:
+ self.weight = weight_init(out_features, in_features)
+ else:
+ self.weight = init.uniform(out_features, in_features, min=-k, max=k)
if bias:
- self.bias = init.zeros(out_features)
+ if bias_init:
+ self.bias = bias_init(out_features)
+ else:
+ self.bias = init.uniform(out_features, min=-k, max=k)
else:
self.bias = None
def forward(self, cx, x):
- y = x @ jnp.transpose(cx[self.weight])
+ x, weight = cx.policy.cast_to_compute((x, cx[self.weight]))
+ bias = cx.policy.cast_to_compute(cx[self.bias]) if self.bias else None
+ y = x @ jnp.transpose(weight)
if self.bias:
- y = y + cx[self.bias]
+ y = y + bias
return y
def extra_repr(self) -> str:
- return 'in_features={}, out_features={}, bias={}'.format(
+ return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)
@@ -77,7 +146,7 @@ def forward(self, cx, x):
return cx[self.weight][x]
def extra_repr(self) -> str:
- s = '{num_embeddings}, {embedding_dim}'
+ s = "{num_embeddings}, {embedding_dim}"
# if self.padding_idx is not None:
# s += ', padding_idx={padding_idx}'
# if self.max_norm is not None:
@@ -91,7 +160,6 @@ def extra_repr(self) -> str:
return s.format(**self.__dict__)
-
class Tanh(Module):
def forward(self, cx, x):
return jnp.tanh(x)
@@ -102,34 +170,39 @@ def __init__(self, p=0.5):
self.rate = p
def forward(self, cx, x):
- if cx.mode == 'eval':
+ if cx.mode == "eval":
return x
mask = cx.random.bernoulli(1.0 - self.rate, shape=x.shape)
return x * mask / (1.0 - self.rate)
+
class Dropout2d(Module):
def __init__(self, p=0.5):
self.rate = p
def forward(self, cx, x):
- if cx.mode == 'eval':
+ if cx.mode == "eval":
return x
drop_shape = x.shape[:2] + (1,) * len(x.shape[2:])
mask = cx.random.bernoulli(1.0 - self.rate, shape=drop_shape)
return x * mask / (1.0 - self.rate)
+
class Sigmoid(Module):
def forward(self, cx, x):
return jax.nn.sigmoid(x)
+
class GELU(Module):
def forward(self, cx, x):
return jax.nn.gelu(x)
+
class ReLU(Module):
def forward(self, cx, x):
return jax.nn.relu(x)
+
class LeakyReLU(Module):
def __init__(self, negative_slope=0.01):
self.negative_slope = negative_slope
@@ -151,77 +224,178 @@ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True):
else:
self.weight = None
self.bias = None
- self.axes = tuple(-i for i in range(1, len(normalized_shape)+1))
+ self.axes = tuple(-i for i in range(1, len(normalized_shape) + 1))
def forward(self, cx, x):
- mu = x.mean(axis=self.axes, keepdims=True)
- sigma = jnp.sqrt((x - mu).square().mean(axis=self.axes, keepdims=True) + self.eps)
+ x = cx.policy.cast_to_compute(x)
+ if self.weight:
+ weight = cx[self.weight]
+ else:
+ weight = 1
+ if self.bias:
+ bias = cx[self.bias]
+ else:
+ bias = 0
+ mu = jnp.mean(x, axis=self.axes, keepdims=True)
+ sigma = jnp.std(x, axis=self.axes, keepdims=True)
normed = (x - mu) / sigma
- return cx[self.weight] * normed + cx[self.bias]
+ return weight * normed + bias
class Conv1d(Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, zero_init=False):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ zero_init=False,
+ weight_init=None,
+ bias_init=None,
+ ):
assert in_channels % groups == 0
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
- self.weight = init.kaiming_uniform(out_channels, in_channels//groups, kernel_size, a=math.sqrt(5.0))
+ k = math.sqrt(groups / (in_channels * kernel_size))
if zero_init:
- self.weight = init.zeros(out_channels, in_channels//groups, kernel_size)
+ self.weight = init.zeros(
+ out_channels,
+ in_channels // groups,
+ kernel_size,
+ )
+ elif weight_init:
+ self.weight = weight_init(
+ out_channels,
+ in_channels // groups,
+ kernel_size,
+ )
+ else:
+ self.weight = init.uniform(
+ out_channels,
+ in_channels // groups,
+ kernel_size,
+ min=-k,
+ max=k,
+ )
self.use_bias = bias
if self.use_bias:
- self.bias = init.zeros(out_channels)
+ if bias_init:
+ self.bias = bias_init(out_channels)
+ else:
+ self.bias = init.uniform(out_channels, min=-k, max=k)
else:
self.bias = None
def forward(self, cx, x):
- return jaxtorch.nn.functional.conv1d(x, cx[self.weight], cx[self.bias] if self.use_bias else None,
- stride=self.stride,
- padding=self.padding,
- dilation=self.dilation,
- groups=self.groups)
+ return jaxtorch.nn.functional.conv1d(
+ x,
+ cx[self.weight],
+ self.bias and cx[self.bias],
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
class Conv2d(Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, zero_init=False):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ zero_init=False,
+ weight_init=None,
+ bias_init=None,
+ ):
assert in_channels % groups == 0
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
- self.weight = init.kaiming_uniform(out_channels, in_channels//groups, kernel_size, kernel_size, a=math.sqrt(5.0))
+ k = math.sqrt(groups / (in_channels * kernel_size * kernel_size))
if zero_init:
- self.weight = init.zeros(out_channels, in_channels//groups, kernel_size, kernel_size)
+ self.weight = init.zeros(
+ out_channels, in_channels // groups, kernel_size, kernel_size
+ )
+ elif weight_init:
+ self.weight = weight_init(
+ out_channels,
+ in_channels // groups,
+ kernel_size,
+ kernel_size,
+ )
+ else:
+ self.weight = init.uniform(
+ out_channels,
+ in_channels // groups,
+ kernel_size,
+ kernel_size,
+ min=-k,
+ max=k,
+ )
self.use_bias = bias
if self.use_bias:
- self.bias = init.zeros(out_channels)
+ if bias_init:
+ self.bias = bias_init(out_channels)
+ else:
+ self.bias = init.uniform(out_channels, min=-k, max=k)
else:
self.bias = None
def forward(self, cx, x):
- return jaxtorch.nn.functional.conv2d(x, cx[self.weight], cx[self.bias] if self.use_bias else None,
- stride=self.stride,
- padding=self.padding,
- dilation=self.dilation,
- groups=self.groups)
+ out = jaxtorch.nn.functional.conv2d(
+ x,
+ cx[self.weight],
+ self.bias and cx[self.bias],
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ return out
class SiLU(Module):
def forward(self, cx, x):
return jax.nn.silu(x)
+
class GroupNorm(Module):
- def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
+ def __init__(
+ self,
+ num_groups,
+ num_channels,
+ eps=1e-05,
+ affine=True,
+ weight_init=None,
+ bias_init=None,
+ ):
self.num_groups = num_groups
self.num_channels = num_channels
assert self.num_channels % self.num_groups == 0
self.eps = eps
self.affine = affine
if self.affine:
- self.weight = init.ones(num_channels)
- self.bias = init.zeros(num_channels)
+ if weight_init:
+ self.weight = weight_init(num_channels)
+ else:
+ self.weight = init.ones(num_channels)
+ if bias_init:
+ self.bias = bias_init(num_channels)
+ else:
+ self.bias = init.zeros(num_channels)
else:
self.weight = None
self.bias = None
@@ -229,10 +403,10 @@ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
def forward(self, cx, x):
B, C, *rest = x.shape
assert C == self.num_channels
- x = x.reshape([B, self.num_groups, C//self.num_groups, *rest])
- mu = x.mean(axis=tuple(range(2,len(x.shape))), keepdims=True)
- var = x.var(axis=tuple(range(2,len(x.shape))), keepdims=True)
- y = (x - mu) / jnp.sqrt(var + self.eps)
+ x = x.reshape([B, self.num_groups, C // self.num_groups, *rest])
+ mu = jnp.mean(x, axis=tuple(range(2, len(x.shape))), keepdims=True)
+ std = jnp.std(x, axis=tuple(range(2, len(x.shape))), keepdims=True)
+ y = (x - mu) / std
y = y.reshape([B, C, *rest])
if self.affine:
broadcast_shape = [self.num_channels] + [1] * len(rest)
@@ -241,14 +415,26 @@ def forward(self, cx, x):
y = y * weight + bias
return y
+
class PixelUnshuffle(Module):
def __init__(self, downscale_factor):
self.downscale_factor = downscale_factor
+
def forward(self, cx, x):
- return x.rearrange('... c (h r) (w s) -> ... (c r s) h w', r = self.downscale_factor, s = self.downscale_factor)
+ return x.rearrange(
+ "... c (h r) (w s) -> ... (c r s) h w",
+ r=self.downscale_factor,
+ s=self.downscale_factor,
+ )
+
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor
+
def forward(self, cx, x):
- return x.rearrange('... (c r s) h w -> ... c (h r) (w s)', r = self.upscale_factor, s = self.upscale_factor)
\ No newline at end of file
+ return x.rearrange(
+ "... (c r s) h w -> ... c (h r) (w s)",
+ r=self.upscale_factor,
+ s=self.upscale_factor,
+ )
diff --git a/jaxtorch/pt.py b/jaxtorch/pt.py
index d8b02c4..8fe5af1 100644
--- a/jaxtorch/pt.py
+++ b/jaxtorch/pt.py
@@ -9,22 +9,36 @@
import numpy as np
import torch
+
@torch.no_grad()
-def load(f):
- """Converts torch.Tensor back to jax arrays after loading."""
+def torch_to_jax(torch_dict):
def from_torch(x):
if isinstance(x, torch.Tensor):
return jnp.asarray(x)
return x
- torch_dict = torch.load(f, map_location='cpu')
+
return jax.tree_util.tree_map(from_torch, torch_dict)
+
@torch.no_grad()
-def save(obj, f):
- """Converts jax arrays (anything under jaxlib.xla_extension.DeviceArrayBase) to torch.Tensor before saving."""
+def jax_to_torch(obj):
def to_torch(x):
if isinstance(x, jaxlib.xla_extension.DeviceArrayBase):
return torch.as_tensor(np.array(x))
return x
- torch_dict = jax.tree_util.tree_map(to_torch, obj)
+
+ return jax.tree_util.tree_map(to_torch, obj)
+
+
+@torch.no_grad()
+def load(f):
+ """Converts torch.Tensor back to jax arrays after loading."""
+ torch_dict = torch.load(f, map_location="cpu")
+ return torch_to_jax(torch_dict)
+
+
+@torch.no_grad()
+def save(obj, f):
+ """Converts jax arrays (anything under jaxlib.xla_extension.DeviceArrayBase) to torch.Tensor before saving."""
+ torch_dict = jax_to_torch(obj)
torch.save(torch_dict, f)
diff --git a/poetry.lock b/poetry.lock
new file mode 100644
index 0000000..fa3bf64
--- /dev/null
+++ b/poetry.lock
@@ -0,0 +1,1144 @@
+# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+
+[[package]]
+name = "bleach"
+version = "6.0.0"
+description = "An easy safelist-based HTML-sanitizing tool."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "bleach-6.0.0-py3-none-any.whl", hash = "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4"},
+ {file = "bleach-6.0.0.tar.gz", hash = "sha256:1a1a85c1595e07d8db14c5f09f09e6433502c51c595970edc090551f0db99414"},
+]
+
+[package.dependencies]
+six = ">=1.9.0"
+webencodings = "*"
+
+[package.extras]
+css = ["tinycss2 (>=1.1.0,<1.2)"]
+
+[[package]]
+name = "cachetools"
+version = "5.3.1"
+description = "Extensible memoizing collections and decorators"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "cachetools-5.3.1-py3-none-any.whl", hash = "sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590"},
+ {file = "cachetools-5.3.1.tar.gz", hash = "sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b"},
+]
+
+[[package]]
+name = "certifi"
+version = "2023.5.7"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"},
+ {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"},
+]
+
+[[package]]
+name = "cffi"
+version = "1.15.1"
+description = "Foreign Function Interface for Python calling C code."
+optional = false
+python-versions = "*"
+files = [
+ {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914"},
+ {file = "cffi-1.15.1-cp27-cp27m-win32.whl", hash = "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3"},
+ {file = "cffi-1.15.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e"},
+ {file = "cffi-1.15.1-cp310-cp310-win32.whl", hash = "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2"},
+ {file = "cffi-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8"},
+ {file = "cffi-1.15.1-cp311-cp311-win32.whl", hash = "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d"},
+ {file = "cffi-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104"},
+ {file = "cffi-1.15.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e"},
+ {file = "cffi-1.15.1-cp36-cp36m-win32.whl", hash = "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf"},
+ {file = "cffi-1.15.1-cp36-cp36m-win_amd64.whl", hash = "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497"},
+ {file = "cffi-1.15.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426"},
+ {file = "cffi-1.15.1-cp37-cp37m-win32.whl", hash = "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9"},
+ {file = "cffi-1.15.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045"},
+ {file = "cffi-1.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192"},
+ {file = "cffi-1.15.1-cp38-cp38-win32.whl", hash = "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314"},
+ {file = "cffi-1.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3"},
+ {file = "cffi-1.15.1-cp39-cp39-win32.whl", hash = "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee"},
+ {file = "cffi-1.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c"},
+ {file = "cffi-1.15.1.tar.gz", hash = "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9"},
+]
+
+[package.dependencies]
+pycparser = "*"
+
+[[package]]
+name = "charset-normalizer"
+version = "3.1.0"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-win32.whl", hash = "sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448"},
+ {file = "charset_normalizer-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-win32.whl", hash = "sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909"},
+ {file = "charset_normalizer-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-win32.whl", hash = "sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974"},
+ {file = "charset_normalizer-3.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-win32.whl", hash = "sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0"},
+ {file = "charset_normalizer-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-win32.whl", hash = "sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1"},
+ {file = "charset_normalizer-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b"},
+ {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
+]
+
+[[package]]
+name = "cryptography"
+version = "41.0.1"
+description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "cryptography-41.0.1-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:f73bff05db2a3e5974a6fd248af2566134d8981fd7ab012e5dd4ddb1d9a70699"},
+ {file = "cryptography-41.0.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1a5472d40c8f8e91ff7a3d8ac6dfa363d8e3138b961529c996f3e2df0c7a411a"},
+ {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fa01527046ca5facdf973eef2535a27fec4cb651e4daec4d043ef63f6ecd4ca"},
+ {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b46e37db3cc267b4dea1f56da7346c9727e1209aa98487179ee8ebed09d21e43"},
+ {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d198820aba55660b4d74f7b5fd1f17db3aa5eb3e6893b0a41b75e84e4f9e0e4b"},
+ {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:948224d76c4b6457349d47c0c98657557f429b4e93057cf5a2f71d603e2fc3a3"},
+ {file = "cryptography-41.0.1-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:059e348f9a3c1950937e1b5d7ba1f8e968508ab181e75fc32b879452f08356db"},
+ {file = "cryptography-41.0.1-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:b4ceb5324b998ce2003bc17d519080b4ec8d5b7b70794cbd2836101406a9be31"},
+ {file = "cryptography-41.0.1-cp37-abi3-win32.whl", hash = "sha256:8f4ab7021127a9b4323537300a2acfb450124b2def3756f64dc3a3d2160ee4b5"},
+ {file = "cryptography-41.0.1-cp37-abi3-win_amd64.whl", hash = "sha256:1fee5aacc7367487b4e22484d3c7e547992ed726d14864ee33c0176ae43b0d7c"},
+ {file = "cryptography-41.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9a6c7a3c87d595608a39980ebaa04d5a37f94024c9f24eb7d10262b92f739ddb"},
+ {file = "cryptography-41.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5d092fdfedaec4cbbffbf98cddc915ba145313a6fdaab83c6e67f4e6c218e6f3"},
+ {file = "cryptography-41.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a8e6c2de6fbbcc5e14fd27fb24414507cb3333198ea9ab1258d916f00bc3039"},
+ {file = "cryptography-41.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cb33ccf15e89f7ed89b235cff9d49e2e62c6c981a6061c9c8bb47ed7951190bc"},
+ {file = "cryptography-41.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5f0ff6e18d13a3de56f609dd1fd11470918f770c6bd5d00d632076c727d35485"},
+ {file = "cryptography-41.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7bfc55a5eae8b86a287747053140ba221afc65eb06207bedf6e019b8934b477c"},
+ {file = "cryptography-41.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:eb8163f5e549a22888c18b0d53d6bb62a20510060a22fd5a995ec8a05268df8a"},
+ {file = "cryptography-41.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8dde71c4169ec5ccc1087bb7521d54251c016f126f922ab2dfe6649170a3b8c5"},
+ {file = "cryptography-41.0.1.tar.gz", hash = "sha256:d34579085401d3f49762d2f7d6634d6b6c2ae1242202e860f4d26b046e3a1006"},
+]
+
+[package.dependencies]
+cffi = ">=1.12"
+
+[package.extras]
+docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
+docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
+nox = ["nox"]
+pep8test = ["black", "check-sdist", "mypy", "ruff"]
+sdist = ["build"]
+ssh = ["bcrypt (>=3.1.5)"]
+test = ["pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
+test-randomorder = ["pytest-randomly"]
+
+[[package]]
+name = "docutils"
+version = "0.20.1"
+description = "Docutils -- Python Documentation Utilities"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6"},
+ {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"},
+]
+
+[[package]]
+name = "einops"
+version = "0.6.1"
+description = "A new flavour of deep learning operations"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"},
+ {file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"},
+]
+
+[[package]]
+name = "filelock"
+version = "3.12.0"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "filelock-3.12.0-py3-none-any.whl", hash = "sha256:ad98852315c2ab702aeb628412cbf7e95b7ce8c3bf9565670b4eaecf1db370a9"},
+ {file = "filelock-3.12.0.tar.gz", hash = "sha256:fc03ae43288c013d2ea83c8597001b1129db351aad9c57fe2409327916b8e718"},
+]
+
+[package.extras]
+docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
+
+[[package]]
+name = "google-auth"
+version = "2.17.3"
+description = "Google Authentication Library"
+optional = false
+python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
+files = [
+ {file = "google-auth-2.17.3.tar.gz", hash = "sha256:ce311e2bc58b130fddf316df57c9b3943c2a7b4f6ec31de9663a9333e4064efc"},
+ {file = "google_auth-2.17.3-py2.py3-none-any.whl", hash = "sha256:f586b274d3eb7bd932ea424b1c702a30e0393a2e2bc4ca3eae8263ffd8be229f"},
+]
+
+[package.dependencies]
+cachetools = ">=2.0.0,<6.0"
+pyasn1-modules = ">=0.2.1"
+rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
+six = ">=1.9.0"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "requests (>=2.20.0,<3.0.0dev)"]
+enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
+pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
+reauth = ["pyu2f (>=0.1.5)"]
+requests = ["requests (>=2.20.0,<3.0.0dev)"]
+
+[[package]]
+name = "idna"
+version = "3.4"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
+ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
+]
+
+[[package]]
+name = "importlib-metadata"
+version = "6.6.0"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "importlib_metadata-6.6.0-py3-none-any.whl", hash = "sha256:43dd286a2cd8995d5eaef7fee2066340423b818ed3fd70adf0bad5f1fac53fed"},
+ {file = "importlib_metadata-6.6.0.tar.gz", hash = "sha256:92501cdf9cc66ebd3e612f1b4f0c0765dfa42f0fa38ffb319b6bd84dd675d705"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
+
+[[package]]
+name = "importlib-resources"
+version = "5.12.0"
+description = "Read resources from Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "importlib_resources-5.12.0-py3-none-any.whl", hash = "sha256:7b1deeebbf351c7578e09bf2f63fa2ce8b5ffec296e0d349139d43cca061a81a"},
+ {file = "importlib_resources-5.12.0.tar.gz", hash = "sha256:4be82589bf5c1d7999aedf2a45159d10cb3ca4f19b2271f8792bc8e6da7b22f6"},
+]
+
+[package.dependencies]
+zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[[package]]
+name = "jaraco-classes"
+version = "3.2.3"
+description = "Utility functions for Python class constructs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "jaraco.classes-3.2.3-py3-none-any.whl", hash = "sha256:2353de3288bc6b82120752201c6b1c1a14b058267fa424ed5ce5984e3b922158"},
+ {file = "jaraco.classes-3.2.3.tar.gz", hash = "sha256:89559fa5c1d3c34eff6f631ad80bb21f378dbcbb35dd161fd2c6b93f5be2f98a"},
+]
+
+[package.dependencies]
+more-itertools = "*"
+
+[package.extras]
+docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"]
+testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[[package]]
+name = "jax"
+version = "0.4.11"
+description = "Differentiate, compile, and transform Numpy code."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jax-0.4.11.tar.gz", hash = "sha256:8b1cd443b698339df8d8807578ee141e5b67e36125b3945b146f600177d60d79"},
+]
+
+[package.dependencies]
+importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""}
+ml_dtypes = ">=0.1.0"
+numpy = ">=1.21"
+opt_einsum = "*"
+scipy = ">=1.7"
+
+[package.extras]
+australis = ["protobuf (>=3.13,<4)"]
+ci = ["jaxlib (==0.4.10)"]
+cpu = ["jaxlib (==0.4.11)"]
+cuda = ["jaxlib (==0.4.11+cuda11.cudnn86)"]
+cuda11-cudnn82 = ["jaxlib (==0.4.11+cuda11.cudnn82)"]
+cuda11-cudnn86 = ["jaxlib (==0.4.11+cuda11.cudnn86)"]
+cuda11-local = ["jaxlib (==0.4.11+cuda11.cudnn86)"]
+cuda11-pip = ["jaxlib (==0.4.11+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"]
+cuda12-local = ["jaxlib (==0.4.11+cuda12.cudnn88)"]
+cuda12-pip = ["jaxlib (==0.4.11+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"]
+minimum-jaxlib = ["jaxlib (==0.4.7)"]
+tpu = ["jaxlib (==0.4.11)", "libtpu-nightly (==0.1.dev20230531)"]
+
+[[package]]
+name = "jaxlib"
+version = "0.4.11"
+description = "XLA library for JAX"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jaxlib-0.4.11-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:50e4c169f85eb4ee41e13431f6bb0a47465aaa168c6d87dc458f411b8cd5eec0"},
+ {file = "jaxlib-0.4.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:27d82d31d5c14b0ccdfe6a783837198859cff598eceeb30a657e87ea964a8824"},
+ {file = "jaxlib-0.4.11-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:5b6ca8a31eb7d4a64bc2d2f8c70250c561b981f4cc380161b0e82ff3c8d58a01"},
+ {file = "jaxlib-0.4.11-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:07cc8ff9be4e609b02dbb8690228dff2467e3ca935b8a2d8588f3c0406c3fe6c"},
+ {file = "jaxlib-0.4.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a8ddca478e6103edf06595e24385a34192da2c842b20956f5f84b207f4ada9b5"},
+ {file = "jaxlib-0.4.11-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:5563eac77eb49dc3936f0313b0eeb528274017453e251d6f72ea10410bb2dd34"},
+ {file = "jaxlib-0.4.11-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:791429847181804e9c60e991c5a524268d0989ab73937f7b08b576c3f1af8289"},
+ {file = "jaxlib-0.4.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d81678d7a417a48fef2ebaa42f032bbe351d9cd4de75bd8e2f93a378daeb4acd"},
+ {file = "jaxlib-0.4.11-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:c54c54f7f56a9b42283acfb212fc3feaf953eac03c8a4cc688c4be7abb678d50"},
+ {file = "jaxlib-0.4.11-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:794158edf18b0eb2abb255d4ba7bd6cc0478c2011c3a8df5c168f5a4667757c7"},
+ {file = "jaxlib-0.4.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec873efe22b7e0b85ed097b146b8639eebd9c65d668c18a1439e54c43c631560"},
+ {file = "jaxlib-0.4.11-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:db264df20bed933d29244f6ef2051c6d421515dd1c041ea095ade35152e7cc2c"},
+]
+
+[package.dependencies]
+ml-dtypes = ">=0.1.0"
+numpy = ">=1.21"
+scipy = ">=1.7"
+
+[package.extras]
+cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"]
+cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"]
+
+[[package]]
+name = "jeepney"
+version = "0.8.0"
+description = "Low-level, pure Python DBus protocol wrapper."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "jeepney-0.8.0-py3-none-any.whl", hash = "sha256:c0a454ad016ca575060802ee4d590dd912e35c122fa04e70306de3d076cce755"},
+ {file = "jeepney-0.8.0.tar.gz", hash = "sha256:5efe48d255973902f6badc3ce55e2aa6c5c3b3bc642059ef3a91247bcfcc5806"},
+]
+
+[package.extras]
+test = ["async-timeout", "pytest", "pytest-asyncio (>=0.17)", "pytest-trio", "testpath", "trio"]
+trio = ["async_generator", "trio"]
+
+[[package]]
+name = "jinja2"
+version = "3.1.2"
+description = "A very fast and expressive template engine."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
+ {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[[package]]
+name = "jmp"
+version = "0.0.4"
+description = "JMP is a Mixed Precision library for JAX."
+optional = false
+python-versions = "*"
+files = [
+ {file = "jmp-0.0.4-py3-none-any.whl", hash = "sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d"},
+ {file = "jmp-0.0.4.tar.gz", hash = "sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730"},
+]
+
+[package.dependencies]
+numpy = ">=1.19.5"
+
+[package.extras]
+jax = ["jax (>=0.2.20)", "jaxlib (>=0.1.71)"]
+
+[[package]]
+name = "keyring"
+version = "23.13.1"
+description = "Store and access your passwords safely."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "keyring-23.13.1-py3-none-any.whl", hash = "sha256:771ed2a91909389ed6148631de678f82ddc73737d85a927f382a8a1b157898cd"},
+ {file = "keyring-23.13.1.tar.gz", hash = "sha256:ba2e15a9b35e21908d0aaf4e0a47acc52d6ae33444df0da2b49d41a46ef6d678"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""}
+importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
+"jaraco.classes" = "*"
+jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""}
+pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""}
+SecretStorage = {version = ">=3.2", markers = "sys_platform == \"linux\""}
+
+[package.extras]
+completion = ["shtab"]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"]
+testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[[package]]
+name = "keyrings-google-artifactregistry-auth"
+version = "1.1.2"
+description = "Keyring backend for Google Auth tokens"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "keyrings.google-artifactregistry-auth-1.1.2.tar.gz", hash = "sha256:bd6abb72740d2dfeb4a5c03c3b105c6f7dba169caa29dee3959694f1f02c77de"},
+ {file = "keyrings.google_artifactregistry_auth-1.1.2-py3-none-any.whl", hash = "sha256:e3f18b50fa945c786593014dc225810d191671d4f5f8e12d9259e39bad3605a3"},
+]
+
+[package.dependencies]
+google-auth = "*"
+keyring = "*"
+pluggy = "*"
+requests = "*"
+
+[package.extras]
+testing = ["pytest (>=3.5,!=3.7.3)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=1.2.3)", "pytest-cov", "pytest-flake8", "pytest-mypy"]
+tox = ["tox"]
+
+[[package]]
+name = "markdown-it-py"
+version = "2.2.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "markdown-it-py-2.2.0.tar.gz", hash = "sha256:7c9a5e412688bc771c67432cbfebcdd686c93ce6484913dccf06cb5a0bea35a1"},
+ {file = "markdown_it_py-2.2.0-py3-none-any.whl", hash = "sha256:5a35f8d1870171d9acc47b99612dc146129b631baf04970128b568f190d0cc30"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[[package]]
+name = "markupsafe"
+version = "2.1.3"
+description = "Safely add untrusted strings to HTML/XML markup."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"},
+ {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
+ {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"},
+ {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"},
+ {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"},
+ {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"},
+ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
+]
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
+[[package]]
+name = "ml-dtypes"
+version = "0.2.0"
+description = ""
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"},
+ {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"},
+ {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"},
+ {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"},
+ {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"},
+ {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"},
+ {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"},
+ {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"},
+ {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"},
+ {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"},
+ {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"},
+ {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"},
+ {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"},
+ {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"},
+ {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"},
+ {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"},
+ {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">1.20", markers = "python_version <= \"3.9\""},
+ {version = ">=1.23.3", markers = "python_version > \"3.10\""},
+ {version = ">=1.21.2", markers = "python_version > \"3.9\""},
+]
+
+[package.extras]
+dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"]
+
+[[package]]
+name = "more-itertools"
+version = "9.1.0"
+description = "More routines for operating on iterables, beyond itertools"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "more-itertools-9.1.0.tar.gz", hash = "sha256:cabaa341ad0389ea83c17a94566a53ae4c9d07349861ecb14dc6d0345cf9ac5d"},
+ {file = "more_itertools-9.1.0-py3-none-any.whl", hash = "sha256:d2bc7f02446e86a68911e58ded76d6561eea00cddfb2a91e7019bbb586c799f3"},
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+description = "Python library for arbitrary-precision floating-point arithmetic"
+optional = false
+python-versions = "*"
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[package.extras]
+develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
+docs = ["sphinx"]
+gmpy = ["gmpy2 (>=2.1.0a4)"]
+tests = ["pytest (>=4.6)"]
+
+[[package]]
+name = "networkx"
+version = "3.1"
+description = "Python package for creating and manipulating graphs and networks"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"},
+ {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"},
+]
+
+[package.extras]
+default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"]
+developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"]
+doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"]
+extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"]
+test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"]
+
+[[package]]
+name = "numpy"
+version = "1.24.3"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "numpy-1.24.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c1104d3c036fb81ab923f507536daedc718d0ad5a8707c6061cdfd6d184e570"},
+ {file = "numpy-1.24.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:202de8f38fc4a45a3eea4b63e2f376e5f2dc64ef0fa692838e31a808520efaf7"},
+ {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8535303847b89aa6b0f00aa1dc62867b5a32923e4d1681a35b5eef2d9591a463"},
+ {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d926b52ba1367f9acb76b0df6ed21f0b16a1ad87c6720a1121674e5cf63e2b6"},
+ {file = "numpy-1.24.3-cp310-cp310-win32.whl", hash = "sha256:f21c442fdd2805e91799fbe044a7b999b8571bb0ab0f7850d0cb9641a687092b"},
+ {file = "numpy-1.24.3-cp310-cp310-win_amd64.whl", hash = "sha256:ab5f23af8c16022663a652d3b25dcdc272ac3f83c3af4c02eb8b824e6b3ab9d7"},
+ {file = "numpy-1.24.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9a7721ec204d3a237225db3e194c25268faf92e19338a35f3a224469cb6039a3"},
+ {file = "numpy-1.24.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d6cc757de514c00b24ae8cf5c876af2a7c3df189028d68c0cb4eaa9cd5afc2bf"},
+ {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e3f4e85fc5d4fd311f6e9b794d0c00e7002ec122be271f2019d63376f1d385"},
+ {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1d3c026f57ceaad42f8231305d4653d5f05dc6332a730ae5c0bea3513de0950"},
+ {file = "numpy-1.24.3-cp311-cp311-win32.whl", hash = "sha256:c91c4afd8abc3908e00a44b2672718905b8611503f7ff87390cc0ac3423fb096"},
+ {file = "numpy-1.24.3-cp311-cp311-win_amd64.whl", hash = "sha256:5342cf6aad47943286afa6f1609cad9b4266a05e7f2ec408e2cf7aea7ff69d80"},
+ {file = "numpy-1.24.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7776ea65423ca6a15255ba1872d82d207bd1e09f6d0894ee4a64678dd2204078"},
+ {file = "numpy-1.24.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ae8d0be48d1b6ed82588934aaaa179875e7dc4f3d84da18d7eae6eb3f06c242c"},
+ {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecde0f8adef7dfdec993fd54b0f78183051b6580f606111a6d789cd14c61ea0c"},
+ {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4749e053a29364d3452c034827102ee100986903263e89884922ef01a0a6fd2f"},
+ {file = "numpy-1.24.3-cp38-cp38-win32.whl", hash = "sha256:d933fabd8f6a319e8530d0de4fcc2e6a61917e0b0c271fded460032db42a0fe4"},
+ {file = "numpy-1.24.3-cp38-cp38-win_amd64.whl", hash = "sha256:56e48aec79ae238f6e4395886b5eaed058abb7231fb3361ddd7bfdf4eed54289"},
+ {file = "numpy-1.24.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4719d5aefb5189f50887773699eaf94e7d1e02bf36c1a9d353d9f46703758ca4"},
+ {file = "numpy-1.24.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ec87a7084caa559c36e0a2309e4ecb1baa03b687201d0a847c8b0ed476a7187"},
+ {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8282b9bcfe2b5e7d491d0bf7f3e2da29700cec05b49e64d6246923329f2b02"},
+ {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210461d87fb02a84ef243cac5e814aad2b7f4be953b32cb53327bb49fd77fbb4"},
+ {file = "numpy-1.24.3-cp39-cp39-win32.whl", hash = "sha256:784c6da1a07818491b0ffd63c6bbe5a33deaa0e25a20e1b3ea20cf0e43f8046c"},
+ {file = "numpy-1.24.3-cp39-cp39-win_amd64.whl", hash = "sha256:d5036197ecae68d7f491fcdb4df90082b0d4960ca6599ba2659957aafced7c17"},
+ {file = "numpy-1.24.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:352ee00c7f8387b44d19f4cada524586f07379c0d49270f87233983bc5087ca0"},
+ {file = "numpy-1.24.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7d6acc2e7524c9955e5c903160aa4ea083736fde7e91276b0e5d98e6332812"},
+ {file = "numpy-1.24.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:35400e6a8d102fd07c71ed7dcadd9eb62ee9a6e84ec159bd48c28235bbb0f8e4"},
+ {file = "numpy-1.24.3.tar.gz", hash = "sha256:ab344f1bf21f140adab8e47fdbc7c35a477dc01408791f8ba00d018dd0bc5155"},
+]
+
+[[package]]
+name = "opt-einsum"
+version = "3.3.0"
+description = "Optimizing numpys einsum function"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"},
+ {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"},
+]
+
+[package.dependencies]
+numpy = ">=1.7"
+
+[package.extras]
+docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"]
+tests = ["pytest", "pytest-cov", "pytest-pep8"]
+
+[[package]]
+name = "pkginfo"
+version = "1.9.6"
+description = "Query metadata from sdists / bdists / installed packages."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pkginfo-1.9.6-py3-none-any.whl", hash = "sha256:4b7a555a6d5a22169fcc9cf7bfd78d296b0361adad412a346c1226849af5e546"},
+ {file = "pkginfo-1.9.6.tar.gz", hash = "sha256:8fd5896e8718a4372f0ea9cc9d96f6417c9b986e23a4d116dda26b62cc29d046"},
+]
+
+[package.extras]
+testing = ["pytest", "pytest-cov"]
+
+[[package]]
+name = "pluggy"
+version = "1.0.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
+ {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "pyasn1"
+version = "0.5.0"
+description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+files = [
+ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
+ {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
+]
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.3.0"
+description = "A collection of ASN.1-based protocols modules"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+files = [
+ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"},
+ {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.4.6,<0.6.0"
+
+[[package]]
+name = "pycparser"
+version = "2.21"
+description = "C parser in Python"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
+ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+]
+
+[[package]]
+name = "pygments"
+version = "2.15.1"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "Pygments-2.15.1-py3-none-any.whl", hash = "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"},
+ {file = "Pygments-2.15.1.tar.gz", hash = "sha256:8ace4d3c1dd481894b2005f560ead0f9f19ee64fe983366be1a21e171d12775c"},
+]
+
+[package.extras]
+plugins = ["importlib-metadata"]
+
+[[package]]
+name = "pywin32-ctypes"
+version = "0.2.0"
+description = ""
+optional = false
+python-versions = "*"
+files = [
+ {file = "pywin32-ctypes-0.2.0.tar.gz", hash = "sha256:24ffc3b341d457d48e8922352130cf2644024a4ff09762a2261fd34c36ee5942"},
+ {file = "pywin32_ctypes-0.2.0-py2.py3-none-any.whl", hash = "sha256:9dc2d991b3479cc2df15930958b674a48a227d5361d413827a4cfd0b5876fc98"},
+]
+
+[[package]]
+name = "readme-renderer"
+version = "37.3"
+description = "readme_renderer is a library for rendering \"readme\" descriptions for Warehouse"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "readme_renderer-37.3-py3-none-any.whl", hash = "sha256:f67a16caedfa71eef48a31b39708637a6f4664c4394801a7b0d6432d13907343"},
+ {file = "readme_renderer-37.3.tar.gz", hash = "sha256:cd653186dfc73055656f090f227f5cb22a046d7f71a841dfa305f55c9a513273"},
+]
+
+[package.dependencies]
+bleach = ">=2.1.0"
+docutils = ">=0.13.1"
+Pygments = ">=2.5.1"
+
+[package.extras]
+md = ["cmarkgfm (>=0.8.0)"]
+
+[[package]]
+name = "requests"
+version = "2.31.0"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
+ {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "requests-toolbelt"
+version = "1.0.0"
+description = "A utility belt for advanced users of python-requests"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
+ {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"},
+]
+
+[package.dependencies]
+requests = ">=2.0.1,<3.0.0"
+
+[[package]]
+name = "rfc3986"
+version = "2.0.0"
+description = "Validating URI References per RFC 3986"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "rfc3986-2.0.0-py2.py3-none-any.whl", hash = "sha256:50b1502b60e289cb37883f3dfd34532b8873c7de9f49bb546641ce9cbd256ebd"},
+ {file = "rfc3986-2.0.0.tar.gz", hash = "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c"},
+]
+
+[package.extras]
+idna2008 = ["idna"]
+
+[[package]]
+name = "rich"
+version = "13.4.1"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "rich-13.4.1-py3-none-any.whl", hash = "sha256:d204aadb50b936bf6b1a695385429d192bc1fdaf3e8b907e8e26f4c4e4b5bf75"},
+ {file = "rich-13.4.1.tar.gz", hash = "sha256:76f6b65ea7e5c5d924ba80e322231d7cb5b5981aa60bfc1e694f1bc097fe6fe1"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0,<3.0.0"
+pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
+[[package]]
+name = "rsa"
+version = "4.9"
+description = "Pure-Python RSA implementation"
+optional = false
+python-versions = ">=3.6,<4"
+files = [
+ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
+ {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
+[[package]]
+name = "scipy"
+version = "1.10.1"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = "<3.12,>=3.8"
+files = [
+ {file = "scipy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019"},
+ {file = "scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e"},
+ {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1553b5dcddd64ba9a0d95355e63fe6c3fc303a8fd77c7bc91e77d61363f7433f"},
+ {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c0ff64b06b10e35215abce517252b375e580a6125fd5fdf6421b98efbefb2d2"},
+ {file = "scipy-1.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:fae8a7b898c42dffe3f7361c40d5952b6bf32d10c4569098d276b4c547905ee1"},
+ {file = "scipy-1.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f1564ea217e82c1bbe75ddf7285ba0709ecd503f048cb1236ae9995f64217bd"},
+ {file = "scipy-1.10.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d925fa1c81b772882aa55bcc10bf88324dadb66ff85d548c71515f6689c6dac5"},
+ {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35"},
+ {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d"},
+ {file = "scipy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f"},
+ {file = "scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35"},
+ {file = "scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88"},
+ {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1"},
+ {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f"},
+ {file = "scipy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415"},
+ {file = "scipy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9"},
+ {file = "scipy-1.10.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6"},
+ {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353"},
+ {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601"},
+ {file = "scipy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea"},
+ {file = "scipy-1.10.1.tar.gz", hash = "sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5"},
+]
+
+[package.dependencies]
+numpy = ">=1.19.5,<1.27.0"
+
+[package.extras]
+dev = ["click", "doit (>=0.36.0)", "flake8", "mypy", "pycodestyle", "pydevtool", "rich-click", "typing_extensions"]
+doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
+test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
+[[package]]
+name = "secretstorage"
+version = "3.3.3"
+description = "Python bindings to FreeDesktop.org Secret Service API"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "SecretStorage-3.3.3-py3-none-any.whl", hash = "sha256:f356e6628222568e3af06f2eba8df495efa13b3b63081dafd4f7d9a7b7bc9f99"},
+ {file = "SecretStorage-3.3.3.tar.gz", hash = "sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77"},
+]
+
+[package.dependencies]
+cryptography = ">=2.0"
+jeepney = ">=0.6"
+
+[[package]]
+name = "six"
+version = "1.16.0"
+description = "Python 2 and 3 compatibility utilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+files = [
+ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
+ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
+]
+
+[[package]]
+name = "sympy"
+version = "1.12"
+description = "Computer algebra system (CAS) in Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
+ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
+]
+
+[package.dependencies]
+mpmath = ">=0.19"
+
+[[package]]
+name = "torch"
+version = "2.0.1"
+description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+optional = false
+python-versions = ">=3.8.0"
+files = [
+ {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"},
+ {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"},
+ {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"},
+ {file = "torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:567f84d657edc5582d716900543e6e62353dbe275e61cdc36eda4929e46df9e7"},
+ {file = "torch-2.0.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:787b5a78aa7917465e9b96399b883920c88a08f4eb63b5a5d2d1a16e27d2f89b"},
+ {file = "torch-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e617b1d0abaf6ced02dbb9486803abfef0d581609b09641b34fa315c9c40766d"},
+ {file = "torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b6019b1de4978e96daa21d6a3ebb41e88a0b474898fe251fd96189587408873e"},
+ {file = "torch-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:dbd68cbd1cd9da32fe5d294dd3411509b3d841baecb780b38b3b7b06c7754434"},
+ {file = "torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:ef654427d91600129864644e35deea761fb1fe131710180b952a6f2e2207075e"},
+ {file = "torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:25aa43ca80dcdf32f13da04c503ec7afdf8e77e3a0183dd85cd3e53b2842e527"},
+ {file = "torch-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ef3ea3d25441d3957348f7e99c7824d33798258a2bf5f0f0277cbcadad2e20d"},
+ {file = "torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0882243755ff28895e8e6dc6bc26ebcf5aa0911ed81b2a12f241fc4b09075b13"},
+ {file = "torch-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:f66aa6b9580a22b04d0af54fcd042f52406a8479e2b6a550e3d9f95963e168c8"},
+ {file = "torch-2.0.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1adb60d369f2650cac8e9a95b1d5758e25d526a34808f7448d0bd599e4ae9072"},
+ {file = "torch-2.0.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:1bcffc16b89e296826b33b98db5166f990e3b72654a2b90673e817b16c50e32b"},
+ {file = "torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:e10e1597f2175365285db1b24019eb6f04d53dcd626c735fc502f1e8b6be9875"},
+ {file = "torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:423e0ae257b756bb45a4b49072046772d1ad0c592265c5080070e0767da4e490"},
+ {file = "torch-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8742bdc62946c93f75ff92da00e3803216c6cce9b132fbca69664ca38cfb3e18"},
+ {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"},
+ {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"},
+]
+
+[package.dependencies]
+filelock = "*"
+jinja2 = "*"
+networkx = "*"
+sympy = "*"
+typing-extensions = "*"
+
+[package.extras]
+opt-einsum = ["opt-einsum (>=3.3)"]
+
+[[package]]
+name = "twine"
+version = "4.0.2"
+description = "Collection of utilities for publishing packages on PyPI"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "twine-4.0.2-py3-none-any.whl", hash = "sha256:929bc3c280033347a00f847236564d1c52a3e61b1ac2516c97c48f3ceab756d8"},
+ {file = "twine-4.0.2.tar.gz", hash = "sha256:9e102ef5fdd5a20661eb88fad46338806c3bd32cf1db729603fe3697b1bc83c8"},
+]
+
+[package.dependencies]
+importlib-metadata = ">=3.6"
+keyring = ">=15.1"
+pkginfo = ">=1.8.1"
+readme-renderer = ">=35.0"
+requests = ">=2.20"
+requests-toolbelt = ">=0.8.0,<0.9.0 || >0.9.0"
+rfc3986 = ">=1.4.0"
+rich = ">=12.0.0"
+urllib3 = ">=1.26.0"
+
+[[package]]
+name = "typing-extensions"
+version = "4.6.3"
+description = "Backported and Experimental Type Hints for Python 3.7+"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"},
+ {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"},
+]
+
+[[package]]
+name = "urllib3"
+version = "2.0.3"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"},
+ {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["zstandard (>=0.18.0)"]
+
+[[package]]
+name = "webencodings"
+version = "0.5.1"
+description = "Character encoding aliases for legacy web content"
+optional = false
+python-versions = "*"
+files = [
+ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"},
+ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
+]
+
+[[package]]
+name = "zipp"
+version = "3.15.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"},
+ {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[metadata]
+lock-version = "2.0"
+python-versions = ">=3.8,<3.12"
+content-hash = "2b458d54f55d68afae1597f94a31b15616f882a8c973acd59a09a97aabf4da6b"
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..fe6cbbd
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,23 @@
+[tool.poetry]
+name = "jaxtorch"
+version = "0.1.1"
+description = ""
+authors = ["Your Name "]
+
+[tool.poetry.dependencies]
+python = ">=3.8,<3.12"
+numpy = ">=1.22.3"
+scipy = ">=1.8.0"
+jax = ">=0.3.8,<=0.5"
+einops = ">=0.4.1"
+jaxlib = ">=0.3.7,<=0.5"
+torch = ">=1.11.0"
+jmp = ">=0.0.2"
+
+[tool.poetry.dev-dependencies]
+twine = "^4.0.0"
+"keyrings.google-artifactregistry-auth" = "^1.0.0"
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
diff --git a/scripts/demo.py b/scripts/demo.py
index 1398d2c..5b9dd72 100644
--- a/scripts/demo.py
+++ b/scripts/demo.py
@@ -61,7 +61,7 @@ def forward(self, cx, x):
def loss(params, key):
# Context wraps params and a PRNG key.
- cx = jaxtorch.Context(params, key)
+ cx = jaxtorch.Context(px=params, key=key)
x = jnp.array([1.0,2.0,3.0])
y = jnp.array([4.0,5.0,6.0])
return jnp.mean((model(cx, x) - y)**2)
diff --git a/scripts/main.py b/scripts/main.py
index c35b445..90fb095 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -52,7 +52,7 @@ def forward(self, cx, x):
print(model.state_dict(px))
def loss(px, x, y, key):
- cx = Context(px, key)
+ cx = Context(px=px, key=key)
return square(model(cx, x) - y).mean()
loss_grad = jax.jit(jax.value_and_grad(loss))
diff --git a/scripts/test.py b/scripts/test.py
index 5759fc8..67ef0ef 100644
--- a/scripts/test.py
+++ b/scripts/test.py
@@ -9,13 +9,13 @@
import gpt
def test_layernorm():
- cx = Context(ParamState(), jax.random.PRNGKey(0))
+ cx = Context(px=ParamState(), key=jax.random.PRNGKey(0))
ln = nn.LayerNorm(cx, 5)
x = jax.random.normal(shape=[2, 5], key=jax.random.PRNGKey(1))
print(ln(cx, x))
def test_gpt():
- cx = Context(ParamState(), jax.random.PRNGKey(0))
+ cx = Context(px=ParamState(), key=jax.random.PRNGKey(0))
mconf = gpt.GPT1Config(10, 10)
model = gpt.GPTLM(cx, mconf)
# with open('mod.cbor', 'wb') as fp:
diff --git a/scripts/train_gpt.py b/scripts/train_gpt.py
index e6b14bb..6514059 100644
--- a/scripts/train_gpt.py
+++ b/scripts/train_gpt.py
@@ -33,7 +33,7 @@ def main():
data = fp.read()
def loss(px, seq, key):
- cx = Context(px, key)
+ cx = Context(px=px, key=key)
return model.loss(cx, seq)
f_grad = jax.jit(jax.value_and_grad(loss))
@@ -47,7 +47,7 @@ def loss(px, seq, key):
if counter % 100 == 0:
idx = jnp.array([[-1] * 64])
- idx = model.generate(Context(px, rng.split()), idx)
+ idx = model.generate(Context(px=px, key=rng.split()), idx)
print(bytes(idx.squeeze().tolist()).decode('utf-8', errors='replace'))
counter += 1
diff --git a/test/test_nn.py b/test/test_nn.py
index 3c5c88a..a4a61b4 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -105,7 +105,7 @@ def test_groupnorm():
x = jax.random.normal(key=rng.split(), shape=[2, 32, 2])
x_torch = totorch(x)
- cx = Context(px, rng.split())
+ cx = Context(px=px, key=rng.split())
new_result = new(cx, x)
old_result = old(x_torch)
check(old_result, new_result)
@@ -118,7 +118,7 @@ def test_dropout():
px = module.init_weights(rng.split())
x = jax.random.normal(key=rng.split(), shape=[1, 32])
- cx = Context(px, rng.split())
+ cx = Context(px=px, key=rng.split())
out_train = module(cx.train_mode_(), x)
assert (out_train != 0).sum() < 20