Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions tools/python_api/test/benchmark_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import contextlib
import os
import subprocess
import sys
Expand All @@ -27,9 +28,14 @@ def parse_args() -> argparse.Namespace:
"and validate deterministic results."
)
)
parser.add_argument("--target-gb", type=float, default=8.0, help="Target Arrow table size in GiB.")
parser.add_argument(
"--chunk-rows", type=int, default=1_000_000, help="Rows per generated Arrow record batch."
"--target-gb", type=float, default=8.0, help="Target Arrow table size in GiB."
)
parser.add_argument(
"--chunk-rows",
type=int,
default=1_000_000,
help="Rows per generated Arrow record batch.",
)
parser.add_argument(
"--filter-cutoff",
Expand All @@ -38,9 +44,14 @@ def parse_args() -> argparse.Namespace:
help="Filter predicate uses n.filter_key < cutoff where filter_key is in [0, 999].",
)
parser.add_argument(
"--threads", type=int, default=max(2, os.cpu_count() or 2), help="Ladybug query worker threads."
"--threads",
type=int,
default=max(2, os.cpu_count() or 2),
help="Ladybug query worker threads.",
)
parser.add_argument(
"--db-path", type=str, default="", help="Optional database path."
)
parser.add_argument("--db-path", type=str, default="", help="Optional database path.")
parser.add_argument(
"--query-runs",
type=int,
Expand All @@ -51,7 +62,9 @@ def parse_args() -> argparse.Namespace:


def read_process_cpu_percent(pid: int) -> float:
output = subprocess.check_output(["ps", "-o", "%cpu=", "-p", str(pid)], text=True).strip()
output = subprocess.check_output(
["ps", "-o", "%cpu=", "-p", str(pid)], text=True
).strip()
if not output:
return 0.0
return float(output)
Expand Down Expand Up @@ -125,17 +138,17 @@ def build_large_arrow_table(
return pa.Table.from_batches(batches), expected_count, expected_checksum


def measure_query_once(conn: lb.Connection, query: str) -> tuple[float, int, int, float, float]:
def measure_query_once(
conn: lb.Connection, query: str
) -> tuple[float, int, int, float, float]:
samples: list[float] = []
stop_event = threading.Event()

def cpu_sampler() -> None:
pid = os.getpid()
while not stop_event.is_set():
try:
with contextlib.suppress(Exception):
samples.append(read_process_cpu_percent(pid))
except Exception:
pass
time.sleep(0.2)

sampler_thread = threading.Thread(target=cpu_sampler, daemon=True)
Expand All @@ -160,11 +173,14 @@ def cpu_sampler() -> None:
def main() -> int:
args = parse_args()
if not (0 < args.filter_cutoff <= 1000):
raise ValueError("--filter-cutoff must be in [1, 1000].")
msg = "--filter-cutoff must be in [1, 1000]."
raise ValueError(msg)
if args.chunk_rows <= 0:
raise ValueError("--chunk-rows must be positive.")
msg = "--chunk-rows must be positive."
raise ValueError(msg)
if args.query_runs <= 0:
raise ValueError("--query-runs must be positive.")
msg = "--query-runs must be positive."
raise ValueError(msg)

target_bytes = int(args.target_gb * (1024**3))
db_path_value = args.db_path
Expand All @@ -174,15 +190,23 @@ def main() -> int:
temp_dir = tempfile.TemporaryDirectory(prefix="ladybug_arrow_bench_")
db_path_value = str(Path(temp_dir.name) / "bench.lbdb")

print(f"Building Arrow table (target ~{args.target_gb:.2f} GiB)... and {args.threads} query threads")
print(
f"Building Arrow table (target ~{args.target_gb:.2f} GiB)... and {args.threads} query threads"
)
build_start = time.perf_counter()
table, expected_count, expected_checksum = build_large_arrow_table(
target_bytes=target_bytes, chunk_rows=args.chunk_rows, filter_cutoff=args.filter_cutoff
target_bytes=target_bytes,
chunk_rows=args.chunk_rows,
filter_cutoff=args.filter_cutoff,
)
build_secs = time.perf_counter() - build_start
print(f"Built table with {table.num_rows:,} rows, {table.nbytes / (1024**3):.2f} GiB in {build_secs:.2f}s")
print(
f"Built table with {table.num_rows:,} rows, {table.nbytes / (1024**3):.2f} GiB in {build_secs:.2f}s"
)

db = lb.Database(database_path=db_path_value, buffer_pool_size=256 * 1024 * 1024, read_only=False)
db = lb.Database(
database_path=db_path_value, buffer_pool_size=256 * 1024 * 1024, read_only=False
)
conn = lb.Connection(db, num_threads=args.threads)

table_name = "arrow_cpu_bench"
Expand All @@ -192,8 +216,7 @@ def main() -> int:
conn.create_arrow_table(table_name, table)
else:
print(f"Creating node table '{table_name}' and loading from Arrow...")
conn.execute(
f"""
conn.execute(f"""
CREATE NODE TABLE {table_name}(
id INT64,
filter_key INT32,
Expand All @@ -211,8 +234,7 @@ def main() -> int:
x11 INT64,
PRIMARY KEY(id)
)
"""
)
""")
conn.execute(f"COPY {table_name} FROM $df", {"df": table})

query = f"""
Expand All @@ -236,11 +258,15 @@ def main() -> int:
run_stats: list[tuple[float, float, float]] = []
for run_idx in range(1, args.query_runs + 1):
print(f"Running CPU-intensive Cypher query (run {run_idx})...")
elapsed, actual_count, actual_checksum, avg_cpu, max_cpu = measure_query_once(conn, query)
elapsed, actual_count, actual_checksum, avg_cpu, max_cpu = measure_query_once(
conn, query
)
print(f"Query time: {elapsed:.2f}s")
print(f"CPU usage during query: avg={avg_cpu:.1f}% max={max_cpu:.1f}%")
print(f"Expected cnt={expected_count:,}, actual cnt={actual_count:,}")
print(f"Expected checksum={expected_checksum:,}, actual checksum={actual_checksum:,}")
print(
f"Expected checksum={expected_checksum:,}, actual checksum={actual_checksum:,}"
)

if actual_count != expected_count or actual_checksum != expected_checksum:
if using_arrow_memory_table:
Expand All @@ -250,7 +276,8 @@ def main() -> int:
conn.close()
if temp_dir:
temp_dir.cleanup()
raise AssertionError("Query result validation failed.")
msg = "Query result validation failed."
raise AssertionError(msg)

run_stats.append((elapsed, avg_cpu, max_cpu))

Expand All @@ -261,7 +288,9 @@ def main() -> int:

# >100% indicates more than one core on ps-based accounting.
if max_cpu_overall <= 100.0:
print("Warning: max CPU did not exceed 100%; try larger target-gb/chunk-rows or more threads.")
print(
"Warning: max CPU did not exceed 100%; try larger target-gb/chunk-rows or more threads."
)
else:
print("Observed CPU > 100%, indicating multi-core usage.")

Expand Down
Loading