Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions graph_net/paddle/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_env(args):
paddle.set_flags({"FLAGS_cudnn_exhaustive_search": 1})


def get_hardward_name(args):
def get_hardware_name(args):
hardware = "unknown"
if test_compiler_util.is_gpu_device(args.device):
hardware = paddle.device.cuda.get_device_name(0)
Expand Down Expand Up @@ -149,7 +149,7 @@ def measure_performance(model_call, args, compiler, profile=False):
min_trials = int(100 / np.mean(warmup_e2e_times[1:]))
trials = max(args.trials, min_trials)

hardware_name = get_hardward_name(args)
hardware_name = get_hardware_name(args)
print(
f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {trials}",
file=sys.stderr,
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_single_model(args):
model.eval()

test_compiler_util.print_basic_config(
args, get_hardward_name(args), get_compile_framework_version(args)
args, get_hardware_name(args), get_compile_framework_version(args)
)

# Run on eager mode
Expand Down
2 changes: 1 addition & 1 deletion graph_net/paddle/test_reference_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_single_model(args):

test_compiler_util.print_basic_config(
args,
test_compiler.get_hardward_name(args),
test_compiler.get_hardware_name(args),
test_compiler.get_compile_framework_version(args),
)

Expand Down
2 changes: 1 addition & 1 deletion graph_net/paddle/test_target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_single_model(args):

test_compiler_util.print_basic_config(
args,
test_compiler.get_hardward_name(args),
test_compiler.get_hardware_name(args),
test_compiler.get_compile_framework_version(args),
)

Expand Down
6 changes: 2 additions & 4 deletions graph_net/torch/test_target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,12 @@ def test_single_model(args):
ref_dump = utils.get_output_path(args.reference_dir, args.model_path)
ref_out = torch.load(str(ref_dump))
ref_log = utils.get_log_path(args.reference_dir, args.model_path)
ref_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(ref_log)
ref_time_stats = test_compiler_util.parse_performance_stats(str(ref_log))

target_dump = utils.get_output_path(target_dir, args.model_path)
target_out = torch.load(str(target_dump))
target_log = utils.get_log_path(target_dir, args.model_path)
target_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(
target_log
)
target_time_stats = test_compiler_util.parse_performance_stats(str(target_log))

eval_backend_diff.compare_correctness(ref_out, target_out, eval_args)
test_compiler_util.print_times_and_speedup(args, ref_time_stats, target_time_stats)
Expand Down
62 changes: 62 additions & 0 deletions graph_net_bench/test_compiler_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import shutil
import base64
import numpy as np
from typing import Dict, Any
from dataclasses import dataclass
from contextlib import contextmanager

Expand Down Expand Up @@ -381,3 +382,64 @@ def convert_to_dict(config_str):
config = json.loads(config_str)
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
return config


def convert_to_base64(config_dict):
"""Convert a dict to base64 encoded JSON string."""
if config_dict is None:
return ""
config_str = json.dumps(config_dict)
return base64.b64encode(config_str.encode("utf-8")).decode("utf-8")


def parse_performance_stats(log_path: str) -> Dict[str, Any]:
"""Parse performance statistics from log file.

Args:
log_path: Path to the log file

Returns:
Dictionary containing time statistics

Raises:
FileNotFoundError: If log_path does not exist
ValueError: If performance data cannot be parsed
"""
if not os.path.isfile(log_path):
raise FileNotFoundError(f"Log file not found: {log_path}")

with open(log_path, "r", encoding="utf-8") as f:
lines = f.readlines()

# Search backwards for performance data
for line in reversed(lines):
if "[Performance][eager]" in line:
start = line.find("{")
end = line.rfind("}")
if start != -1 and end != -1:
try:
time_stats = json.loads(line[start : end + 1])
return time_stats
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse performance stats: {e}")

raise ValueError("No performance statistics found in log file")


def extract_log_content(log_path: str) -> str:
"""Extract and return the entire content of a log file.

Args:
log_path: Path to the log file

Returns:
String containing the log content

Raises:
FileNotFoundError: If log_path does not exist
"""
if not os.path.isfile(log_path):
raise FileNotFoundError(f"Log file not found: {log_path}")

with open(log_path, "r", encoding="utf-8") as f:
return f.read()
Loading