-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathhybrid-hyperoptimization.html
More file actions
207 lines (195 loc) · 11.3 KB
/
hybrid-hyperoptimization.html
File metadata and controls
207 lines (195 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
<!DOCTYPE html>
<html>
<head>
<link rel="canonical" href="https://hardmath123.github.io/hybrid-hyperoptimization.html"/>
<link rel="stylesheet" type="text/css" href="/static/base.css"/>
<title>Hybrid Hyperoptimization - Comfortably Numbered</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
<link rel="alternate" type="application/rss+xml" title="Comfortably Numbered" href="/feed.xml" />
<!--
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script>
MathJax.Hub.Config({
tex2jax: {inlineMath: [['$','$']]}
});
</script>
-->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/katex.min.css" integrity="sha384-Um5gpz1odJg5Z4HAmzPtgZKdTBHZdw8S29IecapCSB31ligYPhHQZMIlWLYQGVoc" crossorigin="anonymous">
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/katex.min.js" integrity="sha384-YNHdsYkH6gMx9y3mRkmcJ2mFUjTd0qNQQvY9VYZgQd7DcN7env35GzlmFaZ23JGp" crossorigin="anonymous"></script>
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/contrib/auto-render.min.js" integrity="sha384-vZTG03m+2yp6N6BNi5iM4rW4oIwk5DfcNdFfxkk9ZWpDriOkXX8voJBFrAO7MpVl" crossorigin="anonymous"></script>
<script>
document.addEventListener("DOMContentLoaded", function() {
renderMathInElement(document.body, {
// customised options
// • auto-render specific keys, e.g.:
delimiters: [
{left: '$$', right: '$$', display: true},
{left: '$', right: '$', display: false},
{left: '\\begin{align}', right: '\\end{align}', display: true},
{left: '\\(', right: '\\)', display: false},
{left: '\\[', right: '\\]', display: true}
],
// • rendering keys, e.g.:
throwOnError : false
});
});
</script>
</head>
<body>
<header id="header">
<script src="static/main.js"></script>
<div>
<a href="/"><span class="left-word">Comfortably</span> <span class="right-word">Numbered</span></a>
</div>
</header>
<article id="postcontent" class="centered">
<section>
<h1>Hybrid Hyperoptimization</h1>
<center><em><p>The most ambitious crossover event in automatic differentiation history</p>
</em></center>
<h4>Friday, September 11, 2020 · 4 min read</h4>
<blockquote>
<p> Summary of this post: It turns out that by carefully mixing forward-mode and
reverse-mode automatic differentiation, you can greatly simplify certain
hyperparameter optimization algorithms.</p>
</blockquote>
<p>Here’s a fun recursive idea: just like how we optimize the parameters of
machine learning models by gradient descent, we can also their
<em>hyperparameters</em> by gradient descent. This is by no means a new idea; you can
find 20-year-old papers that discuss (for example) optimizing gradient descent
step sizes by gradient descent itself.</p>
<p>Broadly, there are two ways to do this. The “short-term” way is to take a
<em>single</em> step of gradient descent, and then based on how that goes, adjust the
hyperparameters. That’s the premise of the 2018 paper <a href="https://arxiv.org/abs/1703.04782">Online Learning Rate
Adaptation with Hypergradient Descent</a>: it
turns out that you can do this in an extremely lightweight way. You would hope
that over time the hyperparameters converge to something optimal alongside the
parameters (and indeed that is the case, though you have to be careful).</p>
<p>The “long-term” way is to train <em>several</em> steps and <em>then</em> backpropagate
through the <em>entire training</em> to adjust the hyperparameters. The hope is that
this provides a stronger “signal” to the hyperparameter-gradient, which is
better for convergence. But, this comes at the (tremendous) expense of having
to store a copy of the entire computation graph for several steps of training,
and then backpropagate through it. If your model has 1GB of parameters, you
need to store ~1GB worth of numbers for <em>each step</em>. You can imagine that adds
up over the course of many epochs, and so the algorithm is severely limited by
how much RAM you have.</p>
<p>The 2015 paper <a href="https://arxiv.org/abs/1502.03492">Gradient-based Hyperparameter Optimization through Reversible
Learning</a> makes this work by throwing away
the intermediate steps of the computation graph and then JIT-re-computing it
“in reverse” during backpropagation. It’s a clever technique, but kind of hairy
and hard to implement (you have to be super careful about numerical precision).</p>
<p>Here’s a graph that visualizes these two techniques (<a href="https://arxiv.org/abs/1909.13371">from this
paper</a>). You’re looking at several parallel
loss curves $\log(f)$ is the loss plotted on a log scale), pointing towards
you, arranged in order of the “step size” hyperparameter
(that’s $\log(\alpha)$. The orange curve represents a “short-term”
hyperparameter optimization, which is allowed to move along the “step size”
axis at each step. The “long-term” hyperparameter optimization instead
optimizes directly on the thick black dashed “U” — that is, after <em>several</em>
steps of training. You can see how the latter is smoother, but also much harder
to compute.</p>
<p><img src="static/hybrid-hyperoptimization/fig-metasurface.png" alt="The preceding paragraph explains this
figure."></p>
<p>To summarize: the short-term way is cheap but noisy. The long-term way is
expensive but less noisy. Can we get the best of both worlds?</p>
<hr>
<p>Well, let’s think more carefully about the source of the expense. The problem
is that we need to store (or be able to reconstruct) the full computation graph
in order to do backpropagation. Okay, but do we really <em>need</em> to do
backpropagation? The only reason we backpropagate is that it’s more efficient
in the case when you want derivatives with respect to <em>many</em> different
variables. If you have millions of model parameters, backpropagation is
millions of times faster than the much simpler <a href="https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers">forward-mode (“dual numbers”)
automatic
differentiation</a>.</p>
<p>But we <em>don’t</em> have millions of <em>hyperparameters!</em> Step size, for example, is
just a single number. The Adam optimizer only has a total of 4 hyperparameters.
With this in mind, backpropagation isn’t even the right choice — we <em>should</em>
be using dual numbers for the hyperparameter optimization. On the other hand,
we should still be using backpropagation for the “inner loop” that optimizes
the (non-hyper-) parameters. That is, we want to do something like this:</p>
<pre><code class="lang-python">initialize hyperparameters
# loop to optimize hyperparameters
while True:
initialize parameters
# loop to optimize parameters
for i in range(100):
run model on data to get loss
# using reverse-mode!
compute d(loss) / d(parameters)
update parameters using hyperparameters
# using forward-mode
compute d(loss) / d(hyperparameters)
update hyperparameters
</code></pre>
<blockquote>
<p>(Update: I discovered that this was suggested in the 2017 paper <a href="https://arxiv.org/abs/1703.01785">Forward and
Reverse Gradient-Based Hyperparameter
Optimization</a>. But keep reading — while
the paper’s
<a href="https://github.com/lucfra/FAR-HO/blob/master/far_ho/hyper_gradients.py#L375">implementation</a>
needs a lot of math to be worked out manually, I’m going to show you how to
implement this in a way that makes all the math in the paper fall out “for
free”…)</p>
</blockquote>
<p>This proposal raises a logistical question: how do we reconcile these two
automatic differentiation algorithms in the same program? <strong>The “trick” is to
“thread” dual numbers through a backpropagation implementation.</strong> In other
words, implement backpropagation as usual, but rather than <code>float</code> type
numbers, exclusively use <code>dual_number</code> type numbers (even when doing derivative
calculations). Initialize the system such that the dual numbers track
derivatives with respect to the hyperparameters you care about. Then, your
final loss value’s attached $\epsilon$-value <em>immediately</em> gives you
<code>d(loss)/d(hyperparameters)</code>. No backpropagation needed — and so, it’s safe
to “forget” the computation graph.</p>
<p>That’s it! That’s all I wanted to share in this blog post! :)</p>
<p>Of course, it’s not obvious how to implement this in PyTorch, since you can’t
naïvely do <code>tensor(dtype=dual_number)</code>. Rather than hack in a custom numeric
data type that implemented dual numbers, I wrote my own tiny implementations of
forward- and reverse-mode automatic differentiation. It’s just a couple dozen
(very-recognizable-to-automatic-differentiation-enthusiasts) lines of code. I
was careful to make each implementation generic in the kind of “number” it
accepts. That allowed me to run the reverse-mode algorithms using the
forward-mode data types.</p>
<p>Running it on a simple quadratic optimization problem, we can see that updating
the hyperparameter $\alpha$ yields better-looking loss curves — in this
GIF, we’re discovering that we should increase the step size. Yay!</p>
<p><img src="static/hybrid-hyperoptimization/loss.gif" alt="GIF of loss curve becoming better over
time"></p>
<p>I think this is a very neat trick, which surely has other cool applications
(for example, differentiating through long-running physics simulations!). If
you’re curious, check out this <a href="static/hybrid-hyperoptimization/hybrid-hyperoptimization.ipynb">IPython
notebook</a> for
the full source code. Feel free to adapt it to your ideas. Then, tell me what
you’ve built!</p>
</section>
<div id="comment-breaker">◊ ◊ ◊</div>
</article>
<footer id="footer">
<div>
<ul>
<li><a href="https://github.com/kach">
Github</a></li>
<li><a href="feed.xml">
Subscribe (RSS feed)</a></li>
<li><a href="https://twitter.com/hardmath123">
Twitter</a></li>
<li><a href="https://creativecommons.org/licenses/by-nc/3.0/deed.en_US">
CC BY-NC 3.0</a></li>
</ul>
</div>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','//www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-46120535-1', 'hardmath123.github.io');
ga('require', 'displayfeatures');
ga('send', 'pageview');
</script>
</footer>
</body>
</html>