diff --git a/graph_net/paddle/test_compiler.py b/graph_net/paddle/test_compiler.py index 76c3d5610..8beea67fb 100644 --- a/graph_net/paddle/test_compiler.py +++ b/graph_net/paddle/test_compiler.py @@ -43,7 +43,7 @@ def init_env(args): paddle.set_flags({"FLAGS_cudnn_exhaustive_search": 1}) -def get_hardward_name(args): +def get_hardware_name(args): hardware = "unknown" if test_compiler_util.is_gpu_device(args.device): hardware = paddle.device.cuda.get_device_name(0) @@ -149,7 +149,7 @@ def measure_performance(model_call, args, compiler, profile=False): min_trials = int(100 / np.mean(warmup_e2e_times[1:])) trials = max(args.trials, min_trials) - hardware_name = get_hardward_name(args) + hardware_name = get_hardware_name(args) print( f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {trials}", file=sys.stderr, @@ -327,7 +327,7 @@ def test_single_model(args): model.eval() test_compiler_util.print_basic_config( - args, get_hardward_name(args), get_compile_framework_version(args) + args, get_hardware_name(args), get_compile_framework_version(args) ) # Run on eager mode diff --git a/graph_net/paddle/test_reference_device.py b/graph_net/paddle/test_reference_device.py index f1db9bc0f..4c7c60b5b 100644 --- a/graph_net/paddle/test_reference_device.py +++ b/graph_net/paddle/test_reference_device.py @@ -49,7 +49,7 @@ def test_single_model(args): test_compiler_util.print_basic_config( args, - test_compiler.get_hardward_name(args), + test_compiler.get_hardware_name(args), test_compiler.get_compile_framework_version(args), ) diff --git a/graph_net/paddle/test_target_device.py b/graph_net/paddle/test_target_device.py index 9697aea5d..08176680d 100644 --- a/graph_net/paddle/test_target_device.py +++ b/graph_net/paddle/test_target_device.py @@ -89,7 +89,7 @@ def test_single_model(args): test_compiler_util.print_basic_config( args, - test_compiler.get_hardward_name(args), + test_compiler.get_hardware_name(args), test_compiler.get_compile_framework_version(args), ) diff --git a/graph_net/torch/test_target_device.py b/graph_net/torch/test_target_device.py index ee46ceee6..88cc9a650 100644 --- a/graph_net/torch/test_target_device.py +++ b/graph_net/torch/test_target_device.py @@ -67,14 +67,12 @@ def test_single_model(args): ref_dump = utils.get_output_path(args.reference_dir, args.model_path) ref_out = torch.load(str(ref_dump)) ref_log = utils.get_log_path(args.reference_dir, args.model_path) - ref_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(ref_log) + ref_time_stats = test_compiler_util.parse_performance_stats(str(ref_log)) target_dump = utils.get_output_path(target_dir, args.model_path) target_out = torch.load(str(target_dump)) target_log = utils.get_log_path(target_dir, args.model_path) - target_time_stats = eval_backend_diff.parse_time_stats_from_reference_log( - target_log - ) + target_time_stats = test_compiler_util.parse_performance_stats(str(target_log)) eval_backend_diff.compare_correctness(ref_out, target_out, eval_args) test_compiler_util.print_times_and_speedup(args, ref_time_stats, target_time_stats) diff --git a/graph_net_bench/test_compiler_util.py b/graph_net_bench/test_compiler_util.py index 44ccc703e..a83f55994 100644 --- a/graph_net_bench/test_compiler_util.py +++ b/graph_net_bench/test_compiler_util.py @@ -7,6 +7,7 @@ import shutil import base64 import numpy as np +from typing import Dict, Any from dataclasses import dataclass from contextlib import contextmanager @@ -381,3 +382,64 @@ def convert_to_dict(config_str): config = json.loads(config_str) assert isinstance(config, dict), f"config should be a dict. {config_str=}" return config + + +def convert_to_base64(config_dict): + """Convert a dict to base64 encoded JSON string.""" + if config_dict is None: + return "" + config_str = json.dumps(config_dict) + return base64.b64encode(config_str.encode("utf-8")).decode("utf-8") + + +def parse_performance_stats(log_path: str) -> Dict[str, Any]: + """Parse performance statistics from log file. + + Args: + log_path: Path to the log file + + Returns: + Dictionary containing time statistics + + Raises: + FileNotFoundError: If log_path does not exist + ValueError: If performance data cannot be parsed + """ + if not os.path.isfile(log_path): + raise FileNotFoundError(f"Log file not found: {log_path}") + + with open(log_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + # Search backwards for performance data + for line in reversed(lines): + if "[Performance][eager]" in line: + start = line.find("{") + end = line.rfind("}") + if start != -1 and end != -1: + try: + time_stats = json.loads(line[start : end + 1]) + return time_stats + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse performance stats: {e}") + + raise ValueError("No performance statistics found in log file") + + +def extract_log_content(log_path: str) -> str: + """Extract and return the entire content of a log file. + + Args: + log_path: Path to the log file + + Returns: + String containing the log content + + Raises: + FileNotFoundError: If log_path does not exist + """ + if not os.path.isfile(log_path): + raise FileNotFoundError(f"Log file not found: {log_path}") + + with open(log_path, "r", encoding="utf-8") as f: + return f.read() diff --git a/graph_net_bench/torch/eval_backend_diff.py b/graph_net_bench/torch/eval_backend_diff.py index cfa171dc6..ce9f27b2f 100755 --- a/graph_net_bench/torch/eval_backend_diff.py +++ b/graph_net_bench/torch/eval_backend_diff.py @@ -1,247 +1,361 @@ -from . import utils +"""Backend Performance Difference Evaluation Script. + +Compares outputs and performance between reference and target compiler backends. +""" + import argparse -import torch -import sys import os -import os.path +import sys import traceback -import json import types -from graph_net_bench import test_compiler_util +from typing import Any, List, Optional, Tuple + +import torch + from graph_net_bench import path_utils -from .eval_backend_perf import eval_single_model_with_single_backend +from graph_net_bench import test_compiler_util +from .runner import RunnerConfig, RunResult, create_runner +_DEFAULT_REF_DIR = "/tmp/eval_perf_diff/reference" +_DEFAULT_TARGET_DIR = "/tmp/eval_perf_diff/target" -def compare_correctness(expected_out, compiled_out, args): - eager_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in expected_out - ] - compiled_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in compiled_out - ] - # datatype check +def _get_dtype_name(value: Any) -> str: + """Extract dtype name from tensor or type name from other objects.""" + if isinstance(value, torch.Tensor): + return str(value.dtype).replace("torch.", "") + return type(value).__name__ + + +def _extract_dtypes(outputs: List[Any]) -> List[str]: + """Extract dtype/type names from a list of outputs.""" + return [_get_dtype_name(x) for x in outputs] + + +def compare_correctness( + expected_out: List[torch.Tensor], + compiled_out: List[torch.Tensor], + args, +) -> None: + """Compare correctness between expected and compiled outputs. + + Args: + expected_out: List of expected output tensors. + compiled_out: List of compiled output tensors. + args: Arguments containing log_prompt and other settings. + """ + eager_dtypes = _extract_dtypes(expected_out) + compiled_dtypes = _extract_dtypes(compiled_out) + type_match = test_compiler_util.check_output_datatype( args, eager_dtypes, compiled_dtypes ) - - if type_match: - test_compiler_util.check_equal( - args, - expected_out, - compiled_out, - cmp_equal_func=get_cmp_equal, - ) - - test_compiler_util.check_allclose( - args, - expected_out, - compiled_out, - cmp_all_close_func=get_cmp_all_close, - cmp_max_diff_func=get_cmp_max_diff, - cmp_mean_diff_func=get_cmp_mean_diff, - ) + if not type_match: + return + + test_compiler_util.check_equal( + args, + expected_out, + compiled_out, + cmp_equal_func=get_cmp_equal, + ) + test_compiler_util.check_allclose( + args, + expected_out, + compiled_out, + cmp_all_close_func=get_cmp_all_close, + cmp_max_diff_func=get_cmp_max_diff, + cmp_mean_diff_func=get_cmp_mean_diff, + ) -def get_cmp_equal(expected_out, compiled_out): +def get_cmp_equal( + expected_out: List[torch.Tensor], compiled_out: List[torch.Tensor] +) -> str: + """Get space-separated string of equality check results (1=equal, 0=not).""" return " ".join( str(int(torch.equal(a, b))) for a, b in zip(expected_out, compiled_out) ) -def get_cmp_all_close(expected_out, compiled_out, atol, rtol): +def get_cmp_all_close( + expected_out: List[torch.Tensor], + compiled_out: List[torch.Tensor], + atol: float, + rtol: float, +) -> str: + """Get space-separated string of allclose check results.""" return " ".join( str(int(torch.allclose(a, b, atol=atol, rtol=rtol))) for a, b in zip(compiled_out, expected_out) ) -def get_cmp_max_diff(expected_out, compiled_out): +def _compute_abs_diff(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Compute absolute difference, converting to float for LongTensor compatibility.""" + return torch.abs(a.float() - b.float()) + + +def get_cmp_max_diff( + expected_out: List[torch.Tensor], compiled_out: List[torch.Tensor] +) -> str: + """Get space-separated string of max absolute differences.""" + return " ".join( + str(torch.max(_compute_abs_diff(a, b)).item()) + for a, b in zip(expected_out, compiled_out) + ) + + +def get_cmp_mean_diff( + expected_out: List[torch.Tensor], compiled_out: List[torch.Tensor] +) -> str: + """Get space-separated string of mean absolute differences.""" return " ".join( - # Transform to float to handle LongTensor output of some models, which cannnot be processed with torch.max(). - str(torch.max(torch.abs(a.float() - b.float())).item()) + str(torch.mean(_compute_abs_diff(a, b)).item()) for a, b in zip(expected_out, compiled_out) ) -def get_cmp_mean_diff(expected_out, compiled_out): +def _count_diff_elements( + a: torch.Tensor, b: torch.Tensor, atol: float, rtol: float +) -> int: + """Count number of differing elements between two tensors.""" + if a.is_floating_point() and b.is_floating_point(): + return torch.sum(~torch.isclose(a, b, atol=atol, rtol=rtol)).item() + return torch.sum(a != b).item() + + +def get_cmp_diff_count( + expected_out: List[torch.Tensor], + compiled_out: List[torch.Tensor], + atol: float, + rtol: float, +) -> str: + """Get space-separated string of element difference counts.""" return " ".join( - # To handle LongTensor - str(torch.mean(torch.abs(a.float() - b.float())).item()) + str(_count_diff_elements(a, b, atol, rtol)) for a, b in zip(expected_out, compiled_out) ) -def get_cmp_diff_count(expected_out, compiled_out, atol, rtol): - results = [] - for a, b in zip(expected_out, compiled_out): - # To handle LongTensor - if a.is_floating_point() and b.is_floating_point(): - diff_count = torch.sum(~torch.isclose(a, b, atol=atol, rtol=rtol)).item() - else: - diff_count = torch.sum(a != b).item() - results.append(str(diff_count)) - return " ".join(results) +def _has_model_file(path: str) -> bool: + """Check if directory contains model.py.""" + return os.path.exists(os.path.join(path, "model.py")) -def parse_time_stats_from_reference_log(log_path): - assert os.path.isfile( - log_path - ), f"{log_path} does not exist or is not a regular file." +def _get_model_paths_from_list( + model_path_list: str, model_path_prefix: str +) -> List[str]: + """Get model paths from a list file with prefix.""" + assert os.path.isdir(model_path_prefix), f"Not a directory: {model_path_prefix}" + assert os.path.isfile(model_path_list), f"Not a file: {model_path_list}" + + test_samples = test_compiler_util.get_allow_samples( + model_path_list, model_path_prefix + ) + return [ + os.path.join(model_path_prefix, rel_path) + for rel_path in test_samples + if _has_model_file(os.path.join(model_path_prefix, rel_path)) + ] - with open(log_path, "r", encoding="utf-8") as f: - lines = f.readlines() - for line in reversed(lines): - if "[Performance][eager]" in line: - start = line.find("{") - end = line.rfind("}") - time_stats = json.loads(line[start : end + 1]) - return time_stats +def _get_model_paths_from_dir( + model_path: str, model_path_list: Optional[str], model_path_prefix: Optional[str] +) -> List[str]: + """Get model paths by recursively scanning a directory.""" + assert os.path.isdir(model_path), f"Not a directory: {model_path}" + + test_samples = test_compiler_util.get_allow_samples( + model_path_list, model_path_prefix + ) + all_paths = path_utils.get_recursively_model_path(model_path) -def _get_model_paths(args, model_path_prefix, use_model_list): + if test_samples is None: + return list(all_paths) + return [p for p in all_paths if os.path.abspath(p) in test_samples] + + +def _get_model_paths( + args, model_path_prefix: Optional[str], use_model_list: bool +) -> List[str]: + """Get list of model paths based on configuration.""" if use_model_list: - assert os.path.isdir(model_path_prefix) and os.path.isfile(args.model_path_list) + return _get_model_paths_from_list(args.model_path_list, model_path_prefix) + return _get_model_paths_from_dir( + args.model_path, args.model_path_list, model_path_prefix + ) - test_samples = test_compiler_util.get_allow_samples( - args.model_path_list, model_path_prefix - ) - model_paths = [ - os.path.join(model_path_prefix, rel_model_path) - for rel_model_path in test_samples - if os.path.exists( - os.path.join(model_path_prefix, rel_model_path, "model.py") - ) - ] - else: - assert os.path.isdir(args.model_path) - - test_samples = test_compiler_util.get_allow_samples( - args.model_path_list, model_path_prefix - ) - model_paths = [ - model_path - for model_path in path_utils.get_recursively_model_path(args.model_path) - if test_samples is None or os.path.abspath(model_path) in test_samples - ] - return model_paths +def _create_model_args( + model_path: str, reference_config: str, target_config: str +) -> argparse.Namespace: + """Create namespace for single model evaluation.""" + return argparse.Namespace( + model_path=model_path, + model_path_list=None, + reference_config=reference_config, + target_config=target_config, + ) -def _create_model_args(model_path, reference_config, target_config): - args = argparse.Namespace() - args.model_path = model_path - args.model_path_list = None - args.reference_config = reference_config - args.target_config = target_config - return args +def _eval_single_model_safe(model_args: argparse.Namespace) -> bool: + """Evaluate single model with exception handling. + Returns: + True if evaluation succeeded, False otherwise. + """ + try: + eval_single_model(model_args) + return True + except KeyboardInterrupt: + print("KeyboardInterrupt") + sys.exit(1) + except Exception: + print("\n--- Full Traceback ---") + traceback.print_exc() + return False + + +def _print_evaluation_summary(total_count: int, failed_samples: List[str]) -> None: + """Print summary of multi-model evaluation.""" + print( + f"Totally {total_count} verified samples, failed {len(failed_samples)} samples.", + file=sys.stderr, + flush=True, + ) + for model_path in failed_samples: + print(f"- {model_path}", file=sys.stderr, flush=True) -def eval_multi_models(args, model_path_prefix=None, use_model_list=False): - module_name = os.path.splitext(os.path.basename(__file__))[0] +def eval_multi_models( + args, + model_path_prefix: Optional[str] = None, + use_model_list: bool = False, +) -> None: + """Evaluate multiple models and collect results.""" + module_name = os.path.splitext(os.path.basename(__file__))[0] model_paths = _get_model_paths(args, model_path_prefix, use_model_list) - failed_samples = [] + failed_samples: List[str] = [] + for sample_idx, model_path in enumerate(model_paths): print( f"[{sample_idx}] {module_name}, model_path: {model_path}", file=sys.stderr, flush=True, ) - - model_args = argparse.Namespace() - model_args.model_path = model_path - model_args.model_path_list = None - model_args.reference_config = args.reference_config - model_args.target_config = args.target_config - - try: - eval_single_model(model_args) - success = True - except KeyboardInterrupt: - print("KeyboardInterrupt") - sys.exit(1) - except Exception: - print("\n--- Full Traceback ---") - traceback.print_exc() - success = False - + model_args = _create_model_args( + model_path, args.reference_config, args.target_config + ) + success = _eval_single_model_safe(model_args) if not success: failed_samples.append(model_path) - print( - f"Totally {len(model_paths)} verified samples, failed {len(failed_samples)} samples.", - file=sys.stderr, - flush=True, + _print_evaluation_summary(len(model_paths), failed_samples) + + +def _parse_runner_configs(args) -> Tuple[RunnerConfig, RunnerConfig]: + """Parse reference and target runner configurations.""" + return ( + RunnerConfig.from_dict( + test_compiler_util.convert_to_dict(args.reference_config) + ), + RunnerConfig.from_dict(test_compiler_util.convert_to_dict(args.target_config)), ) - if failed_samples: - for model_path in failed_samples: - print(f"- {model_path}", file=sys.stderr, flush=True) -def eval_single_model(args): - ref_dir = "/tmp/eval_perf_diff/reference" - target_dir = "/tmp/eval_perf_diff/target" +def _log_runner_info(ref_config: RunnerConfig, target_config: RunnerConfig) -> None: + """Log runner type information.""" + for label, cfg in [("Reference", ref_config), ("Target", target_config)]: + print( + f"[eval_backend_diff] {label} runner: {cfg.strategy.runner_type.value}", + file=sys.stderr, + flush=True, + ) + + +def _run_and_validate( + runner, model_path: str, output_dir: str, label: str +) -> RunResult: + """Run model and validate result.""" + result = runner.run(model_path, output_dir) + if not result.success: + raise RuntimeError(f"{label} run failed: {result.error_message}") + return result + - ref_args = types.SimpleNamespace( - model_path=args.model_path, - output_path=ref_dir, - **test_compiler_util.convert_to_dict(args.reference_config), +def eval_single_model(args) -> None: + """Evaluate single model using Runner abstraction. + + Supports local, process, and remote execution via runner_type in config. + """ + ref_runner_config, target_runner_config = _parse_runner_configs(args) + _log_runner_info(ref_runner_config, target_runner_config) + + ref_runner = create_runner(ref_runner_config) + target_runner = create_runner(target_runner_config) + + ref_result = _run_and_validate( + ref_runner, args.model_path, _DEFAULT_REF_DIR, "Reference" ) - target_args = types.SimpleNamespace( - model_path=args.model_path, - output_path=target_dir, - **test_compiler_util.convert_to_dict(args.target_config), + target_result = _run_and_validate( + target_runner, args.model_path, _DEFAULT_TARGET_DIR, "Target" ) - eval_single_model_with_single_backend(ref_args) - eval_single_model_with_single_backend(target_args) + compare_results(ref_result, target_result, ref_runner_config) - # compare_perf_diff - # A - ref_dump_path = utils.get_output_path(ref_dir, args.model_path) - ref_out = torch.load(str(ref_dump_path)) - ref_log_path = utils.get_log_path(ref_dir, args.model_path) - ref_time_stats = parse_time_stats_from_reference_log(ref_log_path) +def compare_results( + ref_result: RunResult, target_result: RunResult, config: RunnerConfig +) -> None: + """Compare outputs and performance between reference and target results. - # B - target_dump_path = utils.get_output_path(target_dir, args.model_path) - target_out = torch.load(str(target_dump_path)) + Args: + ref_result: Result from reference runner. + target_result: Result from target runner. + config: Runner configuration for logging settings. + """ + if ref_result.outputs is None or target_result.outputs is None: + print("[Warning] Cannot compare: missing outputs", file=sys.stderr) + return - target_log_path = utils.get_log_path(target_dir, args.model_path) - target_time_stats = parse_time_stats_from_reference_log(target_log_path) + dummy_args = types.SimpleNamespace( + log_prompt=config.execution.log_prompt, + compiler=config.execution.compiler, + device=config.execution.device, + ) - compare_correctness(ref_out, target_out, ref_args) + compare_correctness(ref_result.outputs, target_result.outputs, dummy_args) test_compiler_util.print_times_and_speedup( - ref_args, ref_time_stats, target_time_stats + dummy_args, ref_result.time_stats, target_result.time_stats ) -def main(args): +def main(args: argparse.Namespace) -> None: + """Main entry point for backend difference evaluation. + + Args: + args: Parsed command-line arguments. + + Raises: + ValueError: If model_path is invalid. + """ ref_config = test_compiler_util.convert_to_dict(args.reference_config) model_path_prefix = ref_config.get("model_path_prefix") if args.model_path_list and model_path_prefix: eval_multi_models(args, model_path_prefix, use_model_list=True) - elif os.path.isdir(args.model_path): - if path_utils.is_single_model_dir(args.model_path): - eval_single_model(args) - else: - eval_multi_models(args, model_path_prefix, use_model_list=False) - else: + return + + if not os.path.isdir(args.model_path): raise ValueError(f"Invalid model path: {args.model_path}") + if path_utils.is_single_model_dir(args.model_path): + eval_single_model(args) + return + + eval_multi_models(args, model_path_prefix, use_model_list=False) + if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/graph_net_bench/torch/eval_backend_perf.py b/graph_net_bench/torch/eval_backend_perf.py index 5c8586f30..30bf2dacb 100644 --- a/graph_net_bench/torch/eval_backend_perf.py +++ b/graph_net_bench/torch/eval_backend_perf.py @@ -1,142 +1,288 @@ -from . import utils +"""Single Backend Performance Evaluation Script.""" + import argparse import importlib.util -import torch -from pathlib import Path -from typing import Type -import sys -import os -import traceback import json -import random -import numpy as np +import os import platform +import random +import sys +import traceback import types from contextlib import redirect_stdout, redirect_stderr -from graph_net_bench.torch.backend.graph_compiler_backend import GraphCompilerBackend +from pathlib import Path +from typing import Callable, Dict, Any, List, Tuple, Type, Optional + +import numpy as np +import torch + +from . import utils from graph_net_bench import test_compiler_util +from graph_net_bench.torch.backend.graph_compiler_backend import GraphCompilerBackend +_ARG_DEFAULTS: Dict[str, Any] = { + "model_path": None, + "output_path": None, + "seed": 123, + "compiler": "inductor", + "device": "cuda", + "op_lib": None, + "warmup": 3, + "trials": 5, + "log_prompt": "graph-net-bench-log", + "model_path_prefix": None, + "backend_config": None, +} -def register_op_lib(op_lib): - if op_lib == "flaggems": - import flag_gems - flag_gems.enable() - else: - pass +def register_op_lib(op_lib: Optional[str]) -> None: + """Register operator library if specified.""" + if op_lib != "flaggems": + return + import flag_gems + flag_gems.enable() -def set_seed(random_seed): + +def set_seed(random_seed: int) -> None: + """Set random seed for reproducibility across all frameworks.""" random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(random_seed) - torch.cuda.manual_seed_all(random_seed) + if not torch.cuda.is_available(): + return + torch.cuda.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) -def get_hardward_name(device): - hardware_name = "unknown" +def get_hardware_name(device: str) -> str: + """Get hardware name based on device type.""" if "cuda" in device: - hardware_name = torch.cuda.get_device_name(device) - elif device == "cpu": - hardware_name = platform.processor() - return hardware_name + return torch.cuda.get_device_name(device) + if device == "cpu": + return platform.processor() + return "unknown" + + +def get_compiler_version(compiler_name: str) -> str: + """Get version string for the given compiler. + Args: + compiler_name: Name of the compiler (e.g., 'inductor', 'tvm'). -def get_compiler_version(compiler): - if compiler in ["inductor", "nope", "unstable_to_stable"]: + Returns: + Version string or 'unknown' if not determinable. + """ + torch_based_compilers = {"inductor", "nope", "unstable_to_stable"} + if compiler_name in torch_based_compilers: return torch.__version__ - elif compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: - # Assuming compiler object has a version attribute - return f"{compiler.capitalize()} {compiler.version}" + # TODO: For external compilers, version detection would require runtime introspection + # which is not reliably available here. Return a placeholder. return "unknown" -def load_class_from_file( - model_path: str, class_name: str, device: str -) -> Type[torch.nn.Module]: - file_path = f"{model_path}/model.py" - file = Path(file_path).resolve() - module_name = file.stem - +def _read_and_modify_model_code(file_path: str, device: str) -> str: + """Read model file and modify code for target device.""" with open(file_path, "r", encoding="utf-8") as f: model_code = f.read() - model_code = utils.modify_code_by_device(model_code, device) + return utils.modify_code_by_device(model_code, device) + + +def _create_module_from_code( + module_name: str, code: str, file_path: Path +) -> types.ModuleType: + """Create a module by executing code.""" spec = importlib.util.spec_from_loader(module_name, loader=None) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module - compiled_code = compile(model_code, filename=file, mode="exec") + compiled_code = compile(code, filename=file_path, mode="exec") exec(compiled_code, module.__dict__) + return module - model_class = getattr(module, class_name, None) - setattr(model_class, "__graph_net_file_path__", file_path) - setattr(model_class, "__graph_net_device__", device) - return model_class +def load_class_from_file( + model_path: str, class_name: str, device: str +) -> Type[torch.nn.Module]: + """Dynamically load a model class from file. -def get_compiler_backend(args) -> GraphCompilerBackend: - """ - Dynamically load backend class based on args.compiler + Args: + model_path: Directory containing model.py. + class_name: Name of the class to load. + device: Target device for code modification. + + Returns: + The loaded model class with metadata attributes set. + + Raises: + AttributeError: If class_name not found in module. """ - compiler_name = args.compiler.lower() - module_name = f"graph_net_bench.torch.backend.{compiler_name}_backend" + file_path = f"{model_path}/model.py" + resolved_path = Path(file_path).resolve() + module_name = resolved_path.stem - try: - module = __import__(module_name, fromlist=[f"{compiler_name.title()}Backend"]) + model_code = _read_and_modify_model_code(file_path, device) + module = _create_module_from_code(module_name, model_code, resolved_path) - class_name = ( - f"{''.join(part.title() for part in compiler_name.split('_'))}Backend" - ) + model_class = getattr(module, class_name) + model_class.__graph_net_file_path__ = file_path + model_class.__graph_net_device__ = device + return model_class - backend_class = None - if hasattr(module, class_name): - backend_class = getattr(module, class_name) - else: - raise ImportError(f"No valid backend class found in {module_name}") - except ImportError as e: - raise ImportError(f"Failed to import backend module for '{compiler_name}': {e}") +def _build_backend_class_name(compiler_name: str) -> str: + """Convert compiler name to PascalCase backend class name.""" + return "".join(part.title() for part in compiler_name.split("_")) + "Backend" - backend_config = ( - test_compiler_util.convert_to_dict(args.backend_config) - if args.backend_config is not None - else {} - ) - return backend_class(backend_config) +def _load_backend_class(compiler_name: str) -> Type[GraphCompilerBackend]: + """Load backend class by compiler name.""" + module_name = f"graph_net_bench.torch.backend.{compiler_name}_backend" + class_name = _build_backend_class_name(compiler_name) + + module = __import__(module_name, fromlist=[class_name]) + if not hasattr(module, class_name): + raise ImportError( + f"No valid backend class '{class_name}' found in {module_name}" + ) + return getattr(module, class_name) -def get_model(args): - device = "xla" if args.compiler == "xla" else args.device - # device: Torch device object specifying the target device for model loading (e.g., 'cuda', 'cpu', 'xla') - model_class = load_class_from_file( - args.model_path, class_name="GraphModule", device=device - ) - model = model_class().to(torch.device(args.device)) - return model +def get_compiler_backend(args) -> GraphCompilerBackend: + """Dynamically load and instantiate backend class based on args.compiler.""" + backend_class = _load_backend_class(args.compiler.lower()) + backend_config = test_compiler_util.convert_to_dict(args.backend_config) or {} + return backend_class(backend_config) -def get_input_dict(args): - inputs_params = utils.load_converted_from_text(f"{args.model_path}") - params = inputs_params["weight_info"] +def get_model(args) -> torch.nn.Module: + """Load and prepare model for evaluation.""" + load_device = "xla" if args.compiler == "xla" else args.device + model_class = load_class_from_file(args.model_path, "GraphModule", load_device) + return model_class().to(torch.device(args.device)) + + +def _update_tensor_device(params: Dict[str, Any], device: str) -> None: + """Update device info in tensor metadata in-place.""" for tensor_meta in params.values(): if "device" in tensor_meta["info"]: - tensor_meta["info"]["device"] = args.device + tensor_meta["info"]["device"] = device + + +def get_input_dict(args) -> Dict[str, torch.Tensor]: + """Load and prepare input tensors for model evaluation. + + Args: + args: Arguments containing model_path and device settings. + + Returns: + Dictionary mapping parameter names to tensors on target device. + """ + inputs_params = utils.load_converted_from_text(args.model_path) + params = inputs_params["weight_info"] + _update_tensor_device(params, args.device) + + target_device = torch.device(args.device) + return {k: utils.replay_tensor(v).to(target_device) for k, v in params.items()} + + +def _run_warmup(model_call: Callable, warmup_count: int, sync_fn: Callable) -> None: + """Execute warmup runs.""" + for _ in range(warmup_count): + model_call() + sync_fn() + + +def _measure_single_trial_cuda( + model_call: Callable, sync_fn: Callable +) -> Tuple[float, float]: + """Measure a single trial on CUDA device. + + Returns: + Tuple of (e2e_time_ms, gpu_time_ms). + """ + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + duration_box = test_compiler_util.DurationBox(-1) + + with test_compiler_util.naive_timer(duration_box, sync_fn): + start_event.record() + model_call() + end_event.record() + sync_fn() + + gpu_time_ms = start_event.elapsed_time(end_event) + return duration_box.value, gpu_time_ms + + +def _measure_single_trial_cpu(model_call: Callable, sync_fn: Callable) -> float: + """Measure a single trial on CPU or other devices. + + Returns: + End-to-end time in milliseconds. + """ + duration_box = test_compiler_util.DurationBox(-1) + with test_compiler_util.naive_timer(duration_box, sync_fn): + model_call() + return duration_box.value + + +def _run_cuda_trials( + model_call: Callable, trials: int, sync_fn: Callable +) -> Dict[str, Any]: + """Run multiple timing trials on CUDA device.""" + torch.cuda.empty_cache() + e2e_times: List[float] = [] + gpu_times: List[float] = [] + + for i in range(trials): + e2e_time, gpu_time = _measure_single_trial_cuda(model_call, sync_fn) + e2e_times.append(e2e_time) + gpu_times.append(gpu_time) + print( + f"Trial {i + 1}: e2e={e2e_time:.5f} ms, gpu={gpu_time:.5f} ms", + file=sys.stderr, + flush=True, + ) + return { - k: utils.replay_tensor(v).to(torch.device(args.device)) - for k, v in params.items() + "e2e": test_compiler_util.get_timing_stats(e2e_times), + "gpu": test_compiler_util.get_timing_stats(gpu_times), } -def measure_performance(model_call, args, compiler): - stats = {} - outs = model_call() +def _run_cpu_trials( + model_call: Callable, trials: int, sync_fn: Callable +) -> Dict[str, Any]: + """Run multiple timing trials on CPU or other devices.""" + e2e_times: List[float] = [] + + for i in range(trials): + e2e_time = _measure_single_trial_cpu(model_call, sync_fn) + e2e_times.append(e2e_time) + print( + f"Trial {i + 1}: e2e={e2e_time:.5f} ms", + file=sys.stderr, + flush=True, + ) + + return {"e2e": test_compiler_util.get_timing_stats(e2e_times)} - # Warmup runs - for _ in range(args.warmup): - model_call() - compiler.synchronize() + +def measure_performance( + model_call: Callable, args, compiler +) -> Tuple[Any, Dict[str, Any]]: + """Measure model inference performance. + + Args: + model_call: Callable that executes the model. + args: Arguments containing device, warmup, and trials settings. + compiler: Compiler backend with synchronize method. + + Returns: + Tuple of (model_outputs, timing_stats). + """ + outs = model_call() + _run_warmup(model_call, args.warmup, compiler.synchronize) print( f"[Profiling] Warm up {args.warmup}, Trials {args.trials}", @@ -144,58 +290,83 @@ def measure_performance(model_call, args, compiler): flush=True, ) - if "cuda" in args.device: - torch.cuda.empty_cache() - e2e_times = [] - gpu_times = [] - - for i in range(args.trials): - # End-to-end timing (naive_timer) - duration_box = test_compiler_util.DurationBox(-1) - with test_compiler_util.naive_timer(duration_box, compiler.synchronize): - # GPU-only timing (CUDA Events) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - model_call() - - end_event.record() - compiler.synchronize() - - gpu_time_ms = start_event.elapsed_time(end_event) - e2e_times.append(duration_box.value) - gpu_times.append(gpu_time_ms) - print( - f"Trial {i + 1}: e2e={duration_box.value:.5f} ms, gpu={gpu_time_ms:.5f} ms", - file=sys.stderr, - flush=True, - ) - - stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) - stats["gpu"] = test_compiler_util.get_timing_stats(gpu_times) - - else: # CPU or other devices - e2e_times = [] - for i in range(args.trials): - duration_box = test_compiler_util.DurationBox(-1) - with test_compiler_util.naive_timer(duration_box, compiler.synchronize): - model_call() - print( - f"Trial {i + 1}: e2e={duration_box.value:.5f} ms", - file=sys.stderr, - flush=True, - ) - e2e_times.append(duration_box.value) - stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) + is_cuda = "cuda" in args.device + if is_cuda: + stats = _run_cuda_trials(model_call, args.trials, compiler.synchronize) + else: + stats = _run_cpu_trials(model_call, args.trials, compiler.synchronize) return outs, stats -def eval_single_model_with_single_backend(args): +def _compile_and_benchmark( + args, compiler: GraphCompilerBackend, model: torch.nn.Module, input_dict: Dict +) -> Tuple[bool, Any, Dict[str, Any]]: + """Compile model and run performance benchmark. + + Returns: + Tuple of (success, outputs, time_stats). + """ + try: + compiled_model = compiler(model) + + def model_call(): + return compiled_model(**input_dict) + + outputs, time_stats = measure_performance(model_call, args, compiler) + return True, outputs, time_stats + except Exception as e: + print( + f"Run model failed: {str(e)}\n{traceback.format_exc()}", + file=sys.stderr, + flush=True, + ) + return False, None, {} + + +def _run_evaluation_core(args) -> Tuple[bool, Any, Dict[str, Any]]: + """Core evaluation logic: load model, compile, and benchmark.""" + compiler = get_compiler_backend(args) + input_dict = get_input_dict(args) + model = get_model(args) + model.eval() + + test_compiler_util.print_config( + args, + get_hardware_name(args.device), + get_compiler_version(args.compiler), + ) + + return _compile_and_benchmark(args, compiler, model, input_dict) + + +def _finalize_evaluation( + args, + success: bool, + outputs: Any, + time_stats: Dict[str, Any], + output_dump_path: Path, +) -> None: + """Finalize evaluation: save outputs and print status.""" + test_compiler_util.print_running_status(args, success) + if success: + torch.save(outputs, str(output_dump_path)) + test_compiler_util.print_with_log_prompt( + "[Performance][eager]:", json.dumps(time_stats), args.log_prompt + ) + + +def _print_log_file(log_path: Path) -> None: + """Read and print log file content to stderr.""" + print(Path(log_path).read_text(encoding="utf-8"), file=sys.stderr, flush=True) + + +def eval_single_model_with_single_backend(args) -> None: + """Evaluate a single model with a single compiler backend.""" check_and_complete_args(args) set_seed(args.seed) os.makedirs(args.output_path, exist_ok=True) + log_path = utils.get_log_path(args.output_path, args.model_path) output_dump_path = utils.get_output_path(args.output_path, args.model_path) print(f"Log path: {log_path}", file=sys.stderr, flush=True) @@ -203,66 +374,19 @@ def eval_single_model_with_single_backend(args): with open(log_path, "w", encoding="utf-8") as log_f: with redirect_stdout(log_f), redirect_stderr(log_f): - compiler = get_compiler_backend(args) - - input_dict = get_input_dict(args) - model = get_model(args) - model.eval() - - test_compiler_util.print_config( - args, - get_hardward_name(args.device), - get_compiler_version(args.compiler), - ) - - success = False - time_stats = {} - try: - compiled_model = compiler(model) - - def model_call(): - return compiled_model(**input_dict) - - outputs, time_stats = measure_performance(model_call, args, compiler) - success = True - except Exception as e: - print( - f"Run model failed: {str(e)}\n{traceback.format_exc()}", - file=sys.stderr, - flush=True, - ) - - test_compiler_util.print_running_status(args, success) - if success: - torch.save(outputs, str(output_dump_path)) - test_compiler_util.print_with_log_prompt( - "[Performance][eager]:", json.dumps(time_stats), args.log_prompt - ) - - with open(log_path, "r", encoding="utf-8") as f: - content = f.read() - print(content, file=sys.stderr, flush=True) - - -def check_and_complete_args(args): - """ - Ensure all required arguments are present with default values if missing - """ - defaults = { - "model_path": None, # Model path - "output_path": None, # Log and output directory - "seed": 123, # Random seed - "compiler": "inductor", # Compiler name - "device": "cuda", # Device for testing the compiler (e.g., 'cpu' or 'cuda') - "op_lib": None, # Operator library - "warmup": 3, # Number of warmup steps - "trials": 5, # Number of timing trials - "log_prompt": "graph-net-bench-log", # Log prompt for performance log filtering - "model_path_prefix": None, # Prefix path to model path in args.model-path - "backend_config": None, # backend configuration json - } + success, outputs, time_stats = _run_evaluation_core(args) + _finalize_evaluation(args, success, outputs, time_stats, output_dump_path) + + _print_log_file(log_path) - for key, default in defaults.items(): + +def check_and_complete_args(args) -> None: + """Ensure all required arguments are present with default values if missing. + + Args: + args: Namespace object to be validated and completed in-place. + """ + for key, default in _ARG_DEFAULTS.items(): if not hasattr(args, key): setattr(args, key, default) diff --git a/graph_net_bench/torch/runner/__init__.py b/graph_net_bench/torch/runner/__init__.py new file mode 100644 index 000000000..643f28f91 --- /dev/null +++ b/graph_net_bench/torch/runner/__init__.py @@ -0,0 +1,14 @@ +from .base_runner import BaseRunner, RunResult, RunnerConfig, create_runner +from .local_runner import LocalRunner +from .process_runner import ProcessRunner +from .remote_runner import RemoteRunner + +__all__ = [ + "BaseRunner", + "RunResult", + "RunnerConfig", + "LocalRunner", + "ProcessRunner", + "RemoteRunner", + "create_runner", +] diff --git a/graph_net_bench/torch/runner/base_runner.py b/graph_net_bench/torch/runner/base_runner.py new file mode 100644 index 000000000..809f5f281 --- /dev/null +++ b/graph_net_bench/torch/runner/base_runner.py @@ -0,0 +1,152 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple +from pathlib import Path +from enum import Enum + + +class RunnerType(Enum): + LOCAL = "local" + PROCESS = "process" + REMOTE = "remote" + + +@dataclass +class ExecutionConfig: + """Configuration specific to model execution.""" + + compiler: str = "inductor" + device: str = "cuda" + op_lib: str = "default" + warmup: int = 5 + trials: int = 10 + seed: int = 123 + log_prompt: str = "graph-net-runner-log" + backend_config: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + k: v + for k, v in self.__dict__.items() + if v is not None and not k.startswith("_") + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "ExecutionConfig": + return cls(**{k: v for k, v in d.items() if hasattr(cls, k)}) + + +@dataclass +class RunnerStrategyConfig: + """Configuration for runner strategy selection.""" + + runner_type: RunnerType = RunnerType.LOCAL + remote_machine: str = "localhost" + remote_port: int = 50052 + subprocess_timeout: int = 600 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "RunnerStrategyConfig": + runner_type_str = d.get("runner_type", "local") + try: + runner_type = RunnerType(runner_type_str.lower()) + except ValueError: + runner_type = RunnerType.LOCAL + + return cls( + runner_type=runner_type, + remote_machine=d.get("machine", "localhost"), + remote_port=d.get("port", 50052), + subprocess_timeout=d.get("subprocess_timeout", 600), + ) + + +@dataclass +class RunnerConfig: + """Unified configuration combining execution and strategy configs.""" + + execution: ExecutionConfig + strategy: RunnerStrategyConfig + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "RunnerConfig": + execution_config = ExecutionConfig.from_dict(d) + strategy_config = RunnerStrategyConfig.from_dict(d) + return cls(execution=execution_config, strategy=strategy_config) + + def to_dict(self) -> Dict[str, Any]: + return { + **self.execution.to_dict(), + "runner_type": self.strategy.runner_type.value, + "machine": self.strategy.remote_machine, + "port": self.strategy.remote_port, + "subprocess_timeout": self.strategy.subprocess_timeout, + } + + +@dataclass +class RunResult: + """Result of a single backend run.""" + + success: bool = False + outputs: Optional[Tuple[Any, ...]] = None + time_stats: Dict[str, Any] = field(default_factory=dict) + log_content: str = "" + error_message: str = "" + + output_path: Optional[Path] = None + log_path: Optional[Path] = None + + +class BaseRunner(ABC): + """Abstract base class for model execution runners.""" + + def __init__(self, config: RunnerConfig): + self.config = config + + @abstractmethod + def run(self, model_path: str, output_dir: str) -> RunResult: + """ + Execute model evaluation and return results. + + Args: + model_path: Path to model directory (containing model.py, graph_net.json, etc.) + output_dir: Directory to store outputs and logs + + Returns: + RunResult containing outputs, timing stats, and logs + """ + pass + + def _get_output_path(self, output_dir: str, model_path: str) -> Path: + from graph_net_bench.torch import utils + + return Path(utils.get_output_path(output_dir, model_path)) + + def _get_log_path(self, output_dir: str, model_path: str) -> Path: + from graph_net_bench.torch import utils + + return Path(utils.get_log_path(output_dir, model_path)) + + +def _get_runner_class(runner_type: RunnerType) -> type: + """Get runner class by type with lazy imports.""" + if runner_type == RunnerType.LOCAL: + from .local_runner import LocalRunner + + return LocalRunner + if runner_type == RunnerType.PROCESS: + from .process_runner import ProcessRunner + + return ProcessRunner + if runner_type == RunnerType.REMOTE: + from .remote_runner import RemoteRunner + + return RemoteRunner + raise ValueError(f"Unknown runner_type: {runner_type}") + + +def create_runner(config: RunnerConfig) -> BaseRunner: + """Factory function to create appropriate runner based on config.""" + runner_class = _get_runner_class(config.strategy.runner_type) + return runner_class(config) diff --git a/graph_net_bench/torch/runner/local_runner.py b/graph_net_bench/torch/runner/local_runner.py new file mode 100644 index 000000000..1f21dddc0 --- /dev/null +++ b/graph_net_bench/torch/runner/local_runner.py @@ -0,0 +1,144 @@ +"""Local runner for in-process model evaluation.""" + +import json +import os +import sys +import traceback +import types +from io import StringIO +from contextlib import redirect_stdout, redirect_stderr +from pathlib import Path +from typing import Any + +import torch + +from .base_runner import BaseRunner, RunResult, RunnerConfig + + +def _write_log_file(log_path: Path, content: str) -> None: + """Write log content to file.""" + with open(log_path, "w", encoding="utf-8") as f: + f.write(content) + + +def _create_eval_args( + model_path: str, output_dir: str, config: RunnerConfig +) -> types.SimpleNamespace: + """Create evaluation arguments from config.""" + return types.SimpleNamespace( + model_path=model_path, + output_path=output_dir, + seed=config.execution.seed, + compiler=config.execution.compiler, + device=config.execution.device, + op_lib=config.execution.op_lib, + warmup=config.execution.warmup, + trials=config.execution.trials, + log_prompt=config.execution.log_prompt, + backend_config=config.execution.backend_config, + ) + + +class LocalRunner(BaseRunner): + """Execute model evaluation in the current process.""" + + def run(self, model_path: str, output_dir: str) -> RunResult: + os.makedirs(output_dir, exist_ok=True) + + log_path = self._get_log_path(output_dir, model_path) + output_path = self._get_output_path(output_dir, model_path) + eval_args = _create_eval_args(model_path, output_dir, self.config) + + log_buffer = StringIO() + result = RunResult(output_path=output_path, log_path=log_path) + + self._execute_with_logging(eval_args, result, log_buffer) + self._finalize_result(result, log_buffer, log_path) + + return result + + def _execute_with_logging( + self, + eval_args: types.SimpleNamespace, + result: RunResult, + log_buffer: StringIO, + ) -> None: + """Execute evaluation with output redirection.""" + from graph_net_bench.torch import eval_backend_perf + + try: + eval_backend_perf.register_op_lib(self.config.execution.op_lib) + eval_backend_perf.set_seed(self.config.execution.seed) + with redirect_stdout(log_buffer), redirect_stderr(log_buffer): + self._run_evaluation(eval_args, result) + except Exception as e: + result.success = False + result.error_message = f"{str(e)}\n{traceback.format_exc()}" + log_buffer.write(f"\n[ERROR] {result.error_message}\n") + + def _finalize_result( + self, result: RunResult, log_buffer: StringIO, log_path: Path + ) -> None: + """Finalize result: save log and print to stderr.""" + result.log_content = log_buffer.getvalue() + _write_log_file(log_path, result.log_content) + print(result.log_content, file=sys.stderr, flush=True) + + def _run_evaluation(self, args: types.SimpleNamespace, result: RunResult) -> None: + """Run model evaluation and populate result.""" + from graph_net_bench.torch import eval_backend_perf + + compiler, model, input_dict = self._prepare_model(args) + self._log_config(args) + + compiled_model = compiler(model) + + def model_call(): + return compiled_model(**input_dict) + + outputs, time_stats = eval_backend_perf.measure_performance( + model_call, args, compiler + ) + + self._populate_result(result, outputs, time_stats) + self._log_completion(args, time_stats) + + def _prepare_model(self, args: types.SimpleNamespace) -> tuple: + """Prepare compiler, model, and inputs.""" + from graph_net_bench.torch import eval_backend_perf + + compiler = eval_backend_perf.get_compiler_backend(args) + input_dict = eval_backend_perf.get_input_dict(args) + model = eval_backend_perf.get_model(args) + model.eval() + return compiler, model, input_dict + + def _log_config(self, args: types.SimpleNamespace) -> None: + """Log configuration information.""" + from graph_net_bench.torch import eval_backend_perf + from graph_net_bench import test_compiler_util + + test_compiler_util.print_config( + args, + eval_backend_perf.get_hardware_name(args.device), + eval_backend_perf.get_compiler_version(args.compiler), + ) + + def _populate_result( + self, result: RunResult, outputs: Any, time_stats: dict + ) -> None: + """Populate result with outputs and stats.""" + result.success = True + result.outputs = outputs + result.time_stats = time_stats + if result.output_path: + torch.save(outputs, str(result.output_path)) + + def _log_completion(self, args: types.SimpleNamespace, time_stats: dict) -> None: + """Log completion status and performance stats.""" + from graph_net_bench import test_compiler_util + + test_compiler_util.print_running_status(args, True) + test_compiler_util.print_with_log_prompt( + "[Performance][eager]:", json.dumps(time_stats), args.log_prompt + ) diff --git a/graph_net_bench/torch/runner/process_runner.py b/graph_net_bench/torch/runner/process_runner.py new file mode 100644 index 000000000..1a48fec8e --- /dev/null +++ b/graph_net_bench/torch/runner/process_runner.py @@ -0,0 +1,131 @@ +"""Process runner for subprocess-based model evaluation.""" + +import os +import subprocess +import sys +from pathlib import Path +from typing import Dict + +import torch + +from .base_runner import BaseRunner, RunResult + + +def _get_env_with_pythonpath() -> Dict[str, str]: + """Get environment with PYTHONPATH set to repo root.""" + env = os.environ.copy() + repo_root = Path(__file__).resolve().parents[3] + env["PYTHONPATH"] = f"{repo_root}:{env.get('PYTHONPATH', '')}" + return env + + +class ProcessRunner(BaseRunner): + """Execute model evaluation in a separate subprocess on the local machine.""" + + def run(self, model_path: str, output_dir: str) -> RunResult: + os.makedirs(output_dir, exist_ok=True) + + result = RunResult( + output_path=self._get_output_path(output_dir, model_path), + log_path=self._get_log_path(output_dir, model_path), + ) + + cmd = self._build_command(model_path, output_dir) + print(f"[ProcessRunner] Executing: {cmd}", file=sys.stderr, flush=True) + + self._execute_subprocess(cmd, result, output_dir, model_path) + print(result.log_content, file=sys.stderr, flush=True) + + return result + + def _execute_subprocess( + self, cmd: str, result: RunResult, output_dir: str, model_path: str + ) -> None: + """Execute subprocess and handle results.""" + try: + proc = self._run_process(cmd) + result.log_content = proc.stderr or "" + self._handle_process_result(proc, result, output_dir, model_path) + except subprocess.TimeoutExpired as e: + result.success = False + result.error_message = f"Process timed out: {e}" + except Exception as e: + result.success = False + result.error_message = f"Process execution failed: {e}" + + def _run_process(self, cmd: str) -> subprocess.CompletedProcess: + """Run subprocess with configured timeout.""" + return subprocess.run( + cmd, + shell=True, + env=_get_env_with_pythonpath(), + capture_output=True, + text=True, + timeout=self.config.strategy.subprocess_timeout, + ) + + def _handle_process_result( + self, + proc: subprocess.CompletedProcess, + result: RunResult, + output_dir: str, + model_path: str, + ) -> None: + """Handle subprocess completion result.""" + if proc.returncode != 0: + result.success = False + result.error_message = ( + f"Process exited with code {proc.returncode}\n" + f"stdout: {proc.stdout}\n" + f"stderr: {proc.stderr}" + ) + return + result.success = True + self._parse_result(result, output_dir, model_path) + + def _build_command(self, model_path: str, output_dir: str) -> str: + """Build subprocess command string.""" + from graph_net_bench import test_compiler_util + + config_str = test_compiler_util.convert_to_base64(self.config.to_dict()) + cmd_parts = [ + sys.executable, + "-m", + "graph_net_bench.torch.eval_backend_perf", + "--model-path", + model_path, + "--output-path", + output_dir, + "--config", + config_str, + ] + return " ".join(cmd_parts) + + def _parse_result( + self, result: RunResult, output_dir: str, model_path: str + ) -> None: + """Parse outputs and logs from subprocess result.""" + self._load_outputs(result) + self._parse_log(result) + + def _load_outputs(self, result: RunResult) -> None: + """Load model outputs from file.""" + if not result.output_path or not result.output_path.exists(): + return + try: + result.outputs = torch.load(str(result.output_path)) + except Exception as e: + result.error_message += f"\nFailed to load outputs: {e}" + + def _parse_log(self, result: RunResult) -> None: + """Parse log file for content and timing stats.""" + if not result.log_path or not result.log_path.exists(): + return + from graph_net_bench import test_compiler_util + + try: + log_path_str = str(result.log_path) + result.log_content = test_compiler_util.extract_log_content(log_path_str) + result.time_stats = test_compiler_util.parse_performance_stats(log_path_str) + except Exception as e: + result.error_message += f"\nFailed to parse log: {e}" diff --git a/graph_net_bench/torch/runner/remote_runner.py b/graph_net_bench/torch/runner/remote_runner.py new file mode 100644 index 000000000..c7c371278 --- /dev/null +++ b/graph_net_bench/torch/runner/remote_runner.py @@ -0,0 +1,181 @@ +"""Remote runner for gRPC-based model evaluation.""" + +import os +import sys +import traceback +from pathlib import Path +from typing import Dict, Optional + +import torch + +from .base_runner import BaseRunner, RunResult + + +def _find_file_by_extension( + files_dict: Dict[str, bytes], expected_name: Optional[str], extension: str +) -> Optional[str]: + """Find file in dict by expected name or by extension if only one exists.""" + if expected_name and expected_name in files_dict: + return expected_name + available = sorted(k for k in files_dict.keys() if k.endswith(extension)) + if len(available) == 1: + return available[0] + return None + + +def _save_bytes_to_file(path: Path, content: bytes) -> None: + """Save bytes content to file.""" + with open(path, "wb") as f: + f.write(content) + + +class RemoteRunner(BaseRunner): + """Execute model evaluation on a remote machine via gRPC.""" + + def run(self, model_path: str, output_dir: str) -> RunResult: + os.makedirs(output_dir, exist_ok=True) + + result = RunResult( + output_path=self._get_output_path(output_dir, model_path), + log_path=self._get_log_path(output_dir, model_path), + ) + + self._execute_remote(model_path, result) + return result + + def _execute_remote(self, model_path: str, result: RunResult) -> None: + """Execute model on remote machine.""" + from graph_net_rpc.sample_remote_executor import SampleRemoteExecutor + + executor = SampleRemoteExecutor( + machine=self.config.strategy.remote_machine, + port=self.config.strategy.remote_port, + ) + + try: + self._log_execution_start() + rpc_cmd = self._build_rpc_command() + print(f"[RemoteRunner] rpc_cmd: {rpc_cmd}", file=sys.stderr, flush=True) + + files_dict = executor.execute(model_path, rpc_cmd) + self._process_remote_output(result, files_dict) + result.success = True + except Exception as e: + result.success = False + result.error_message = ( + f"Remote execution failed: {e}\n{traceback.format_exc()}" + ) + print(result.error_message, file=sys.stderr, flush=True) + finally: + executor.close() + + def _log_execution_start(self) -> None: + """Log remote execution start.""" + machine = self.config.strategy.remote_machine + port = self.config.strategy.remote_port + print( + f"[RemoteRunner] Sending to {machine}:{port}", file=sys.stderr, flush=True + ) + + def _build_rpc_command(self) -> str: + """Build remote execution command string.""" + exec_cfg = self.config.execution + cmd_parts = [ + "python3 -m graph_net.torch.test_reference_device", + '--model-path "$INPUT_WORKSPACE"', + '--reference-dir "$OUTPUT_WORKSPACE"', + f"--compiler {exec_cfg.compiler}", + f"--device {exec_cfg.device}", + f"--op-lib {exec_cfg.op_lib}", + f"--warmup {exec_cfg.warmup}", + f"--trials {exec_cfg.trials}", + f"--seed {exec_cfg.seed}", + ] + if exec_cfg.log_prompt: + cmd_parts.append(f"--log-prompt {exec_cfg.log_prompt}") + if exec_cfg.backend_config: + cmd_parts.append(f"--config {exec_cfg.backend_config}") + return " ".join(cmd_parts) + + def _process_remote_output( + self, result: RunResult, files_dict: Dict[str, bytes] + ) -> None: + """Process files received from remote execution.""" + self._process_log_file(result, files_dict) + self._process_output_file(result, files_dict) + + def _process_log_file( + self, result: RunResult, files_dict: Dict[str, bytes] + ) -> None: + """Process log file from remote output.""" + expected_name = result.log_path.name if result.log_path else None + log_filename = _find_file_by_extension(files_dict, expected_name, ".log") + + if not log_filename: + available = [k for k in files_dict.keys() if k.endswith(".log")] + print( + f"Warning: log not found. expected={expected_name}, available={available}", + file=sys.stderr, + ) + return + + log_bytes = files_dict[log_filename] + self._save_and_parse_log(result, log_bytes) + + def _save_and_parse_log(self, result: RunResult, log_bytes: bytes) -> None: + """Save log file and parse timing stats.""" + + if result.log_path: + _save_bytes_to_file(result.log_path, log_bytes) + + result.log_content = self._decode_log_content(log_bytes) + print(result.log_content, file=sys.stderr, flush=True) + + self._parse_time_stats(result) + + def _decode_log_content(self, log_bytes: bytes) -> str: + """Decode log bytes to string.""" + try: + return log_bytes.decode("utf-8") + except Exception: + return f"[Binary log, {len(log_bytes)} bytes]" + + def _parse_time_stats(self, result: RunResult) -> None: + """Parse performance stats from log file.""" + if not result.log_path: + return + from graph_net_bench import test_compiler_util + + try: + result.time_stats = test_compiler_util.parse_performance_stats( + str(result.log_path) + ) + except Exception as e: + print(f"Warning: Failed to parse time stats: {e}", file=sys.stderr) + + def _process_output_file( + self, result: RunResult, files_dict: Dict[str, bytes] + ) -> None: + """Process output .pth file from remote output.""" + expected_name = result.output_path.name if result.output_path else None + pth_filename = _find_file_by_extension(files_dict, expected_name, ".pth") + + if not pth_filename: + available = [k for k in files_dict.keys() if k.endswith(".pth")] + print( + f"Warning: output not found. expected={expected_name}, available={available}", + file=sys.stderr, + ) + return + + pth_bytes = files_dict[pth_filename] + self._save_and_load_outputs(result, pth_bytes) + + def _save_and_load_outputs(self, result: RunResult, pth_bytes: bytes) -> None: + """Save output file and load tensors.""" + if result.output_path: + _save_bytes_to_file(result.output_path, pth_bytes) + try: + result.outputs = torch.load(str(result.output_path)) + except Exception as e: + print(f"Warning: Failed to load outputs: {e}", file=sys.stderr) diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index 8ee670fd2..52e027c65 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -58,7 +58,7 @@ def set_seed(random_seed): torch.cuda.manual_seed_all(random_seed) -def get_hardward_name(args): +def get_hardware_name(args): hardware_name = "unknown" if "cuda" in args.device: hardware_name = torch.cuda.get_device_name(args.device) @@ -146,7 +146,7 @@ def measure_performance(model_call, args, compiler): model_call() compiler.synchronize() - hardware_name = get_hardward_name(args) + hardware_name = get_hardware_name(args) print( f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}", file=sys.stderr, @@ -214,7 +214,7 @@ def test_single_model(args): "[Processing]", model_path, args.log_prompt ) test_compiler_util.print_basic_config( - args, get_hardward_name(args), get_compile_framework_version(args) + args, get_hardware_name(args), get_compile_framework_version(args) ) runtime_seed = 1024 diff --git a/test/eval_device_diff_test.sh b/test/eval_device_diff_test.sh new file mode 100755 index 000000000..6840b53a7 --- /dev/null +++ b/test/eval_device_diff_test.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +AI4C_ROOT=$(python3 -c "import graph_net_bench; import os; print(os.path.dirname(os.path.dirname(graph_net_bench.__file__)))") +OUTPUT_PATH=/tmp/workspace_eval_device_diff_test + +mkdir -p "$OUTPUT_PATH" +model_list="$AI4C_ROOT/test/workspace_eval_backend_diff/sample_list.txt" + +# Default remote server settings (can be overridden by environment variables) +REMOTE_MACHINE="${REMOTE_MACHINE:-localhost}" +REMOTE_PORT="${REMOTE_PORT:-50052}" + +python3 -m graph_net_bench.torch.eval_backend_diff \ + --model-path-list $model_list \ + --reference-config $(base64 -w 0 <&1 | tee "$OUTPUT_PATH/validation.log" \ No newline at end of file