From 0efcd51ab91a81804c581a23e16ddaf3bb1017ff Mon Sep 17 00:00:00 2001 From: Ingo Scholtes Date: Thu, 2 Oct 2025 08:46:26 +0000 Subject: [PATCH 1/3] new semantics of degree calculation --- src/pathpyG/core/graph.py | 83 ++++++++++++++++----------- src/pathpyG/core/multi_order_model.py | 2 +- tests/core/test_graph.py | 6 +- 3 files changed, 55 insertions(+), 36 deletions(-) diff --git a/src/pathpyG/core/graph.py b/src/pathpyG/core/graph.py index e8a438224..1857eabbe 100644 --- a/src/pathpyG/core/graph.py +++ b/src/pathpyG/core/graph.py @@ -440,7 +440,7 @@ def sparse_adj_matrix(self, edge_attr: Any = None) -> Any: @property def in_degrees(self) -> Dict[str, float]: - """Return in-degrees of nodes in directed network. + """Return unweighted in-degrees of nodes in directed network. Returns: dict: dictionary containing in-degrees of nodes @@ -449,60 +449,79 @@ def in_degrees(self) -> Dict[str, float]: @property def out_degrees(self) -> Dict[str, float]: - """Return out-degrees of nodes in directed network. + """Return unweighted out-degrees of nodes in directed network. Returns: dict: dictionary containing out-degrees of nodes """ return self.degrees(mode="out") - def degrees(self, mode: str = "in") -> Dict[str, float]: + def degrees(self, mode: str = "in", edge_attr: Any = None, return_tensor: bool = False) -> Dict[str, float]: """ - Return degrees of nodes. + Return (weighted) degrees of nodes. Args: - mode: `in` or `out` to calculate the in- or out-degree for + mode: `in` or `out` to calculate in- or out-degree for directed networks. - + edge_attr: Optional numerical edge attribute that will + be used to compute weighted degrees + return_tensor: if True the function returns a degree tensor, if False (default) + a dictionary will be returned that can be indexed by nodes Returns: - dict: dictionary containing degrees of nodes + dict: dictionary containing node degrees """ if mode == "in": - d = torch_geometric.utils.degree(self.data.edge_index[1], num_nodes=self.n, dtype=torch.int) + if not edge_attr: + d = torch_geometric.utils.degree(self.data.edge_index[1], num_nodes=self.n, dtype=torch.int) + else: + edge_weight = getattr(self.data, edge_attr, None) + d = scatter(edge_weight, self.data.edge_index[1], dim=0, dim_size=self.data.num_nodes, reduce="sum") + elif mode == "out": + if not edge_attr: + d = torch_geometric.utils.degree(self.data.edge_index[0], num_nodes=self.n, dtype=torch.int) + else: + edge_weight = getattr(self.data, edge_attr, None) + d = scatter(edge_weight, self.data.edge_index[0], dim=0, dim_size=self.data.num_nodes, reduce="sum") + if return_tensor: + return d else: - d = torch_geometric.utils.degree(self.data.edge_index[0], num_nodes=self.n, dtype=torch.int) - return {self.mapping.to_id(i): d[i].item() for i in range(self.n)} + return {self.mapping.to_id(i): d[i].item() for i in range(self.n)} - def weighted_outdegrees(self) -> torch.Tensor: - """ - Compute the weighted outdegrees of each node in the graph. + # def weighted_outdegrees(self) -> torch.Tensor: + # """ + # Compute the weighted outdegrees of each node in the graph. - Args: - graph (Graph): pathpy graph object. + # Args: + # graph (Graph): pathpy graph object. - Returns: - tensor: Weighted outdegrees of nodes. - """ - edge_weight = getattr(self.data, "edge_weight", None) - if edge_weight is None: - edge_weight = torch.ones(self.data.num_edges, device=self.data.edge_index.device) - weighted_outdegree = scatter( - edge_weight, self.data.edge_index[0], dim=0, dim_size=self.data.num_nodes, reduce="sum" - ) - return weighted_outdegree + # Returns: + # tensor: Weighted outdegrees of nodes. + # """ + # edge_weight = getattr(self.data, "edge_weight", None) + # if edge_weight is None: + # edge_weight = torch.ones(self.data.num_edges, device=self.data.edge_index.device) + # weighted_outdegree = scatter( + # edge_weight, self.data.edge_index[0], dim=0, dim_size=self.data.num_nodes, reduce="sum" + # ) + # return weighted_outdegree - def transition_probabilities(self) -> torch.Tensor: + def transition_probabilities(self, edge_attr: Any = None) -> torch.Tensor: """ - Compute transition probabilities based on weighted outdegrees. + Compute transition probabilities based on (weighted) outdegrees. + + Args: + edge_attr: Optional name of numerical edge attribute that will + will be used to calculate weighted out-degrees for the + visitation probabilities. Returns: tensor: Transition probabilities. """ - weighted_outdegree = self.weighted_outdegrees() - source_ids = self.data.edge_index[0] - edge_weight = getattr(self.data, "edge_weight", None) - if edge_weight is None: - edge_weight = torch.ones(self.data.num_edges, device=self.data.edge_index.device) + weighted_outdegree = self.degrees(mode="out", edge_attr=edge_attr, return_tensor=True) + source_ids = self.data.edge_index[0] + edge_weight = torch.ones(self.data.num_edges, device=self.data.edge_index.device) + if edge_attr: + edge_weight = getattr(self.data, edge_attr, None) return edge_weight / weighted_outdegree[source_ids] def laplacian(self, normalization: Any = None, edge_attr: Any = None) -> Any: diff --git a/src/pathpyG/core/multi_order_model.py b/src/pathpyG/core/multi_order_model.py index c432a3205..d0c70ec07 100644 --- a/src/pathpyG/core/multi_order_model.py +++ b/src/pathpyG/core/multi_order_model.py @@ -357,7 +357,7 @@ def get_mon_log_likelihood(self, dag_graph: Data, max_order: int = 1) -> float: # Adding the likelihood of highest/stationary order if max_order > 0: - transition_probabilities = self.layers[max_order].transition_probabilities() + transition_probabilities = self.layers[max_order].transition_probabilities(edge_attr="edge_weight") log_transition_probabilities = torch.log(transition_probabilities) llh_by_subpath = log_transition_probabilities * self.layers[max_order].data.edge_weight llh += llh_by_subpath.sum().item() diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py index e1e1352dc..9b05c74b4 100644 --- a/tests/core/test_graph.py +++ b/tests/core/test_graph.py @@ -243,12 +243,12 @@ def test_out_degrees(simple_graph): def test_weighted_outdegrees(simple_graph): # Test on graph without defined weights - out_degrees = simple_graph.weighted_outdegrees() + out_degrees = simple_graph.degrees(mode="out", return_tensor=True) assert out_degrees.equal(torch.tensor([2, 1, 0])) # Test on graph with defined weights simple_graph.data["edge_weight"] = torch.tensor([1, 3, 2]) - out_degrees = simple_graph.weighted_outdegrees() + out_degrees = simple_graph.degrees(mode="out", edge_attr="edge_weight", return_tensor=True) assert out_degrees.equal(torch.tensor([4, 2, 0])) @@ -259,7 +259,7 @@ def test_transition_probabilities(simple_graph): # Test on graph with defined weights simple_graph.data["edge_weight"] = torch.tensor([1, 3, 2]) - transition_probs = simple_graph.transition_probabilities() + transition_probs = simple_graph.transition_probabilities(edge_attr="edge_weight") assert transition_probs.equal(torch.tensor([0.25, 0.75, 1])) From 32798bbf4d460cff9db8cf1a72e725710a545622 Mon Sep 17 00:00:00 2001 From: Ingo Scholtes Date: Thu, 2 Oct 2025 08:56:57 +0000 Subject: [PATCH 2/3] updted tutorial, fixing str type --- docs/tutorial/basic_concepts.ipynb | 209 ++++++++++++++++++++++++----- src/pathpyG/core/graph.py | 2 +- 2 files changed, 175 insertions(+), 36 deletions(-) diff --git a/docs/tutorial/basic_concepts.ipynb b/docs/tutorial/basic_concepts.ipynb index 624319f72..e7dfb8125 100644 --- a/docs/tutorial/basic_concepts.ipynb +++ b/docs/tutorial/basic_concepts.ipynb @@ -171,9 +171,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "EdgeIndex([[0, 0, 1],\n", + " [2, 1, 2]], sparse_size=(4, 4), nnz=3, sort_order=row)\n" + ] + } + ], "source": [ "print(g.data.edge_index)" ] @@ -194,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -228,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -257,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -274,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -298,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -314,20 +323,20 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "a\n", - "b\n", - "c\n", - "d\n", - "('a', 'c')\n", - "('a', 'b')\n", - "('b', 'c')\n" + "0\n", + "1\n", + "2\n", + "3\n", + "(0, 2)\n", + "(0, 1)\n", + "(1, 2)\n" ] } ], @@ -348,7 +357,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -364,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -398,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -435,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -444,7 +453,7 @@ "tensor([2, 1])" ] }, - "execution_count": 22, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -609,7 +618,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -627,6 +636,136 @@ " print(f\"{v} -> {g.in_degrees[v]}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `in_degree` and `out_degree` properties are shortcuts to a general `degree` function that can be used to calculate (weighted) in- and outdegrees. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': 0, 'c': 2, 'b': 1}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.degrees(mode='in')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': 2, 'c': 0, 'b': 1}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.degrees(mode='out')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Degrees can be alternatively returned as torch.tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 2, 1], dtype=torch.int32)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.degrees(mode='in', return_tensor=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use arbitrary numerical edge attributes that will be used for a weighted (in- or out) degree calculation." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "g.data.edge_weight=torch.tensor([1.0, 2.0, 3.0])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 5., 1.])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.degrees(mode='in', edge_attr='edge_weight', return_tensor=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([3., 0., 3.])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g.degrees(mode='out', edge_attr='edge_weight', return_tensor=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -636,16 +775,16 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Data(edge_index=[2, 3], num_nodes=3, node_sequence=[3, 1])" + "Data(edge_index=[2, 3], num_nodes=3, node_sequence=[3, 1], edge_weight=[3])" ] }, - "execution_count": 35, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -690,7 +829,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -700,8 +839,8 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[37], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m g \u001b[38;5;241m=\u001b[39m \u001b[43mpp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mGraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_edge_list\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m g\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mis_cuda\n", - "File \u001b[0;32m/workspaces/pathpyG/src/pathpyG/core/graph.py:180\u001b[0m, in \u001b[0;36mGraph.from_edge_list\u001b[0;34m(edge_list, is_undirected, mapping, num_nodes, device)\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_nodes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 177\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m mapping\u001b[38;5;241m.\u001b[39mnum_ids()\n\u001b[1;32m 179\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m EdgeIndex(\n\u001b[0;32m--> 180\u001b[0m \u001b[43mmapping\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_idxs\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mT\u001b[38;5;241m.\u001b[39mcontiguous(),\n\u001b[1;32m 181\u001b[0m sparse_size\u001b[38;5;241m=\u001b[39m(num_nodes, num_nodes),\n\u001b[1;32m 182\u001b[0m is_undirected\u001b[38;5;241m=\u001b[39mis_undirected,\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Graph(Data(edge_index\u001b[38;5;241m=\u001b[39medge_index, num_nodes\u001b[38;5;241m=\u001b[39mnum_nodes), mapping\u001b[38;5;241m=\u001b[39mmapping)\n", + "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m g \u001b[38;5;241m=\u001b[39m \u001b[43mpp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mGraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_edge_list\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ma\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m g\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mis_cuda\n", + "File \u001b[0;32m/workspaces/pathpyG/src/pathpyG/core/graph.py:179\u001b[0m, in \u001b[0;36mGraph.from_edge_list\u001b[0;34m(edge_list, is_undirected, mapping, device)\u001b[0m\n\u001b[1;32m 174\u001b[0m mapping \u001b[38;5;241m=\u001b[39m IndexMap(node_ids)\n\u001b[1;32m 176\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m mapping\u001b[38;5;241m.\u001b[39mnum_ids()\n\u001b[1;32m 178\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m EdgeIndex(\n\u001b[0;32m--> 179\u001b[0m \u001b[43mmapping\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_idxs\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mT\u001b[38;5;241m.\u001b[39mcontiguous(),\n\u001b[1;32m 180\u001b[0m sparse_size\u001b[38;5;241m=\u001b[39m(num_nodes, num_nodes),\n\u001b[1;32m 181\u001b[0m is_undirected\u001b[38;5;241m=\u001b[39mis_undirected,\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Graph(Data(edge_index\u001b[38;5;241m=\u001b[39medge_index, num_nodes\u001b[38;5;241m=\u001b[39mnum_nodes), mapping\u001b[38;5;241m=\u001b[39mmapping)\n", "File \u001b[0;32m/workspaces/pathpyG/src/pathpyG/core/index_map.py:361\u001b[0m, in \u001b[0;36mIndexMap.to_idxs\u001b[0;34m(self, nodes, device)\u001b[0m\n\u001b[1;32m 359\u001b[0m shape \u001b[38;5;241m=\u001b[39m nodes\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 360\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_shape \u001b[38;5;241m==\u001b[39m (\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,):\n\u001b[0;32m--> 361\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid_to_idx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mnodes\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mreshape(shape)\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 363\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_to_idx[\u001b[38;5;28mtuple\u001b[39m(node)] \u001b[38;5;28;01mfor\u001b[39;00m node \u001b[38;5;129;01min\u001b[39;00m nodes\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_shape)], device\u001b[38;5;241m=\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 364\u001b[0m shape[: \u001b[38;5;241m-\u001b[39m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_shape) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 365\u001b[0m )\n", "File \u001b[0;32m/opt/conda/lib/python3.11/site-packages/torch/cuda/__init__.py:314\u001b[0m, in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCUDA_MODULE_LOADING\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39menviron:\n\u001b[1;32m 313\u001b[0m os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCUDA_MODULE_LOADING\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLAZY\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 314\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cuda_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;66;03m# Some of the queued calls may reentrantly call _lazy_init();\u001b[39;00m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;66;03m# we need to just return without initializing in that case.\u001b[39;00m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;66;03m# However, we must not let any *other* threads in!\u001b[39;00m\n\u001b[1;32m 318\u001b[0m _tls\u001b[38;5;241m.\u001b[39mis_initializing \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", "\u001b[0;31mRuntimeError\u001b[0m: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 500: named symbol not found" @@ -1024,7 +1163,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1052,7 +1191,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1084,7 +1223,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1118,7 +1257,7 @@ "\n", "\n", "\n", - "
\n", + "
\n", "\n", "