Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ def community2batch(
if node_data:
nodes_data.append((node, node_data))
edges_data = []
for u, v in edges:
edge_data = g.get_edge(u, v)
for edge in edges:
# Filter out self-loops and invalid edges
if not isinstance(edge, tuple) or len(edge) != 2:
continue
u, v = edge
if u == v:
continue

edge_data = g.get_edge(u, v) or g.get_edge(v, u)
if edge_data:
edges_data.append((u, v, edge_data))
else:
edge_data = g.get_edge(v, u)
if edge_data:
edges_data.append((v, u, edge_data))
return nodes_data, edges_data

@staticmethod
Expand All @@ -61,9 +64,11 @@ def _build_adjacency_list(
"""
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
edge_set: set[tuple[str, str]] = set()
for e in edges:
adj[e[0]].append(e[1])
adj[e[1]].append(e[0])
edge_set.add((e[0], e[1]))
edge_set.add((e[1], e[0]))
for u, v, _ in edges:
if u == v:
continue
adj[u].append(v)
adj[v].append(u)
edge_set.add((u, v))
edge_set.add((v, u))
return adj, edge_set
4 changes: 1 addition & 3 deletions graphgen/models/partitioner/bfs_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def partition(
if it in used_e:
continue
used_e.add(it)

u, v = it
comm_e.append((u, v))
comm_e.append(tuple(sorted(it)))
cnt += 1
# push nodes that are not visited
for n in it:
Expand Down
7 changes: 4 additions & 3 deletions graphgen/models/partitioner/dfs_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from collections.abc import Iterable
from typing import Any
from typing import Any, List

from graphgen.bases import BaseGraphStorage, BasePartitioner
from graphgen.bases.datatypes import Community
Expand Down Expand Up @@ -42,7 +42,8 @@ def partition(
):
continue

comm_n, comm_e = [], []
comm_n: List[str] = []
comm_e: List[tuple[str, str]] = []
stack = [(kind, seed)]
cnt = 0

Expand All @@ -63,7 +64,7 @@ def partition(
if it in used_e:
continue
used_e.add(it)
comm_e.append(tuple(it))
comm_e.append(tuple(sorted(it)))
cnt += 1
# push neighboring nodes
for n in it:
Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _add_unit(u):
return Community(
id=seed_unit[1],
nodes=list(community_nodes.keys()),
edges=[tuple(edge) for edge in community_edges if isinstance(edge, frozenset) and len(edge)==2],
edges=[tuple(sorted(e)) for e in community_edges]
)

for unit in tqdm(all_units, desc="ECE partition"):
Expand Down
5 changes: 3 additions & 2 deletions graphgen/utils/help_nltk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import Dict, List, Final, Optional
import warnings
import nltk
import jieba

warnings.filterwarnings(
"ignore",
category=UserWarning,
module=r"jieba\._compat"
)
# pylint: disable=wrong-import-position
import jieba


class NLTKHelper:
"""
Expand Down
78 changes: 0 additions & 78 deletions tests/integration_tests/test_engine.py

This file was deleted.