Add block-spin update sampler#44
Conversation
jackraymond
left a comment
There was a problem hiding this comment.
Great job, I'll come back to this when you add tests. Per your summary. Thanks for getting this draft up quickly.
Would be nice to have a tidy public GPU code to use and cite for speedup relative to our CPU implementation in higher throughput applications as soon as possible.
|
@jackraymond I updated the PR:
|
VolodyaCO
left a comment
There was a problem hiding this comment.
I left some comments about form.
Overall looks fantastic. Thank you.
|
Can the user compile a method of an existing |
|
Addressed @thisac and @VolodyaCO 's PRs. Briefly,
|
jackraymond
left a comment
There was a problem hiding this comment.
Looks good tests ran for me.
I'd recommend you allow setting of x as an initial condition. Check it doesn't move when num_sweeps is none, and check it oscillates under BlockMetropolis when temperature is infinity as tests perhaps.
In the latest commit, I made the following changes:
|
VolodyaCO
left a comment
There was a problem hiding this comment.
All comments have been addressed. Thank you @kevinchern
| ZEPHYR = dnx.zephyr_graph(1, coordinates=True) | ||
| GRBM_ZEPHYR = GRBM(ZEPHYR.nodes, ZEPHYR.edges) | ||
| CRAYON_ZEPHYR = dnx.zephyr_four_color | ||
|
|
||
| BIPARTITE = nx.complete_bipartite_graph(5, 3) | ||
| GRBM_BIPARTITE = GRBM(BIPARTITE.nodes, BIPARTITE.edges) | ||
| def CRAYON_BIPARTITE(b): return b < 5 | ||
|
|
||
| GRBM_SINGLE = GRBM([0], []) | ||
| def CRAYON_SINGLE(s): 0 | ||
|
|
||
| GRBM_CRAYON_TEST_CASES = [(GRBM_ZEPHYR, CRAYON_ZEPHYR), | ||
| (GRBM_BIPARTITE, CRAYON_BIPARTITE), | ||
| (GRBM_SINGLE, CRAYON_SINGLE)] |
There was a problem hiding this comment.
Better put in a setUpClass() method.
There was a problem hiding this comment.
If I understood correctly, that would be incompatible with @parameterized.expand
| self._x = nn.Parameter(initial_states.float(), requires_grad=False) | ||
| self._zeros = nn.Parameter(torch.zeros((num_chains, 1)), requires_grad=False) | ||
|
|
||
| def to(self, device: DeviceLikeType) -> BlockSpinSampler: |
There was a problem hiding this comment.
Encourage my laziness by setting device to torch.device("cuda") by default?
There was a problem hiding this comment.
No changes necessary
Test ran fine for me, got it working for my benchmarking straightfowardly.
So far as I can tell on an L4:
156 ns per spin update for single sample on Zephyr[m=12] problem.
drops to ~ 156/58 ns per spin update with 58 fold parallelism on an L4. Which is good.
thisac
left a comment
There was a problem hiding this comment.
Apart from a few minor comments, LGTM.
1d67e64 to
74666ca
Compare
This PR introduces a block-spin update sampler.
It includes duplicate code from #40