Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion learned_optimization/baselines/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def ema(data: chex.Array, alpha: float, ignore_nan=False):
"""Exponential moving average."""
# TODO(lmetz) dedup with notebook_utils!
if len(data) == 0: # pylint: disable=g-explicit-length-test
if len(data) == 0: # pylint: disable=g-explicit-length-test # pyrefly: ignore[bad-argument-type]
return data
data = onp.asarray(data)
x = onp.zeros_like(data)
Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/baselines/run_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def maybe_archive_hparam_set(task_name: str, hparam_set_name: str) -> bool:


@gin.configurable
def wait_until_ready_then_archive_task(task_name: str = gin.REQUIRED,
hparam_set_name: str = gin.REQUIRED):
def wait_until_ready_then_archive_task(task_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
hparam_set_name: str = gin.REQUIRED): # pyrefly: ignore[bad-function-definition]
"""Continually try to create and save an archive of hparam set + task_name.

This function is designed to be run while the baselines are being computed
Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/baselines/run_time_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


@gin.configurable
def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED,
save_dir: str = gin.REQUIRED):
def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
save_dir: str = gin.REQUIRED): # pyrefly: ignore[bad-function-definition]
"""Compute and save `num_to_run` runtime statistics."""

dev = jax.devices()[0]
Expand Down
8 changes: 4 additions & 4 deletions learned_optimization/baselines/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def _get_gin_name(gin_arg_name: str, fallback: str) -> str:
got_config = False

if got_config:
return configurable.selector
return configurable.selector # pyrefly: ignore[unbound-name]
else:
return fallback


@profile.wrap()
@gin.configurable
def inner_train_task(
task: tasks_base.Task = gin.REQUIRED,
opt: opt_base.Optimizer = gin.REQUIRED,
num_steps: int = gin.REQUIRED,
task: tasks_base.Task = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
opt: opt_base.Optimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
num_steps: int = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
eval_every: int = 10,
eval_batches: int = 5,
last_eval_batches: int = 10,
Expand Down
12 changes: 6 additions & 6 deletions learned_optimization/continuous_eval/run_eval_chief.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def write_results_thread_main(
]

for fn in values_to_metrics_fns:
metric = fn(task_group, values, tasks)
metric = fn(task_group, values, tasks) # pyrefly: ignore[not-callable]
for k, v in metric.items():
if k in metrics:
raise ValueError(f"Duplicate metric key found! [[{k}]]")
Expand Down Expand Up @@ -472,14 +472,14 @@ def write_results_thread_main(
maybe_finished,
metrics[log_to_population_tag],
population_server_name=population_server_name,
population_worker_id=population_worker_id)
population_worker_id=population_worker_id) # pyrefly: ignore[bad-argument-type]



@gin.configurable
def eval_chief_config(chief_name: str = gin.REQUIRED,
num_workers: int = gin.REQUIRED,
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED):
def eval_chief_config(chief_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
num_workers: int = gin.REQUIRED, # pyrefly: ignore[bad-function-definition]
learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED): # pyrefly: ignore[bad-function-definition]
"""Parameters of the evaluation. To be set with gin."""
if chief_name == gin.REQUIRED or num_workers == gin.REQUIRED:
raise ValueError("Must set chief_name and num_workers with gin!")
Expand Down Expand Up @@ -579,7 +579,7 @@ def main(_):
logging.info("Waiting on %s", train_log_dir)

i = 0
while not filesystem.exists(train_log_dir):
while not filesystem.exists(train_log_dir): # pyrefly: ignore[bad-argument-type]
time.sleep(1)
i += 1
if i % 20 == 0:
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/continuous_eval/run_eval_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def connect_to_server_and_do_tasks(train_log_dir: str):
def main(_):
train_log_dir = setup_experiment.setup_experiment(gin_finalize=False)

connect_to_server_and_do_tasks(train_log_dir)
connect_to_server_and_do_tasks(train_log_dir) # pyrefly: ignore[bad-argument-type]


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/learned_optimizers/adafac_mlp_lopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def init(
iteration=jnp.asarray(0, dtype=jnp.int32),
num_steps=jnp.asarray(num_steps))

def update(
def update( # pyrefly: ignore[bad-override]
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
opt_state: AdafacMLPLOptState,
grad: opt_base.Gradient,
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/learned_optimizers/adafac_nominal.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def init(
iteration=jnp.asarray(0, dtype=jnp.int32),
num_steps=jnp.asarray(num_steps))

def update(
def update( # pyrefly: ignore[bad-override]
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
opt_state: AdafacMLPLOptState,
grad: opt_base.Gradient,
Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/learned_optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(self, opts: Sequence[opt_base.Optimizer]):
if len(opts) != 2:
raise ValueError("Only 2 opts are supported for now!")

def init(self, params, model_state=None, num_steps=None, **kwargs):
def init(self, params, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override]
opt_states = tuple([
opt.init(params, model_state, num_steps=num_steps, **kwargs)
for opt in self.opts
Expand Down Expand Up @@ -240,7 +240,7 @@ def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disabl
return SumOptimizerState(
iteration=opt_state.iteration + 1,
params=new_params,
state=model_state,
state=model_state, # pyrefly: ignore[bad-argument-type]
inner_opt_states=tuple(new_opt_states),
)

Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/learned_optimizers/mlp_lopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def init(self,
rolling_features=common.vec_rolling_mom(decays).init(params),
iteration=jnp.asarray(0, dtype=jnp.int32))

def update(
def update( # pyrefly: ignore[bad-override]
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
opt_state: MLPLOptState,
grad: Any,
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/learned_optimizers/nn_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def init(
return NNAdamState(
params=params,
rolling_features=rolling.init(params),
iteration=jnp.asarray(0, dtype=jnp.int32),
iteration=jnp.asarray(0, dtype=jnp.int32), # pyrefly: ignore[bad-argument-type]
state=model_state,
lstm_hidden_state=lstm_hidden_state,
per_layer_lr=jax.tree_util.tree_map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,5 @@ def maybe_add_scope(c):
wrapped = _GinScopeClass(opt, scope)
# For now, just add the lopt to the returned class.
# TODO(lmetz) change this api to return a more structured class?
wrapped.lopt = lopt
wrapped.lopt = lopt # pyrefly: ignore[missing-attribute]
return wrapped # type: ignore
6 changes: 3 additions & 3 deletions learned_optimization/learned_optimizers/rnn_mlp_lopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def corrected_mean(self, state: _LossNormalizerState) -> jnp.ndarray:


def _avg_square_mean(tree: Any) -> jnp.ndarray:
return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree)
return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree) # pyrefly: ignore[bad-return]
]) / len(jax.tree_util.tree_leaves(tree))


Expand Down Expand Up @@ -456,7 +456,7 @@ def mlp_features_for_tensor(self, m: jnp.ndarray, rms: jnp.ndarray,
],
axis=1)

def update(self,
def update(self, # pyrefly: ignore[bad-override]
opt_state: RNNMLPLOptState,
grads,
loss: Optional[jnp.ndarray] = None,
Expand Down Expand Up @@ -560,7 +560,7 @@ def to_map_get_mlp_features(m, rms, g, v, ff_inputs):
v,
ff_inputs, # pytype: disable=wrong-arg-types # jax-ndarray
opt_state.iteration,
num_tensors,
num_tensors, # pyrefly: ignore[bad-argument-type]
)

# Prep the features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def train_one(worker_id: int):
for _ in range(20):
batch = next(te_iterator)
key, key1 = jax.random.split(key)
l = common.loss(params, key1, batch, meta_params, False)
l = common.loss(params, key1, batch, meta_params, False) # pyrefly: ignore[unbound-name]
te_ls.append(l)

batch = next(tr_iterator)
Expand All @@ -92,7 +92,7 @@ def train_one(worker_id: int):
tr_mean_l = onp.mean(tr_ls)

# save to disk
model_state = (params, opt_state)
model_state = (params, opt_state) # pyrefly: ignore[unbound-name]
state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model")
common.save_state(state_path, model_state)

Expand Down
10 changes: 5 additions & 5 deletions learned_optimization/population/examples/simple_cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,25 @@ def mutate_fn(meta_params, direction, _):
for _ in range(5):
batch = next(te_iterator)
key, key1 = jax.random.split(key)
l = common.loss(params, key1, batch)
l = common.loss(params, key1, batch) # pyrefly: ignore[unbound-name]
ls.append(l)
mean_l = onp.mean(ls)

print(f"step={step}, loss={l} path={state_path}")
summary_writer.scalar("loss", l, step=step)
summary_writer.scalar(
"learning_rate", meta_params["learning_rate"], step=step)
"learning_rate", meta_params["learning_rate"], step=step) # pyrefly: ignore[unsupported-operation]
summary_writer.scalar(
"log_learning_rate", onp.log(meta_params["learning_rate"]), step=step)
"log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) # pyrefly: ignore[unsupported-operation]
summary_writer.flush()

# save to disk
# we don't send back raw parameter values, instead we send checkpoint
# locations.
model_state = (params, opt_state)
model_state = (params, opt_state) # pyrefly: ignore[unbound-name]
state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model")
common.save_state(state_path, model_state)
population.set_eval(worker_id, gen_id, step, state_path, mean_l)
population.set_eval(worker_id, gen_id, step, state_path, mean_l) # pyrefly: ignore[bad-argument-type]

# Actually update the params to train one step.
batch = next(tr_iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ def train_one(worker_id):
for _ in range(5):
batch = next(te_iterator)
key, key1 = jax.random.split(key)
l = common.loss(params, key1, batch)
l = common.loss(params, key1, batch) # pyrefly: ignore[unbound-name]
ls.append(l)
mean_l = onp.mean(ls)

# save to disk
model_state = (params, opt_state)
model_state = (params, opt_state) # pyrefly: ignore[unbound-name]
state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model")
common.save_state(state_path, model_state)
population.set_eval(worker_id, gen_id, step, state_path, mean_l)
print(f"{worker_id} ]] step={step}, loss={l} path={state_path}")
summary_writer.scalar("loss", l, step=step)
summary_writer.scalar(
"learning_rate", meta_params["learning_rate"], step=step)
"learning_rate", meta_params["learning_rate"], step=step) # pyrefly: ignore[unsupported-operation]
summary_writer.scalar(
"log_learning_rate", onp.log(meta_params["learning_rate"]), step=step)
"log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) # pyrefly: ignore[unsupported-operation]
summary_writer.flush()

params, opt_state, l = common.update(params, key, opt_state, batch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train(steps=10000):
for i in range(5):
if i % 5 == 0:
l = loss(params)
population.set_eval(0, gen_id, step, params, l)
population.set_eval(0, gen_id, step, params, l) # pyrefly: ignore[bad-argument-type]
print(f"\t {l}, params: {params}, meta_params:{meta_params}")

params = update(params, meta_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def add_worker_to_cache(from_checkpoint: population.Checkpoint,

scores = [center_score, neg_score, pos_score]
idx = onp.nanargmin(scores)
best_checkpoint = [center_steps, neg_steps,
best_checkpoint = [center_steps, neg_steps, # pyrefly: ignore[bad-index]
pos_steps][idx].values()[-1]

meta_params = best_checkpoint.meta_params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def update(
# We assume that the values here are all floating loss values.
# grab the highest performing checkpoint data
best_idx = onp.argmin(values)
genid = current_workers[best_idx].generation_id
genid = current_workers[best_idx].generation_id # pyrefly: ignore[bad-index]

# sort by time or step. These should always be the same
if self._steps_per_exploit:
Expand Down
14 changes: 7 additions & 7 deletions learned_optimization/population/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def maybe_get_worker_data(
old_state = self.serialized_state()

# also potentially mutate the cache
self._active_workers = self.mutate.get_worker_data(
self._active_workers, self._cached, worker_id, generation_id, step,
self._active_workers = self.mutate.get_worker_data( # pyrefly: ignore[bad-assignment]
self._active_workers, self._cached, worker_id, generation_id, step, # pyrefly: ignore[bad-argument-type]
params, meta_params)

new_state = self.serialized_state()
Expand Down Expand Up @@ -204,8 +204,8 @@ def maybe_get_worker_data(
generation_id=generation_id,
params=params,
meta_params=meta_params,
parent=(generation_id, step),
step=step,
parent=(generation_id, step), # pyrefly: ignore[bad-argument-type]
step=step, # pyrefly: ignore[bad-argument-type]
value=None,
time=time.time(),
)
Expand Down Expand Up @@ -254,14 +254,14 @@ def set_eval(self, worker_id: int, generation_id: GenerationID, step: int,
# "cast" to a mutable sequence here to make pytype happy.
mut_active_workers = list(
self._active_workers) # type: MutableSequence[ActiveWorker]
mut_active_workers[worker_id] = self._active_workers[worker_id].replace(
mut_active_workers[worker_id] = self._active_workers[worker_id].replace( # pyrefly: ignore[missing-attribute]
step=step)
mut_active_workers[worker_id] = mut_active_workers[worker_id].replace(
mut_active_workers[worker_id] = mut_active_workers[worker_id].replace( # pyrefly: ignore[missing-attribute]
params=params)
self._active_workers = mut_active_workers

# in light of this new value, run the mutator
self._mutate_state, self._active_workers = self.mutate.update(
self._mutate_state, self._active_workers = self.mutate.update( # pyrefly: ignore[bad-assignment]
self._mutate_state, self._active_workers, self._cached)

self.save_state()
Expand Down
8 changes: 4 additions & 4 deletions learned_optimization/research/brax/brax_env_truncated_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def do_step(unroll_state):
next_env_state = self.env.step(unroll_state.env_state, action)

out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray
loss=-next_env_state.reward,
is_done=False,
loss=-next_env_state.reward, # pyrefly: ignore[bad-argument-type]
is_done=False, # pyrefly: ignore[bad-argument-type]
task_param=None,
iteration=unroll_state.iteration,
mask=True,
mask=True, # pyrefly: ignore[bad-argument-type]
)
return BraxEnvState(
env_state=next_env_state,
Expand All @@ -166,7 +166,7 @@ def do_step(unroll_state):

def reset(unroll_state):
out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray
loss=0.0, is_done=True, task_param=None, iteration=0, mask=False
loss=0.0, is_done=True, task_param=None, iteration=0, mask=False # pyrefly: ignore[bad-argument-type]
)
unroll_state = BraxEnvState(
env_state=self.env.reset(key),
Expand Down
4 changes: 2 additions & 2 deletions learned_optimization/research/data_driven/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def _generate_task(self, key):
def _generate_tasks(self, key):
key_tasks, key_mask = jax.random.split(key)
del key
key_tasks = jax.random.split(key_tasks, self._dataset_size)
key_tasks = jax.random.split(key_tasks, self._dataset_size) # pyrefly: ignore[bad-argument-type]
mask = jax.random.bernoulli(
key_mask, p=self._bias_prob, shape=(self._dataset_size,))
key_mask, p=self._bias_prob, shape=(self._dataset_size,)) # pyrefly: ignore[bad-argument-type]
key_tasks = jnp.where(mask[:, None], self._bias_key[None], key_tasks)
return jax.vmap(self._generate_task)(key_tasks)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def meta_test(params, test_batch, permute_labels: bool):
if self.num_tasks > 0:
test_value, meta_test_value, id_test_value = tree_util.tree_map(
functools.partial(jnp.mean, axis=0),
jnp.split(metric_value, [num_within_tasks, test_tasks.shape[0]]))
jnp.split(metric_value, [num_within_tasks, test_tasks.shape[0]])) # pyrefly: ignore[unbound-name]
else:
test_value = meta_test_value = id_test_value = metric_value
log_dict.update({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def add_batch(nest, batch_size: Optional[int]):


def split_axis(x: jnp.ndarray, shape=Tuple[int], axis=-1):
new_shape = x.shape[:axis] + shape + x.shape[axis:][1:]
new_shape = x.shape[:axis] + shape + x.shape[axis:][1:] # pyrefly: ignore[unsupported-operation]
return x.reshape(new_shape)
Loading