Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions learned_optimization/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def periodically_save_checkpoint(
elif step_interval is not None:
do_save = (current_iteration % step_interval == 0)

if do_save:
if do_save: # pyrefly: ignore[unbound-name]
# if a checkpoint exists already, delete it.

# get the last step
Expand All @@ -184,7 +184,7 @@ def do_save_fn():

path = save_checkpoint(train_log_dir, prefix, value, step, keep=keep)
paths[prefix] = path
_last_checkpoint_time[prefix] = time.time()
_last_checkpoint_time[prefix] = time.time() # pyrefly: ignore[unsupported-operation]

paths = hk.data_structures.to_immutable_dict(paths)
return paths
Expand Down
8 changes: 4 additions & 4 deletions learned_optimization/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def start_server(self):
self._server.Start()

def _is_step_valid(self, step: int) -> bool:
step = onp.asarray(step)
step = onp.asarray(step) # pyrefly: ignore[bad-assignment]
return (self._current_iteration >= step and # pytype: disable=bad-return-type # typed-numpy
(self._current_iteration - step) <= self._staleness)

Expand Down Expand Up @@ -224,7 +224,7 @@ def set_weights(self,
"""
with self._lock:
self._weights = weights
self._current_iteration = onp.asarray(current_iteration)
self._current_iteration = onp.asarray(current_iteration) # pyrefly: ignore[bad-assignment]

before = len(self._outer_gradients)

Expand Down Expand Up @@ -350,7 +350,7 @@ def put_grads(self, worker_id: Any, step: int, value: T):
self._lock.acquire(blocking=True)
assert worker_id < self._num_workers
if step == self._current_iteration:
self._outer_gradients[worker_id] = (step, value)
self._outer_gradients[worker_id] = (step, value) # pyrefly: ignore[unsupported-operation]
self._lock.release()
self._cv.notify_all()

Expand Down Expand Up @@ -425,7 +425,7 @@ def set_weights(self,
del clear_buffer
with self._lock, self._cv:
self._weights = weights
self._current_iteration = onp.asarray(current_iteration)
self._current_iteration = onp.asarray(current_iteration) # pyrefly: ignore[bad-assignment]
self._outer_gradients = {k: None for k in self._outer_gradients.keys()}
self._cv.notify_all()

Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/eval_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def single_task_training_curves(
opt,
opt_state,
key1,
task.datasets.split(s) if use_data else (),
task.datasets.split(s) if use_data else (), # pyrefly: ignore[bad-argument-type, missing-attribute]
eval_batches if not on_last else last_eval_batches,
device=device)
m[f"eval/{s}/loss"] = loss
Expand All @@ -197,7 +197,7 @@ def single_task_training_curves(
eval_xs.append(i)

with profile.Profile("get_batch"):
batch = next(task.datasets.train) if use_data else ()
batch = next(task.datasets.train) if use_data else () # pyrefly: ignore[missing-attribute]
with profile.Profile("put_batch_and_split"):
batch = jax.device_put(batch, device=device)

Expand Down
6 changes: 3 additions & 3 deletions learned_optimization/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def __init__(self, magnitude_opt: Optimizer, direction_opt: Optimizer):
self.magnitude_opt = magnitude_opt
self.direction_opt = direction_opt

def init(self, params, model_state=None, num_steps=None, **kwargs):
def init(self, params, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override]
return GraftedOptimizerState(
iteration=jnp.asarray(0, dtype=jnp.int32),
params=params,
state=model_state,
state=model_state, # pyrefly: ignore[bad-argument-type]
mag_opt_state=self.magnitude_opt.init(
params, model_state=model_state, num_steps=num_steps, **kwargs),
dir_opt_state=self.direction_opt.init(
Expand Down Expand Up @@ -165,7 +165,7 @@ def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disabl
return GraftedOptimizerState(
iteration=opt_state.iteration + 1,
params=next_params,
state=model_state,
state=model_state, # pyrefly: ignore[bad-argument-type]
mag_opt_state=next_mag_opt_state,
dir_opt_state=next_dir_opt_state,
)
Expand Down
6 changes: 3 additions & 3 deletions learned_optimization/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def set_params(self, state, params):
def get_state(self, state):
return state.model_state

def init(self, p, model_state=None, num_steps=None, **kwargs):
def init(self, p, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override]
if num_steps is not None:
rescale_num_steps = num_steps // self.num_average
else:
Expand All @@ -76,7 +76,7 @@ def init(self, p, model_state=None, num_steps=None, **kwargs):
grad_accum,
loss_accum,
inner_opt_state,
model_state=model_state,
model_state=model_state, # pyrefly: ignore[bad-argument-type]
iteration=jnp.asarray(0, dtype=jnp.int64))

def update(self,
Expand Down Expand Up @@ -116,5 +116,5 @@ def do_update(args):
new_grad_accum,
new_loss_accum,
new_inner_opt_state,
model_state=model_state,
model_state=model_state, # pyrefly: ignore[bad-argument-type]
iteration=opt_state.iteration + 1)
2 changes: 1 addition & 1 deletion learned_optimization/optimizers/learning_rate_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ def __init__(self, base_lr: float, decay_amount: float):
def __call__(self,
global_step: chex.Array,
max_steps: Optional[chex.Array] = None) -> chex.Array:
return self.base_lr * (1 - self.decay_amount)**global_step
return self.base_lr * (1 - self.decay_amount)**global_step # pyrefly: ignore[bad-return]
4 changes: 2 additions & 2 deletions learned_optimization/optimizers/nadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
"use_bias_correction": use_bias_correction
}

def init(self, params, model_state=None, num_steps=None):
def init(self, params, model_state=None, num_steps=None): # pyrefly: ignore[bad-override]
return NAdamWState(
iteration=jnp.asarray(0, dtype=jnp.int64),
params=params,
Expand All @@ -177,7 +177,7 @@ def init(self, params, model_state=None, num_steps=None):
num_steps=jnp.asarray(num_steps, dtype=jnp.int64),
state=model_state)

def update(self,
def update(self, # pyrefly: ignore[bad-override]
opt_state: NAdamWState,
grads: Params,
model_state: Optional[ModelState] = None,
Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/optimizers/opt_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def __init__(self, idx):
# We write init and update by constructing new instances of NAdamW to allow
# for vmap-ing over different idx and to prevent jax tracer leaks.

def init(self,
def init(self, # pyrefly: ignore[bad-override]
params: Params,
model_state: Optional[ModelState] = None,
*,
num_steps: int) -> nadamw.NAdamWState:
return nadamw.NAdamW(**_get_optimizer_config(self.idx)).init(
params, model_state, num_steps=num_steps)

def update(self,
def update(self, # pyrefly: ignore[bad-override]
opt_state: nadamw.NAdamWState,
grads: Params,
model_state: Optional[ModelState] = None,
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/optimizers/opt_to_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ def update_fn(

return step, next_state

return optax.GradientTransformationExtraArgs(init_fn, update_fn)
return optax.GradientTransformationExtraArgs(init_fn, update_fn) # pyrefly: ignore[bad-argument-type]
16 changes: 8 additions & 8 deletions learned_optimization/optimizers/optax_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def init(self,
key: Optional[chex.PRNGKey] = None):
return OptaxState( # pytype: disable=wrong-arg-types # jax-ndarray
params=params,
optax_opt_state=self.opt.init(params),
state=model_state,
iteration=0,
optax_opt_state=self.opt.init(params), # pyrefly: ignore[bad-argument-type]
state=model_state, # pyrefly: ignore[bad-argument-type]
iteration=0, # pyrefly: ignore[bad-argument-type]
)

@functools.partial(jax.jit, static_argnums=(0,))
Expand All @@ -74,9 +74,9 @@ def update(self,
update, new_opt_state = self.opt.update(grad, opt_state.optax_opt_state,
opt_state.params)
return OptaxState(
state=model_state,
params=optax.apply_updates(opt_state.params, update),
optax_opt_state=new_opt_state,
state=model_state, # pyrefly: ignore[bad-argument-type]
params=optax.apply_updates(opt_state.params, update), # pyrefly: ignore[bad-argument-type]
optax_opt_state=new_opt_state, # pyrefly: ignore[bad-argument-type]
iteration=opt_state.iteration + 1,
)

Expand Down Expand Up @@ -143,8 +143,8 @@ def name(self):
def piecewise_linear(times: Sequence[float],
vals: Sequence[float]) -> Callable[[float], float]:
"""Returns a function which interpolates piecewise values."""
times = jnp.asarray(times)
vals = jnp.asarray(vals)
times = jnp.asarray(times) # pyrefly: ignore[bad-assignment]
vals = jnp.asarray(vals) # pyrefly: ignore[bad-assignment]

def fn(x):
if len(times) <= 1:
Expand Down
8 changes: 4 additions & 4 deletions learned_optimization/optimizers/optimizer_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def __init__(self, opt, warp_fn):
self._opt = opt
self._warp_fn = warp_fn

def init(self, params, model_state=None, *, num_steps):
def init(self, params, model_state=None, *, num_steps): # pyrefly: ignore[bad-override]
num_steps = jnp.asarray(self._warp_fn(num_steps), jnp.int32)

inner_opt_state = self._opt.init(params, model_state, num_steps=num_steps)
return ExtendTimeState(jnp.asarray(0, jnp.int32), inner_opt_state)

def update(self, opt_state, grad, loss=None, **kwargs):
def update(self, opt_state, grad, loss=None, **kwargs): # pyrefly: ignore[bad-override]
inner_state = opt_state.inner_opt_state
inner_state = inner_state.replace(
iteration=self._warp_fn(opt_state.iteration))
Expand Down Expand Up @@ -77,15 +77,15 @@ def set_params(self, state, params):
def get_state(self, opt_state):
return self.opt.get_state(opt_state)

def init(self, params, model_state=None, **kwargs):
def init(self, params, model_state=None, **kwargs): # pyrefly: ignore[bad-override]
return self.opt.init(params, model_state=model_state, **kwargs)

def update(self, opt_state, grads, model_state=None, loss=None, **kwargs):
ps = self.opt.get_params(opt_state)

if self.add_to_loss:
l2 = [jnp.sum(p**2) for p in jax.tree_util.tree_leaves(ps)]
loss = loss + sum([x * self.weight_decay for x in l2])
loss = loss + sum([x * self.weight_decay for x in l2]) # pyrefly: ignore[unsupported-operation]

grad_l2 = jax.tree_util.tree_map(lambda p: self.weight_decay * p, ps)
grads = jax.tree_util.tree_map(lambda g, g_l2: g + g_l2, grads, grad_l2)
Expand Down
46 changes: 23 additions & 23 deletions learned_optimization/outer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def iter_group_amount(it, n):
@gin.configurable
def build_gradient_estimators(
*,
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED,
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
sample_task_family_fn: Callable[[PRNGKey],
tasks_base.TaskFamily] = gin.REQUIRED,
tasks_base.TaskFamily] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
gradient_estimator_fn: Callable[
[truncated_step_mod.VectorizedTruncatedStep],
gradient_learner.GradientLearner] = gin.REQUIRED,
gradient_learner.GradientLearner] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
truncated_step_fn: Callable[
[tasks_base.TaskFamily, lopt_base.LearnedOptimizer],
truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step
truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step # pyrefly: ignore[bad-function-definition]
.VectorizedLOptTruncatedStep,
key: PRNGKey,
num_gradient_estimators: int,
Expand Down Expand Up @@ -129,14 +129,14 @@ def build_gradient_estimators(
@gin.configurable
def build_gradient_estimators_fixed(
*,
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED,
list_of_task_family_per_machine: ListListTaskFamilyFn = gin.REQUIRED,
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
list_of_task_family_per_machine: ListListTaskFamilyFn = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
gradient_estimator_fn: Callable[
[truncated_step_mod.VectorizedTruncatedStep],
gradient_learner.GradientLearner] = gin.REQUIRED,
gradient_learner.GradientLearner] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
truncated_step_fn: Callable[
[tasks_base.TaskFamily, lopt_base.LearnedOptimizer],
truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step
truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step # pyrefly: ignore[bad-function-definition]
.VectorizedLOptTruncatedStep,
key: PRNGKey,
num_gradient_estimators: int,
Expand Down Expand Up @@ -285,7 +285,7 @@ def maybe_resample_gradient_estimators(
worker_id=worker_id)
gradient_estimators[j] = ests[0]
unroll_states[j] = gradient_estimators[j].init_worker_state(
worker_weights, key2)
worker_weights, key2) # pyrefly: ignore[bad-argument-type]

return gradient_estimators, unroll_states

Expand Down Expand Up @@ -350,7 +350,7 @@ def build_static_and_init_unroll_state(
return estimators, unroll_states

distributed_worker = distributed.DistributedWorker(
train_log_dir, worker_id, learner_address=learner_address)
train_log_dir, worker_id, learner_address=learner_address) # pyrefly: ignore[bad-argument-type]
last_outer_cfg = None
grad_estimators = None
worker_weights = None
Expand Down Expand Up @@ -390,12 +390,12 @@ def build_static_and_init_unroll_state(
gradient_worker_out = gradient_learner.gradient_worker_compute(
worker_weights=worker_weights,
gradient_estimators=[grad_estimators[gidx] for gidx in idxs],
unroll_states=[unroll_states[gidx] for gidx in idxs],
unroll_states=[unroll_states[gidx] for gidx in idxs], # pyrefly: ignore[unbound-name]
key=next(rng),
with_metrics=with_m,
device=device)

unroll_states = list(unroll_states)
unroll_states = list(unroll_states) # pyrefly: ignore[unbound-name]
for oidx, gidx in enumerate(idxs):
unroll_states[gidx] = gradient_worker_out.unroll_states[oidx]

Expand All @@ -405,8 +405,8 @@ def build_static_and_init_unroll_state(
with profile.Profile("grads_to_onp"):
to_put_grads = GradientsFromWorker( # pytype: disable=wrong-arg-types # jax-ndarray
metrics=gradient_worker_out.metrics,
worker_id=worker_id,
total_inner_steps=total_inner_steps,
worker_id=worker_id, # pyrefly: ignore[bad-argument-type]
total_inner_steps=total_inner_steps, # pyrefly: ignore[bad-argument-type]
gen_id=dist_data.gen_id,
outer_trainer_grads=gradient_worker_out.to_put,
)
Expand Down Expand Up @@ -683,7 +683,7 @@ def _load_checkpoint(checkpoint_path):
experiment_name=train_log_dir,
weights=DataForWorker(worker_weights, gen_id, outer_cfg),
current_iteration=step,
num_workers=num_workers,
num_workers=num_workers, # pyrefly: ignore[bad-argument-type]
start_server=False,
port=learner_port)

Expand Down Expand Up @@ -737,7 +737,7 @@ def _load_checkpoint(checkpoint_path):
with profile.Profile("checkpoints"):
opt_checkpoint = gradient_learner.OptCheckpoint(
gradient_learner_state, jnp.asarray(elapsed_time, dtype=jnp.float64),
total_inner_steps)
total_inner_steps) # pyrefly: ignore[bad-argument-type]
param_checkpoint = gradient_learner.ParameterCheckpoint(
outer_learner.get_meta_params(gradient_learner_state), gen_id, step)
paths = checkpoints.periodically_save_checkpoint(
Expand Down Expand Up @@ -818,7 +818,7 @@ def _load_checkpoint(checkpoint_path):
delta_inner_steps=applied_inner_steps,
)

to_write = dict(**to_write, **summarize_outer_cfg(outer_cfg))
to_write = dict(**to_write, **summarize_outer_cfg(outer_cfg)) # pyrefly: ignore[bad-argument-type]

if i % 5 == 0:
elapsed_time = elapsed_time + time.time() - train_start_time
Expand Down Expand Up @@ -998,7 +998,7 @@ def build_static_and_init_unroll_state(worker_weights, key):
to_put_grads = GradientsFromWorker( # pytype: disable=wrong-arg-types # jax-ndarray
metrics=gradient_worker_out.metrics,
worker_id=0,
total_inner_steps=total_inner_steps,
total_inner_steps=total_inner_steps, # pyrefly: ignore[bad-argument-type]
gen_id="no_gen_id",
outer_trainer_grads=gradient_worker_out.to_put,
)
Expand All @@ -1021,7 +1021,7 @@ def build_static_and_init_unroll_state(worker_weights, key):
lopt,
gradient_estimators,
unroll_states,
worker_weights=worker_weights,
worker_weights=worker_weights, # pyrefly: ignore[bad-argument-type]
key=next(rng),
stochastic_resample_frequency=stochastic_resample_frequency,
sample_estimators_fn=sample_estimators_fn,
Expand Down Expand Up @@ -1115,9 +1115,9 @@ def _move_all_gin_config_to_default_scope():
@gin.configurable
def run_train(
train_log_dir: str,
lopt: Union[GinRequired, lopt_base.LearnedOptimizer] = gin.REQUIRED,
lopt: Union[GinRequired, lopt_base.LearnedOptimizer] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
outer_learner_fn: Union[GinRequired, Callable[
[], gradient_learner.GradientLearner]] = gin.REQUIRED,
[], gradient_learner.GradientLearner]] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
num_estimators: int = 2,
is_trainer: bool = True,
is_worker: bool = True,
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def run_train(
if is_trainer and is_worker:
local_train(
train_log_dir=train_log_dir,
outer_learner=outer_learner_fn(),
outer_learner=outer_learner_fn(), # pyrefly: ignore[not-callable]
lopt=lopt,
num_estimators=num_estimators,
summary_every_n=summary_every_n,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def run_train(
try:
train_worker(
worker_id=worker_id,
lopt=lopt,
lopt=lopt, # pyrefly: ignore[bad-argument-type]
num_estimators=num_estimators,
summary_every_n=summary_every_n,
stochastic_resample_frequency=stochastic_resample_frequency,
Expand Down
Loading