Plotting convergence of alpha weights in One-Shot Optimizers#154
Plotting convergence of alpha weights in One-Shot Optimizers#154
Conversation
naslib/defaults/trainer.py
Outdated
| # logger.info("param size = %fMB", n_parameters) | ||
| self.search_trajectory = utils.AttrDict( | ||
| logger.info("param size = %fMB", n_parameters) | ||
| self.errors_dict = utils.AttrDict( |
There was a problem hiding this comment.
I would keep the name as "search_trajectory", since it also contains non-error metrics such as accuracies, train_time and params
| ) | ||
| ) | ||
|
|
||
| return top1.avg |
There was a problem hiding this comment.
Why has this line been removed?
| metric=metric, dataset=self.config.dataset, dataset_api=dataset_api | ||
| ) | ||
| logger.info("Queried results ({}): {}".format(metric, result)) | ||
| return result |
There was a problem hiding this comment.
Why has this line been removed?
naslib/defaults/trainer.py
Outdated
| "Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format( | ||
| epoch, | ||
| self.train_top1.avg, | ||
| self.train_top5.avg, |
There was a problem hiding this comment.
This is useful, but it doesn't belong with the viz PR
naslib/defaults/trainer.py
Outdated
| train from scratch. | ||
| """ | ||
| logger.info("Beginning search") | ||
| logger.info("Start training") |
There was a problem hiding this comment.
Try and keep the changes in a PR relevant to the functionality it addresses. Reversing the commits in a PR should only affect the functionality it introduces.
naslib/defaults/trainer.py
Outdated
| start_time = time.time() | ||
| self.optimizer.new_epoch(e) | ||
|
|
||
| arch_weights_lst = [] |
There was a problem hiding this comment.
Generally, as a programming practice, don't mention the data-structure in the name of the variable. arch_weights.append(...) makes it clear that it is a list, not a dict.
naslib/defaults/trainer.py
Outdated
| for e in range(start_epoch, self.epochs): | ||
|
|
||
| # create the arch directory (without overwriting) | ||
| if self.config.save_arch_weights: |
There was a problem hiding this comment.
It is better for readability to write this as if self.config.save_arch_weights is True
naslib/defaults/trainer.py
Outdated
|
|
||
| best_arch = self.optimizer.get_final_architecture() | ||
| logger.info(f"Final architecture hash: {best_arch.get_hash()}") | ||
| logger.info("Final architecture:\n" + best_arch.modules_str()) |
There was a problem hiding this comment.
Does not belong in this PR
| seed = config.search.seed | ||
| batch_size = config.batch_size | ||
| train_portion = config.train_portion | ||
| batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size |
There was a problem hiding this comment.
This looks like a bug fix independent of the visualization code. If so, create a new PR, or push the fix directly to Develop.
naslib/defaults/trainer.py
Outdated
| import numpy as np | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import seaborn as sns |
There was a problem hiding this comment.
- Neither matplotlib.pyplot nor seaborn are used in this file. Remove.
- Seaborn is missing in requirements.txt.
| from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer | ||
| from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace | ||
|
|
||
| from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir |
There was a problem hiding this comment.
I tried running this file and it crashed because create_exp_dir is not imported in the __init__.py of utils
|
|
||
| # save and possibly plot architectural weights | ||
| logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") | ||
| if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: |
There was a problem hiding this comment.
self.config.save_arch_weights is True for better readability
| logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") | ||
| if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: | ||
| torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') | ||
| if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights: |
There was a problem hiding this comment.
self.config.plot_arch_weights is True
| all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas | ||
|
|
||
| # unpack search space information | ||
| alpha_dict = {} |
There was a problem hiding this comment.
Avoid data-structure name in var name
| import numpy as np | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| from matplotlib.cm import ScalarMappable |
| alpha_dict = {} | ||
| min_soft, max_soft = np.inf, -np.inf | ||
| for graph in optimizer.graph._get_child_graphs(single_instances=True): | ||
| for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()): |
There was a problem hiding this comment.
Neonkraft
left a comment
There was a problem hiding this comment.
Please address the comments :)
There seems to be a small bug in the plotting code.

Finalized Lukas' idea for heat map to visualize the convergence of alpha weights for One-Shot optimizers, via two new configuration parameters
config.save_arch_weightsandconfig.plot_arch_weightsPlotting was made to be extensible to larger search spaces by limiting the number of edges to 4, but this could also be parameterized if the user wants to be able to control this too.