From 260dc69f029dba40faac93f664bf4d2116583000 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Thu, 21 May 2026 11:51:01 +0100 Subject: [PATCH 1/5] [KB] Enable dynamo plugin for KB tests This PR enables a new mode in kernel_bench that allows one to use the Torch Dynamo compiler plugin (CPU only for now). This is similar to the previous imported module: it runs the same pipeline, compiles the same MLIR modules and yield the same results. Major changes: * Factoring some code around import & pipeline - import_module extracted in lh.ingress.torch - make_function_callable an external helper now * Make kernel_bench import the Torch model first - Run torch eager to gather the reference output - import to MLIR or run with torch.compile - reuse the same pipeline on both cases - Compare outputs instead of printing output * Allowing empty input/output shapes for torch-compile - For Dynamo this can be fetched from model and used - Import can probably also, but for a later PR Torch Compile mode enabled in CI with the previous tests. Note: This is not the default yet due to benchmark not enabled on this mode yet. A follow up PR will add support and enable it by default. --- .../KernelBench/test_kernel_bench.py | 35 +-- lighthouse/ingress/torch/__init__.py | 3 +- lighthouse/ingress/torch/importer.py | 231 +++++++++------- lighthouse/pipeline/driver.py | 27 +- tools/kernel_bench | 248 +++++++++++++----- tools/lh-run | 4 +- 6 files changed, 358 insertions(+), 190 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index c12e8483..3f8f76b1 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -1,4 +1,5 @@ # RUN: python %s --ci | FileCheck %s +# RUN: python %s --ci --torch-compile | FileCheck %s # REQUIRES: torch # REQUIRES: kernel_bench @@ -115,6 +116,11 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: action=argparse.BooleanOptionalAction, help="Enable CI mode (faster run, fewer kernels). Incompatible with --smoke-test.", ) + Parser.add_argument( + "--torch-compile", + action=argparse.BooleanOptionalAction, + help="Enable TorchScript compilation. Default is False.", + ) Parser.add_argument( "--test", type=str, @@ -157,12 +163,14 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: test["output_shape"], "--pipeline", test["pipeline"], - "--print-tensor=1", + "--print-output", "--seed=42", ] benchmark = args.benchmark and test.get("gflops") is not None if benchmark: command_line += ["--benchmark"] + if args.torch_compile: + command_line += ["--torch-compile"] if args.print_mlir_after_all: command_line += ["--print-mlir-after-all"] if test.get("warning"): @@ -200,22 +208,17 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: if not args.smoke_test: assert result.returncode == 0, "Execution failed" -# CHECK: 1_Square_matrix_multiplication_.mlir -# CHECK: 0.3745{{.*}} 0.9507{{.*}} 0.7319{{.*}} ... 0.2973{{.*}} 0.9243{{.*}} 0.9710{{.*}} -# CHECK: 0.7201{{.*}} 0.9926{{.*}} 0.1208{{.*}} ... 0.1742{{.*}} 0.3485{{.*}} 0.6436{{.*}} +# CHECK: 1_Square_matrix_multiplication_.py +# CHECK: Success: The output of the compiled model matches the reference output. -# CHECK: 2_Standard_matrix_multiplication_.mlir -# CHECK: 249.78{{.*}} 260.13{{.*}} 249.36{{.*}} ... 261.10{{.*}} 260.49{{.*}} 257.09{{.*}} -# CHECK: 243.56{{.*}} 250.91{{.*}} 252.38{{.*}} ... 260.40{{.*}} 261.56{{.*}} 256.24{{.*}} +# CHECK: 2_Standard_matrix_multiplication_.py +# CHECK: Success: The output of the compiled model matches the reference output. -# CHECK: 3_Batched_matrix_multiplication.mlir -# CHECK: 5.2403{{.*}} 7.7905{{.*}} 6.0769{{.*}} ... 7.8579{{.*}} 6.8890{{.*}} 6.6193{{.*}} -# CHECK: 9.0407{{.*}} 6.3299{{.*}} 5.2003{{.*}} ... 6.2594{{.*}} 6.2980{{.*}} 5.9807{{.*}} +# CHECK: 3_Batched_matrix_multiplication.py +# CHECK: Success: The output of the compiled model matches the reference output. -# CHECK: 4_Matrix_vector_multiplication_.mlir -# CHECK: 264.86{{.*}} -# CHECK: 265.12{{.*}} +# CHECK: 4_Matrix_vector_multiplication_.py +# CHECK: Success: The output of the compiled model matches the reference output. -# CHECK: 5_Matrix_scalar_multiplication.mlir -# CHECK: 0.1750{{.*}} 0.4442{{.*}} 0.3420{{.*}} ... 0.1389{{.*}} 0.4319{{.*}} 0.4538{{.*}} -# CHECK: 0.3365{{.*}} 0.4638{{.*}} 0.0564{{.*}} ... 0.0814{{.*}} 0.1628{{.*}} 0.3007{{.*}} +# CHECK: 5_Matrix_scalar_multiplication.py +# CHECK: Success: The output of the compiled model matches the reference output. diff --git a/lighthouse/ingress/torch/__init__.py b/lighthouse/ingress/torch/__init__.py index ea758f9e..7e1382c5 100644 --- a/lighthouse/ingress/torch/__init__.py +++ b/lighthouse/ingress/torch/__init__.py @@ -1,6 +1,6 @@ """Provides functions to convert PyTorch models to MLIR.""" -from .importer import import_from_file, import_from_model +from .importer import import_from_file, import_from_model, import_model from .compile import cpu_backend from .compile import gpu_backend from .compile import TargetDialect @@ -11,4 +11,5 @@ "gpu_backend", "import_from_file", "import_from_model", + "import_model", ] diff --git a/lighthouse/ingress/torch/importer.py b/lighthouse/ingress/torch/importer.py index 26943b11..eb80cdb2 100644 --- a/lighthouse/ingress/torch/importer.py +++ b/lighthouse/ingress/torch/importer.py @@ -29,6 +29,134 @@ from mlir import ir +def import_model( + filepath: str | Path, + model_class_name: str = "Model", + init_args_fn_name: str | None = "get_init_inputs", + init_kwargs_fn_name: str | None = None, + model_init_args: Iterable | None = None, + sample_args_fn_name: str = "get_inputs", + sample_kwargs_fn_name: str | None = None, + sample_args: Iterable | None = None, + state_path: str | Path | None = None, + **kwargs, +) -> str | ir.Module: + """Load a PyTorch nn.Module from a file. + + The function takes a `filepath` to a Python file containing the model definition, + along with the names of functions to get model init arguments and sample inputs. + The function imports the model class on its own and instantiates it. + + Args: + filepath (str | Path): Path to the Python file containing the model definition. + model_class_name (str, optional): The name of the model class in the file. + Defaults to "Model". + init_args_fn_name (str | None, optional): The name of the function in the file + that returns the arguments for initializing the model. If None, the model + is initialized without arguments. Defaults to "get_init_inputs". + init_kwargs_fn_name (str | None, optional): The name of the function in the file + that returns the keyword arguments for initializing the model. If None, the model + is initialized without keyword arguments. + model_init_args (Iterable | None, optional): If provided, these are used directly as + initialization arguments instead of calling ``init_args_fn_name`` from the file. + Useful for overriding hard-coded sizes in the model file. Defaults to None. + sample_args_fn_name (str, optional): The name of the function in the file that + returns the sample input arguments for the model. Defaults to "get_inputs". + sample_kwargs_fn_name (str, optional): The name of the function in the file that + returns the sample keyword input arguments for the model. Defaults to None. + sample_args (Iterable | None, optional): If provided, these are used directly as + sample inputs instead of calling ``sample_args_fn_name`` from the file. + Useful for overriding hard-coded sizes in the model file. Defaults to None. + state_path (str | Path | None, optional): Optional path to a file containing + the model's ``state_dict``. Defaults to None. + **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. + + Returns: + torch.nn.Module: The imported PyTorch model. + sample_args: The sample input arguments for the model. + sample_kwargs: The sample keyword input arguments for the model. + + Examples: + Given a file `path/to/model_file.py` with the following content: + ```python + import torch + import torch.nn as nn + + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 5) + + def forward(self, x): + return self.fc(x) + + + def get_inputs(): + return (torch.randn(1, 10),) + ``` + + The import script would look like: + >>> from lighthouse.ingress.torch_import import import_model + >>> # option 1: get MLIR module as a string + >>> model: nn.Module = import_model( + ... "path/to/model_file.py", + ... model_class_name="MyModel", + ... init_args_fn_name=None, + ... ) + """ + if isinstance(filepath, str): + filepath = Path(filepath) + module_name = filepath.stem + + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + model = getattr(module, model_class_name, None) + if model is None: + raise ValueError(f"Model class '{model_class_name}' not found in {filepath}") + + model_init_args = ( + maybe_load_and_run_callable( + module, + init_args_fn_name, + default=tuple(), + error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}", + ) + if model_init_args is None + else model_init_args + ) + model_init_kwargs = maybe_load_and_run_callable( + module, + init_kwargs_fn_name, + default={}, + error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}", + ) + sample_args = ( + load_and_run_callable( + module, + sample_args_fn_name, + f"Sample args function '{sample_args_fn_name}' not found in {filepath}", + ) + if sample_args is None + else sample_args + ) + sample_kwargs = maybe_load_and_run_callable( + module, + sample_kwargs_fn_name, + default={}, + error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}", + ) + + nn_model: nn.Module = model(*model_init_args, **model_init_kwargs) + if state_path is not None: + state_dict = torch.load(state_path) + nn_model.load_state_dict(state_dict) + + return nn_model, sample_args, sample_kwargs + + def import_from_model( model: nn.Module, sample_args: Iterable, @@ -118,10 +246,8 @@ def import_from_file( ) -> str | ir.Module: """Load a PyTorch nn.Module from a file and import it into MLIR. - The function takes a `filepath` to a Python file containing the model definition, - along with the names of functions to get model init arguments and sample inputs. - The function imports the model class on its own, instantiates it, and passes - it to ``torch_mlir`` to get a MLIR module in the specified `dialect`. + The function calls ``import_model`` to load the model from the given file + and then calls ``import_from_model`` to convert it into an MLIR module. Args: filepath (str | Path): Path to the Python file containing the model definition. @@ -156,94 +282,21 @@ def import_from_file( str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. Examples: - Given a file `path/to/model_file.py` with the following content: - ```python - import torch - import torch.nn as nn - - - class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(10, 5) - - def forward(self, x): - return self.fc(x) - - - def get_inputs(): - return (torch.randn(1, 10),) - ``` - - The import script would look like: - >>> from lighthouse.ingress.torch_import import import_from_file - >>> # option 1: get MLIR module as a string - >>> mlir_module: str = import_from_file( - ... "path/to/model_file.py", - ... model_class_name="MyModel", - ... init_args_fn_name=None, - ... dialect="linalg-on-tensors", - ... ) - >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect - >>> # option 2: get MLIR module as an ir.Module - >>> ir_context = ir.Context() - >>> mlir_module_ir: ir.Module = import_from_file( - ... "path/to/model_file.py", - ... model_class_name="MyModel", - ... init_args_fn_name=None, - ... dialect="linalg-on-tensors", - ... ir_context=ir_context, - ... ) + See ``import_model`` and ``import_from_model`` for examples + of the expected content of the model file and how to call this function. """ - if isinstance(filepath, str): - filepath = Path(filepath) - module_name = filepath.stem - - spec = importlib.util.spec_from_file_location(module_name, filepath) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - model = getattr(module, model_class_name, None) - if model is None: - raise ValueError(f"Model class '{model_class_name}' not found in {filepath}") - - model_init_args = ( - maybe_load_and_run_callable( - module, - init_args_fn_name, - default=tuple(), - error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}", - ) - if model_init_args is None - else model_init_args - ) - model_init_kwargs = maybe_load_and_run_callable( - module, - init_kwargs_fn_name, - default={}, - error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}", - ) - sample_args = ( - load_and_run_callable( - module, - sample_args_fn_name, - f"Sample args function '{sample_args_fn_name}' not found in {filepath}", - ) - if sample_args is None - else sample_args - ) - sample_kwargs = maybe_load_and_run_callable( - module, - sample_kwargs_fn_name, - default={}, - error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}", + nn_model, sample_args, sample_kwargs = import_model( + filepath=filepath, + model_class_name=model_class_name, + init_args_fn_name=init_args_fn_name, + init_kwargs_fn_name=init_kwargs_fn_name, + model_init_args=model_init_args, + sample_args_fn_name=sample_args_fn_name, + sample_kwargs_fn_name=sample_kwargs_fn_name, + sample_args=sample_args, + state_path=state_path, ) - nn_model: nn.Module = model(*model_init_args, **model_init_kwargs) - if state_path is not None: - state_dict = torch.load(state_path) - nn_model.load_state_dict(state_dict) - return import_from_model( nn_model, sample_args=sample_args, diff --git a/lighthouse/pipeline/driver.py b/lighthouse/pipeline/driver.py index 47fba9e7..1eb32f72 100644 --- a/lighthouse/pipeline/driver.py +++ b/lighthouse/pipeline/driver.py @@ -7,6 +7,20 @@ import lighthouse.dialects as lh_dialects +def make_function_callable(module: ir.Module, func_name: str) -> None: + """ + Set the 'llvm.emit_c_interface' attribute of the given function in the module. + This is required to make the function callable from the execution engine. + It has to be called on a @func.func (not an @llvm.func), so should be called + before the LLVM lowering stages are added to the pipeline. + """ + with module.context: + for func in module.body.operations: + if func.sym_name.value == func_name: + func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + break + + class PipelineDriver: """ A simple driver that runs the optimization pipeline on a given workload. @@ -160,19 +174,6 @@ def add_module_stage(self, stage_module: ir.Module) -> None: raise ValueError("Pipeline is fixed. Reset to start again.") self.pipeline.add_transform(stage_module) - def make_function_callable(self, func_name: str) -> None: - """ - Set the 'llvm.emit_c_interface' attribute of the given function in the module. - This is required to make the function callable from the execution engine. - It has to be called on a @func.func (not an @llvm.func), so should be called - before the LLVM lowering stages are added to the pipeline. - """ - with self.context: - for func in self.module.body.operations: - if func.sym_name.value == func_name: - func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - break - def reset(self) -> None: """Reset the pipeline to an empty state, allowing for new stages to be added.""" self.pipeline.reset() diff --git a/tools/kernel_bench b/tools/kernel_bench index b6be576a..8ae6ba5d 100755 --- a/tools/kernel_bench +++ b/tools/kernel_bench @@ -3,88 +3,111 @@ import argparse from datetime import datetime import sys -import tempfile import ml_dtypes import numpy as np from pathlib import Path import torch +from mlir import ir from lighthouse.execution.runner import Runner from lighthouse.execution.init import KernelArgumentParser -from lighthouse.pipeline.driver import CompilerDriver +from lighthouse.pipeline.descriptor import Descriptor +from lighthouse.pipeline.driver import PipelineDriver, make_function_callable from lighthouse.schedule import convert_function_results from lighthouse import dialects as lh_dialects from lighthouse import ingress as lh_ingress +from lighthouse.ingress.torch import cpu_backend -def import_kb_module( +def import_torch( filepath: str | Path, model_class_name: str = "Model", - dialect: str = "linalg-on-tensors", sample_args=None, model_init_args=None, -) -> Path: +) -> torch.nn.Module: """ - Imports a PyTorch model from a KernelBench file and converts it to an MLIR module in the specified dialect. - The MLIR module is saved into a file to be passed to the CompilerDrviver for optimization and execution. + Imports a PyTorch model from a KernelBench file and returns the PyTorch module. The Python file must define a Kernel Bench module file. - Return: The path to the generated MLIR module file. """ try: - mlir_module = lh_ingress.torch.import_from_file( + model, sample_args, sample_kwargs = lh_ingress.torch.import_model( filepath=filepath, model_class_name=model_class_name, - dialect=dialect, sample_args=sample_args, model_init_args=model_init_args, ) - assert isinstance(mlir_module, str) - temp_dir = tempfile.TemporaryDirectory(delete=False) - mlir_file = Path(temp_dir.name) / (Path(filepath).stem + ".mlir") - mlir_file.write_text(mlir_module) - print(f"Successfully imported {filepath} to MLIR module at {mlir_file}") - return mlir_file + assert isinstance(model, torch.nn.Module) + return model, sample_args, sample_kwargs + except Exception as e: + print( + f"ERROR: got an error converting {filepath} to a PyTorch module:", + file=sys.stderr, + ) + raise e + + +def import_mlir( + model: torch.nn.Module, + sample_args=None, +) -> ir.Module: + """ + Imports a PyTorch model from a KernelBench file and converts it to an MLIR module in the specified dialect. + The MLIR module is saved into a file to be passed to the CompilerDrviver for optimization and execution. + The Python file must define a Kernel Bench module file. + Return: The MLIR Module. + """ + try: + mlir_module = lh_ingress.torch.import_from_model( + model=model, + sample_args=sample_args, + ir_context=ir.Context(), + ) + assert isinstance(mlir_module, ir.Module) + return mlir_module except Exception as e: print( - f"ERROR: got an error converting {filepath} to MLIR:", + f"ERROR: got an error converting {model} to MLIR:", file=sys.stderr, ) raise e def define_compiler_pipeline( - mlir_file: Path, + module: ir.Module, pipeline_yaml: str | None, benchmark: bool = False, + convert_results: bool = False, ) -> None: """ Defines the compiler pipeline by adding stages to the driver. The stages can be defined in a YAML file, or added manually in this function. """ - driver = CompilerDriver(mlir_file) + driver = PipelineDriver(module.context) - with driver.context: + with module.context: lh_dialects.register_and_load() - driver.add_module_stage(convert_function_results(args.entry_point)) + if convert_results: + driver.add_transform(convert_function_results(args.entry_point)) - if benchmark: - # Calling the benchmark wrapper, not the entry point. - # FIXME: Eliminate this cross-dependency between the Runner and the Driver. - with driver.context: - lh_dialects.register_and_load() + if benchmark: + # Calling the benchmark wrapper, not the entry point. + # FIXME: Eliminate this cross-dependency between the Runner and the Driver. bench_wrapper = Runner.get_bench_wrapper_schedule(args.entry_point) - driver.add_module_stage(bench_wrapper) - else: - # Calling the entry point directly, so set the attribute on the entry point function. - driver.make_function_callable(args.entry_point) + driver.add_transform(bench_wrapper) + else: + # Calling the entry point directly, so set the attribute on the entry point function. + make_function_callable(module, args.entry_point) - if pipeline_yaml: - # Add stages defined by the user. - driver.add_stages(pipeline_yaml) - else: - # Search for the yaml file in the script directory. - path = Path(__file__).parent / "kernel_bench.yaml" - driver.add_stage(str(path)) + if not pipeline_yaml: + # Search for the yaml file in the script directory. + pipeline_yaml = Path(__file__).parent / "kernel_bench.yaml" + if isinstance(pipeline_yaml, list): + for yaml in pipeline_yaml: + desc = Descriptor(yaml) + driver.add_stage(desc) + else: + desc = Descriptor(str(pipeline_yaml)) + driver.add_stage(desc) return driver @@ -95,6 +118,10 @@ def parse_inputs_and_outputs( """ Parses the input and output shape strings to create sample tensors for the inputs and outputs. """ + # Empty buffers shortcut (either both or none, ignore either) + if input_shapes_str is None or output_shape_str is None: + return None, None + # Parse input shapes first, to create sample tensors for the inputs only buffers = [arg.arg for arg in KernelArgumentParser.parse_all(input_shapes_str)] # Build sample torch tensors from input shapes to override hard-coded sizes in get_inputs(). @@ -110,11 +137,80 @@ def parse_inputs_and_outputs( else torch.from_numpy(buf) for buf in buffers ] + # Now include the output shape in the buffers buffers = [KernelArgumentParser.parse(output_shape_str).arg, *buffers] return buffers, sample_tensors +def torch_compile(args, buffers: list, sample_tensors: list): + """ + Compiles the model using torch.compile with a custom MLIR backend. + """ + + # Create the compiler pipeline for the torch compiler + def compiler_pipeline(module: ir.Module) -> ir.Module: + # The pipeline is defined as a function that takes an MLIR module and returns an optimized MLIR module. + # This is a simplified version of the define_compiler_pipeline function, which is used for the torch.compile backend. + # The pipeline is fixed and cannot be modified after the first run, to avoid accidentally modifying the pipeline after it has been run. + if args.print_original_module: + print(module) + driver = define_compiler_pipeline(module, args.pipeline, args.benchmark) + driver.apply(module, print_after_all=args.print_mlir_after_all) + if args.print_optimized_module: + print(module) + return module + + # TODO: Implement benchmarking + if args.benchmark: + print("Benchmarking is not yet implemented for torch.compile backend.") + exit(1) + else: + # Reconfigure the model to be compiled using torch.compile, take the compiled output. + model.compile(dynamic=False, backend=cpu_backend(compiler_pipeline)) + out = model(*sample_tensors, **sample_kwargs) + + return out + + +def torch_import(args, model: torch.nn.Module, buffers: list, sample_tensors: list): + """ + Imports the model to MLIR and compiles with a custom MLIR backend. + """ + # Import the model with the custom sample inputs from the command line. + mlir_module = import_mlir(model, sample_args=sample_tensors) + if args.print_original_module: + print(mlir_module) + + # Create the driver and run the pipeline. + driver = define_compiler_pipeline( + mlir_module, args.pipeline, args.benchmark, convert_results=True + ) + + # Run the pipeline to get the optimized module. + optimized_module = driver.apply( + mlir_module, print_after_all=args.print_mlir_after_all + ) + if args.print_optimized_module: + print(optimized_module) + + # Create the runner and execute/benchmark the module. + runner = Runner(optimized_module, opt_level=args.O) + + if args.benchmark: + print("Running the benchmark...") + time_array = runner.benchmark(host_input_buffers=buffers) + print(f"{len(time_array)} runs: {np.mean(time_array)} seconds") + else: + print("Executing the module...") + runner.execute( + payload_function_name=args.entry_point, + host_input_buffers=buffers, + ) + + return torch.from_numpy(buffers[0]) + + if __name__ == "__main__": Parser = argparse.ArgumentParser( description="""Kernel Bench importer, optimization and runner.""" @@ -126,14 +222,12 @@ if __name__ == "__main__": ) Parser.add_argument( "--input-shapes", - required=True, help="Shape of the input tensors in format: \ DIMS(MxNx...)xTYPE(f16/f32/f64/bf16/i8)xINIT(0/1/rnd/id). \ For multiple inputs, separate by comma.", ) Parser.add_argument( "--output-shape", - required=True, help="Shape of the output tensor in format: \ DIMS(MxNx...)xTYPE(f16/f32/f64/bf16/i8)xINIT(0/1/rnd/id). \ Single output shape supported.", @@ -161,10 +255,20 @@ if __name__ == "__main__": action=argparse.BooleanOptionalAction, help="Whether to run the benchmark. Default is False.", ) + Parser.add_argument( + "--torch-compile", + action=argparse.BooleanOptionalAction, + help="Whether to use TorchScript compilation. Default is False.", + ) + Parser.add_argument( + "--print-original-module", + action=argparse.BooleanOptionalAction, + help="Whether to print the original MLIR module. Default is False.", + ) Parser.add_argument( "--print-optimized-module", action=argparse.BooleanOptionalAction, - help="Whether to print the optimized module. Default is False.", + help="Whether to print the optimized MLIR module. Default is False.", ) Parser.add_argument( "--entry-point", @@ -173,10 +277,9 @@ if __name__ == "__main__": help="Name of the function to execute or benchmark. Default is 'main'.", ) Parser.add_argument( - "--print-tensor", - type=int, - default=0, - help="Print the Nth tensor. Default is 0 (no print).", + "--print-output", + action=argparse.BooleanOptionalAction, + help="Whether to print the output tensor after execution. Default is False.", ) Parser.add_argument( "--print-mlir-after-all", @@ -191,37 +294,44 @@ if __name__ == "__main__": else: np.random.seed(int(datetime.now().timestamp())) + if ( + args.input_shapes is None or args.output_shape is None + ) and not args.torch_compile: + print( + "ERROR: --input-shapes and --output-shape must be provided together for non-torch-compile mode.", + file=sys.stderr, + ) + exit(1) + # Initialize the device data buffers, sample_tensors = parse_inputs_and_outputs( args.input_shapes, args.output_shape ) - # Import the model with the custom sample inputs from the command line. - mlir_file = import_kb_module(args.kernel_bench_model, sample_args=sample_tensors) - - # Create the driver and run the pipeline. - driver = define_compiler_pipeline(mlir_file, args.pipeline, args.benchmark) + # First, import the model and run the reference numbers + model, sample_args, sample_kwargs = import_torch( + args.kernel_bench_model, sample_args=sample_tensors + ) + out_ref = model(*sample_args, **sample_kwargs) + if args.print_output: + print(f"Reference output:\n{out_ref}") - # Run the pipeline to get the optimized module. - optimized_module = driver.run(print_after_all=args.print_mlir_after_all) - if args.print_optimized_module: - print(optimized_module) + # Now run the MLIR compiler on either imported or Dynamo plugin + if sample_tensors is None: + sample_tensors = sample_args + if args.torch_compile: + out = torch_compile(args, model, sample_tensors) + else: + out = torch_import(args, model, buffers, sample_tensors) - # Create the runner and execute/benchmark the module. - runner = Runner(optimized_module, opt_level=args.O) + # Optionally print the output tensors after execution. + if args.print_output: + print(f"MLIR Output:\n{out}") - if args.benchmark: - print("Running the benchmark...") - time_array = runner.benchmark(host_input_buffers=buffers) - print(f"{len(time_array)} runs: {np.mean(time_array)} seconds") - else: - print("Executing the module...") - runner.execute( - payload_function_name=args.entry_point, - host_input_buffers=buffers, + # Compare the outputs + if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01): + print( + "ERROR: The output of the compiled model does not match the reference output." ) - - # Optionally print the output tensor after execution. - if args.print_tensor > 0: - idx = args.print_tensor - 1 - print(f"Output: {buffers[idx]}") + exit(1) + print("Success: The output of the compiled model matches the reference output.") diff --git a/tools/lh-run b/tools/lh-run index dca8e756..1153a0b8 100755 --- a/tools/lh-run +++ b/tools/lh-run @@ -7,7 +7,7 @@ import numpy as np from lighthouse.execution.runner import Runner from lighthouse.execution.init import KernelArgumentParser from lighthouse.pipeline.descriptor import Descriptor -from lighthouse.pipeline.driver import CompilerDriver +from lighthouse.pipeline.driver import CompilerDriver, make_function_callable from lighthouse import dialects as lh_dialects @@ -94,7 +94,7 @@ if __name__ == "__main__": driver.add_module_stage(bench_wrapper) else: # Calling the entry point directly, so set the attribute on the entry point function. - driver.make_function_callable(args.entry_point) + make_function_callable(driver.module, args.entry_point) # Add the remaining stages defined by the user. driver.add_stages(args.stage) From 201d37b30149991ba8f45825faba100ea41f9f7a Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Thu, 21 May 2026 18:28:03 +0100 Subject: [PATCH 2/5] flush stdout to help with long running processes in the background --- examples/end-to-end/KernelBench/test_kernel_bench.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 3f8f76b1..a911cec2 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -175,7 +175,7 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: command_line += ["--print-mlir-after-all"] if test.get("warning"): print(f"WARNING: {test['warning']}") - print(f"Running command: {' '.join(command_line)}") + print(f"Running command: {' '.join(command_line)}", flush=True) # While debugging kernels, it's useful to see the output as it comes. # Note: GFLOPS can't be shown if the output is not captured. @@ -201,7 +201,7 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: print("STDERR:") print(result.stderr) - print(f"Return code: {result.returncode}") + print(f"Return code: {result.returncode}", flush=True) # Only stop on failure on normal runs. # Smoke tests try to run as much as possible. From fb8cf45ecd0421f81cc0ad741d53922ea044ec95 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Thu, 21 May 2026 18:33:55 +0100 Subject: [PATCH 3/5] allowing tocrh.compile to pick its own inputs so that we know they're always correct you can still override in kernel_bench, this just makes the smoke test much easier CI wants to be fast, so we allow to pass the shapes by cmd line --- .../end-to-end/KernelBench/test_kernel_bench.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index a911cec2..2b688b2e 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -157,10 +157,6 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: command_line = [ str(kb_program), str(kb_kernel), - "--input-shapes", - test["input_shapes"], - "--output-shape", - test["output_shape"], "--pipeline", test["pipeline"], "--print-output", @@ -169,8 +165,20 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: benchmark = args.benchmark and test.get("gflops") is not None if benchmark: command_line += ["--benchmark"] + + # We allow toch.compile to pick its own shapes (unless it's CI) if args.torch_compile: command_line += ["--torch-compile"] + + # TODO: Implement auto-shapes for non-compile mode as well. + if args.ci or not args.torch_compile: + command_line += [ + "--input-shapes", + test["input_shapes"], + "--output-shape", + test["output_shape"], + ] + if args.print_mlir_after_all: command_line += ["--print-mlir-after-all"] if test.get("warning"): From f7470928b07552b9f2ca13e099422a30d9bd47c0 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 22 May 2026 09:17:08 +0100 Subject: [PATCH 4/5] don't print output on smoke-tests or CI (plus some comments) --- examples/end-to-end/KernelBench/test_kernel_bench.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/end-to-end/KernelBench/test_kernel_bench.py b/examples/end-to-end/KernelBench/test_kernel_bench.py index 2b688b2e..ec8b2787 100755 --- a/examples/end-to-end/KernelBench/test_kernel_bench.py +++ b/examples/end-to-end/KernelBench/test_kernel_bench.py @@ -159,14 +159,14 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: str(kb_kernel), "--pipeline", test["pipeline"], - "--print-output", "--seed=42", ] + # Benchmarks only if there's data to calculate FLOPS. benchmark = args.benchmark and test.get("gflops") is not None if benchmark: command_line += ["--benchmark"] - # We allow toch.compile to pick its own shapes (unless it's CI) + # We allow torch.compile to pick its own shapes (unless it's CI). if args.torch_compile: command_line += ["--torch-compile"] @@ -179,8 +179,15 @@ def get_flops_per_second(stdout: str, gflops: float) -> float: test["output_shape"], ] + # Smoke tests / CI don't print outputs. + if not args.smoke_test and not args.ci: + command_line += ["--print-output"] + + # For debugging, prefer not to capture output. if args.print_mlir_after_all: command_line += ["--print-mlir-after-all"] + + # Print out before we run the test. if test.get("warning"): print(f"WARNING: {test['warning']}") print(f"Running command: {' '.join(command_line)}", flush=True) From 3b8f7f1bb78c8df372e78ce9d29c65afd5ede1dd Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 22 May 2026 11:00:10 +0100 Subject: [PATCH 5/5] tolerance checking for output analysis --- tools/kernel_bench | 60 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/tools/kernel_bench b/tools/kernel_bench index 8ae6ba5d..33ef99e5 100755 --- a/tools/kernel_bench +++ b/tools/kernel_bench @@ -163,8 +163,9 @@ def torch_compile(args, buffers: list, sample_tensors: list): # TODO: Implement benchmarking if args.benchmark: - print("Benchmarking is not yet implemented for torch.compile backend.") - exit(1) + raise NotImplementedError( + "Benchmarking is not yet implemented for torch.compile backend." + ) else: # Reconfigure the model to be compiled using torch.compile, take the compiled output. model.compile(dynamic=False, backend=cpu_backend(compiler_pipeline)) @@ -211,6 +212,30 @@ def torch_import(args, model: torch.nn.Module, buffers: list, sample_tensors: li return torch.from_numpy(buffers[0]) +def compare_outputs(out_ref, out, tolerance, max_tolerance): + """ + Compares the reference output with the output from the compiled model. + On success, iteratively reduces the tolerance to find the maximum + tolerance at which the outputs still match. + """ + if torch.allclose(out_ref, out, rtol=tolerance, atol=tolerance): + return True, tolerance + + # From here on, we know the outputs don't match at the initial tolerance. + max_tol = max_tolerance + if tolerance >= max_tol: + return False, tolerance + + # Now try to find the widest tolerance. + tol = tolerance * 10 + while tol <= max_tol: + # Fails with this but not previous, return previous tolerance + if torch.allclose(out_ref, out, rtol=tol, atol=tol): + return False, tol / 10 + tol *= 10 + return False, max_tol + + if __name__ == "__main__": Parser = argparse.ArgumentParser( description="""Kernel Bench importer, optimization and runner.""" @@ -238,6 +263,18 @@ if __name__ == "__main__": default=0, help="Random seed for initializing input tensors.", ) + Parser.add_argument( + "--tolerance", + type=float, + default=0.01, + help="Expected tolerance for comparing outputs.", + ) + Parser.add_argument( + "--max-tolerance", + type=float, + default=1.0, + help="Maximum tolerance to check if outputs don't match.", + ) Parser.add_argument( "--pipeline", type=str, @@ -297,11 +334,9 @@ if __name__ == "__main__": if ( args.input_shapes is None or args.output_shape is None ) and not args.torch_compile: - print( - "ERROR: --input-shapes and --output-shape must be provided together for non-torch-compile mode.", - file=sys.stderr, + raise ValueError( + "--input-shapes and --output-shape must be provided together for non-torch-compile mode." ) - exit(1) # Initialize the device data buffers, sample_tensors = parse_inputs_and_outputs( @@ -329,9 +364,14 @@ if __name__ == "__main__": print(f"MLIR Output:\n{out}") # Compare the outputs - if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01): - print( - "ERROR: The output of the compiled model does not match the reference output." + match, tol = compare_outputs( + out_ref, out, tolerance=args.tolerance, max_tolerance=args.max_tolerance + ) + if not match: + raise RuntimeError( + f"The output of the compiled model does't match the reference output.\n" + f"Tolerance requested: [{args.tolerance}, {args.max_tolerance}], got at or beyond: {tol}" ) - exit(1) + print("Success: The output of the compiled model matches the reference output.") + print(f"Tolerance: {tol}")