Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/end-to-end/KernelBench/test_kernel_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
type=str,
help="Specify a particular test to run.",
)
Parser.add_argument(
"--print-output",
action=argparse.BooleanOptionalAction,
help="Whether to print the output of the kernel. Default is False.",
)
Parser.add_argument(
"--print-mlir-after-all",
action=argparse.BooleanOptionalAction,
Expand Down Expand Up @@ -180,7 +185,7 @@ def get_flops_per_second(stdout: str, gflops: float) -> float:
]

# Smoke tests / CI don't print outputs.
if not args.smoke_test and not args.ci:
if args.print_output:
command_line += ["--print-output"]

# For debugging, prefer not to capture output.
Expand Down
10 changes: 9 additions & 1 deletion tools/kernel_bench
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ from lighthouse.schedule import convert_function_results
from lighthouse import dialects as lh_dialects
from lighthouse import ingress as lh_ingress
from lighthouse.ingress.torch import cpu_backend
from lighthouse.utils.mlir import get_mlir_library_path
import os

lib_dir = get_mlir_library_path()
c_runner_lib = os.path.join(lib_dir, "libmlir_c_runner_utils.so")


def import_torch(
Expand Down Expand Up @@ -168,7 +173,10 @@ def torch_compile(args, buffers: list, sample_tensors: list):
)
else:
# Reconfigure the model to be compiled using torch.compile, take the compiled output.
model.compile(dynamic=False, backend=cpu_backend(compiler_pipeline))
model.compile(
dynamic=False,
backend=cpu_backend(compiler_pipeline, shared_libs=[c_runner_lib]),
)
out = model(*sample_tensors, **sample_kwargs)

return out
Expand Down
Loading