diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index d948e3a7..384c9e4e 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -39,14 +39,17 @@ def community2batch( if node_data: nodes_data.append((node, node_data)) edges_data = [] - for u, v in edges: - edge_data = g.get_edge(u, v) + for edge in edges: + # Filter out self-loops and invalid edges + if not isinstance(edge, tuple) or len(edge) != 2: + continue + u, v = edge + if u == v: + continue + + edge_data = g.get_edge(u, v) or g.get_edge(v, u) if edge_data: edges_data.append((u, v, edge_data)) - else: - edge_data = g.get_edge(v, u) - if edge_data: - edges_data.append((v, u, edge_data)) return nodes_data, edges_data @staticmethod @@ -61,9 +64,11 @@ def _build_adjacency_list( """ adj: dict[str, List[str]] = {n[0]: [] for n in nodes} edge_set: set[tuple[str, str]] = set() - for e in edges: - adj[e[0]].append(e[1]) - adj[e[1]].append(e[0]) - edge_set.add((e[0], e[1])) - edge_set.add((e[1], e[0])) + for u, v, _ in edges: + if u == v: + continue + adj[u].append(v) + adj[v].append(u) + edge_set.add((u, v)) + edge_set.add((v, u)) return adj, edge_set diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 994e08e8..a00ad76d 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -63,9 +63,7 @@ def partition( if it in used_e: continue used_e.add(it) - - u, v = it - comm_e.append((u, v)) + comm_e.append(tuple(sorted(it))) cnt += 1 # push nodes that are not visited for n in it: diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index 4d93ad7f..fa2786e6 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections.abc import Iterable -from typing import Any +from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community @@ -42,7 +42,8 @@ def partition( ): continue - comm_n, comm_e = [], [] + comm_n: List[str] = [] + comm_e: List[tuple[str, str]] = [] stack = [(kind, seed)] cnt = 0 @@ -63,7 +64,7 @@ def partition( if it in used_e: continue used_e.add(it) - comm_e.append(tuple(it)) + comm_e.append(tuple(sorted(it))) cnt += 1 # push neighboring nodes for n in it: diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index af3af7c7..c2611be3 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -142,7 +142,7 @@ def _add_unit(u): return Community( id=seed_unit[1], nodes=list(community_nodes.keys()), - edges=[tuple(edge) for edge in community_edges if isinstance(edge, frozenset) and len(edge)==2], + edges=[tuple(sorted(e)) for e in community_edges] ) for unit in tqdm(all_units, desc="ECE partition"): diff --git a/graphgen/utils/help_nltk.py b/graphgen/utils/help_nltk.py index c7d5e301..86d55e5f 100644 --- a/graphgen/utils/help_nltk.py +++ b/graphgen/utils/help_nltk.py @@ -3,13 +3,14 @@ from typing import Dict, List, Final, Optional import warnings import nltk -import jieba - warnings.filterwarnings( "ignore", category=UserWarning, module=r"jieba\._compat" ) +# pylint: disable=wrong-import-position +import jieba + class NLTKHelper: """ diff --git a/tests/integration_tests/test_engine.py b/tests/integration_tests/test_engine.py deleted file mode 100644 index 6a389e42..00000000 --- a/tests/integration_tests/test_engine.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -from graphgen.engine import Context, Engine, op - -engine = Engine(max_workers=2) - - -def test_simple_dag(capsys): - """Verify the DAG A->B/C->D execution results and print order.""" - ctx = Context() - - @op("A") - def op_a(self, ctx): - print("Running A") - ctx.set("A", 1) - - @op("B", deps=["A"]) - def op_b(self, ctx): - print("Running B") - ctx.set("B", ctx.get("A") + 1) - - @op("C", deps=["A"]) - def op_c(self, ctx): - print("Running C") - ctx.set("C", ctx.get("A") + 2) - - @op("D", deps=["B", "C"]) - def op_d(self, ctx): - print("Running D") - ctx.set("D", ctx.get("B") + ctx.get("C")) - - # Explicitly list the nodes to run; avoid relying on globals(). - ops = [op_a, op_b, op_c, op_d] - engine.run(ops, ctx) - - # Assert final results. - assert ctx["A"] == 1 - assert ctx["B"] == 2 - assert ctx["C"] == 3 - assert ctx["D"] == 5 - - # Assert print order: A must run before B and C; D must run after B and C. - captured = capsys.readouterr().out.strip().splitlines() - assert "Running A" in captured - assert "Running B" in captured - assert "Running C" in captured - assert "Running D" in captured - - a_idx = next(i for i, line in enumerate(captured) if "Running A" in line) - b_idx = next(i for i, line in enumerate(captured) if "Running B" in line) - c_idx = next(i for i, line in enumerate(captured) if "Running C" in line) - d_idx = next(i for i, line in enumerate(captured) if "Running D" in line) - - assert a_idx < b_idx - assert a_idx < c_idx - assert d_idx > b_idx - assert d_idx > c_idx - - -def test_cyclic_detection(): - """A cyclic dependency should raise ValueError.""" - ctx = Context() - - @op("X", deps=["Y"]) - def op_x(self, ctx): - pass - - @op("Y", deps=["X"]) - def op_y(self, ctx): - pass - - ops = [op_x, op_y] - with pytest.raises(ValueError, match="Cyclic dependencies"): - engine.run(ops, ctx) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"])