From bfe39b8812eac8afa440508b7f70ef38eeafd4de Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 20 May 2026 12:01:10 +0300 Subject: [PATCH 1/4] Add xegpu matmul cost model --- examples/xegpu/matmul.py | 40 +- examples/xegpu/mlp.py | 12 +- examples/xegpu/torch_matmul.py | 44 +-- examples/xegpu/tune_matmul_gridsearch.py | 181 +-------- lighthouse/schedule/xegpu/__init__.py | 6 + .../schedule/xegpu/matmul_constraints.py | 246 +++++++++++++ lighthouse/schedule/xegpu/matmul_costmodel.py | 344 ++++++++++++++++++ lighthouse/schedule/xegpu/mlp_schedule.py | 63 +++- .../xegpu/xegpu_parameter_selector.py | 65 ++-- lighthouse/schedule/xegpu/xegpu_specs.py | 78 ++++ 10 files changed, 803 insertions(+), 276 deletions(-) create mode 100644 lighthouse/schedule/xegpu/matmul_constraints.py create mode 100644 lighthouse/schedule/xegpu/matmul_costmodel.py create mode 100644 lighthouse/schedule/xegpu/xegpu_specs.py diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index e91fd790..a2a7a36d 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -29,7 +29,6 @@ from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary from lighthouse.utils.numpy import mlir_to_numpy_dtype from lighthouse.ingress.mlir_gen import generate_gpu_matmul_payload, get_mlir_elem_type -from lighthouse.schedule.xegpu import xegpu_parameter_selector def matmul_complexity( @@ -345,6 +344,11 @@ def parse_cli_args(description): "--json", help="Read problem sizes and tile parameters from a JSON file.", ) + parser.add_argument( + "--target", + choices=["B70", "B50"], + help="Target GPU device, e.g., B70.", + ) parser.add_argument( "--verbose", "-v", @@ -370,31 +374,15 @@ def parse_cli_args(description): # Problem size m, n, k = args.sizes if args.sizes else (4096, 4096, 4096) - # Get default parameters from the database - try: - params = xegpu_parameter_selector.get_matmul_parameters(m, n, k) - except ValueError: - # Initialize with a stub and assume the rest will be populated - params = { - "m": m, - "n": n, - "k": k, - "wg_m": None, - "wg_n": None, - "sg_m": None, - "sg_n": None, - "k_tile": None, - "load_a_m": None, - "load_a_k": None, - "load_b_k": None, - "load_b_n": None, - "prefetch_a_m": None, - "prefetch_a_k": None, - "prefetch_b_k": None, - "prefetch_b_n": None, - "prefetch_a_nb": None, - "prefetch_b_nb": None, - } + # Set required parameters + params = { + "m": m, + "n": n, + "k": k, + } + if args.target: + params["device"] = args.target + if args.json: # Override parameters with values from JSON file if provided with open(args.json, "r") as f: diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 9ed205b2..d19441b5 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -36,7 +36,6 @@ generate_gpu_mlp_payload, get_mlir_elem_type, ) -from lighthouse.schedule.xegpu import xegpu_parameter_selector from matmul import matmul_complexity @@ -332,6 +331,11 @@ def parse_cli(): action="store_true", help="Dump transform schedule.", ) + parser.add_argument( + "--target", + choices=["B70", "B50"], + help="Target GPU device, e.g., B70.", + ) parser.add_argument( "--verbose", "-v", @@ -371,7 +375,11 @@ def parse_cli(): ab_type = wload.ab_type acc_type = wload.acc_type - params = xegpu_parameter_selector.get_parameters_for_layers(matmuls) + # Initialize layer parameters + params = [{"m": M, "n": N, "k": K} for M, N, K in matmuls] + if args.target: + for layer_params in params: + layer_params["device"] = args.target if args.dump_kernel or args.dump_schedule: pipeline = TransformDriver( diff --git a/examples/xegpu/torch_matmul.py b/examples/xegpu/torch_matmul.py index 9442cadb..7a629a00 100644 --- a/examples/xegpu/torch_matmul.py +++ b/examples/xegpu/torch_matmul.py @@ -20,11 +20,12 @@ from lighthouse import schedule as lh_schedule from lighthouse.pipeline.driver import TransformDriver from lighthouse.utils.mlir import get_mlir_library_path -from lighthouse.schedule.xegpu import mlp_schedule, xegpu_to_binary +from lighthouse.schedule.xegpu import ( + mlp_schedule, + xegpu_to_binary, +) from lighthouse.ingress.torch import gpu_backend, TargetDialect -import parameter_selector - class Model(nn.Module): def __init__(self): @@ -164,6 +165,11 @@ def parse_cli_args(description): "--json", help="Read problem sizes and tile parameters from a JSON file.", ) + parser.add_argument( + "--target", + choices=["B70", "B50"], + help="Target GPU device, e.g., B70.", + ) args = parser.parse_args() return args @@ -182,30 +188,14 @@ def parse_cli_args(description): # Problem size m, n, k = args.sizes if args.sizes else (4096, 4096, 4096) - # Get default parameters from the database - try: - params = parameter_selector.get_matmul_parameters(m, n, k) - except ValueError: - # Initialize with a stub and assume the rest will be populated - params = { - "m": m, - "n": n, - "k": k, - "wg_m": None, - "wg_n": None, - "sg_m": None, - "sg_n": None, - "k_tile": None, - "load_a_m": None, - "load_a_k": None, - "load_b_k": None, - "load_b_n": None, - "prefetch_a_m": None, - "prefetch_a_k": None, - "prefetch_b_k": None, - "prefetch_b_n": None, - "prefetch_nb": None, - } + # Set required parameters + params = { + "m": m, + "n": n, + "k": k, + } + if args.target: + params["device"] = args.target if args.json: # Override parameters with values from JSON file if provided with open(args.json, "r") as f: diff --git a/examples/xegpu/tune_matmul_gridsearch.py b/examples/xegpu/tune_matmul_gridsearch.py index 6ad35791..5accd46d 100644 --- a/examples/xegpu/tune_matmul_gridsearch.py +++ b/examples/xegpu/tune_matmul_gridsearch.py @@ -16,6 +16,8 @@ from lighthouse.execution.runner import Runner from lighthouse.schedule.xegpu.mlp_schedule import DPAS from lighthouse.pipeline.driver import TransformDriver +from lighthouse.schedule.xegpu import check_constraints +from lighthouse.schedule.xegpu import XeGPUSpecs from matmul import XeGPUMatMul, check_results, cli_parser from genetic_algorithm import ( @@ -23,16 +25,6 @@ VariableSet, ) from tune_utils import dump_configs_json, execute_and_log -from lighthouse.schedule.xegpu.mlp_schedule import ( - MAX_NB_SG_THREADS, - LOAD_MAX_COLS, - LOAD_MAX_ROWS, - PFETCH_MAX_COLS, - PFETCH_MAX_ROWS, - PFETCH_MIN_COLS, - PFETCH_MIN_ROWS, - MIN_NB_THREADS, -) def run_experiment( @@ -109,156 +101,6 @@ def run_experiment( return elapsed, gflops -def check_constraints(params: dict, verbose: bool = False) -> bool: - def print_reason(msg): - if verbose: - print(f" Invalid: {msg}") - - M = params["m"] - N = params["n"] - wg_tile_m = params["wg_m"] - wg_tile_n = params["wg_n"] - sg_tile_m = params["sg_m"] - sg_tile_n = params["sg_n"] - load_tile_a_m = params["load_a_m"] - load_tile_a_k = params["load_a_k"] - load_tile_b_k = params["load_b_k"] - load_tile_b_n = params["load_b_n"] - prefetch_tile_a_m = params["prefetch_a_m"] - prefetch_tile_a_k = params["prefetch_a_k"] - prefetch_tile_b_k = params["prefetch_b_k"] - prefetch_tile_b_n = params["prefetch_b_n"] - k_tile = params["k_tile"] - - if M % wg_tile_m != 0: - print_reason("wg_tile_m does not divide M") - return False - if N % wg_tile_n != 0: - print_reason("wg_tile_n does not divide N") - return False - if wg_tile_m % sg_tile_m != 0: - print_reason("sg_tile_m does not divide wg_tile_m") - return False - if wg_tile_n % sg_tile_n != 0: - print_reason("sg_tile_n does not divide wg_tile_n") - return False - if sg_tile_m % DPAS.M != 0: - print_reason("sg_tile_m not multiple of dpas_m") - return False - if sg_tile_n % DPAS.N != 0: - print_reason("sg_tile_n not multiple of dpas_n") - return False - if k_tile % DPAS.K != 0: - print_reason("k_tile not multiple of dpas_k") - return False - - # SG level thread layout: [nb_sg_threads_m, nb_sg_threads_n] - nb_sg_threads_m = wg_tile_m // sg_tile_m - nb_sg_threads_n = wg_tile_n // sg_tile_n - nb_sg_threads = nb_sg_threads_m * nb_sg_threads_n - if nb_sg_threads > MAX_NB_SG_THREADS: - print_reason("too many sg threads") - return False - if nb_sg_threads < MIN_NB_THREADS: - print_reason("too few sg threads") - return False - - if sg_tile_m % load_tile_a_m != 0: - print_reason("load_tile_a_m does not divide sg_tile_m") - return False - if k_tile % load_tile_a_k != 0: - print_reason("load_tile_a_k does not divide k_tile") - return False - if k_tile % load_tile_b_k != 0: - print_reason("load_tile_b_k does not divide k_tile") - return False - if sg_tile_n % load_tile_b_n != 0: - print_reason("load_tile_b_n does not divide sg_tile_n") - return False - if load_tile_a_m > LOAD_MAX_ROWS: - print_reason("too large load_tile_a_m") - return False - if load_tile_a_k > LOAD_MAX_COLS: - print_reason("too large load_tile_a_k") - return False - if load_tile_b_k > LOAD_MAX_ROWS: - print_reason("too large load_tile_b_k") - return False - if load_tile_b_n > LOAD_MAX_COLS: - print_reason("too large load_tile_b_n") - return False - if sg_tile_m % prefetch_tile_a_m != 0: - print_reason("prefetch_tile_a_m does not divide sg_tile_m") - return False - if k_tile % prefetch_tile_a_k != 0: - print_reason("prefetch_tile_a_k does not divide k_tile") - return False - if k_tile % prefetch_tile_b_k != 0: - print_reason("prefetch_tile_b_k does not divide k_tile") - return False - if sg_tile_n % prefetch_tile_b_n != 0: - print_reason("prefetch_tile_b_n does not divide sg_tile_n") - return False - if prefetch_tile_a_m > PFETCH_MAX_ROWS: - print_reason("too large prefetch_tile_a_m") - return False - if prefetch_tile_a_k > PFETCH_MAX_COLS: - print_reason("too large prefetch_tile_a_k") - return False - if prefetch_tile_b_k > PFETCH_MAX_ROWS: - print_reason("too large prefetch_tile_b_k") - return False - if prefetch_tile_b_n > PFETCH_MAX_COLS: - print_reason("too large prefetch_tile_b_n") - return False - if prefetch_tile_a_m < PFETCH_MIN_ROWS: - print_reason("too small prefetch_tile_a_m") - return False - if prefetch_tile_a_k < PFETCH_MIN_COLS: - print_reason("too small prefetch_tile_a_k") - return False - if prefetch_tile_b_k < PFETCH_MIN_ROWS: - print_reason("too small prefetch_tile_b_k") - return False - if prefetch_tile_b_n < PFETCH_MIN_COLS: - print_reason("too small prefetch_tile_b_n") - return False - if load_tile_a_m % DPAS.M != 0: - print_reason("load_tile_a_m not multiple of dpas_m") - return False - if load_tile_a_k % DPAS.K != 0: - print_reason("load_tile_a_k not multiple of dpas_k") - return False - if load_tile_b_k % DPAS.K != 0: - print_reason("load_tile_b_k not multiple of dpas_k") - return False - if load_tile_b_n % DPAS.N != 0: - print_reason("load_tile_b_n not multiple of dpas_n") - return False - - # prefetch A layout - nb_prefetch_a_m = wg_tile_m // prefetch_tile_a_m - nb_prefetch_a_k = k_tile // prefetch_tile_a_k - if nb_prefetch_a_m * nb_prefetch_a_k > MAX_NB_SG_THREADS: - print_reason("too many prefetch A tiles") - return False - if nb_prefetch_a_m * nb_prefetch_a_k < MIN_NB_THREADS: - print_reason("too few prefetch A threads") - return False - - # prefetch B layout - nb_prefetch_b_k = k_tile // prefetch_tile_b_k - nb_prefetch_b_n = wg_tile_n // prefetch_tile_b_n - if nb_prefetch_b_k * nb_prefetch_b_n > MAX_NB_SG_THREADS: - print_reason("too many prefetch B tiles") - return False - if nb_prefetch_b_k * nb_prefetch_b_n < MIN_NB_THREADS: - print_reason("too few prefetch B threads") - return False - - return True - - def get_divisors(n: int, min_tile: int = 32, max_tile: int = 256) -> list[int]: p = np.ceil(n / max_tile) q = n // min_tile @@ -271,7 +113,9 @@ def divisible_by(a_list: list, b: int) -> list: return [a for a in a_list if a % b == 0] -def construct_search_space(M: int, N: int, K: int): +def construct_search_space( + M: int, N: int, K: int, gpu_specs: XeGPUSpecs +) -> tuple[VariableSet, callable]: wg_tile_lim_m = min(max(M // 4, 16), 64), min(M, 256) wg_tile_lim_n = min(max(N // 4, 16), 64), min(N, 256) sg_tile_lim_m = min(max(M // 8, 16), 32), min(M, 128) @@ -288,7 +132,7 @@ def construct_search_space(M: int, N: int, K: int): def sample_is_valid(sample_params, verbose=False): params = {"m": M, "n": N, "k": K} params.update(sample_params) - return check_constraints(params, verbose=verbose) + return check_constraints(params, gpu_specs, verbose=verbose) var_set = VariableSet( [ @@ -328,6 +172,12 @@ def sample_to_dict(sample: list) -> dict: action="store_true", help="Check validity of combinations but do not execute kernels.", ) + parser.add_argument( + "--target", + choices=["B70", "B580"], + default="B70", + help="Target GPU device.", + ) parser.add_argument( "--max-iters", type=int, @@ -368,8 +218,11 @@ def sample_to_dict(sample: list) -> dict: csv_file = "out_gridsearch.csv" csv_logger = CSVLogger(csv_file) - var_set, sample_to_dict = construct_search_space(*sizes) + gpu_specs = XeGPUSpecs.get(args.target) + + var_set, sample_to_dict = construct_search_space(*sizes, gpu_specs=gpu_specs) print(f"Matmul problem size: {sizes}") + print(f"device={gpu_specs.name}") print(f"{ab_type=}") print(f"{c_type=}") print(f"{has_bias=}") @@ -383,7 +236,7 @@ def sample_to_dict(sample: list) -> dict: tic = perf_counter() for sample in product(*var_set.iterables()): params = sample_to_dict(sample) - if not check_constraints(params, verbose=False): + if not check_constraints(params, gpu_specs, verbose=False): continue i += 1 diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..ee77f0c2 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,14 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .xegpu_parameter_selector import XeGPUParameterSelector +from .matmul_constraints import check_constraints +from .xegpu_specs import XeGPUSpecs __all__ = [ + "XeGPUParameterSelector", + "XeGPUSpecs", + "check_constraints", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/matmul_constraints.py b/lighthouse/schedule/xegpu/matmul_constraints.py new file mode 100644 index 00000000..a95b9f00 --- /dev/null +++ b/lighthouse/schedule/xegpu/matmul_constraints.py @@ -0,0 +1,246 @@ +from collections import namedtuple + +from .xegpu_specs import XeGPUSpecs + +# hardware constraints +DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( + 8, 16, 16, (8, 16), (16, 16), (8, 16) +) +PREFETCH_INST_DATA = [8, 16] +NB_WORKITEMS = 16 # workitems in subgroup +LOAD_MAX_ROWS = 32 +LOAD_MAX_COLS = 32 +PFETCH_MIN_ROWS = 8 +PFETCH_MAX_ROWS = 32 +PFETCH_MIN_COLS = 16 +PFETCH_MAX_COLS = 32 +MAX_NB_SG_THREADS = 32 # 32 for large register file, 16 otherwise +# heuristics: skip likely suboptimal configurations +MIN_NB_THREADS = 16 + + +def check_wg_tile(M: int, N: int, wg_tile: tuple[int, int]) -> tuple[int, int]: + if M % wg_tile[0] != 0: + raise ValueError("wg_tile_m does not divide M") + if N % wg_tile[1] != 0: + raise ValueError("wg_tile_n does not divide N") + wg_grid = (M // wg_tile[0], N // wg_tile[1]) + return wg_grid + + +def check_sg_tile( + wg_tile: tuple[int, int], + sg_tile: tuple[int, int], + gpu_specs: XeGPUSpecs, + min_nb_threads: int | None = None, +) -> tuple[int, int]: + if wg_tile[0] % sg_tile[0] != 0: + raise ValueError("sg_tile_m does not divide wg_tile_m") + if wg_tile[1] % sg_tile[1] != 0: + raise ValueError("sg_tile_n does not divide wg_tile_n") + if sg_tile[0] % DPAS.M != 0: + raise ValueError("sg_tile_m not multiple of dpas_m") + if sg_tile[1] % DPAS.N != 0: + raise ValueError("sg_tile_n not multiple of dpas_n") + nb_sg_threads_m = wg_tile[0] // sg_tile[0] + nb_sg_threads_n = wg_tile[1] // sg_tile[1] + nb_sg_threads = nb_sg_threads_m * nb_sg_threads_n + if nb_sg_threads > gpu_specs.max_nb_threads: + raise ValueError("too many sg threads") + if min_nb_threads is not None and nb_sg_threads < min_nb_threads: + raise ValueError("too few sg threads") + return nb_sg_threads_m, nb_sg_threads_n + + +def check_k_tile(K: int, k_tile: int): + if K % k_tile != 0: + raise ValueError("k_tile does not divide K") + if k_tile % DPAS.K != 0: + raise ValueError("k_tile not multiple of dpas_k") + + +def check_load_tile( + tile: tuple[int, int], + parent_shape: tuple[int, int], + child_shape: tuple[int, int], + name: str = "A", +): + if parent_shape[0] % tile[0] != 0 or parent_shape[1] % tile[1] != 0: + raise ValueError( + f"Load tile {name} {tile} does not divide the parent shape {parent_shape}." + ) + if tile[0] % child_shape[0] != 0 or tile[1] % child_shape[1] != 0: + raise ValueError( + f"Load tile {name} {tile} does not divide the child shape {child_shape}." + ) + if tile[0] < child_shape[0]: + raise ValueError(f"Load tile {name} {tile} has too few rows.") + if tile[1] < child_shape[1]: + raise ValueError(f"Load tile {name} {tile} has too few cols.") + if tile[0] > LOAD_MAX_ROWS: + raise ValueError(f"Load tile {name} {tile} has too many rows.") + if tile[1] > LOAD_MAX_COLS: + raise ValueError(f"Load tile {name} {tile} has too many cols.") + + +def check_load_tile_a( + tile: tuple[int, int], + sg_tile: tuple[int, int], + k_tile: int, +): + data_shape = (sg_tile[0], k_tile) + child_shape = DPAS.A_TILE + check_load_tile(tile, data_shape, child_shape, name="A") + + +def check_load_tile_b( + tile: tuple[int, int], + sg_tile: tuple[int, int], + k_tile: int, +): + data_shape = (k_tile, sg_tile[1]) + child_shape = DPAS.B_TILE + check_load_tile(tile, data_shape, child_shape, name="B") + + +def check_prefetch_tile( + tile: tuple[int, int], + data_shape: tuple[int, int], + gpu_specs: XeGPUSpecs, + name: str = "A", + min_nb_threads: int | None = None, + verbose: bool = False, +) -> tuple[int, int]: + if tile[0] < PFETCH_MIN_ROWS: + raise ValueError( + f"Prefetch tile {name} {tile} has too few rows (min {PFETCH_MIN_ROWS})." + ) + if tile[0] > PFETCH_MAX_ROWS: + raise ValueError( + f"Prefetch tile {name} {tile} has too many rows (max {PFETCH_MAX_ROWS})." + ) + if tile[1] < PFETCH_MIN_COLS: + raise ValueError( + f"Prefetch tile {name} {tile} has too few cols (min {PFETCH_MIN_COLS})." + ) + if tile[1] > PFETCH_MAX_COLS: + raise ValueError( + f"Prefetch tile {name} {tile} has too many cols (max {PFETCH_MAX_COLS})." + ) + if data_shape[0] % tile[0] != 0 or data_shape[1] % tile[1] != 0: + raise ValueError( + f"Prefetch tile {name} {tile} does not divide the parent shape {data_shape}." + ) + rows = int(data_shape[0] / tile[0]) + cols = int(data_shape[1] / tile[1]) + nb_threads = int(rows * cols) + if verbose: + print(f"=== Prefetch {name} ===") + print(f"tile size {tile}, grid size ({rows}, {cols}), {nb_threads} threads") + if nb_threads > gpu_specs.max_nb_threads: + raise ValueError( + f"Number of threads for {name} prefetch ({nb_threads}) exceeds max threads ({gpu_specs.max_nb_threads})." + ) + if min_nb_threads is not None and nb_threads < min_nb_threads: + raise ValueError( + f"Number of threads for {name} prefetch ({nb_threads}) is less than minimum threads ({min_nb_threads})." + ) + return rows, cols + + +def check_prefetch_tile_a( + tile: tuple[int, int], + wg_tile: tuple[int, int], + k_tile: int, + gpu_specs: XeGPUSpecs, + min_nb_threads: int | None = None, + verbose: bool = False, +) -> tuple[int, int]: + data_shape = (wg_tile[0], k_tile) + return check_prefetch_tile( + tile, + data_shape, + gpu_specs, + name="A", + min_nb_threads=min_nb_threads, + verbose=verbose, + ) + + +def check_prefetch_tile_b( + tile: tuple[int, int], + wg_tile: tuple[int, int], + k_tile: int, + gpu_specs: XeGPUSpecs, + min_nb_threads: int | None = None, + verbose: bool = False, +) -> tuple[int, int]: + data_shape = (k_tile, wg_tile[1]) + return check_prefetch_tile( + tile, + data_shape, + gpu_specs, + name="B", + min_nb_threads=min_nb_threads, + verbose=verbose, + ) + + +def check_constraints( + params: dict[str, int], + gpu_specs: XeGPUSpecs, + verbose: bool = False, +) -> bool: + """Check that the given tile size configuration is valid.""" + + M = params["m"] + N = params["n"] + K = params["k"] + wg_tile_m = params["wg_m"] + wg_tile_n = params["wg_n"] + sg_tile_m = params["sg_m"] + sg_tile_n = params["sg_n"] + load_tile_a_m = params["load_a_m"] + load_tile_a_k = params["load_a_k"] + load_tile_b_k = params["load_b_k"] + load_tile_b_n = params["load_b_n"] + prefetch_tile_a_m = params["prefetch_a_m"] + prefetch_tile_a_k = params["prefetch_a_k"] + prefetch_tile_b_k = params["prefetch_b_k"] + prefetch_tile_b_n = params["prefetch_b_n"] + k_tile = params["k_tile"] + + wg_tile = (wg_tile_m, wg_tile_n) + sg_tile = (sg_tile_m, sg_tile_n) + load_tile_a = (load_tile_a_m, load_tile_a_k) + load_tile_b = (load_tile_b_k, load_tile_b_n) + prefetch_tile_a = (prefetch_tile_a_m, prefetch_tile_a_k) + prefetch_tile_b = (prefetch_tile_b_k, prefetch_tile_b_n) + + try: + check_wg_tile(M, N, wg_tile) + check_sg_tile(wg_tile, sg_tile, gpu_specs, min_nb_threads=MIN_NB_THREADS) + check_k_tile(K, k_tile) + check_load_tile_a(load_tile_a, sg_tile, k_tile) + check_load_tile_b(load_tile_b, sg_tile, k_tile) + check_prefetch_tile_a( + prefetch_tile_a, + wg_tile, + k_tile, + gpu_specs, + min_nb_threads=MIN_NB_THREADS, + verbose=verbose, + ) + check_prefetch_tile_b( + prefetch_tile_b, + wg_tile, + k_tile, + gpu_specs, + min_nb_threads=MIN_NB_THREADS, + verbose=verbose, + ) + except ValueError as e: + if verbose: + print(f"Invalid configuration: {e}") + return False + return True diff --git a/lighthouse/schedule/xegpu/matmul_costmodel.py b/lighthouse/schedule/xegpu/matmul_costmodel.py new file mode 100644 index 00000000..77af9c77 --- /dev/null +++ b/lighthouse/schedule/xegpu/matmul_costmodel.py @@ -0,0 +1,344 @@ +""" +Utilities for matrix multiplication tile size selection and performance +estimation for XeGPU targets. +""" + +from itertools import product +from typing import Callable + +from .xegpu_specs import XeGPUSpecs +from .matmul_constraints import ( + check_constraints, + check_wg_tile, + check_sg_tile, + check_k_tile, + check_prefetch_tile_a, + check_prefetch_tile_b, +) +from .matmul_constraints import ( + DPAS, + PFETCH_MIN_ROWS, + PFETCH_MAX_ROWS, + PFETCH_MIN_COLS, + PFETCH_MAX_COLS, +) + + +def generate_configs( + M: int, + N: int, + K: int, + gpu_specs: XeGPUSpecs, + perf_threshold: float | None = None, + load_strategy: str = "dpas", + pf_strategy: str = "first", + max_nb_configs: int | None = None, +) -> list[tuple[float, dict[str, int]]]: + """Generate valid tile size configurations for (M, N, K) matrix multiplication. + + gpu_specs: XeGPUSpecs object containing the target GPU specifications. + + perf_threshold: if set, only return configurations with + estimated_perf >= perf_threshold * max_found_estimated_perf. + + load_strategy: sets the load tile selection strategy + - "dpas": use dpas op A/B tile size as load tile + + pf_strategy: sets the prefetch tile selection strategy + - "first": take the first prefetch tile for A and B + - "all": append all valid prefetch tiles for A and B + + Returns: + A list of (perf_estimate, params_dict) tuples sorted by perf_estimate (descending). + """ + # TODO add data types as variables + + assert load_strategy == "dpas", "Only 'dpas' load strategy is supported" + + def tuple_to_param_dict( + M: int, + N: int, + K: int, + config: tuple[ + tuple[int, int], + tuple[int, int], + int, + tuple[int, int], + tuple[int, int], + tuple[int, int], + tuple[int, int], + ], + ) -> dict[str, int]: + wg_tile, sg_tile, k_tile, ld_a, ld_b, pf_a, pf_b = config + return { + "m": M, + "n": N, + "k": K, + "wg_m": wg_tile[0], + "wg_n": wg_tile[1], + "sg_m": sg_tile[0], + "sg_n": sg_tile[1], + "k_tile": k_tile, + "load_a_m": ld_a[0], + "load_a_k": ld_a[1], + "load_b_k": ld_b[0], + "load_b_n": ld_b[1], + "prefetch_a_m": pf_a[0], + "prefetch_a_k": pf_a[1], + "prefetch_b_k": pf_b[0], + "prefetch_b_n": pf_b[1], + "prefetch_a_nb": 1, + "prefetch_b_nb": 1, + } + + # define search space + wg_options = [64, 128, 256] + sg_options = [32, 64, 128] + k_tile_options = [16, 32, 64] + + wg_tiles = product(wg_options, wg_options) + sg_tiles = product(sg_options, sg_options) + + # grid search + valid_configs = [] + for config in product(wg_tiles, sg_tiles, k_tile_options): + wg_tile, sg_tile, k_tile = config + try: + perf = estimate_performance( + M, N, K, wg_tile, sg_tile, k_tile, gpu_specs, verbose=False + ) + if pf_strategy == "first": + pf_a, pf_b = generate_prefetch_tiles(wg_tile, k_tile, gpu_specs, n=1) + pf_a_list = [pf_a] + pf_b_list = [pf_b] + else: + pf_a_list, pf_b_list = generate_prefetch_tiles( + wg_tile, k_tile, gpu_specs + ) + # load_strategy = "dpas" + load_a_list = [DPAS.A_TILE] + load_b_list = [DPAS.B_TILE] + for la, lb, pa, pb in product( + load_a_list, load_b_list, pf_a_list, pf_b_list + ): + c = (wg_tile, sg_tile, k_tile, la, lb, pa, pb) + params = tuple_to_param_dict(M, N, K, c) + if check_constraints(params, gpu_specs, verbose=False): + valid_configs.append((perf, params)) + except ValueError: + pass + + # sort by performance (descending) + valid_configs.sort(key=lambda x: x[0], reverse=True) + + if perf_threshold is not None: + assert 0 < perf_threshold <= 1, "perf_threshold must be in (0, 1]" + max_perf = valid_configs[0][0] + valid_configs = [c for c in valid_configs if c[0] >= perf_threshold * max_perf] + + if max_nb_configs is not None: + valid_configs = valid_configs[:max_nb_configs] + + return valid_configs + + +def generate_prefetch_tiles( + wg_tile: tuple[int, int], + k_tile: int, + gpu_specs: XeGPUSpecs, + n: int | None = None, +) -> tuple[ + list[tuple[int, int]] | tuple[int, int], + list[tuple[int, int]] | tuple[int, int], +]: + """Generates valid prefetch tile sizes for A and B. + + Candidates are sorted by number of threads (descending) and then by how + balanced the thread grid is (descending). + """ + + def gridsearch( + check_fn: Callable[ + [tuple[int, int], tuple[int, int], int, XeGPUSpecs], + tuple[int, int], + ], + ) -> list[tuple[int, int]]: + tiles = [] + for rows in range(PFETCH_MIN_ROWS, PFETCH_MAX_ROWS + 1): + for cols in range(PFETCH_MIN_COLS, PFETCH_MAX_COLS + 1): + tile = (rows, cols) + try: + grid = check_fn(tile, wg_tile, k_tile, gpu_specs) + nb_threads = int(grid[0] * grid[1]) + tiles.append((tile, nb_threads, grid)) + except ValueError: + pass + # sort by number of threads and then by how balanced the thread grid is + tiles.sort(key=lambda x: (x[1], -abs(x[2][0] - x[2][1])), reverse=True) + tiles = [t[0] for t in tiles] + return tiles + + prefetch_tiles_a = gridsearch(check_prefetch_tile_a) + prefetch_tiles_b = gridsearch(check_prefetch_tile_b) + if n is not None: + if n == 1: + prefetch_tiles_a = prefetch_tiles_a[0] + prefetch_tiles_b = prefetch_tiles_b[0] + else: + prefetch_tiles_a = prefetch_tiles_a[:n] + prefetch_tiles_b = prefetch_tiles_b[:n] + + return prefetch_tiles_a, prefetch_tiles_b + + +def estimate_performance( + M: int, + N: int, + K: int, + wg_tile: tuple[int, int], + sg_tile: tuple[int, int], + k_tile: int, + gpu_specs: XeGPUSpecs, + prefetch_tile_a: tuple[int, int] | None = None, + prefetch_tile_b: tuple[int, int] | None = None, + verbose: bool = True, +) -> float: + """ + Estimate the performance of the given tile size configuration for (M,N,K) + matrix multiplication on the target GPU. + + The performance estimate is based on a simple roofline model using the + workgroup and K tile sizes and the GPU's peak FLOPS and memory bandwidth. + + If `verbose` is True, prints a summary of the configuration. + + Returns the estimated performance in FLOPS. + + Raises ValueError if the given configuration is invalid. + """ + if verbose: + print("=== Global Level ===") + print(f"Matrix sizes: M={M}, N={N}, K={K}") + + # TODO generalize + ab_dtype_size = 2 # bytes for f16 + c_dtype_size = 4 # bytes for f32 + + # WG + if verbose: + print("=== Workgroup Level ===") + roofline_threshold = gpu_specs.peak_flops / gpu_specs.bw_global_mem # in FLOPs/Byte + + wg_grid = check_wg_tile(M, N, wg_tile) + check_k_tile(K, k_tile) + nb_wgs = wg_grid[0] * wg_grid[1] + if verbose: + print(f"Workgroup tile size: {wg_tile}, grid size: {wg_grid}, nb WGs: {nb_wgs}") + print(f"K tile size: {k_tile}") + + A_wg_shape = (wg_tile[0], k_tile) + B_wg_shape = (k_tile, wg_tile[1]) + C_wg_shape = (wg_tile[0], wg_tile[1]) + + A_footprint = A_wg_shape[0] * A_wg_shape[1] * ab_dtype_size + B_footprint = B_wg_shape[0] * B_wg_shape[1] * ab_dtype_size + C_footprint = C_wg_shape[0] * C_wg_shape[1] * c_dtype_size + + if verbose: + print(f"A: shape={A_wg_shape}, footprint={A_footprint / 1024:.2f} KB") + print(f"B: shape={B_wg_shape}, footprint={B_footprint / 1024:.2f} KB") + print(f"C: shape={C_wg_shape}, footprint={C_footprint / 1024:.2f} KB") + + total_footprint = A_footprint + B_footprint + if verbose: + print(f"Total SLM footprint: {total_footprint / 1024:.1f} KB") + # TODO check that A,B,C fit in shared local memory + + # arithmetic intensity + f = (wg_tile[0] * wg_tile[1]) / (wg_tile[0] + wg_tile[1]) + ai = f * ab_dtype_size + if verbose: + print(f"Arithmetic intensity: {ai:.2f} FLOPs/Byte") + print(f"Roofline threshold: {roofline_threshold:.2f} FLOPs/Byte") + + if verbose: + if ai < roofline_threshold: + print(" => Bandwidth-bound regime") + else: + print(" => Compute-bound regime") + + xe_core_utilization = min(nb_wgs / gpu_specs.nb_xe_cores, 1.0) + if verbose: + print(f"XE core utilization: {xe_core_utilization:.2f}") + + # predict flops + peak_flops = ( + gpu_specs.peak_flops * xe_core_utilization + ) # possible under-utilization + predicted_throughput = min(peak_flops, ai * gpu_specs.bw_global_mem) + if verbose: + print(f"Predicted throughput: {predicted_throughput / 1e12:.2f} TFLOPS") + + # SG + if verbose: + print("=== Subgroup Level ===") + + sg_grid = check_sg_tile(wg_tile, sg_tile, gpu_specs) + nb_sgs = sg_grid[0] * sg_grid[1] + if verbose: + print( + f"Subgroup tile size: {sg_tile}, grid size: {sg_grid}, nb SGs per WG: {nb_sgs}" + ) + + A_sg_shape = (sg_tile[0], k_tile) + B_sg_shape = (k_tile, sg_tile[1]) + C_sg_shape = (sg_tile[0], sg_tile[1]) + + A_footprint = A_sg_shape[0] * A_sg_shape[1] * ab_dtype_size + B_footprint = B_sg_shape[0] * B_sg_shape[1] * ab_dtype_size + C_footprint = C_sg_shape[0] * C_sg_shape[1] * c_dtype_size + + total_footprint = A_footprint + B_footprint + C_footprint + if verbose: + print(f"A: shape={A_sg_shape}, footprint={A_footprint / 1024:.2f} KB") + print(f"B: shape={B_sg_shape}, footprint={B_footprint / 1024:.2f} KB") + print(f"C: shape={C_sg_shape}, footprint={C_footprint / 1024:.2f} KB") + print(f"Total register footprint: {total_footprint / 1024:.2f} KB") + + nb_parallel_dpas = (sg_tile[0] // DPAS.M) * (sg_tile[1] // DPAS.N) + if verbose: + print(f"Number of DPAS threads: {nb_parallel_dpas}") + nb_dpas_ops = nb_parallel_dpas * (k_tile // DPAS.K) + if verbose: + print(f"Number of total DPAS ops: {nb_dpas_ops}") + + # FIXME move remaining checks to util funcs + if nb_parallel_dpas > gpu_specs.dpas_exec_size: + raise ValueError( + f"Number of parallel DPAS ops ({nb_parallel_dpas}) exceeds hardware execution size ({gpu_specs.dpas_exec_size})." + ) + + # estimate number of used registers + reg_size = 64 # bytes per register + nb_reg = int((A_footprint + B_footprint + C_footprint) / reg_size) + if verbose: + print(f"Number of registers: {nb_reg}") + + if nb_reg > gpu_specs.nb_registers: + raise ValueError( + f"Number of registers ({nb_reg}) exceeds hardware register file size ({gpu_specs.nb_registers})." + ) + + if prefetch_tile_a: + # check that prefetch tile is suitable for WG-k tile + check_prefetch_tile_a( + prefetch_tile_a, wg_tile, k_tile, gpu_specs, verbose=verbose + ) + + if prefetch_tile_b: + # check that prefetch tile is suitable for WG-k tile + check_prefetch_tile_b( + prefetch_tile_b, wg_tile, k_tile, gpu_specs, verbose=verbose + ) + + return predicted_throughput diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 46d68724..0a0067dd 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -1,5 +1,3 @@ -from collections import namedtuple - from mlir import ir from mlir.dialects.transform import loop from mlir.dialects.transform import bufferization @@ -21,22 +19,20 @@ from lighthouse.dialects import smt_ext from lighthouse.dialects.transform import smt_ext as td_smt_ext from lighthouse.dialects.transform.tune_ext import knob, KnobValue - -# hardware constraints -DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( - 8, 16, 16, (8, 16), (16, 16), (8, 16) +from lighthouse.schedule.xegpu.xegpu_parameter_selector import XeGPUParameterSelector +from lighthouse.schedule.xegpu.matmul_constraints import ( + DPAS, + PREFETCH_INST_DATA, + NB_WORKITEMS, + LOAD_MAX_ROWS, + LOAD_MAX_COLS, + PFETCH_MIN_ROWS, + PFETCH_MIN_COLS, + PFETCH_MAX_ROWS, + PFETCH_MAX_COLS, + MAX_NB_SG_THREADS, + MIN_NB_THREADS, ) -PREFETCH_INST_DATA = [8, 16] -NB_WORKITEMS = 16 # workitems in subgroup -LOAD_MAX_ROWS = 32 -LOAD_MAX_COLS = 32 -PFETCH_MIN_ROWS = 8 -PFETCH_MAX_ROWS = 32 -PFETCH_MIN_COLS = 16 -PFETCH_MAX_COLS = 32 -MAX_NB_SG_THREADS = 32 # 32 for large register file, 16 otherwise -# heuristics: skip likely suboptimal configurations -MIN_NB_THREADS = 16 @KnobValue.ast_rewrite(in_exprs=True) @@ -124,7 +120,40 @@ def mlp_schedule( op_name="builtin.module", deduplicate=True, ) + # preprocess layer parameters for i, layer_params in enumerate(params): + m = layer_params.get("m") + n = layer_params.get("n") + k = layer_params.get("k") + assert all(d is not None for d in (m, n, k)), ( + "m, n, k must be provided in params" + ) + + required_params = [ + "wg_m", + "wg_n", + "sg_m", + "sg_n", + "k_tile", + "load_a_m", + "load_a_k", + "load_b_k", + "load_b_n", + "prefetch_a_m", + "prefetch_a_k", + "prefetch_b_k", + "prefetch_b_n", + "prefetch_a_nb", + "prefetch_b_nb", + ] + if not all(p in layer_params for p in required_params): + # Some parameters are missing, use the parameter selector to fill + device = layer_params.get("device") + param_selector = XeGPUParameterSelector(device=device) + generated_params = param_selector.get_parameters(m, n, k) + # Overwrite original params to ensure consistent configuration + layer_params.update(generated_params) + layer_params |= params_with_constraints_imposed( layer_params, knob_name_prefix=f"layer_{i}_" ) diff --git a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py index 87373fea..1b14148c 100644 --- a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py +++ b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py @@ -4,28 +4,11 @@ import json from pathlib import Path -from .mlp_schedule import DPAS +from .matmul_costmodel import generate_configs +from .xegpu_specs import XeGPUSpecs DEFAULT_JSON_FILE = str(Path(__file__).parent / "matmul_params.json") -DEFAULT_PARAMS = { - "wg_m": 128, - "wg_n": 128, - "sg_m": 32, - "sg_n": 32, - "k_tile": 32, - "load_a_m": DPAS.A_TILE[0], - "load_a_k": DPAS.A_TILE[1], - "load_b_k": DPAS.B_TILE[0], - "load_b_n": DPAS.B_TILE[1], - "prefetch_a_m": 16, - "prefetch_a_k": 16, - "prefetch_b_k": 16, - "prefetch_b_n": 32, - "prefetch_a_nb": 1, - "prefetch_b_nb": 1, -} - def load_param_database(json_file: str = DEFAULT_JSON_FILE) -> dict: matmul_param_db = {} @@ -39,24 +22,26 @@ def load_param_database(json_file: str = DEFAULT_JSON_FILE) -> dict: return matmul_param_db -matmul_param_db = load_param_database() - - -def get_matmul_parameters(m: int, n: int, k: int) -> list: - shape = (m, n, k) - if shape not in matmul_param_db: - if m >= 128 and n >= 256 and k >= 64: - params = DEFAULT_PARAMS.copy() - params["m"] = m - params["n"] = n - params["k"] = k - return params - else: - raise ValueError( - f"Parameter selector: No parameters found for matmul shape {shape}" - ) - return matmul_param_db[shape] - - -def get_parameters_for_layers(shapes: list[tuple[int, int, int]]) -> list: - return [get_matmul_parameters(*shape) for shape in shapes] +class XeGPUParameterSelector: + def __init__(self, device: str | None = None, json_file: str | None = None): + if json_file is None: + json_file = DEFAULT_JSON_FILE + self.device = device if device is not None else "B70" + self.matmul_param_db = load_param_database(json_file) + + def get_parameters(self, m: int, n: int, k: int) -> dict: + shape = (m, n, k) + if shape not in self.matmul_param_db: + try: + # Use cost model to generate tile sizes and take first config + gpu_specs = XeGPUSpecs.get(self.device) + configs = generate_configs(m, n, k, gpu_specs, max_nb_configs=1) + params = configs[0][1] + return params + except Exception as e: + msg = f"Error generating parameters for shape {shape} using cost model: {e}" + raise ValueError(msg) + return self.matmul_param_db[shape] + + def get_parameters_for_layers(self, shapes: list[tuple[int, int, int]]) -> list: + return [self.get_parameters(*shape) for shape in shapes] diff --git a/lighthouse/schedule/xegpu/xegpu_specs.py b/lighthouse/schedule/xegpu/xegpu_specs.py new file mode 100644 index 00000000..e008747e --- /dev/null +++ b/lighthouse/schedule/xegpu/xegpu_specs.py @@ -0,0 +1,78 @@ +""" +XeGPU hardware specifications for tile size selection. +""" + +from dataclasses import dataclass + +__all__ = ["XeGPUSpecs"] + + +gpu_bmg_common = { + "dpas_exec_size": 16, # number of parallel dpas ops in sg + # large register file + "max_nb_threads_lgrf": 32, + "nb_registers_lgrf": 256, + # small register file + "max_nb_threads_sgrf": 64, + "nb_registers_sgrf": 128, +} + +gpu_specs_db = { + "B70": { + "name": "Intel Arc B70", + "nb_xe_cores": 32, + "peak_flops": 155000e9, # float16 + "bw_global_mem": 608e9, # GB/s + **gpu_bmg_common, + }, + "B50": { + "name": "Intel Arc B50", + "nb_xe_cores": 16, + "peak_flops": 78000e9, # float16 + "bw_global_mem": 224e9, # GB/s + **gpu_bmg_common, + }, +} + + +@dataclass +class XeGPUSpecs: + """XeGPU hardware specification relevant for tile size selection.""" + + name: str + nb_xe_cores: int + peak_flops: float # in FLOPS + bw_global_mem: float # in bytes/s + dpas_exec_size: int # number of parallel dpas ops in subgroup + # large register file + max_nb_threads_lgrf: int # max number of threads per subgroup + nb_registers_lgrf: int # number of registers per thread + # small register file + max_nb_threads_sgrf: int # max number of threads per subgroup + nb_registers_sgrf: int # number of registers per thread + reg_file: str # "large" or "small" + + @property + def max_nb_threads(self): + if self.reg_file == "large": + return self.max_nb_threads_lgrf + else: + return self.max_nb_threads_sgrf + + @property + def nb_registers(self): + if self.reg_file == "large": + return self.nb_registers_lgrf + else: + return self.nb_registers_sgrf + + @classmethod + def get(cls, device_name: str, reg_file: str = "large") -> "XeGPUSpecs": + assert reg_file in ["large", "small"], "reg_file must be 'large' or 'small'" + if device_name not in gpu_specs_db: + raise ValueError( + f"Unknown device name: {device_name}. Available devices: {list(gpu_specs_db.keys())}" + ) + specs_dict = gpu_specs_db[device_name] + specs_dict["reg_file"] = reg_file + return cls(**specs_dict) From b23c7abc37eae4755cc7ee295e56ffcf69882727 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 20 May 2026 19:57:47 +0300 Subject: [PATCH 2/4] copilot comments --- examples/xegpu/torch_matmul.py | 15 +++++++--- examples/xegpu/tune_matmul_gridsearch.py | 2 +- .../schedule/xegpu/matmul_constraints.py | 1 - lighthouse/schedule/xegpu/matmul_costmodel.py | 2 +- lighthouse/schedule/xegpu/mlp_schedule.py | 28 +++++++++++++------ .../xegpu/xegpu_parameter_selector.py | 10 +++++-- lighthouse/schedule/xegpu/xegpu_specs.py | 2 +- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/examples/xegpu/torch_matmul.py b/examples/xegpu/torch_matmul.py index 7a629a00..6c4bb3d3 100644 --- a/examples/xegpu/torch_matmul.py +++ b/examples/xegpu/torch_matmul.py @@ -140,9 +140,14 @@ def parse_cli_args(description): help="Tile size for cooperative prefetching of subgroup B matrix", ) parser.add_argument( - "--prefetch-nb", + "--prefetch-a-nb", type=int, - help="Number of initial prefetches.", + help="Number of initial prefetches for A matrix.", + ) + parser.add_argument( + "--prefetch-b-nb", + type=int, + help="Number of initial prefetches for B matrix.", ) parser.add_argument( "--check-result", @@ -217,8 +222,10 @@ def parse_cli_args(description): params["prefetch_a_m"], params["prefetch_a_k"] = args.prefetch_tile_a if args.prefetch_tile_b: params["prefetch_b_k"], params["prefetch_b_n"] = args.prefetch_tile_b - if args.prefetch_nb is not None: - params["prefetch_nb"] = args.prefetch_nb + if args.prefetch_a_nb is not None: + params["prefetch_a_nb"] = args.prefetch_a_nb + if args.prefetch_b_nb is not None: + params["prefetch_b_nb"] = args.prefetch_b_nb for param_key, v in params.items(): if v is None: diff --git a/examples/xegpu/tune_matmul_gridsearch.py b/examples/xegpu/tune_matmul_gridsearch.py index 5accd46d..41c98277 100644 --- a/examples/xegpu/tune_matmul_gridsearch.py +++ b/examples/xegpu/tune_matmul_gridsearch.py @@ -174,7 +174,7 @@ def sample_to_dict(sample: list) -> dict: ) parser.add_argument( "--target", - choices=["B70", "B580"], + choices=["B70", "B50"], default="B70", help="Target GPU device.", ) diff --git a/lighthouse/schedule/xegpu/matmul_constraints.py b/lighthouse/schedule/xegpu/matmul_constraints.py index a95b9f00..3e2bf8c1 100644 --- a/lighthouse/schedule/xegpu/matmul_constraints.py +++ b/lighthouse/schedule/xegpu/matmul_constraints.py @@ -14,7 +14,6 @@ PFETCH_MAX_ROWS = 32 PFETCH_MIN_COLS = 16 PFETCH_MAX_COLS = 32 -MAX_NB_SG_THREADS = 32 # 32 for large register file, 16 otherwise # heuristics: skip likely suboptimal configurations MIN_NB_THREADS = 16 diff --git a/lighthouse/schedule/xegpu/matmul_costmodel.py b/lighthouse/schedule/xegpu/matmul_costmodel.py index 77af9c77..726e08b3 100644 --- a/lighthouse/schedule/xegpu/matmul_costmodel.py +++ b/lighthouse/schedule/xegpu/matmul_costmodel.py @@ -131,7 +131,7 @@ def tuple_to_param_dict( # sort by performance (descending) valid_configs.sort(key=lambda x: x[0], reverse=True) - if perf_threshold is not None: + if perf_threshold is not None and len(valid_configs) > 0: assert 0 < perf_threshold <= 1, "perf_threshold must be in (0, 1]" max_perf = valid_configs[0][0] valid_configs = [c for c in valid_configs if c[0] >= perf_threshold * max_perf] diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 0a0067dd..66a9f70d 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -19,8 +19,9 @@ from lighthouse.dialects import smt_ext from lighthouse.dialects.transform import smt_ext as td_smt_ext from lighthouse.dialects.transform.tune_ext import knob, KnobValue -from lighthouse.schedule.xegpu.xegpu_parameter_selector import XeGPUParameterSelector -from lighthouse.schedule.xegpu.matmul_constraints import ( +from .xegpu_specs import XeGPUSpecs +from .xegpu_parameter_selector import XeGPUParameterSelector +from .matmul_constraints import ( DPAS, PREFETCH_INST_DATA, NB_WORKITEMS, @@ -30,7 +31,6 @@ PFETCH_MIN_COLS, PFETCH_MAX_ROWS, PFETCH_MAX_COLS, - MAX_NB_SG_THREADS, MIN_NB_THREADS, ) @@ -110,6 +110,12 @@ def mlp_schedule( ) -> ir.Module: """Generate transform schedule module for MLP payload.""" assert params is not None and len(params) > 0, "params must be provided." + devices = {p.get("device") for p in params if "device" in p} + assert len(devices) <= 1, f"Multiple devices specified in params list: {devices}" + device = devices.pop() if devices else None + param_selector = XeGPUParameterSelector(device=device) + gpu_specs = param_selector.gpu_specs + with schedule_boilerplate() as (schedule, named_seq): # match the payload module anytype = transform.AnyOpType.get() @@ -148,8 +154,7 @@ def mlp_schedule( ] if not all(p in layer_params for p in required_params): # Some parameters are missing, use the parameter selector to fill - device = layer_params.get("device") - param_selector = XeGPUParameterSelector(device=device) + # NOTE None values are interpreted as knobs in the constraint function generated_params = param_selector.get_parameters(m, n, k) # Overwrite original params to ensure consistent configuration layer_params.update(generated_params) @@ -161,6 +166,7 @@ def mlp_schedule( try: bundle_xegpu_mlp_schedule( payload_mod, + gpu_specs=gpu_specs, params=params, stop_at_stage=stop_at_stage, ) @@ -174,6 +180,7 @@ def mlp_schedule( def bundle_xegpu_mlp_schedule( mod: ir.Value[transform.AnyOpType], + gpu_specs: XeGPUSpecs, params: list[dict[str, int | KnobValue]], stop_at_stage: str = "", ) -> ir.Value[transform.AnyOpType]: @@ -291,7 +298,9 @@ def constrain_wg_sg_and_calc_nb_threads( sg_m_threads = WG_M // SG_M sg_n_threads = WG_N // SG_N sg_threads = sg_m_threads * sg_n_threads - smt_ext.assert_(sg_threads <= MAX_NB_SG_THREADS, "too many SG threads") + smt_ext.assert_( + sg_threads <= gpu_specs.max_nb_threads, "too many SG threads" + ) smt_ext.assert_(sg_threads >= MIN_NB_THREADS, "too few SG threads") # number of threads collapsed to 1d layout @@ -332,7 +341,7 @@ def constrain_wg_sg_and_calc_nb_threads( ) for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) - xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params) + xegpu_wg_annotation_for_mlp_layer(gpu_func, gpu_specs=gpu_specs, **layer_params) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() @@ -342,6 +351,7 @@ def constrain_wg_sg_and_calc_nb_threads( def xegpu_wg_annotation_for_mlp_layer( gpu_func: ir.Value, + gpu_specs: XeGPUSpecs, *, wg_m: int | KnobValue, wg_n: int | KnobValue, @@ -447,14 +457,14 @@ def constrain_and_calculate_load_and_prefetch_params( prefetch_th_a_m = WG_M // PFA_M prefetch_th_a_k = K_TILE // PFA_K prefetch_th_a = prefetch_th_a_m * prefetch_th_a_k - smt_ext.assert_(prefetch_th_a <= MAX_NB_SG_THREADS) + smt_ext.assert_(prefetch_th_a <= gpu_specs.max_nb_threads) smt_ext.assert_(prefetch_th_a_m * prefetch_th_a_k >= MIN_NB_THREADS) # prefetch B thread layout prefetch_th_b_k = K_TILE // PFB_K prefetch_th_b_n = WG_N // PFB_N prefetch_th_b = prefetch_th_b_k * prefetch_th_b_n - smt_ext.assert_(prefetch_th_b <= MAX_NB_SG_THREADS) + smt_ext.assert_(prefetch_th_b <= gpu_specs.max_nb_threads) if isinstance(prefetch_th_b, smt_ext.SMTIntValue): # NB: Constraint only enabled during tuning. smt_ext.assert_(prefetch_th_b_k * prefetch_th_b_n >= MIN_NB_THREADS) diff --git a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py index 1b14148c..256fa6c9 100644 --- a/lighthouse/schedule/xegpu/xegpu_parameter_selector.py +++ b/lighthouse/schedule/xegpu/xegpu_parameter_selector.py @@ -27,6 +27,7 @@ def __init__(self, device: str | None = None, json_file: str | None = None): if json_file is None: json_file = DEFAULT_JSON_FILE self.device = device if device is not None else "B70" + self.gpu_specs = XeGPUSpecs.get(self.device) self.matmul_param_db = load_param_database(json_file) def get_parameters(self, m: int, n: int, k: int) -> dict: @@ -34,13 +35,16 @@ def get_parameters(self, m: int, n: int, k: int) -> dict: if shape not in self.matmul_param_db: try: # Use cost model to generate tile sizes and take first config - gpu_specs = XeGPUSpecs.get(self.device) - configs = generate_configs(m, n, k, gpu_specs, max_nb_configs=1) + configs = generate_configs(m, n, k, self.gpu_specs, max_nb_configs=1) + if not configs: + raise ValueError( + f"Cost model did not return any valid configurations for matmul {shape}." + ) params = configs[0][1] return params except Exception as e: msg = f"Error generating parameters for shape {shape} using cost model: {e}" - raise ValueError(msg) + raise ValueError(msg) from e return self.matmul_param_db[shape] def get_parameters_for_layers(self, shapes: list[tuple[int, int, int]]) -> list: diff --git a/lighthouse/schedule/xegpu/xegpu_specs.py b/lighthouse/schedule/xegpu/xegpu_specs.py index e008747e..ecb68e04 100644 --- a/lighthouse/schedule/xegpu/xegpu_specs.py +++ b/lighthouse/schedule/xegpu/xegpu_specs.py @@ -73,6 +73,6 @@ def get(cls, device_name: str, reg_file: str = "large") -> "XeGPUSpecs": raise ValueError( f"Unknown device name: {device_name}. Available devices: {list(gpu_specs_db.keys())}" ) - specs_dict = gpu_specs_db[device_name] + specs_dict = gpu_specs_db[device_name].copy() specs_dict["reg_file"] = reg_file return cls(**specs_dict) From 7fd5c21a682dd491eb831019bc187f6aaca62b4e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Wed, 20 May 2026 20:02:03 +0300 Subject: [PATCH 3/4] cost model: simplify generate_configs --- lighthouse/schedule/xegpu/matmul_costmodel.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/lighthouse/schedule/xegpu/matmul_costmodel.py b/lighthouse/schedule/xegpu/matmul_costmodel.py index 726e08b3..e5a60a80 100644 --- a/lighthouse/schedule/xegpu/matmul_costmodel.py +++ b/lighthouse/schedule/xegpu/matmul_costmodel.py @@ -30,7 +30,6 @@ def generate_configs( K: int, gpu_specs: XeGPUSpecs, perf_threshold: float | None = None, - load_strategy: str = "dpas", pf_strategy: str = "first", max_nb_configs: int | None = None, ) -> list[tuple[float, dict[str, int]]]: @@ -41,20 +40,17 @@ def generate_configs( perf_threshold: if set, only return configurations with estimated_perf >= perf_threshold * max_found_estimated_perf. - load_strategy: sets the load tile selection strategy - - "dpas": use dpas op A/B tile size as load tile - pf_strategy: sets the prefetch tile selection strategy - "first": take the first prefetch tile for A and B - "all": append all valid prefetch tiles for A and B + Load tile sizes are currently fixed to DPAS tile sizes for A and B. + Returns: A list of (perf_estimate, params_dict) tuples sorted by perf_estimate (descending). """ # TODO add data types as variables - assert load_strategy == "dpas", "Only 'dpas' load strategy is supported" - def tuple_to_param_dict( M: int, N: int, @@ -107,15 +103,10 @@ def tuple_to_param_dict( perf = estimate_performance( M, N, K, wg_tile, sg_tile, k_tile, gpu_specs, verbose=False ) - if pf_strategy == "first": - pf_a, pf_b = generate_prefetch_tiles(wg_tile, k_tile, gpu_specs, n=1) - pf_a_list = [pf_a] - pf_b_list = [pf_b] - else: - pf_a_list, pf_b_list = generate_prefetch_tiles( - wg_tile, k_tile, gpu_specs - ) - # load_strategy = "dpas" + n_prefetch = 1 if pf_strategy == "first" else None + pf_a_list, pf_b_list = generate_prefetch_tiles( + wg_tile, k_tile, gpu_specs, n=n_prefetch + ) load_a_list = [DPAS.A_TILE] load_b_list = [DPAS.B_TILE] for la, lb, pa, pb in product( @@ -148,8 +139,8 @@ def generate_prefetch_tiles( gpu_specs: XeGPUSpecs, n: int | None = None, ) -> tuple[ - list[tuple[int, int]] | tuple[int, int], - list[tuple[int, int]] | tuple[int, int], + list[tuple[int, int]], + list[tuple[int, int]], ]: """Generates valid prefetch tile sizes for A and B. @@ -181,12 +172,8 @@ def gridsearch( prefetch_tiles_a = gridsearch(check_prefetch_tile_a) prefetch_tiles_b = gridsearch(check_prefetch_tile_b) if n is not None: - if n == 1: - prefetch_tiles_a = prefetch_tiles_a[0] - prefetch_tiles_b = prefetch_tiles_b[0] - else: - prefetch_tiles_a = prefetch_tiles_a[:n] - prefetch_tiles_b = prefetch_tiles_b[:n] + prefetch_tiles_a = prefetch_tiles_a[:n] + prefetch_tiles_b = prefetch_tiles_b[:n] return prefetch_tiles_a, prefetch_tiles_b From f91d3526911c6d299335cbb1e8bd2be7c01b417e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 21 May 2026 09:00:25 +0300 Subject: [PATCH 4/4] matmul example: add test for custom not pre-tuned shape --- examples/xegpu/matmul.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index a2a7a36d..c73dfa4a 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -1,3 +1,4 @@ +# RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --relu | FileCheck %s