diff --git a/learned_optimization/baselines/normalizers.py b/learned_optimization/baselines/normalizers.py index f8064900..866f868c 100644 --- a/learned_optimization/baselines/normalizers.py +++ b/learned_optimization/baselines/normalizers.py @@ -45,7 +45,7 @@ def ema(data: chex.Array, alpha: float, ignore_nan=False): """Exponential moving average.""" # TODO(lmetz) dedup with notebook_utils! - if len(data) == 0: # pylint: disable=g-explicit-length-test + if len(data) == 0: # pylint: disable=g-explicit-length-test # pyrefly: ignore[bad-argument-type] return data data = onp.asarray(data) x = onp.zeros_like(data) diff --git a/learned_optimization/baselines/run_archive.py b/learned_optimization/baselines/run_archive.py index 2747d43b..ef278ce7 100644 --- a/learned_optimization/baselines/run_archive.py +++ b/learned_optimization/baselines/run_archive.py @@ -98,8 +98,8 @@ def maybe_archive_hparam_set(task_name: str, hparam_set_name: str) -> bool: @gin.configurable -def wait_until_ready_then_archive_task(task_name: str = gin.REQUIRED, - hparam_set_name: str = gin.REQUIRED): +def wait_until_ready_then_archive_task(task_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + hparam_set_name: str = gin.REQUIRED): # pyrefly: ignore[bad-function-definition] """Continually try to create and save an archive of hparam set + task_name. This function is designed to be run while the baselines are being computed diff --git a/learned_optimization/baselines/run_time_task.py b/learned_optimization/baselines/run_time_task.py index 8a85ecf0..44979900 100644 --- a/learned_optimization/baselines/run_time_task.py +++ b/learned_optimization/baselines/run_time_task.py @@ -30,8 +30,8 @@ @gin.configurable -def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED, - save_dir: str = gin.REQUIRED): +def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + save_dir: str = gin.REQUIRED): # pyrefly: ignore[bad-function-definition] """Compute and save `num_to_run` runtime statistics.""" dev = jax.devices()[0] diff --git a/learned_optimization/baselines/run_trainer.py b/learned_optimization/baselines/run_trainer.py index b1387e88..ab40e22b 100644 --- a/learned_optimization/baselines/run_trainer.py +++ b/learned_optimization/baselines/run_trainer.py @@ -49,7 +49,7 @@ def _get_gin_name(gin_arg_name: str, fallback: str) -> str: got_config = False if got_config: - return configurable.selector + return configurable.selector # pyrefly: ignore[unbound-name] else: return fallback @@ -57,9 +57,9 @@ def _get_gin_name(gin_arg_name: str, fallback: str) -> str: @profile.wrap() @gin.configurable def inner_train_task( - task: tasks_base.Task = gin.REQUIRED, - opt: opt_base.Optimizer = gin.REQUIRED, - num_steps: int = gin.REQUIRED, + task: tasks_base.Task = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + opt: opt_base.Optimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + num_steps: int = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] eval_every: int = 10, eval_batches: int = 5, last_eval_batches: int = 10, diff --git a/learned_optimization/continuous_eval/run_eval_chief.py b/learned_optimization/continuous_eval/run_eval_chief.py index 85e7e08a..74ded07b 100644 --- a/learned_optimization/continuous_eval/run_eval_chief.py +++ b/learned_optimization/continuous_eval/run_eval_chief.py @@ -443,7 +443,7 @@ def write_results_thread_main( ] for fn in values_to_metrics_fns: - metric = fn(task_group, values, tasks) + metric = fn(task_group, values, tasks) # pyrefly: ignore[not-callable] for k, v in metric.items(): if k in metrics: raise ValueError(f"Duplicate metric key found! [[{k}]]") @@ -472,14 +472,14 @@ def write_results_thread_main( maybe_finished, metrics[log_to_population_tag], population_server_name=population_server_name, - population_worker_id=population_worker_id) + population_worker_id=population_worker_id) # pyrefly: ignore[bad-argument-type] @gin.configurable -def eval_chief_config(chief_name: str = gin.REQUIRED, - num_workers: int = gin.REQUIRED, - learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED): +def eval_chief_config(chief_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + num_workers: int = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED): # pyrefly: ignore[bad-function-definition] """Parameters of the evaluation. To be set with gin.""" if chief_name == gin.REQUIRED or num_workers == gin.REQUIRED: raise ValueError("Must set chief_name and num_workers with gin!") @@ -579,7 +579,7 @@ def main(_): logging.info("Waiting on %s", train_log_dir) i = 0 - while not filesystem.exists(train_log_dir): + while not filesystem.exists(train_log_dir): # pyrefly: ignore[bad-argument-type] time.sleep(1) i += 1 if i % 20 == 0: diff --git a/learned_optimization/continuous_eval/run_eval_worker.py b/learned_optimization/continuous_eval/run_eval_worker.py index 58c3a278..85a4d8ad 100644 --- a/learned_optimization/continuous_eval/run_eval_worker.py +++ b/learned_optimization/continuous_eval/run_eval_worker.py @@ -276,7 +276,7 @@ def connect_to_server_and_do_tasks(train_log_dir: str): def main(_): train_log_dir = setup_experiment.setup_experiment(gin_finalize=False) - connect_to_server_and_do_tasks(train_log_dir) + connect_to_server_and_do_tasks(train_log_dir) # pyrefly: ignore[bad-argument-type] if __name__ == "__main__": diff --git a/learned_optimization/learned_optimizers/adafac_mlp_lopt.py b/learned_optimization/learned_optimizers/adafac_mlp_lopt.py index eec6a31b..8733a001 100644 --- a/learned_optimization/learned_optimizers/adafac_mlp_lopt.py +++ b/learned_optimization/learned_optimizers/adafac_mlp_lopt.py @@ -404,7 +404,7 @@ def init( iteration=jnp.asarray(0, dtype=jnp.int32), num_steps=jnp.asarray(num_steps)) - def update( + def update( # pyrefly: ignore[bad-override] self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks opt_state: AdafacMLPLOptState, grad: opt_base.Gradient, diff --git a/learned_optimization/learned_optimizers/adafac_nominal.py b/learned_optimization/learned_optimizers/adafac_nominal.py index cd6aca31..da304996 100644 --- a/learned_optimization/learned_optimizers/adafac_nominal.py +++ b/learned_optimization/learned_optimizers/adafac_nominal.py @@ -477,7 +477,7 @@ def init( iteration=jnp.asarray(0, dtype=jnp.int32), num_steps=jnp.asarray(num_steps)) - def update( + def update( # pyrefly: ignore[bad-override] self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks opt_state: AdafacMLPLOptState, grad: opt_base.Gradient, diff --git a/learned_optimization/learned_optimizers/base.py b/learned_optimization/learned_optimizers/base.py index 9a676dcc..8df331e5 100644 --- a/learned_optimization/learned_optimizers/base.py +++ b/learned_optimization/learned_optimizers/base.py @@ -209,7 +209,7 @@ def __init__(self, opts: Sequence[opt_base.Optimizer]): if len(opts) != 2: raise ValueError("Only 2 opts are supported for now!") - def init(self, params, model_state=None, num_steps=None, **kwargs): + def init(self, params, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override] opt_states = tuple([ opt.init(params, model_state, num_steps=num_steps, **kwargs) for opt in self.opts @@ -240,7 +240,7 @@ def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disabl return SumOptimizerState( iteration=opt_state.iteration + 1, params=new_params, - state=model_state, + state=model_state, # pyrefly: ignore[bad-argument-type] inner_opt_states=tuple(new_opt_states), ) diff --git a/learned_optimization/learned_optimizers/mlp_lopt.py b/learned_optimization/learned_optimizers/mlp_lopt.py index 64f75ff5..72790e97 100644 --- a/learned_optimization/learned_optimizers/mlp_lopt.py +++ b/learned_optimization/learned_optimizers/mlp_lopt.py @@ -115,7 +115,7 @@ def init(self, rolling_features=common.vec_rolling_mom(decays).init(params), iteration=jnp.asarray(0, dtype=jnp.int32)) - def update( + def update( # pyrefly: ignore[bad-override] self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks opt_state: MLPLOptState, grad: Any, diff --git a/learned_optimization/learned_optimizers/nn_adam.py b/learned_optimization/learned_optimizers/nn_adam.py index 1402765b..dad022f5 100644 --- a/learned_optimization/learned_optimizers/nn_adam.py +++ b/learned_optimization/learned_optimizers/nn_adam.py @@ -272,7 +272,7 @@ def init( return NNAdamState( params=params, rolling_features=rolling.init(params), - iteration=jnp.asarray(0, dtype=jnp.int32), + iteration=jnp.asarray(0, dtype=jnp.int32), # pyrefly: ignore[bad-argument-type] state=model_state, lstm_hidden_state=lstm_hidden_state, per_layer_lr=jax.tree_util.tree_map( diff --git a/learned_optimization/learned_optimizers/opt_from_checkpoint.py b/learned_optimization/learned_optimizers/opt_from_checkpoint.py index 5de85306..bba60931 100644 --- a/learned_optimization/learned_optimizers/opt_from_checkpoint.py +++ b/learned_optimization/learned_optimizers/opt_from_checkpoint.py @@ -151,5 +151,5 @@ def maybe_add_scope(c): wrapped = _GinScopeClass(opt, scope) # For now, just add the lopt to the returned class. # TODO(lmetz) change this api to return a more structured class? - wrapped.lopt = lopt + wrapped.lopt = lopt # pyrefly: ignore[missing-attribute] return wrapped # type: ignore diff --git a/learned_optimization/learned_optimizers/rnn_mlp_lopt.py b/learned_optimization/learned_optimizers/rnn_mlp_lopt.py index 29688134..25bf7b23 100644 --- a/learned_optimization/learned_optimizers/rnn_mlp_lopt.py +++ b/learned_optimization/learned_optimizers/rnn_mlp_lopt.py @@ -92,7 +92,7 @@ def corrected_mean(self, state: _LossNormalizerState) -> jnp.ndarray: def _avg_square_mean(tree: Any) -> jnp.ndarray: - return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree) + return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree) # pyrefly: ignore[bad-return] ]) / len(jax.tree_util.tree_leaves(tree)) @@ -456,7 +456,7 @@ def mlp_features_for_tensor(self, m: jnp.ndarray, rms: jnp.ndarray, ], axis=1) - def update(self, + def update(self, # pyrefly: ignore[bad-override] opt_state: RNNMLPLOptState, grads, loss: Optional[jnp.ndarray] = None, @@ -560,7 +560,7 @@ def to_map_get_mlp_features(m, rms, g, v, ff_inputs): v, ff_inputs, # pytype: disable=wrong-arg-types # jax-ndarray opt_state.iteration, - num_tensors, + num_tensors, # pyrefly: ignore[bad-argument-type] ) # Prep the features diff --git a/learned_optimization/population/examples/complex_cnn/train_threads.py b/learned_optimization/population/examples/complex_cnn/train_threads.py index 6b973f28..d5243503 100644 --- a/learned_optimization/population/examples/complex_cnn/train_threads.py +++ b/learned_optimization/population/examples/complex_cnn/train_threads.py @@ -81,7 +81,7 @@ def train_one(worker_id: int): for _ in range(20): batch = next(te_iterator) key, key1 = jax.random.split(key) - l = common.loss(params, key1, batch, meta_params, False) + l = common.loss(params, key1, batch, meta_params, False) # pyrefly: ignore[unbound-name] te_ls.append(l) batch = next(tr_iterator) @@ -92,7 +92,7 @@ def train_one(worker_id: int): tr_mean_l = onp.mean(tr_ls) # save to disk - model_state = (params, opt_state) + model_state = (params, opt_state) # pyrefly: ignore[unbound-name] state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model") common.save_state(state_path, model_state) diff --git a/learned_optimization/population/examples/simple_cnn/train.py b/learned_optimization/population/examples/simple_cnn/train.py index 83f99c7f..1c720654 100644 --- a/learned_optimization/population/examples/simple_cnn/train.py +++ b/learned_optimization/population/examples/simple_cnn/train.py @@ -94,25 +94,25 @@ def mutate_fn(meta_params, direction, _): for _ in range(5): batch = next(te_iterator) key, key1 = jax.random.split(key) - l = common.loss(params, key1, batch) + l = common.loss(params, key1, batch) # pyrefly: ignore[unbound-name] ls.append(l) mean_l = onp.mean(ls) print(f"step={step}, loss={l} path={state_path}") summary_writer.scalar("loss", l, step=step) summary_writer.scalar( - "learning_rate", meta_params["learning_rate"], step=step) + "learning_rate", meta_params["learning_rate"], step=step) # pyrefly: ignore[unsupported-operation] summary_writer.scalar( - "log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) + "log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) # pyrefly: ignore[unsupported-operation] summary_writer.flush() # save to disk # we don't send back raw parameter values, instead we send checkpoint # locations. - model_state = (params, opt_state) + model_state = (params, opt_state) # pyrefly: ignore[unbound-name] state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model") common.save_state(state_path, model_state) - population.set_eval(worker_id, gen_id, step, state_path, mean_l) + population.set_eval(worker_id, gen_id, step, state_path, mean_l) # pyrefly: ignore[bad-argument-type] # Actually update the params to train one step. batch = next(tr_iterator) diff --git a/learned_optimization/population/examples/simple_cnn/train_threads.py b/learned_optimization/population/examples/simple_cnn/train_threads.py index b7d1c3bc..9eaf30fd 100644 --- a/learned_optimization/population/examples/simple_cnn/train_threads.py +++ b/learned_optimization/population/examples/simple_cnn/train_threads.py @@ -78,21 +78,21 @@ def train_one(worker_id): for _ in range(5): batch = next(te_iterator) key, key1 = jax.random.split(key) - l = common.loss(params, key1, batch) + l = common.loss(params, key1, batch) # pyrefly: ignore[unbound-name] ls.append(l) mean_l = onp.mean(ls) # save to disk - model_state = (params, opt_state) + model_state = (params, opt_state) # pyrefly: ignore[unbound-name] state_path = os.path.join(train_log_dir, f"{step}__{gen_id}.model") common.save_state(state_path, model_state) population.set_eval(worker_id, gen_id, step, state_path, mean_l) print(f"{worker_id} ]] step={step}, loss={l} path={state_path}") summary_writer.scalar("loss", l, step=step) summary_writer.scalar( - "learning_rate", meta_params["learning_rate"], step=step) + "learning_rate", meta_params["learning_rate"], step=step) # pyrefly: ignore[unsupported-operation] summary_writer.scalar( - "log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) + "log_learning_rate", onp.log(meta_params["learning_rate"]), step=step) # pyrefly: ignore[unsupported-operation] summary_writer.flush() params, opt_state, l = common.update(params, key, opt_state, batch, diff --git a/learned_optimization/population/examples/synthetic/train.py b/learned_optimization/population/examples/synthetic/train.py index f065dae2..ef16047a 100644 --- a/learned_optimization/population/examples/synthetic/train.py +++ b/learned_optimization/population/examples/synthetic/train.py @@ -78,7 +78,7 @@ def train(steps=10000): for i in range(5): if i % 5 == 0: l = loss(params) - population.set_eval(0, gen_id, step, params, l) + population.set_eval(0, gen_id, step, params, l) # pyrefly: ignore[bad-argument-type] print(f"\t {l}, params: {params}, meta_params:{meta_params}") params = update(params, meta_params) diff --git a/learned_optimization/population/mutators/single_worker_explore.py b/learned_optimization/population/mutators/single_worker_explore.py index d46bfcc1..c8507ab6 100644 --- a/learned_optimization/population/mutators/single_worker_explore.py +++ b/learned_optimization/population/mutators/single_worker_explore.py @@ -195,7 +195,7 @@ def add_worker_to_cache(from_checkpoint: population.Checkpoint, scores = [center_score, neg_score, pos_score] idx = onp.nanargmin(scores) - best_checkpoint = [center_steps, neg_steps, + best_checkpoint = [center_steps, neg_steps, # pyrefly: ignore[bad-index] pos_steps][idx].values()[-1] meta_params = best_checkpoint.meta_params diff --git a/learned_optimization/population/mutators/winner_take_all_genetic.py b/learned_optimization/population/mutators/winner_take_all_genetic.py index 41f79a11..de634246 100644 --- a/learned_optimization/population/mutators/winner_take_all_genetic.py +++ b/learned_optimization/population/mutators/winner_take_all_genetic.py @@ -97,7 +97,7 @@ def update( # We assume that the values here are all floating loss values. # grab the highest performing checkpoint data best_idx = onp.argmin(values) - genid = current_workers[best_idx].generation_id + genid = current_workers[best_idx].generation_id # pyrefly: ignore[bad-index] # sort by time or step. These should always be the same if self._steps_per_exploit: diff --git a/learned_optimization/population/population.py b/learned_optimization/population/population.py index 0329d18b..d33fede7 100644 --- a/learned_optimization/population/population.py +++ b/learned_optimization/population/population.py @@ -170,8 +170,8 @@ def maybe_get_worker_data( old_state = self.serialized_state() # also potentially mutate the cache - self._active_workers = self.mutate.get_worker_data( - self._active_workers, self._cached, worker_id, generation_id, step, + self._active_workers = self.mutate.get_worker_data( # pyrefly: ignore[bad-assignment] + self._active_workers, self._cached, worker_id, generation_id, step, # pyrefly: ignore[bad-argument-type] params, meta_params) new_state = self.serialized_state() @@ -204,8 +204,8 @@ def maybe_get_worker_data( generation_id=generation_id, params=params, meta_params=meta_params, - parent=(generation_id, step), - step=step, + parent=(generation_id, step), # pyrefly: ignore[bad-argument-type] + step=step, # pyrefly: ignore[bad-argument-type] value=None, time=time.time(), ) @@ -254,14 +254,14 @@ def set_eval(self, worker_id: int, generation_id: GenerationID, step: int, # "cast" to a mutable sequence here to make pytype happy. mut_active_workers = list( self._active_workers) # type: MutableSequence[ActiveWorker] - mut_active_workers[worker_id] = self._active_workers[worker_id].replace( + mut_active_workers[worker_id] = self._active_workers[worker_id].replace( # pyrefly: ignore[missing-attribute] step=step) - mut_active_workers[worker_id] = mut_active_workers[worker_id].replace( + mut_active_workers[worker_id] = mut_active_workers[worker_id].replace( # pyrefly: ignore[missing-attribute] params=params) self._active_workers = mut_active_workers # in light of this new value, run the mutator - self._mutate_state, self._active_workers = self.mutate.update( + self._mutate_state, self._active_workers = self.mutate.update( # pyrefly: ignore[bad-assignment] self._mutate_state, self._active_workers, self._cached) self.save_state() diff --git a/learned_optimization/research/brax/brax_env_truncated_step.py b/learned_optimization/research/brax/brax_env_truncated_step.py index 702d9b8b..b0c8eaad 100644 --- a/learned_optimization/research/brax/brax_env_truncated_step.py +++ b/learned_optimization/research/brax/brax_env_truncated_step.py @@ -153,11 +153,11 @@ def do_step(unroll_state): next_env_state = self.env.step(unroll_state.env_state, action) out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray - loss=-next_env_state.reward, - is_done=False, + loss=-next_env_state.reward, # pyrefly: ignore[bad-argument-type] + is_done=False, # pyrefly: ignore[bad-argument-type] task_param=None, iteration=unroll_state.iteration, - mask=True, + mask=True, # pyrefly: ignore[bad-argument-type] ) return BraxEnvState( env_state=next_env_state, @@ -166,7 +166,7 @@ def do_step(unroll_state): def reset(unroll_state): out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray - loss=0.0, is_done=True, task_param=None, iteration=0, mask=False + loss=0.0, is_done=True, task_param=None, iteration=0, mask=False # pyrefly: ignore[bad-argument-type] ) unroll_state = BraxEnvState( env_state=self.env.reset(key), diff --git a/learned_optimization/research/data_driven/data.py b/learned_optimization/research/data_driven/data.py index ce3c4212..24d44f6b 100644 --- a/learned_optimization/research/data_driven/data.py +++ b/learned_optimization/research/data_driven/data.py @@ -152,9 +152,9 @@ def _generate_task(self, key): def _generate_tasks(self, key): key_tasks, key_mask = jax.random.split(key) del key - key_tasks = jax.random.split(key_tasks, self._dataset_size) + key_tasks = jax.random.split(key_tasks, self._dataset_size) # pyrefly: ignore[bad-argument-type] mask = jax.random.bernoulli( - key_mask, p=self._bias_prob, shape=(self._dataset_size,)) + key_mask, p=self._bias_prob, shape=(self._dataset_size,)) # pyrefly: ignore[bad-argument-type] key_tasks = jnp.where(mask[:, None], self._bias_key[None], key_tasks) return jax.vmap(self._generate_task)(key_tasks) diff --git a/learned_optimization/research/data_driven/mnist_projections.py b/learned_optimization/research/data_driven/mnist_projections.py index 18c1b6d4..7a4bd584 100644 --- a/learned_optimization/research/data_driven/mnist_projections.py +++ b/learned_optimization/research/data_driven/mnist_projections.py @@ -421,7 +421,7 @@ def meta_test(params, test_batch, permute_labels: bool): if self.num_tasks > 0: test_value, meta_test_value, id_test_value = tree_util.tree_map( functools.partial(jnp.mean, axis=0), - jnp.split(metric_value, [num_within_tasks, test_tasks.shape[0]])) + jnp.split(metric_value, [num_within_tasks, test_tasks.shape[0]])) # pyrefly: ignore[unbound-name] else: test_value = meta_test_value = id_test_value = metric_value log_dict.update({ diff --git a/learned_optimization/research/data_driven/model_components.py b/learned_optimization/research/data_driven/model_components.py index fa5cdb4c..23205eb1 100644 --- a/learned_optimization/research/data_driven/model_components.py +++ b/learned_optimization/research/data_driven/model_components.py @@ -86,5 +86,5 @@ def add_batch(nest, batch_size: Optional[int]): def split_axis(x: jnp.ndarray, shape=Tuple[int], axis=-1): - new_shape = x.shape[:axis] + shape + x.shape[axis:][1:] + new_shape = x.shape[:axis] + shape + x.shape[axis:][1:] # pyrefly: ignore[unsupported-operation] return x.reshape(new_shape) diff --git a/learned_optimization/research/data_driven/models.py b/learned_optimization/research/data_driven/models.py index 9aa20316..5693abf7 100644 --- a/learned_optimization/research/data_driven/models.py +++ b/learned_optimization/research/data_driven/models.py @@ -284,9 +284,9 @@ def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, class LayerState(NamedTuple): - lstm_state: hk.LSTMState = None - fwd_msg: jnp.ndarray = None - bwd_msg: jnp.ndarray = None + lstm_state: hk.LSTMState = None # pyrefly: ignore[bad-assignment] + fwd_msg: jnp.ndarray = None # pyrefly: ignore[bad-assignment] + bwd_msg: jnp.ndarray = None # pyrefly: ignore[bad-assignment] @gin.configurable() @@ -362,7 +362,7 @@ def __call__(self, out, lstm_state = self._tick(lstm_state, fwd_msg, bwd_msg, aux) # Update forward messages - out_fwd_msg = self._fwd_messenger(out).mean(axis=0) + out_fwd_msg = self._fwd_messenger(out).mean(axis=0) # pyrefly: ignore[unbound-name] # Update backward messages out_bwd_msg = self._bwd_messenger(out).mean(axis=1) @@ -414,7 +414,7 @@ def __call__(self, layer = self.layers[i] _, new_states[i + 1] = layer(state, prev_s.fwd_msg, next_s.bwd_msg, aux) - return out, new_states[1:-1] + return out, new_states[1:-1] # pyrefly: ignore[unbound-name] @gin.configurable() @@ -591,7 +591,7 @@ def __init__(self, self._use_maml = use_maml self._grad_func = jax.grad(self._loss, has_aux=True) - self._network = hk.without_apply_rng(hk.transform(self._network)) + self._network = hk.without_apply_rng(hk.transform(self._network)) # pyrefly: ignore[bad-assignment] self._opt = getattr(optax, optimizer)(learning_rate) def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: @@ -618,7 +618,7 @@ def _network(self, x: jnp.ndarray): return x def _loss(self, params, x, labels): - logits = self._network.apply(params, x) + logits = self._network.apply(params, x) # pyrefly: ignore[missing-attribute] loss = optax.softmax_cross_entropy(logits, labels) return loss, logits @@ -636,10 +636,10 @@ def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: dummy_inp = inputs[0] if self._use_maml: key = hk.next_rng_key() if hk.running_init() else None - params = hk.lift(self._network.init, name='maml_lift')(key, dummy_inp) + params = hk.lift(self._network.init, name='maml_lift')(key, dummy_inp) # pyrefly: ignore[missing-attribute] else: key = hk.next_rng_key() - params = self._network.init(key, dummy_inp) + params = self._network.init(key, dummy_inp) # pyrefly: ignore[missing-attribute] opt_state = self._opt.init(params) def scan_tick(carry, x): diff --git a/learned_optimization/research/data_driven/run_mnist_projections.py b/learned_optimization/research/data_driven/run_mnist_projections.py index 6a27f80c..ab4d2b87 100644 --- a/learned_optimization/research/data_driven/run_mnist_projections.py +++ b/learned_optimization/research/data_driven/run_mnist_projections.py @@ -27,7 +27,7 @@ def main(_) -> None: rank = jax.process_index() train_log_dir = setup_experiment.setup_experiment(make_dir=(rank == 0)) - train(train_log_dir) + train(train_log_dir) # pyrefly: ignore[bad-argument-type] def train(training_log_directory: str): @@ -37,7 +37,7 @@ def train(training_log_directory: str): training_log_directory: Directory to store log data to. """ - experiment = mnist_projections.ProjectionExperiment(training_log_directory) + experiment = mnist_projections.ProjectionExperiment(training_log_directory) # pyrefly: ignore[missing-argument] log_dict = experiment.run() if jax.process_index() == 0: diff --git a/learned_optimization/research/data_driven/transformer.py b/learned_optimization/research/data_driven/transformer.py index df3adfa4..d396a57b 100644 --- a/learned_optimization/research/data_driven/transformer.py +++ b/learned_optimization/research/data_driven/transformer.py @@ -55,7 +55,7 @@ def __call__( seq_len = query.shape[1] causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len))) - mask = mask * causal_mask if mask is not None else causal_mask + mask = mask * causal_mask if mask is not None else causal_mask # pyrefly: ignore[bad-assignment] return super().__call__(query, key, value, mask) diff --git a/learned_optimization/research/distill/truncated_distill.py b/learned_optimization/research/distill/truncated_distill.py index 4964a3a3..00892454 100644 --- a/learned_optimization/research/distill/truncated_distill.py +++ b/learned_optimization/research/distill/truncated_distill.py @@ -94,7 +94,7 @@ def state_compare_mse_params(prev_src_state: UnrollState, "state_compare_loss", "with_summary", "unroll_length", "outer_param_noise"), ) -@functools.partial(summary.add_with_summary, static_argnums=(0, 1, 2, 3, 4)) +@functools.partial(summary.add_with_summary, static_argnums=(0, 1, 2, 3, 4)) # pyrefly: ignore[bad-specialization] @functools.partial(jax.value_and_grad, has_aux=True, argnums=5) def distill_truncated_unroll( truncated_step: truncated_step_mod.VectorizedTruncatedStep, @@ -226,7 +226,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, if self.outer_param_noise > 0.0: theta = _multi_perturb(worker_weights.theta, key3, self.outer_param_noise, - self.truncated_step.num_tasks) + self.truncated_step.num_tasks) # pyrefly: ignore[missing-attribute] theta_is_vector = True else: theta = worker_weights.theta @@ -236,7 +236,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, theta, worker_weights.outer_state, key1, - theta_is_vector=theta_is_vector) + theta_is_vector=theta_is_vector) # pyrefly: ignore[unexpected-keyword] src_unroll_state = self.src_truncated_step.init_step_state( theta, diff --git a/learned_optimization/research/general_lopt/hyper_v2.py b/learned_optimization/research/general_lopt/hyper_v2.py index a4e61e9b..c877d6f9 100644 --- a/learned_optimization/research/general_lopt/hyper_v2.py +++ b/learned_optimization/research/general_lopt/hyper_v2.py @@ -384,7 +384,7 @@ def _ff_mod(self, global_feat, extra_step_mult, p, g, m, rms, fac_g, rsqrt = lax.rsqrt(rms + 1e-6) if self.with_rms_norm_g: - norm_g = m * rsqrt + norm_g = m * rsqrt # pyrefly: ignore[unbound-name] inps.append(norm_g) if self.with_rsqrt_rms: @@ -502,8 +502,8 @@ def _ff_mod(self, global_feat, extra_step_mult, p, g, m, rms, fac_g, o = jax.nn.relu(o) # extract outputs from MLP to construct a step. - direction = o[..., 0] - magnitude_param = o[..., 1] + direction = o[..., 0] # pyrefly: ignore[bad-index] + magnitude_param = o[..., 1] # pyrefly: ignore[bad-index] mag_param = jnp.exp(magnitude_param * self.exp_mult) param_scale = jnp.sqrt(jnp.mean(jnp.square(p)) + 1e-9) @@ -691,7 +691,7 @@ def init(self, return State( params=params, - state=model_state, + state=model_state, # pyrefly: ignore[bad-argument-type] rms_rolling=rms_roll.init(params), mom_rolling=mom_roll.init(params), fac_rolling=adafac_roll.init(params), @@ -711,7 +711,7 @@ def update(self, if parent.constant_loss: loss = 1.0 assert loss is not None - summary.summary("validation_mode", parent.validation_mode) + summary.summary("validation_mode", parent.validation_mode) # pyrefly: ignore[bad-argument-type] next_loss_buffer = parent.buffer_loss_fns.update( opt_state.loss_buffer, loss) @@ -805,7 +805,7 @@ def interpolate_theta(ff_p): return next_p l, struct = jax.tree_util.tree_flatten(control_params) - key, key1 = jax.random.split(key) + key, key1 = jax.random.split(key) # pyrefly: ignore[bad-argument-type] keys = struct.unflatten([k for k in jax.random.split(key1, len(l))]) next_params = jax.tree_util.tree_map( apply_one, control_params, keys, lr_mult, opt_state.params, grads, @@ -815,7 +815,7 @@ def interpolate_theta(ff_p): ss = State( params=next_params, - state=model_state, + state=model_state, # pyrefly: ignore[bad-argument-type] mom_rolling=next_mom_rolling, rms_rolling=next_rms_rolling, fac_rolling=next_adafac_rolling, diff --git a/learned_optimization/research/general_lopt/prefab.py b/learned_optimization/research/general_lopt/prefab.py index a1cd8017..af3c8015 100644 --- a/learned_optimization/research/general_lopt/prefab.py +++ b/learned_optimization/research/general_lopt/prefab.py @@ -94,7 +94,7 @@ def get_state(self, opt_state): def set_params(self, opt_state, params): return self.opt.set_params(opt_state, params) - def name(self): + def name(self): # pyrefly: ignore[bad-override] return "LearnedOptimizer" diff --git a/learned_optimization/research/hysteresis/truncated_es_shared_noise.py b/learned_optimization/research/hysteresis/truncated_es_shared_noise.py index 471f8fdf..4f216fc2 100644 --- a/learned_optimization/research/hysteresis/truncated_es_shared_noise.py +++ b/learned_optimization/research/hysteresis/truncated_es_shared_noise.py @@ -83,11 +83,11 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, worker_weights.theta, worker_weights.outer_state, key, - theta_is_vector=False) + theta_is_vector=False) # pyrefly: ignore[unexpected-keyword] # we use sample_perturbations instead of vector_sample_perturbations # as we don't need the positively/negatively perturbed thetas - keys = jax.random.split(key, self.truncated_step.num_tasks) + keys = jax.random.split(key, self.truncated_step.num_tasks) # pyrefly: ignore[missing-attribute] epsilons = sample_multiple_perturbations(worker_weights.theta, keys, self.std) @@ -120,7 +120,7 @@ def compute_gradient_estimate( # pytype: disable=signature-mismatch # overridi total_count = 0. for i in range(self.unroll_length): - data = self.truncated_step.get_batch() + data = self.truncated_step.get_batch() # pyrefly: ignore[missing-argument] curr_key = next(rng) state, loss_sum_step, g_sum_step, count =\ @@ -148,7 +148,7 @@ def compute_gradient_estimate( # pytype: disable=signature-mismatch # overridi # is_done=p_ys.is_done) output = gradient_learner.GradientEstimatorOut( - mean_loss=mean_loss, grad=g, unroll_state=state, unroll_info=None) + mean_loss=mean_loss, grad=g, unroll_state=state, unroll_info=None) # pyrefly: ignore[bad-argument-type] return output, {} @@ -183,7 +183,7 @@ def shared_es_unroll( key=key1, data=data, outer_state=outer_state, - theta_is_vector=True) + theta_is_vector=True) # pyrefly: ignore[unexpected-keyword] neg_unroll_states, neg_outs = \ self.truncated_step.unroll_step( theta=neg_perturbed_thetas, @@ -193,7 +193,7 @@ def shared_es_unroll( # and also ensures we get the same loss evaluation (if it takes randomness) data=data, # also use the same data outer_state=outer_state, - theta_is_vector=True) + theta_is_vector=True) # pyrefly: ignore[unexpected-keyword] # keep track of sum of losses for logging # pos_outs.loss is an array of losses (one for each trajectory/particle) @@ -218,13 +218,13 @@ def shared_es_unroll( count = jnp.sum(pos_outs.mask) # for the particle that resets, we sample a new epsilon - keys = jax.random.split(key2, self.truncated_step.num_tasks) + keys = jax.random.split(key2, self.truncated_step.num_tasks) # pyrefly: ignore[missing-attribute] new_epsilons = sample_multiple_perturbations(theta, keys, self.std) # replace epsilon of the trajectory that has finished with a new epsilon def update_eps(eps, new_eps): reshape_isdone = jnp.reshape(pos_outs.is_done, - [self.truncated_step.num_tasks] + [1] * + [self.truncated_step.num_tasks] + [1] * # pyrefly: ignore[missing-attribute] (len(eps.shape) - 1)) return eps * (1 - reshape_isdone) + new_eps * (reshape_isdone) diff --git a/learned_optimization/research/jaxnerf/jaxnerf.py b/learned_optimization/research/jaxnerf/jaxnerf.py index 93a0ad5e..9414891b 100644 --- a/learned_optimization/research/jaxnerf/jaxnerf.py +++ b/learned_optimization/research/jaxnerf/jaxnerf.py @@ -109,7 +109,7 @@ def __init__(self, jaxnerf_cfg, lopt_datasets): def init(self, key): key1, key2, key3 = jax.random.split(key, num=3) - rays = next(self.datasets.train)["rays"] + rays = next(self.datasets.train)["rays"] # pyrefly: ignore[missing-attribute] init_variables = self.model.init( key1, rng_0=key2, rng_1=key3, rays=rays, randomized=self.cfg.randomized) return init_variables @@ -171,7 +171,7 @@ def tree_sum_fn(fn): @gin.configurable def _create_jaxnerf_config(cfg: ml_collections.ConfigDict, data_dir: str): """Create the JaxNeRF config.""" - base_cfg = ml_collections.ConfigDict(DEFAULT_JAXNERF_CONFIG, type_safe=False) + base_cfg = ml_collections.ConfigDict(DEFAULT_JAXNERF_CONFIG, type_safe=False) # pyrefly: ignore[bad-argument-type] base_cfg.update(cfg) # Set data dir @@ -212,21 +212,21 @@ def _create_jaxnerf_config(cfg: ml_collections.ConfigDict, data_dir: str): @gin.configurable def JAXNeRF_LegoBlenderTask(): - cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "lego")) + cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "lego")) # pyrefly: ignore[no-matching-overload] ds = datasets.load_jaxnerf_datasets(cfg) return JaxNeRFTask(cfg, ds) @gin.configurable def JAXNeRF_ShipBlenderTask(): - cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "ship")) + cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "ship")) # pyrefly: ignore[no-matching-overload] ds = datasets.load_jaxnerf_datasets(cfg) return JaxNeRFTask(cfg, ds) @gin.configurable def JAXNeRF_HotdogBlenderTask(): - cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "hotdog")) + cfg = _create_jaxnerf_config(LEGO_CONFIG, os.path.join(DATA_DIR, "hotdog")) # pyrefly: ignore[no-matching-overload] ds = datasets.load_jaxnerf_datasets(cfg) return JaxNeRFTask(cfg, ds) diff --git a/learned_optimization/tasks/base.py b/learned_optimization/tasks/base.py index 1d73fd1c..c794c93b 100644 --- a/learned_optimization/tasks/base.py +++ b/learned_optimization/tasks/base.py @@ -87,7 +87,7 @@ def task_fn(self, cfg: TaskCfg) -> Task: raise NotImplementedError() def eval_task_fn(self, cfg: TaskCfg) -> Task: - raise self.task_fn(cfg) + raise self.task_fn(cfg) # pyrefly: ignore[bad-raise] def sample_task(self, key): params = self.sample(key) @@ -128,7 +128,7 @@ class _TaskFamily(TaskFamily, Generic[T]): eval_datasets = eval_task.datasets def sample(self, key: PRNGKey) -> T: - return jnp.asarray(0) + return jnp.asarray(0) # pyrefly: ignore[bad-return] def task_fn(self, _: T) -> Task: return task diff --git a/learned_optimization/tasks/datasets/base.py b/learned_optimization/tasks/datasets/base.py index 450516b7..13856da5 100644 --- a/learned_optimization/tasks/datasets/base.py +++ b/learned_optimization/tasks/datasets/base.py @@ -114,27 +114,27 @@ def __init__(self, dataset_fn: Callable[[], Datasets]): # pylint: disable=super self._fn = functools.lru_cache(None)(dataset_fn) @property - def train(self): + def train(self): # pyrefly: ignore[bad-override] return self._fn().train @property - def inner_valid(self): + def inner_valid(self): # pyrefly: ignore[bad-override] return self._fn().inner_valid @property - def outer_valid(self): + def outer_valid(self): # pyrefly: ignore[bad-override] return self._fn().outer_valid @property - def test(self): + def test(self): # pyrefly: ignore[bad-override] return self._fn().test @property - def extra_info(self): + def extra_info(self): # pyrefly: ignore[bad-override] return self._fn().extra_info @property - def abstract_batch(self): + def abstract_batch(self): # pyrefly: ignore[bad-override] return self._fn().abstract_batch @@ -460,9 +460,9 @@ def make_python_iter(split: str) -> Iterator[Batch]: filenames = _tfrecord_filenames_from_dataset_name(datasetname, split) filenames = [tf.convert_to_tensor(filename) for filename in filenames] - filenames = tf.data.Dataset.from_tensor_slices(filenames).repeat( + filenames = tf.data.Dataset.from_tensor_slices(filenames).repeat( # pyrefly: ignore[bad-argument-type] -1).shuffle(len(filenames) * 2) - ds = tf.data.TFRecordDataset( + ds = tf.data.TFRecordDataset( # pyrefly: ignore[bad-instantiation] filenames, compression_type="GZIP", num_parallel_reads=4) features = { diff --git a/learned_optimization/tasks/es_wrapper.py b/learned_optimization/tasks/es_wrapper.py index 044512d2..870366d6 100644 --- a/learned_optimization/tasks/es_wrapper.py +++ b/learned_optimization/tasks/es_wrapper.py @@ -116,7 +116,7 @@ def fn(theta, *args, es_key=None, **kwargs): lambda e: e * (pos_loss - neg_loss) / (2 * std**2), pos) if has_aux: - return (jnp.mean(losses), aux), es_grad + return (jnp.mean(losses), aux), es_grad # pyrefly: ignore[unbound-name] else: return jnp.mean(losses), es_grad @@ -166,7 +166,7 @@ def new_vmap(key): value = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), value) if has_aux: - return (value, aux), grad + return (value, aux), grad # pyrefly: ignore[unbound-name] else: return value, grad @@ -212,15 +212,15 @@ def loss_fn(params, state): self.loss_with_state_and_aux.defvjp(f_fwd, f_bwd) def loss(self, params, key, data): - loss, _, _ = self.loss_with_state_and_aux(params, None, key, data) + loss, _, _ = self.loss_with_state_and_aux(params, None, key, data) # pyrefly: ignore[bad-argument-type] return loss def loss_with_state(self, params, state, key, data): - loss, state, _ = self.loss_with_state_and_aux(params, state, key, data) + loss, state, _ = self.loss_with_state_and_aux(params, state, key, data) # pyrefly: ignore[bad-argument-type] return loss, state def loss_with_aux(self, params, key, data): - loss, _, aux = self.loss_with_state_and_aux(params, None, key, data) + loss, _, aux = self.loss_with_state_and_aux(params, None, key, data) # pyrefly: ignore[bad-argument-type] return loss, aux diff --git a/learned_optimization/tasks/task_augmentation.py b/learned_optimization/tasks/task_augmentation.py index 68ac8996..4b1e3a1c 100644 --- a/learned_optimization/tasks/task_augmentation.py +++ b/learned_optimization/tasks/task_augmentation.py @@ -151,9 +151,9 @@ def single(p, key): param_scale = jax.tree_util.tree_map(single, abstract_params, keys) - task = self.task_family.task_fn(sub_config) + task = self.task_family.task_fn(sub_config) # pyrefly: ignore[unbound-name] - if isinstance(param_scale, LogFeat): + if isinstance(param_scale, LogFeat): # pyrefly: ignore[unbound-name] param_scale = param_scale.value return ReparamWeights(task, param_scale) @@ -231,9 +231,9 @@ def reduce_abstract_bs(x): return core.ShapedArray((bs,) + x.shape[1:], dtype=x.dtype) abstract_batch = jax.tree_util.tree_map(reduce_abstract_bs, - self.task.datasets.abstract_batch) + self.task.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] self.datasets = datasets_base.datasets_map( - functools.partial(jax.tree_util.tree_map, reduce_bs), task.datasets, + functools.partial(jax.tree_util.tree_map, reduce_bs), task.datasets, # pyrefly: ignore[bad-argument-type] abstract_batch) @@ -258,10 +258,10 @@ def reduce_abstract_bs(x): return core.ShapedArray((bs,) + x.shape[1:], dtype=x.dtype) abstract_batch = jax.tree_util.tree_map( - reduce_abstract_bs, self.task_family.datasets.abstract_batch) + reduce_abstract_bs, self.task_family.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] self.datasets = datasets_base.datasets_map( functools.partial(jax.tree_util.tree_map, reduce_bs), - task_family.datasets, + task_family.datasets, # pyrefly: ignore[bad-argument-type] abstract_batch=abstract_batch) self.task_fn = task_family.task_fn @@ -370,15 +370,15 @@ def f_bwd(args, g): self.loss_with_state_and_aux.defvjp(f_fwd, f_bwd) def loss(self, params, key, data): - loss, _, _ = self.loss_with_state_and_aux(params, None, key, data) + loss, _, _ = self.loss_with_state_and_aux(params, None, key, data) # pyrefly: ignore[bad-argument-type] return loss def loss_with_state(self, params, state, key, data): - loss, state, _ = self.loss_with_state_and_aux(params, state, key, data) + loss, state, _ = self.loss_with_state_and_aux(params, state, key, data) # pyrefly: ignore[bad-argument-type] return loss, state def loss_with_aux(self, params, key, data): - loss, _, aux = self.loss_with_state_and_aux(params, None, key, data) + loss, _, aux = self.loss_with_state_and_aux(params, None, key, data) # pyrefly: ignore[bad-argument-type] return loss, aux diff --git a/learned_optimization/time_filter/run_sample_and_time.py b/learned_optimization/time_filter/run_sample_and_time.py index aa238e39..f18d985f 100644 --- a/learned_optimization/time_filter/run_sample_and_time.py +++ b/learned_optimization/time_filter/run_sample_and_time.py @@ -64,9 +64,9 @@ def eval_and_save_one_timing( @gin.configurable def run_many_eval_and_save( sample_task_family_cfg_fn: Callable[[PRNGKey], - cfgobject.CFGObject] = gin.REQUIRED, - save_dir: str = gin.REQUIRED, - num_to_run: int = gin.REQUIRED): + cfgobject.CFGObject] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + save_dir: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + num_to_run: int = gin.REQUIRED): # pyrefly: ignore[bad-function-definition] """Compute and save `num_to_run` runtime statistics.""" dev = jax.devices()[0] diff --git a/learned_optimization/time_filter/run_train_time_model.py b/learned_optimization/time_filter/run_train_time_model.py index 901a9f13..aa73a035 100644 --- a/learned_optimization/time_filter/run_train_time_model.py +++ b/learned_optimization/time_filter/run_train_time_model.py @@ -140,8 +140,8 @@ def save_model(model: Any, sample_fn_name: str, hardware_name: str, @gin.configurable -def train_and_save_timing_model(sample_fn_name: str = gin.REQUIRED, - hardware_name: str = gin.REQUIRED, +def train_and_save_timing_model(sample_fn_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + hardware_name: str = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] min_samples: int = 10000, num_train_iterations: int = 1000, test_samples: int = 2000, diff --git a/learned_optimization/time_filter/timings.py b/learned_optimization/time_filter/timings.py index ac50e85c..1e672037 100644 --- a/learned_optimization/time_filter/timings.py +++ b/learned_optimization/time_filter/timings.py @@ -156,7 +156,7 @@ def init(task_param, key): else: opt_state = jax.jit( jax.vmap( - lambda pp, ss: opt.init(pp, ss, num_steps=inner_traj_num_steps)))(p, + lambda pp, ss: opt.init(pp, ss, num_steps=inner_traj_num_steps)))(p, # pyrefly: ignore[missing-attribute] s) def meta_loss(opt_state, task_params, key, datas, theta): @@ -252,7 +252,7 @@ def timing_for_iterator(it: Iterator[Any], break dtimes = onp.diff(times) - return onp.mean(dtimes), onp.std(dtimes) / onp.sqrt(len(dtimes)) + return onp.mean(dtimes), onp.std(dtimes) / onp.sqrt(len(dtimes)) # pyrefly: ignore[bad-return] def task_family_runtime_stats(