diff --git a/learned_optimization/checkpoints.py b/learned_optimization/checkpoints.py index ce47d340..febb3a07 100644 --- a/learned_optimization/checkpoints.py +++ b/learned_optimization/checkpoints.py @@ -160,7 +160,7 @@ def periodically_save_checkpoint( elif step_interval is not None: do_save = (current_iteration % step_interval == 0) - if do_save: + if do_save: # pyrefly: ignore[unbound-name] # if a checkpoint exists already, delete it. # get the last step @@ -184,7 +184,7 @@ def do_save_fn(): path = save_checkpoint(train_log_dir, prefix, value, step, keep=keep) paths[prefix] = path - _last_checkpoint_time[prefix] = time.time() + _last_checkpoint_time[prefix] = time.time() # pyrefly: ignore[unsupported-operation] paths = hk.data_structures.to_immutable_dict(paths) return paths diff --git a/learned_optimization/distributed.py b/learned_optimization/distributed.py index 803b222e..a99d0db3 100644 --- a/learned_optimization/distributed.py +++ b/learned_optimization/distributed.py @@ -127,7 +127,7 @@ def start_server(self): self._server.Start() def _is_step_valid(self, step: int) -> bool: - step = onp.asarray(step) + step = onp.asarray(step) # pyrefly: ignore[bad-assignment] return (self._current_iteration >= step and # pytype: disable=bad-return-type # typed-numpy (self._current_iteration - step) <= self._staleness) @@ -224,7 +224,7 @@ def set_weights(self, """ with self._lock: self._weights = weights - self._current_iteration = onp.asarray(current_iteration) + self._current_iteration = onp.asarray(current_iteration) # pyrefly: ignore[bad-assignment] before = len(self._outer_gradients) @@ -350,7 +350,7 @@ def put_grads(self, worker_id: Any, step: int, value: T): self._lock.acquire(blocking=True) assert worker_id < self._num_workers if step == self._current_iteration: - self._outer_gradients[worker_id] = (step, value) + self._outer_gradients[worker_id] = (step, value) # pyrefly: ignore[unsupported-operation] self._lock.release() self._cv.notify_all() @@ -425,7 +425,7 @@ def set_weights(self, del clear_buffer with self._lock, self._cv: self._weights = weights - self._current_iteration = onp.asarray(current_iteration) + self._current_iteration = onp.asarray(current_iteration) # pyrefly: ignore[bad-assignment] self._outer_gradients = {k: None for k in self._outer_gradients.keys()} self._cv.notify_all() diff --git a/learned_optimization/eval_training.py b/learned_optimization/eval_training.py index a8bda402..74a8d9ab 100644 --- a/learned_optimization/eval_training.py +++ b/learned_optimization/eval_training.py @@ -183,7 +183,7 @@ def single_task_training_curves( opt, opt_state, key1, - task.datasets.split(s) if use_data else (), + task.datasets.split(s) if use_data else (), # pyrefly: ignore[bad-argument-type, missing-attribute] eval_batches if not on_last else last_eval_batches, device=device) m[f"eval/{s}/loss"] = loss @@ -197,7 +197,7 @@ def single_task_training_curves( eval_xs.append(i) with profile.Profile("get_batch"): - batch = next(task.datasets.train) if use_data else () + batch = next(task.datasets.train) if use_data else () # pyrefly: ignore[missing-attribute] with profile.Profile("put_batch_and_split"): batch = jax.device_put(batch, device=device) diff --git a/learned_optimization/optimizers/base.py b/learned_optimization/optimizers/base.py index 738a7ab9..e73c10f1 100644 --- a/learned_optimization/optimizers/base.py +++ b/learned_optimization/optimizers/base.py @@ -130,11 +130,11 @@ def __init__(self, magnitude_opt: Optimizer, direction_opt: Optimizer): self.magnitude_opt = magnitude_opt self.direction_opt = direction_opt - def init(self, params, model_state=None, num_steps=None, **kwargs): + def init(self, params, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override] return GraftedOptimizerState( iteration=jnp.asarray(0, dtype=jnp.int32), params=params, - state=model_state, + state=model_state, # pyrefly: ignore[bad-argument-type] mag_opt_state=self.magnitude_opt.init( params, model_state=model_state, num_steps=num_steps, **kwargs), dir_opt_state=self.direction_opt.init( @@ -165,7 +165,7 @@ def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disabl return GraftedOptimizerState( iteration=opt_state.iteration + 1, params=next_params, - state=model_state, + state=model_state, # pyrefly: ignore[bad-argument-type] mag_opt_state=next_mag_opt_state, dir_opt_state=next_dir_opt_state, ) diff --git a/learned_optimization/optimizers/gradient_accumulator.py b/learned_optimization/optimizers/gradient_accumulator.py index 2059c23c..719ceb3d 100644 --- a/learned_optimization/optimizers/gradient_accumulator.py +++ b/learned_optimization/optimizers/gradient_accumulator.py @@ -62,7 +62,7 @@ def set_params(self, state, params): def get_state(self, state): return state.model_state - def init(self, p, model_state=None, num_steps=None, **kwargs): + def init(self, p, model_state=None, num_steps=None, **kwargs): # pyrefly: ignore[bad-override] if num_steps is not None: rescale_num_steps = num_steps // self.num_average else: @@ -76,7 +76,7 @@ def init(self, p, model_state=None, num_steps=None, **kwargs): grad_accum, loss_accum, inner_opt_state, - model_state=model_state, + model_state=model_state, # pyrefly: ignore[bad-argument-type] iteration=jnp.asarray(0, dtype=jnp.int64)) def update(self, @@ -116,5 +116,5 @@ def do_update(args): new_grad_accum, new_loss_accum, new_inner_opt_state, - model_state=model_state, + model_state=model_state, # pyrefly: ignore[bad-argument-type] iteration=opt_state.iteration + 1) diff --git a/learned_optimization/optimizers/learning_rate_schedules.py b/learned_optimization/optimizers/learning_rate_schedules.py index 32d7487e..ea688341 100644 --- a/learned_optimization/optimizers/learning_rate_schedules.py +++ b/learned_optimization/optimizers/learning_rate_schedules.py @@ -191,4 +191,4 @@ def __init__(self, base_lr: float, decay_amount: float): def __call__(self, global_step: chex.Array, max_steps: Optional[chex.Array] = None) -> chex.Array: - return self.base_lr * (1 - self.decay_amount)**global_step + return self.base_lr * (1 - self.decay_amount)**global_step # pyrefly: ignore[bad-return] diff --git a/learned_optimization/optimizers/nadamw.py b/learned_optimization/optimizers/nadamw.py index b71cda98..66e66c27 100644 --- a/learned_optimization/optimizers/nadamw.py +++ b/learned_optimization/optimizers/nadamw.py @@ -168,7 +168,7 @@ def __init__( "use_bias_correction": use_bias_correction } - def init(self, params, model_state=None, num_steps=None): + def init(self, params, model_state=None, num_steps=None): # pyrefly: ignore[bad-override] return NAdamWState( iteration=jnp.asarray(0, dtype=jnp.int64), params=params, @@ -177,7 +177,7 @@ def init(self, params, model_state=None, num_steps=None): num_steps=jnp.asarray(num_steps, dtype=jnp.int64), state=model_state) - def update(self, + def update(self, # pyrefly: ignore[bad-override] opt_state: NAdamWState, grads: Params, model_state: Optional[ModelState] = None, diff --git a/learned_optimization/optimizers/opt_list.py b/learned_optimization/optimizers/opt_list.py index 4f685882..a8a4c8eb 100644 --- a/learned_optimization/optimizers/opt_list.py +++ b/learned_optimization/optimizers/opt_list.py @@ -90,7 +90,7 @@ def __init__(self, idx): # We write init and update by constructing new instances of NAdamW to allow # for vmap-ing over different idx and to prevent jax tracer leaks. - def init(self, + def init(self, # pyrefly: ignore[bad-override] params: Params, model_state: Optional[ModelState] = None, *, @@ -98,7 +98,7 @@ def init(self, return nadamw.NAdamW(**_get_optimizer_config(self.idx)).init( params, model_state, num_steps=num_steps) - def update(self, + def update(self, # pyrefly: ignore[bad-override] opt_state: nadamw.NAdamWState, grads: Params, model_state: Optional[ModelState] = None, diff --git a/learned_optimization/optimizers/opt_to_optax.py b/learned_optimization/optimizers/opt_to_optax.py index 410a0876..ffe8ccbd 100644 --- a/learned_optimization/optimizers/opt_to_optax.py +++ b/learned_optimization/optimizers/opt_to_optax.py @@ -76,4 +76,4 @@ def update_fn( return step, next_state - return optax.GradientTransformationExtraArgs(init_fn, update_fn) + return optax.GradientTransformationExtraArgs(init_fn, update_fn) # pyrefly: ignore[bad-argument-type] diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index 6578b900..a23fb452 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -57,9 +57,9 @@ def init(self, key: Optional[chex.PRNGKey] = None): return OptaxState( # pytype: disable=wrong-arg-types # jax-ndarray params=params, - optax_opt_state=self.opt.init(params), - state=model_state, - iteration=0, + optax_opt_state=self.opt.init(params), # pyrefly: ignore[bad-argument-type] + state=model_state, # pyrefly: ignore[bad-argument-type] + iteration=0, # pyrefly: ignore[bad-argument-type] ) @functools.partial(jax.jit, static_argnums=(0,)) @@ -74,9 +74,9 @@ def update(self, update, new_opt_state = self.opt.update(grad, opt_state.optax_opt_state, opt_state.params) return OptaxState( - state=model_state, - params=optax.apply_updates(opt_state.params, update), - optax_opt_state=new_opt_state, + state=model_state, # pyrefly: ignore[bad-argument-type] + params=optax.apply_updates(opt_state.params, update), # pyrefly: ignore[bad-argument-type] + optax_opt_state=new_opt_state, # pyrefly: ignore[bad-argument-type] iteration=opt_state.iteration + 1, ) @@ -143,8 +143,8 @@ def name(self): def piecewise_linear(times: Sequence[float], vals: Sequence[float]) -> Callable[[float], float]: """Returns a function which interpolates piecewise values.""" - times = jnp.asarray(times) - vals = jnp.asarray(vals) + times = jnp.asarray(times) # pyrefly: ignore[bad-assignment] + vals = jnp.asarray(vals) # pyrefly: ignore[bad-assignment] def fn(x): if len(times) <= 1: diff --git a/learned_optimization/optimizers/optimizer_wrappers.py b/learned_optimization/optimizers/optimizer_wrappers.py index 224bacc5..31689bee 100644 --- a/learned_optimization/optimizers/optimizer_wrappers.py +++ b/learned_optimization/optimizers/optimizer_wrappers.py @@ -36,13 +36,13 @@ def __init__(self, opt, warp_fn): self._opt = opt self._warp_fn = warp_fn - def init(self, params, model_state=None, *, num_steps): + def init(self, params, model_state=None, *, num_steps): # pyrefly: ignore[bad-override] num_steps = jnp.asarray(self._warp_fn(num_steps), jnp.int32) inner_opt_state = self._opt.init(params, model_state, num_steps=num_steps) return ExtendTimeState(jnp.asarray(0, jnp.int32), inner_opt_state) - def update(self, opt_state, grad, loss=None, **kwargs): + def update(self, opt_state, grad, loss=None, **kwargs): # pyrefly: ignore[bad-override] inner_state = opt_state.inner_opt_state inner_state = inner_state.replace( iteration=self._warp_fn(opt_state.iteration)) @@ -77,7 +77,7 @@ def set_params(self, state, params): def get_state(self, opt_state): return self.opt.get_state(opt_state) - def init(self, params, model_state=None, **kwargs): + def init(self, params, model_state=None, **kwargs): # pyrefly: ignore[bad-override] return self.opt.init(params, model_state=model_state, **kwargs) def update(self, opt_state, grads, model_state=None, loss=None, **kwargs): @@ -85,7 +85,7 @@ def update(self, opt_state, grads, model_state=None, loss=None, **kwargs): if self.add_to_loss: l2 = [jnp.sum(p**2) for p in jax.tree_util.tree_leaves(ps)] - loss = loss + sum([x * self.weight_decay for x in l2]) + loss = loss + sum([x * self.weight_decay for x in l2]) # pyrefly: ignore[unsupported-operation] grad_l2 = jax.tree_util.tree_map(lambda p: self.weight_decay * p, ps) grads = jax.tree_util.tree_map(lambda g, g_l2: g + g_l2, grads, grad_l2) diff --git a/learned_optimization/outer_train.py b/learned_optimization/outer_train.py index 6eddae10..30eaaa41 100644 --- a/learned_optimization/outer_train.py +++ b/learned_optimization/outer_train.py @@ -75,15 +75,15 @@ def iter_group_amount(it, n): @gin.configurable def build_gradient_estimators( *, - learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, + learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] sample_task_family_fn: Callable[[PRNGKey], - tasks_base.TaskFamily] = gin.REQUIRED, + tasks_base.TaskFamily] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] gradient_estimator_fn: Callable[ [truncated_step_mod.VectorizedTruncatedStep], - gradient_learner.GradientLearner] = gin.REQUIRED, + gradient_learner.GradientLearner] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] truncated_step_fn: Callable[ [tasks_base.TaskFamily, lopt_base.LearnedOptimizer], - truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step + truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step # pyrefly: ignore[bad-function-definition] .VectorizedLOptTruncatedStep, key: PRNGKey, num_gradient_estimators: int, @@ -129,14 +129,14 @@ def build_gradient_estimators( @gin.configurable def build_gradient_estimators_fixed( *, - learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, - list_of_task_family_per_machine: ListListTaskFamilyFn = gin.REQUIRED, + learned_opt: lopt_base.LearnedOptimizer = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] + list_of_task_family_per_machine: ListListTaskFamilyFn = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] gradient_estimator_fn: Callable[ [truncated_step_mod.VectorizedTruncatedStep], - gradient_learner.GradientLearner] = gin.REQUIRED, + gradient_learner.GradientLearner] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] truncated_step_fn: Callable[ [tasks_base.TaskFamily, lopt_base.LearnedOptimizer], - truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step + truncated_step_mod.VectorizedTruncatedStep] = lopt_truncated_step # pyrefly: ignore[bad-function-definition] .VectorizedLOptTruncatedStep, key: PRNGKey, num_gradient_estimators: int, @@ -285,7 +285,7 @@ def maybe_resample_gradient_estimators( worker_id=worker_id) gradient_estimators[j] = ests[0] unroll_states[j] = gradient_estimators[j].init_worker_state( - worker_weights, key2) + worker_weights, key2) # pyrefly: ignore[bad-argument-type] return gradient_estimators, unroll_states @@ -350,7 +350,7 @@ def build_static_and_init_unroll_state( return estimators, unroll_states distributed_worker = distributed.DistributedWorker( - train_log_dir, worker_id, learner_address=learner_address) + train_log_dir, worker_id, learner_address=learner_address) # pyrefly: ignore[bad-argument-type] last_outer_cfg = None grad_estimators = None worker_weights = None @@ -390,12 +390,12 @@ def build_static_and_init_unroll_state( gradient_worker_out = gradient_learner.gradient_worker_compute( worker_weights=worker_weights, gradient_estimators=[grad_estimators[gidx] for gidx in idxs], - unroll_states=[unroll_states[gidx] for gidx in idxs], + unroll_states=[unroll_states[gidx] for gidx in idxs], # pyrefly: ignore[unbound-name] key=next(rng), with_metrics=with_m, device=device) - unroll_states = list(unroll_states) + unroll_states = list(unroll_states) # pyrefly: ignore[unbound-name] for oidx, gidx in enumerate(idxs): unroll_states[gidx] = gradient_worker_out.unroll_states[oidx] @@ -405,8 +405,8 @@ def build_static_and_init_unroll_state( with profile.Profile("grads_to_onp"): to_put_grads = GradientsFromWorker( # pytype: disable=wrong-arg-types # jax-ndarray metrics=gradient_worker_out.metrics, - worker_id=worker_id, - total_inner_steps=total_inner_steps, + worker_id=worker_id, # pyrefly: ignore[bad-argument-type] + total_inner_steps=total_inner_steps, # pyrefly: ignore[bad-argument-type] gen_id=dist_data.gen_id, outer_trainer_grads=gradient_worker_out.to_put, ) @@ -683,7 +683,7 @@ def _load_checkpoint(checkpoint_path): experiment_name=train_log_dir, weights=DataForWorker(worker_weights, gen_id, outer_cfg), current_iteration=step, - num_workers=num_workers, + num_workers=num_workers, # pyrefly: ignore[bad-argument-type] start_server=False, port=learner_port) @@ -737,7 +737,7 @@ def _load_checkpoint(checkpoint_path): with profile.Profile("checkpoints"): opt_checkpoint = gradient_learner.OptCheckpoint( gradient_learner_state, jnp.asarray(elapsed_time, dtype=jnp.float64), - total_inner_steps) + total_inner_steps) # pyrefly: ignore[bad-argument-type] param_checkpoint = gradient_learner.ParameterCheckpoint( outer_learner.get_meta_params(gradient_learner_state), gen_id, step) paths = checkpoints.periodically_save_checkpoint( @@ -818,7 +818,7 @@ def _load_checkpoint(checkpoint_path): delta_inner_steps=applied_inner_steps, ) - to_write = dict(**to_write, **summarize_outer_cfg(outer_cfg)) + to_write = dict(**to_write, **summarize_outer_cfg(outer_cfg)) # pyrefly: ignore[bad-argument-type] if i % 5 == 0: elapsed_time = elapsed_time + time.time() - train_start_time @@ -998,7 +998,7 @@ def build_static_and_init_unroll_state(worker_weights, key): to_put_grads = GradientsFromWorker( # pytype: disable=wrong-arg-types # jax-ndarray metrics=gradient_worker_out.metrics, worker_id=0, - total_inner_steps=total_inner_steps, + total_inner_steps=total_inner_steps, # pyrefly: ignore[bad-argument-type] gen_id="no_gen_id", outer_trainer_grads=gradient_worker_out.to_put, ) @@ -1021,7 +1021,7 @@ def build_static_and_init_unroll_state(worker_weights, key): lopt, gradient_estimators, unroll_states, - worker_weights=worker_weights, + worker_weights=worker_weights, # pyrefly: ignore[bad-argument-type] key=next(rng), stochastic_resample_frequency=stochastic_resample_frequency, sample_estimators_fn=sample_estimators_fn, @@ -1115,9 +1115,9 @@ def _move_all_gin_config_to_default_scope(): @gin.configurable def run_train( train_log_dir: str, - lopt: Union[GinRequired, lopt_base.LearnedOptimizer] = gin.REQUIRED, + lopt: Union[GinRequired, lopt_base.LearnedOptimizer] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] outer_learner_fn: Union[GinRequired, Callable[ - [], gradient_learner.GradientLearner]] = gin.REQUIRED, + [], gradient_learner.GradientLearner]] = gin.REQUIRED, # pyrefly: ignore[bad-function-definition] num_estimators: int = 2, is_trainer: bool = True, is_worker: bool = True, @@ -1182,7 +1182,7 @@ def run_train( if is_trainer and is_worker: local_train( train_log_dir=train_log_dir, - outer_learner=outer_learner_fn(), + outer_learner=outer_learner_fn(), # pyrefly: ignore[not-callable] lopt=lopt, num_estimators=num_estimators, summary_every_n=summary_every_n, @@ -1222,7 +1222,7 @@ def run_train( try: train_worker( worker_id=worker_id, - lopt=lopt, + lopt=lopt, # pyrefly: ignore[bad-argument-type] num_estimators=num_estimators, summary_every_n=summary_every_n, stochastic_resample_frequency=stochastic_resample_frequency, diff --git a/learned_optimization/outer_trainers/common.py b/learned_optimization/outer_trainers/common.py index 99313a0b..c0d1ec16 100644 --- a/learned_optimization/outer_trainers/common.py +++ b/learned_optimization/outer_trainers/common.py @@ -88,7 +88,7 @@ def _stack(a, b, axis=0): static_argnames=("truncated_step", "with_summary", "unroll_length", "theta_is_vector", "wrap_step_fn"), ) -@functools.partial(summary.add_with_summary, static_argnums=(0, 1, 2, 3, 9)) +@functools.partial(summary.add_with_summary, static_argnums=(0, 1, 2, 3, 9)) # pyrefly: ignore[bad-specialization] def truncated_unroll( truncated_step: truncated_step_mod.VectorizedTruncatedStep, unroll_length: int, diff --git a/learned_optimization/outer_trainers/full_es.py b/learned_optimization/outer_trainers/full_es.py index 198ecf6d..aea4104c 100644 --- a/learned_optimization/outer_trainers/full_es.py +++ b/learned_optimization/outer_trainers/full_es.py @@ -419,7 +419,7 @@ def compute_gradient_estimate( if datas_list is not None: raise NotImplementedError() - num_tasks = self.truncated_step.num_tasks + num_tasks = self.truncated_step.num_tasks # pyrefly: ignore[missing-attribute] rng = hk.PRNGSequence(key) theta = worker_weights.theta @@ -444,15 +444,15 @@ def compute_gradient_estimate( vec_p_theta, worker_weights.outer_state, key, - theta_is_vector=True, - num_steps_override=length) + theta_is_vector=True, # pyrefly: ignore[unexpected-keyword] + num_steps_override=length) # pyrefly: ignore[unexpected-keyword] n_state = self.truncated_step.init_step_state( vec_n_theta, worker_weights.outer_state, key, - theta_is_vector=True, - num_steps_override=length) + theta_is_vector=True, # pyrefly: ignore[unexpected-keyword] + num_steps_override=length) # pyrefly: ignore[unexpected-keyword] if not hasattr(trunc_state, "length"): raise AttributeError("Please specify a truncation schedule whose state" @@ -470,7 +470,7 @@ def compute_gradient_estimate( p_state, n_state = tree_utils.strip_weak_type((p_state, n_state)) p_state, n_state, p_ys, n_ys, m = common.maybe_stacked_es_unroll( - self.truncated_step, + self.truncated_step, # pyrefly: ignore[bad-argument-type] self.steps_per_jit, self.stack_antithetic_samples, vec_p_theta, @@ -587,7 +587,7 @@ def __init__(self, functools.partial( common.vector_sample_perturbations, std=self.std, - num_samples=self.truncated_step.num_tasks), + num_samples=self.truncated_step.num_tasks), # pyrefly: ignore[missing-attribute] in_axes=(None, 0), ) @@ -596,8 +596,8 @@ def init(theta, outer_state, key, override): theta, outer_state, key, - theta_is_vector=True, - num_steps_override=override) + theta_is_vector=True, # pyrefly: ignore[unexpected-keyword] + num_steps_override=override) # pyrefly: ignore[unexpected-keyword] self.pmap_init_step_state = jax.pmap(init, in_axes=(0, None, 0, None)) @@ -613,7 +613,7 @@ def pmap_maybe_stacked_es_unroll(self, with_summary, vec_p_theta, vec_n_theta, length): key1, key2 = jax.random.split(key) p_state, n_state, p_ys, n_ys, m = common.maybe_stacked_es_unroll( - self.truncated_step, + self.truncated_step, # pyrefly: ignore[bad-argument-type] self.steps_per_jit, self.stack_antithetic_samples, vec_p_theta, diff --git a/learned_optimization/outer_trainers/gradient_learner.py b/learned_optimization/outer_trainers/gradient_learner.py index a259aa3c..850986e0 100644 --- a/learned_optimization/outer_trainers/gradient_learner.py +++ b/learned_optimization/outer_trainers/gradient_learner.py @@ -266,7 +266,7 @@ def update( min_loss = jnp.min(losses) fn = _get_theta_update_fn(self._theta_opt) - key1, key2 = jax.random.split(key) + key1, key2 = jax.random.split(key) # pyrefly: ignore[bad-argument-type] theta_opt_state, theta_update_metrics = fn( theta_opt_state, grads, @@ -432,7 +432,7 @@ def extract_one(idx, x): "loss": estimator_out.unroll_info.loss[idx, :], "task_param": jax.tree_util.tree_map(fn, onp_task_params), "iteration": iteration, - "outer_iteration": worker_weights.outer_state.outer_iteration, + "outer_iteration": worker_weights.outer_state.outer_iteration, # pyrefly: ignore[missing-attribute] }) else: logging.warn("No out specified by learner. " @@ -593,7 +593,7 @@ def update( next_theta_state, metrics = self.gradient_learner.update( state.gradient_learner_state, [worker_compute_out.to_put], key=key2, - with_metrics=with_metrics) + with_metrics=with_metrics) # pyrefly: ignore[bad-argument-type] metrics = summary.aggregate_metric_list( [worker_compute_out.metrics, metrics]) diff --git a/learned_optimization/outer_trainers/lopt_truncated_step.py b/learned_optimization/outer_trainers/lopt_truncated_step.py index 6ce84278..47f66b6b 100644 --- a/learned_optimization/outer_trainers/lopt_truncated_step.py +++ b/learned_optimization/outer_trainers/lopt_truncated_step.py @@ -85,10 +85,10 @@ def train(unroll_state): unroll_state = opt.update(unroll_state, grad, loss=loss) out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray loss=loss, - is_done=False, + is_done=False, # pyrefly: ignore[bad-argument-type] task_param=None, iteration=unroll_state.iteration, - mask=True, + mask=True, # pyrefly: ignore[bad-argument-type] ) return unroll_state, out @@ -96,11 +96,11 @@ def reset(unroll_state): params = self.task.init(key) unroll_state = self.lopt.opt_fn(theta).init(params) out = truncated_step.TruncatedUnrollOut( # pytype: disable=wrong-arg-types # jax-ndarray - loss=0.0, - is_done=True, + loss=0.0, # pyrefly: ignore[bad-argument-type] + is_done=True, # pyrefly: ignore[bad-argument-type] task_param=None, iteration=unroll_state.iteration, - mask=False, + mask=False, # pyrefly: ignore[bad-argument-type] ) return unroll_state, out @@ -201,7 +201,7 @@ def _init_truncation_state( inner_step=jnp.asarray(0, dtype=jnp.int32), truncation_state=trunc_state, task_param=task_param, - is_done=False, + is_done=False, # pyrefly: ignore[bad-argument-type] ) diff --git a/learned_optimization/outer_trainers/truncated_es.py b/learned_optimization/outer_trainers/truncated_es.py index d54e0e28..a0fa3d76 100644 --- a/learned_optimization/outer_trainers/truncated_es.py +++ b/learned_optimization/outer_trainers/truncated_es.py @@ -142,7 +142,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, worker_weights.theta, worker_weights.outer_state, key, - theta_is_vector=False) + theta_is_vector=False) # pyrefly: ignore[unexpected-keyword] @profile.wrap() def get_datas(self): @@ -195,7 +195,7 @@ def compute_gradient_estimate( key = next(rng) p_state, n_state, p_ys, n_ys, m = common.maybe_stacked_es_unroll( - self.truncated_step, + self.truncated_step, # pyrefly: ignore[bad-argument-type] self.steps_per_jit, self.stack_antithetic_samples, vec_p_theta, diff --git a/learned_optimization/outer_trainers/truncated_grad.py b/learned_optimization/outer_trainers/truncated_grad.py index 2f46ab5d..abebf1a8 100644 --- a/learned_optimization/outer_trainers/truncated_grad.py +++ b/learned_optimization/outer_trainers/truncated_grad.py @@ -90,7 +90,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, worker_weights.theta, worker_weights.outer_state, key, - theta_is_vector=False) + theta_is_vector=False) # pyrefly: ignore[unexpected-keyword] @profile.wrap() def get_datas(self): @@ -145,7 +145,7 @@ def flat_first(x): ys = jax.tree_util.tree_map(flat_first, tree_utils.tree_zip_jnp(outputs)) assert ys.loss.shape == (self.unroll_length, - self.truncated_step.num_tasks) + self.truncated_step.num_tasks) # pyrefly: ignore[missing-attribute] vec_mean_loss = jnp.sum( ys.mask * ys.loss, axis=0) / jnp.sum( diff --git a/learned_optimization/outer_trainers/truncated_pes.py b/learned_optimization/outer_trainers/truncated_pes.py index 02b2c8e5..fe5b2c80 100644 --- a/learned_optimization/outer_trainers/truncated_pes.py +++ b/learned_optimization/outer_trainers/truncated_pes.py @@ -130,7 +130,7 @@ def _switch_one_accum(a, b): pos_loss = jnp.sum(p_ys.loss * p_ys.mask, axis=0) / jnp.sum(p_ys.mask, axis=0) neg_loss = jnp.sum(n_ys.loss * n_ys.mask, axis=0) / jnp.sum(n_ys.mask, axis=0) - return ( + return ( # pyrefly: ignore[bad-return] jnp.mean((pos_loss + neg_loss) / 2.0), es_grad, new_accumulator, @@ -182,11 +182,11 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, theta = worker_weights.theta pos_unroll_state = self.truncated_step.init_step_state( - theta, worker_weights.outer_state, key, theta_is_vector=False) + theta, worker_weights.outer_state, key, theta_is_vector=False) # pyrefly: ignore[unexpected-keyword] neg_unroll_state = pos_unroll_state accumulator = jax.tree_util.tree_map( - lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), + lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), # pyrefly: ignore[missing-attribute] theta) return PESWorkerState( @@ -218,7 +218,7 @@ def compute_gradient_estimate( # pytype: disable=signature-mismatch # overridi theta = worker_weights.theta vec_pos, vec_p_theta, vec_n_theta = common.vector_sample_perturbations( - theta, next(rng), self.std, self.truncated_step.num_tasks) + theta, next(rng), self.std, self.truncated_step.num_tasks) # pyrefly: ignore[missing-attribute] p_yses = [] n_yses = [] @@ -243,7 +243,7 @@ def compute_gradient_estimate( # pytype: disable=signature-mismatch # overridi key = next(rng) p_state, n_state, p_ys, n_ys, m = common.maybe_stacked_es_unroll( - self.truncated_step, + self.truncated_step, # pyrefly: ignore[bad-argument-type] self.steps_per_jit, self.stack_antithetic_samples, vec_p_theta, @@ -335,7 +335,7 @@ def __init__(self, functools.partial( common.vector_sample_perturbations, std=self.std, - num_samples=self.truncated_step.num_tasks), + num_samples=self.truncated_step.num_tasks), # pyrefly: ignore[missing-attribute] in_axes=(None, 0), ) @@ -380,7 +380,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, neg_unroll_state = pos_unroll_state accumulator = jax.tree_util.tree_map( - lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), + lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), # pyrefly: ignore[missing-attribute] theta) accumulator = flax_jax_utils.replicate(accumulator) diff --git a/learned_optimization/outer_trainers/truncation_schedule.py b/learned_optimization/outer_trainers/truncation_schedule.py index ee4a271f..62e9682c 100644 --- a/learned_optimization/outer_trainers/truncation_schedule.py +++ b/learned_optimization/outer_trainers/truncation_schedule.py @@ -137,7 +137,7 @@ def init(self, key: PRNGKey, outer_state: Any) -> ConstantTruncationState: shift = jnp.asarray(jax.random.normal(key) * self.std, dtype=jnp.int32) length = jnp.maximum(length + shift, self.min_length) - length = summary.summary("length", length) + length = summary.summary("length", length) # pyrefly: ignore[bad-argument-type] return ConstantTruncationState(length=jnp.asarray(length, dtype=jnp.int32)) def next_state( @@ -147,4 +147,4 @@ def next_state( is_done = step >= state.length state = lax.cond(is_done, lambda ss: self.init(*ss), lambda ss: state, (key, outer_state)) - return state, is_done + return state, is_done # pyrefly: ignore[bad-return] diff --git a/learned_optimization/profile.py b/learned_optimization/profile.py index 494992fa..f22c4d5a 100644 --- a/learned_optimization/profile.py +++ b/learned_optimization/profile.py @@ -55,12 +55,12 @@ def wrap(): def _wrapper(fn: T) -> T: - @functools.wraps(fn) + @functools.wraps(fn) # pyrefly: ignore[bad-argument-type] def _fn(*args, **kwargs): - with Profile(fn.__name__): - return fn(*args, **kwargs) + with Profile(fn.__name__): # pyrefly: ignore[missing-attribute] + return fn(*args, **kwargs) # pyrefly: ignore[not-callable] - return _fn + return _fn # pyrefly: ignore[bad-return] return _wrapper diff --git a/learned_optimization/setup_experiment.py b/learned_optimization/setup_experiment.py index cc9482e7..c78f1be4 100644 --- a/learned_optimization/setup_experiment.py +++ b/learned_optimization/setup_experiment.py @@ -63,7 +63,7 @@ def parse_and_set_gin_config(finalize: bool, skip_unknown: bool): assert imp.endswith(".*") prefix = imp[0:-2] path = importlib.import_module(prefix).__file__ - for p in glob.glob(os.path.join(os.path.dirname(path), "*.py")): + for p in glob.glob(os.path.join(os.path.dirname(path), "*.py")): # pyrefly: ignore[no-matching-overload] p = p.split("/")[-1].replace(".py", "") to_import = prefix + "." + p logging.info("Gin is importing %s from glob", to_import) diff --git a/learned_optimization/summary.py b/learned_optimization/summary.py index 5bc51919..355e406c 100644 --- a/learned_optimization/summary.py +++ b/learned_optimization/summary.py @@ -83,9 +83,9 @@ def _fn(*args, **kwargs): with summary_scope(name): return to_wrap(*args, **kwargs) - return _fn + return _fn # pyrefly: ignore[bad-return] - return ff + return ff # pyrefly: ignore[bad-return] count_per_tags = {} @@ -190,14 +190,14 @@ def aggregate_metric(k: str, assert "||" in k, f"bad summary -- {k}" agg, _ = k.split("||") # summaries don't have to be the same length. lets ensure there all xnp though - vs = [xnp.asarray(v) for v in vs] + vs = [xnp.asarray(v) for v in vs] # pyrefly: ignore[bad-assignment] if agg == AggregationType.mean: # size is known at compile time. size = onp.sum([onp.prod(v.shape) for v in vs]) return xnp.sum(xnp.asarray([xnp.sum(v) / size for v in vs])) # pytype: disable=bad-return-type # jnp-type elif agg == AggregationType.sample: - vs = xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) + vs = xnp.concatenate([xnp.asarray(v).ravel() for v in vs], axis=0) # pyrefly: ignore[bad-assignment] if use_jnp: assert key is not None i = jax.random.randint(key, [], 0, len(vs)) @@ -307,7 +307,7 @@ def out_fn(unused_in, *args): return outs, metrics - return _fn + return _fn # pyrefly: ignore[bad-return] def add_with_summary(fn: F, static_argnums=()) -> G: @@ -351,8 +351,8 @@ def _fn(*args, **kwargs): params.append(inspect.Parameter( "with_summary", inspect.Parameter.KEYWORD_ONLY, default=False)) sig = sig.replace(parameters=tuple(params)) - _fn.__signature__ = sig - return _fn + _fn.__signature__ = sig # pyrefly: ignore[missing-attribute] + return _fn # pyrefly: ignore[bad-return] def tree_scalar_mean(prefix, values): diff --git a/learned_optimization/tasks/fixed/conv.py b/learned_optimization/tasks/fixed/conv.py index 3ffbc46d..8c79f7be 100644 --- a/learned_optimization/tasks/fixed/conv.py +++ b/learned_optimization/tasks/fixed/conv.py @@ -39,7 +39,7 @@ def _cross_entropy_pool_loss( num_classes: int = 10): """Haiku function for a conv net with pooling and cross entropy loss.""" if not initializers: - initializers = {} + initializers = {} # pyrefly: ignore[bad-assignment] def _fn(batch): net = batch["image"] @@ -80,7 +80,7 @@ def init(self, key) -> Params: def init_with_state(self, key) -> Tuple[Params, ModelState]: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._mod.init(key, batch) def loss(self, params, key, data): @@ -97,7 +97,7 @@ def loss_with_state_and_aux(self, params, state, key, data): def normalizer(self, loss): return jnp.clip(loss, 0, - 1.5 * jnp.log(self.datasets.extra_info["num_classes"])) + 1.5 * jnp.log(self.datasets.extra_info["num_classes"])) # pyrefly: ignore[missing-attribute] @gin.configurable diff --git a/learned_optimization/tasks/fixed/image_mlp.py b/learned_optimization/tasks/fixed/image_mlp.py index b33aa29c..0659a9fe 100644 --- a/learned_optimization/tasks/fixed/image_mlp.py +++ b/learned_optimization/tasks/fixed/image_mlp.py @@ -54,18 +54,18 @@ def _forward(inp): def init(self, key: PRNGKey) -> Any: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._mod.init(key, batch["image"]) def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] logits = self._mod.apply(params, key, data["image"]) labels = jax.nn.one_hot(data["label"], num_classes) vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels) return jnp.mean(vec_loss) def normalizer(self, loss): - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] maxval = 1.5 * onp.log(num_classes) loss = jnp.clip(loss, 0, maxval) return jnp.nan_to_num(loss, nan=maxval, posinf=maxval, neginf=maxval) @@ -197,7 +197,7 @@ class _MLPImageTaskMSE(_MLPImageTask): """Image model with a Mean squared error loss.""" def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray: - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] logits = self._mod.apply(params, key, data["image"]) labels = jax.nn.one_hot(data["label"], num_classes) return jnp.mean(jnp.square(logits - labels)) @@ -270,13 +270,13 @@ def _forward(inp): def init_with_state(self, key: PRNGKey) -> Any: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] params, state = self._mod.init(key, batch["image"]) return params, state def loss_with_state(self, params: Params, state: ModelState, key: PRNGKey, data: Any) -> Tuple[jnp.ndarray, ModelState]: - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] logits, state = self._mod.apply(params, state, key, data["image"]) labels = jax.nn.one_hot(data["label"], num_classes) vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels) @@ -289,7 +289,7 @@ def loss_with_state_and_aux( return loss, state, {} def normalizer(self, loss): - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] maxval = 1.5 * onp.log(num_classes) loss = jnp.clip(loss, 0, maxval) return jnp.nan_to_num(loss, nan=maxval, posinf=maxval, neginf=maxval) diff --git a/learned_optimization/tasks/fixed/lopt.py b/learned_optimization/tasks/fixed/lopt.py index 61571c8e..ac51f953 100644 --- a/learned_optimization/tasks/fixed/lopt.py +++ b/learned_optimization/tasks/fixed/lopt.py @@ -297,7 +297,7 @@ def loss(self, params, key, data): raise ValueError("Use loss_with_state instead!") def loss_with_state_and_aux(self, params, model_state, key, datas): - l, s = self.loss_with_state(params, model_state, key, datas) + l, s = self.loss_with_state(params, model_state, key, datas) # pyrefly: ignore[bad-argument-type] return l, s, {} diff --git a/learned_optimization/tasks/fixed/mlp_mixer.py b/learned_optimization/tasks/fixed/mlp_mixer.py index ac53eccc..12708e47 100644 --- a/learned_optimization/tasks/fixed/mlp_mixer.py +++ b/learned_optimization/tasks/fixed/mlp_mixer.py @@ -40,7 +40,7 @@ def __init__(self, cfg, datasets): def init(self, key: chex.PRNGKey): batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self.flax_module.init({ "params": key, "dropout": key @@ -51,13 +51,13 @@ def init(self, key: chex.PRNGKey): def loss(self, params: Any, key: chex.PRNGKey, data: Any): logits = self.flax_module.apply( params, data["image"], train=True, rngs={"dropout": key}) - labels_onehot = jax.nn.one_hot(data["label"], logits.shape[1]) - loss_vec = base.softmax_cross_entropy(logits=logits, labels=labels_onehot) + labels_onehot = jax.nn.one_hot(data["label"], logits.shape[1]) # pyrefly: ignore[missing-attribute] + loss_vec = base.softmax_cross_entropy(logits=logits, labels=labels_onehot) # pyrefly: ignore[bad-argument-type] return jnp.mean(loss_vec) def normalizer(self, loss): # TODO(lmetz) This normalizer is shared a great many places. De-dup! - max_class = onp.log(2 * self.datasets.extra_info["num_classes"]) + max_class = onp.log(2 * self.datasets.extra_info["num_classes"]) # pyrefly: ignore[missing-attribute] loss = jnp.nan_to_num( loss, nan=max_class, neginf=max_class, posinf=max_class) # shift to [0, 10] then clip. diff --git a/learned_optimization/tasks/fixed/resnet.py b/learned_optimization/tasks/fixed/resnet.py index 0004d85f..13cb6106 100644 --- a/learned_optimization/tasks/fixed/resnet.py +++ b/learned_optimization/tasks/fixed/resnet.py @@ -51,7 +51,7 @@ def _hk_forward(self, batch): 'initial_conv_kernel_size', 'initial_conv_stride', 'max_pool', 'resnet_v2' ] - num_classes = self.datasets.extra_info['num_classes'] + num_classes = self.datasets.extra_info['num_classes'] # pyrefly: ignore[missing-attribute] mod = resnet.ResNet( num_classes=num_classes, **{k: self._cfg[k] for k in args}) logits = mod(batch['image'], is_training=True) @@ -61,7 +61,7 @@ def _hk_forward(self, batch): def init_with_state(self, key: chex.PRNGKey) -> base.Params: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._net.init(key, batch) def loss_with_state(self, params, state, key, data): diff --git a/learned_optimization/tasks/fixed/rnn_lm.py b/learned_optimization/tasks/fixed/rnn_lm.py index c9afb2b6..6acc4ec4 100644 --- a/learned_optimization/tasks/fixed/rnn_lm.py +++ b/learned_optimization/tasks/fixed/rnn_lm.py @@ -78,14 +78,14 @@ def get_param_like(name: str, val: jnp.ndarray) -> jnp.ndarray: def init(self, key: PRNGKey) -> base.Params: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._mod.init(key, batch["obs"]) def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray obs = data["obs"] target = data["target"] - max_vocab_size = self.datasets.extra_info["vocab_size"] + max_vocab_size = self.datasets.extra_info["vocab_size"] # pyrefly: ignore[missing-attribute] vocab_size = self._vocab_size if vocab_size < max_vocab_size: # if the target vocab is smaller, we use a mod to keep all diff --git a/learned_optimization/tasks/fixed/transformer_lm.py b/learned_optimization/tasks/fixed/transformer_lm.py index 247e0c52..7a0955f4 100644 --- a/learned_optimization/tasks/fixed/transformer_lm.py +++ b/learned_optimization/tasks/fixed/transformer_lm.py @@ -42,7 +42,7 @@ def name(self): return self._name def _hk_forward(self, batch): - vocab_size = self.datasets.extra_info['vocab_size'] + vocab_size = self.datasets.extra_info['vocab_size'] # pyrefly: ignore[missing-attribute] mod = transformer.Transformer( num_heads=self._cfg['num_heads'], num_layers=self._cfg['num_layers'], @@ -57,7 +57,7 @@ def _hk_forward(self, batch): def init(self, key: chex.PRNGKey) -> base.Params: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._net.init(key, batch) def loss(self, params, key, data): diff --git a/learned_optimization/tasks/fixed/vit.py b/learned_optimization/tasks/fixed/vit.py index 10704dfc..606373cd 100644 --- a/learned_optimization/tasks/fixed/vit.py +++ b/learned_optimization/tasks/fixed/vit.py @@ -40,7 +40,7 @@ def __init__(self, cfg, datasets): def init(self, key: chex.PRNGKey): batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self.flax_module.init({ "params": key, "dropout": key @@ -51,12 +51,12 @@ def init(self, key: chex.PRNGKey): def loss(self, params: Any, key: chex.PRNGKey, data: Any): logits = self.flax_module.apply( params, data["image"], train=True, rngs={"dropout": key}) - labels_onehot = jax.nn.one_hot(data["label"], logits.shape[1]) - loss_vec = base.softmax_cross_entropy(logits=logits, labels=labels_onehot) + labels_onehot = jax.nn.one_hot(data["label"], logits.shape[1]) # pyrefly: ignore[missing-attribute] + loss_vec = base.softmax_cross_entropy(logits=logits, labels=labels_onehot) # pyrefly: ignore[bad-argument-type] return jnp.mean(loss_vec) def normalizer(self, loss): - max_class = onp.log(2 * self.datasets.extra_info["num_classes"]) + max_class = onp.log(2 * self.datasets.extra_info["num_classes"]) # pyrefly: ignore[missing-attribute] loss = jnp.nan_to_num( loss, nan=max_class, neginf=max_class, posinf=max_class) # shift to [0, 10] then clip. diff --git a/learned_optimization/tasks/parametric/cfgobject.py b/learned_optimization/tasks/parametric/cfgobject.py index a972139a..b4769ff7 100644 --- a/learned_optimization/tasks/parametric/cfgobject.py +++ b/learned_optimization/tasks/parametric/cfgobject.py @@ -163,7 +163,7 @@ def flatten_cfg(cfg: CFGObject, features_for: str) -> Mapping[str, Any]: for k2, v in a.items(): to_process.append((k + "/" + k2, v)) elif isinstance(a, CFGObject): - to_process.append((k + "/" + a.obj, a.kwargs)) + to_process.append((k + "/" + a.obj, a.kwargs)) # pyrefly: ignore[bad-argument-type] elif isinstance(a, CFGNamed): to_process.append((k + "/" + a.name, a.values)) elif isinstance(a, DoNotFeaturize): diff --git a/learned_optimization/tasks/parametric/image_conv.py b/learned_optimization/tasks/parametric/image_conv.py index ccc98409..568052f1 100644 --- a/learned_optimization/tasks/parametric/image_conv.py +++ b/learned_optimization/tasks/parametric/image_conv.py @@ -88,7 +88,7 @@ def __init__(self): def init(self, rng: PRNGKey) -> Params: init_net, unused_apply_net = hk.without_apply_rng( hk.transform(_forward)) - image = next(self.datasets.train)["image"] + image = next(self.datasets.train)["image"] # pyrefly: ignore[missing-attribute] return init_net(rng, image) def loss(self, params: Params, rng: PRNGKey, data: Batch) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray diff --git a/learned_optimization/tasks/parametric/image_mlp.py b/learned_optimization/tasks/parametric/image_mlp.py index 41e7fb01..8824b407 100644 --- a/learned_optimization/tasks/parametric/image_mlp.py +++ b/learned_optimization/tasks/parametric/image_mlp.py @@ -83,7 +83,7 @@ def __init__(self): def init(self, rng: PRNGKey) -> Params: init_net, unused_apply_net = hk.without_apply_rng( hk.transform(_forward)) - image = next(self.datasets.train)["image"] + image = next(self.datasets.train)["image"] # pyrefly: ignore[missing-attribute] return init_net(rng, image) def loss(self, params: Params, rng: PRNGKey, data: Batch) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray diff --git a/learned_optimization/tasks/parametric/image_mlp_ae.py b/learned_optimization/tasks/parametric/image_mlp_ae.py index 1bd94ea2..c5026edb 100644 --- a/learned_optimization/tasks/parametric/image_mlp_ae.py +++ b/learned_optimization/tasks/parametric/image_mlp_ae.py @@ -92,7 +92,7 @@ def __init__(self): def init(self, key: PRNGKey) -> Params: init_net, unused_apply_net = hk.without_apply_rng( hk.transform(_forward)) - image = next(self.datasets.train)["image"] + image = next(self.datasets.train)["image"] # pyrefly: ignore[missing-attribute] return init_net(key, image) def loss(self, params: Params, key: PRNGKey, data: Batch) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray diff --git a/learned_optimization/tasks/parametric/image_mlp_vae.py b/learned_optimization/tasks/parametric/image_mlp_vae.py index 4b52b5d6..4fc12b9b 100644 --- a/learned_optimization/tasks/parametric/image_mlp_vae.py +++ b/learned_optimization/tasks/parametric/image_mlp_vae.py @@ -123,7 +123,7 @@ def __init__(self): self.datasets = datasets def init(self, key: PRNGKey) -> Params: - image = next(self.datasets.train)["image"] + image = next(self.datasets.train)["image"] # pyrefly: ignore[missing-attribute] return hk.transform(_forward).init(key, image) def loss(self, params: Params, key: PRNGKey, data: Batch) -> jnp.ndarray: # pytype: disable=signature-mismatch # jax-ndarray @@ -138,7 +138,7 @@ def normalizer(self, loss): # loss is from a mix of p(x|z) and kl. # p(x|z) is the biggest component so let's ignore kl. # This is the sum over pixels, so we normalize by dividing by # pixels. - n_elements = onp.prod(next(datasets.train)["image"].shape[1:]) + n_elements = onp.prod(next(datasets.train)["image"].shape[1:]) # pyrefly: ignore[missing-attribute] out = jax.lax.cond(task_params["per_dim_loss"], lambda x: x, lambda x: x / n_elements, loss) out = jnp.nan_to_num(out, nan=10, neginf=10, posinf=10) diff --git a/learned_optimization/tasks/parametric/image_resnet.py b/learned_optimization/tasks/parametric/image_resnet.py index a2828c21..4ea9e347 100644 --- a/learned_optimization/tasks/parametric/image_resnet.py +++ b/learned_optimization/tasks/parametric/image_resnet.py @@ -60,7 +60,7 @@ def sample(self, key: PRNGKey) -> cfgobject.CFGNamed: }) def task_fn(self, task_params) -> base.Task: - num_classes = self.datasets.extra_info["num_classes"] + num_classes = self.datasets.extra_info["num_classes"] # pyrefly: ignore[missing-attribute] datasets = self.datasets def _forward(inp): @@ -86,7 +86,7 @@ def __init__(self): def init_with_state(self, key: PRNGKey) -> Tuple[Params, ModelState]: init_net, unused_apply_net = hk.transform_with_state(_forward) - image = next(self.datasets.train)["image"] + image = next(self.datasets.train)["image"] # pyrefly: ignore[missing-attribute] params, state = init_net(key, image) return params, state diff --git a/learned_optimization/tasks/parametric/lm_rnn.py b/learned_optimization/tasks/parametric/lm_rnn.py index 9ab882c8..8729001f 100644 --- a/learned_optimization/tasks/parametric/lm_rnn.py +++ b/learned_optimization/tasks/parametric/lm_rnn.py @@ -72,7 +72,7 @@ def sample(self, key: PRNGKey) -> cfgobject.CFGNamed: }) def task_fn(self, task_params) -> base.Task: - max_vocab_size = self.datasets.extra_info["vocab_size"] + max_vocab_size = self.datasets.extra_info["vocab_size"] # pyrefly: ignore[missing-attribute] if self.vocab_size is None: vocab_size = max_vocab_size else: @@ -120,7 +120,7 @@ def init(self, rng: PRNGKey) -> Params: init_net, unused_apply_net = hk.without_apply_rng( hk.transform(_forward)) batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] seq = batch["obs"] return init_net(rng, seq) diff --git a/learned_optimization/tasks/parametric/lm_transformer.py b/learned_optimization/tasks/parametric/lm_transformer.py index c3314e8e..9d154375 100644 --- a/learned_optimization/tasks/parametric/lm_transformer.py +++ b/learned_optimization/tasks/parametric/lm_transformer.py @@ -61,7 +61,7 @@ def sample(self, key: chex.PRNGKey) -> cfgobject.CFGNamed: return cfgobject.CFGNamed("ParametricLMTransformer", {}) def task_fn(self, task_params) -> base.Task: - max_vocab_size = self.datasets.extra_info["vocab_size"] + max_vocab_size = self.datasets.extra_info["vocab_size"] # pyrefly: ignore[missing-attribute] if self.vocab_size is None: vocab_size = max_vocab_size else: @@ -103,7 +103,7 @@ def _hk_forward(self, batch): def init(self, key: chex.PRNGKey) -> base.Params: batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype), - self.datasets.abstract_batch) + self.datasets.abstract_batch) # pyrefly: ignore[missing-attribute] return self._net.init(key, batch) def loss(self, params, key, data): diff --git a/learned_optimization/tasks/parametric/parametric_utils.py b/learned_optimization/tasks/parametric/parametric_utils.py index 39da0611..3289697a 100644 --- a/learned_optimization/tasks/parametric/parametric_utils.py +++ b/learned_optimization/tasks/parametric/parametric_utils.py @@ -129,7 +129,7 @@ def orth_init(shape, dtype, key, scale=1.0, axis=-1): def uniform_scale_init(shape, dtype, key, scale=1.0): """uniform scale init.""" - input_size = onp.product(shape[:-1]) + input_size = onp.product(shape[:-1]) # pyrefly: ignore[missing-attribute] max_val = onp.sqrt(3 / input_size) * scale return jax.random.uniform(key, shape, dtype, -max_val, max_val) @@ -233,7 +233,7 @@ def sample(cls, key): def get_dynamic(cls, cfg): """Get the initializer for the given config.""" - class _SwitchedInitializer(hk.initializers.Initializer): + class _SwitchedInitializer(hk.initializers.Initializer): # pyrefly: ignore[invalid-inheritance] """A haiku initializer which dynamically switches amoung initializers.""" def __init__(self): diff --git a/learned_optimization/tree_utils.py b/learned_optimization/tree_utils.py index cd75fd47..5467c60f 100644 --- a/learned_optimization/tree_utils.py +++ b/learned_optimization/tree_utils.py @@ -142,12 +142,12 @@ def map_named(function: Callable[[str, Any], Any], Struct with the same pytree. """ if isinstance(val, Mapping): - return type(val)( - **{k: map_named(function, v, key + "/" + k) for k, v in val.items()}) + return type(val)( # pyrefly: ignore[bad-instantiation] + **{k: map_named(function, v, key + "/" + k) for k, v in val.items()}) # pyrefly: ignore[unsupported-operation] elif isinstance(val, tuple) or isinstance(val, list): return type(val)( * - [map_named(function, v, key + "/" + str(i)) for i, v in enumerate(val)]) + [map_named(function, v, key + "/" + str(i)) for i, v in enumerate(val)]) # pyrefly: ignore[unsupported-operation] # check if it's a flax dataclass elif hasattr(val, "__dataclass_fields__"): classname = repr(val).split("(")[0] @@ -156,7 +156,7 @@ def map_named(function: Callable[[str, Any], Any], for k, v in val.__dataclass_fields__.items() }) else: - return function(key, val) + return function(key, val) # pyrefly: ignore[bad-argument-type] def strip_weak_type(pytree):