Skip to content
Merged

Devel #208

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions dmff/torch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,64 @@ def j2t_pytree(v):

def wrap_torch_potential_kernel(potential_t):

@partial(jax.custom_jvp, nondiff_argnums=(2,))
# jvp, good for push-forward mode
# @partial(jax.custom_jvp, nondiff_argnums=(2,))
# def potential(positions, box, pairs, params):
# res = potential_t(j2t_pytree(positions), \
# j2t_pytree(box), \
# np.array(pairs), \
# j2t_pytree(params))
# return res

# @potential.defjvp
# def potential_jvp(pairs, primals, tangents):
# positions, box, params = primals
# dpositions, dbox, dparams = tangents
# # convert inputs to torch
# positions_t = j2t_pytree(positions)
# box_t = j2t_pytree(box)
# params_t = j2t_pytree(params)
# # do fwd and bwd in torch
# primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t)
# primal_out_torch.backward()
# # read gradient in torch
# g_positions = t2j_extract_grad(positions_t)
# g_box = t2j_extract_grad(box_t)
# g_params = t2j_extract_grad(params_t)
# # prepare output
# primal_out = t2j(primal_out_torch)
# tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box)
# tangents_leaves = jax.tree.leaves(dparams)
# grad_leaves = jax.tree.leaves(g_params)
# for x, y in zip(tangents_leaves, grad_leaves):
# tangent_out += jnp.sum(x * y)
# return primal_out, tangent_out

# vjp: good for backward
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def potential(positions, box, pairs, params):
res = potential_t(j2t_pytree(positions), \
j2t_pytree(box), \
np.array(pairs), \
res = potential_t(j2t_pytree(positions),
j2t_pytree(box),
np.array(pairs),
j2t_pytree(params))
return res

@potential.defjvp
def potential_jvp(pairs, primals, tangents):
positions, box, params = primals
dpositions, dbox, dparams = tangents
# convert inputs to torch
positions_t = j2t_pytree(positions)
def potential_fwd(positions, box, pairs, params):
pos_t = j2t_pytree(positions)
box_t = j2t_pytree(box)
pairs = np.array(pairs)
params_t = j2t_pytree(params)
# do fwd and bwd in torch
primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t)
primal_out_torch.backward()
# read gradient in torch
g_positions = t2j_extract_grad(positions_t)
g_box = t2j_extract_grad(box_t)
g_params = t2j_extract_grad(params_t)
# prepare output
primal_out = t2j(primal_out_torch)
tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box)
tangents_leaves = jax.tree.leaves(dparams)
grad_leaves = jax.tree.leaves(g_params)
for x, y in zip(tangents_leaves, grad_leaves):
tangent_out += jnp.sum(x * y)
return primal_out, tangent_out
energy = potential_t(pos_t, box_t, pairs, params_t)
energy.backward()
grads = (t2j_extract_grad(pos_t),
t2j_extract_grad(box_t),
t2j_extract_grad(params_t))
return t2j(energy), grads

def potential_bwd(pairs, res, g):
return res[0]*g, res[1]*g, jax.tree.map(lambda x: x*g, res[2])

potential.defvjp(potential_fwd, potential_bwd)

return potential

Expand Down
Binary file modified examples/eann/eann_model.pickle
Binary file not shown.
Binary file modified tests/data/eann_model.pickle
Binary file not shown.
Binary file modified tests/data/water_eann.pickle
Binary file not shown.