diff --git a/jraph/_src/models.py b/jraph/_src/models.py index e6a2562..7135642 100644 --- a/jraph/_src/models.py +++ b/jraph/_src/models.py @@ -22,6 +22,7 @@ import jax.tree_util as tree from jraph._src import graph as gn_graph from jraph._src import utils +from contextlib import nullcontext # As of 04/2020 pytype doesn't support recursive types. # pytype: disable=not-supported-yet @@ -548,18 +549,18 @@ def _ApplyGCN(graph): # Equivalent to jnp.sum(n_node), but jittable total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] if add_self_edges: - # We add self edges to the senders and receivers so that each node - # includes itself in aggregation. - # In principle, a `GraphsTuple` should partition by n_edge, but in - # this case it is not required since a GCN is agnostic to whether - # the `GraphsTuple` is a batch of graphs or a single large graph. - conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), + # We add self edges to the senders and receivers so that each node + # includes itself in aggregation. + # In principle, a `GraphsTuple` should partition by n_edge, but in + # this case it is not required since a GCN is agnostic to whether + # the `GraphsTuple` is a batch of graphs or a single large graph. + conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0) - conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), + conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0) else: - conv_senders = senders - conv_receivers = receivers + conv_senders = senders + conv_receivers = receivers # pylint: disable=g-long-lambda if symmetric_normalization: @@ -594,3 +595,17 @@ def _ApplyGCN(graph): return graph._replace(nodes=nodes) return _ApplyGCN + + +def random_graph(device=None): + """Returns a random graph with 10 nodes and 20 edges. + + Args: + device: Optional device to place the arrays on. If None, uses current device. + """ + n_node = 10 + n_edge = 20 + with jax.device(device) if device else nullcontext(): + senders = jnp.random.randint(0, n_node, size=n_edge) + receivers = jnp.random.randint(0, n_node, size=n_edge) + # ...