diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 7dd59d10..ee09dc7a 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -5,6 +5,8 @@ Select precision with ``data_format="fp4"|"fp8"|"a8w4"``. """ +import functools + import flydsl.compiler as flyc import flydsl.expr as fx from flydsl._mlir import ir @@ -37,6 +39,7 @@ LDS_PAD_D_BYTES = 16 +@functools.lru_cache(maxsize=256) def compile_mxscale_gemm( *, data_format: str = "fp4", @@ -61,6 +64,8 @@ def compile_mxscale_gemm( use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, + b_streaming: bool = False, + scale_load_path: str = "tdm", ): """Compile an MXFP4 or MXFP8 GEMM kernel with TDM async copy. @@ -85,6 +90,11 @@ def compile_mxscale_gemm( if out_dtype not in ("f32", "bf16", "f16"): raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 + scale_load_paths = ("tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split") + if scale_load_path not in scale_load_paths: + raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") + use_scale_buffer_load = scale_load_path != "tdm" + use_ab_split_scale_buffer_load = scale_load_path == "buffer_lds_stage_ab_split" if num_buffers not in (2, 3, 4): raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") @@ -96,16 +106,16 @@ def compile_mxscale_gemm( if cluster_m * cluster_n > 16: raise ValueError(f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}") effective_waves_per_eu = waves_per_eu - if use_cluster and effective_waves_per_eu is None: - effective_waves_per_eu = 2 num_warps = m_warp * n_warp block_threads = num_warps * WAVE_SIZE if block_threads > 1024: raise ValueError(f"block_threads must be <= 1024, got {block_threads}") - if wave_specialized_tdm and num_warps != 4: - raise ValueError(f"wave_specialized_tdm requires exactly 4 waves, got {num_warps}") + if wave_specialized_tdm and num_warps < 4: + raise ValueError(f"wave_specialized_tdm requires at least 4 waves, got {num_warps}") + if use_ab_split_scale_buffer_load and not wave_specialized_tdm: + raise ValueError("scale_load_path='buffer_lds_stage_ab_split' requires wave_specialized_tdm=True") # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) @@ -176,17 +186,39 @@ def compile_mxscale_gemm( b_scale_load_rep = warp_tile_n // WMMA_M if is_fp4 else wmma_n_rep _b_frag_loads_per_wn = 2 if is_a8w4 else 4 - _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + (b_scale_load_rep + 3) // 4 + (wmma_m_rep + 3) // 4 + _a_frag_loads_per_wm = 2 if is_fp4 else 4 + _scale_ds_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 + _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads + _as_ds_loads = wmma_m_rep * _a_frag_loads_per_wm + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES + if use_ab_split_scale_buffer_load: + if tile_m % 2 != 0: + raise ValueError(f"buffer_lds_stage_ab_split requires even tile_m, got {tile_m}") + if tile_n % 32 != 0: + raise ValueError(f"buffer_lds_stage_ab_split requires tile_n divisible by 32, got {tile_n}") lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b + ab_split_a_rows = tile_m // 2 + ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 lds_a_scale_bytes = tile_m * scale_k_per_tile + _scale_guard_bytes lds_b_scale_bytes = tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile + _scale_dma_bytes = 16 + if use_scale_buffer_load: + if interleaved_scale_cols_a % _scale_dma_bytes != 0: + raise ValueError( + "buffer_lds_stage scale loads require A scale rows to be 16-byte aligned, " + f"got interleaved_scale_cols_a={interleaved_scale_cols_a}" + ) + if interleaved_scale_cols_b % _scale_dma_bytes != 0: + raise ValueError( + "buffer_lds_stage scale loads require B scale rows to be 16-byte aligned, " + f"got interleaved_scale_cols_b={interleaved_scale_cols_b}" + ) def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -263,10 +295,12 @@ def _align_up(value: int, align: int) -> int: arena_alloc.ptr = total_d_bytes check_smem_capacity(arena_total_bytes, gpu_arch) - # TENSORcnt is tracked per-wave in hardware. The regular path issues four - # tensor ops per wave per K-stage, while the wave-specialized path issues - # only one tensor op from each dedicated loader wave. - TDM_LOADS_PER_STEP = 1 if wave_specialized_tdm else 4 + # TENSORcnt is tracked per-wave in hardware. When scale is loaded through + # buffer_load_lds, TDM only carries A/B data. + if wave_specialized_tdm: + TDM_LOADS_PER_STEP = 1 + else: + TDM_LOADS_PER_STEP = 2 if use_scale_buffer_load else 4 tail_plan = [(ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan] # Pre-compute epilogue sub-tile layout (unified for FP4 vec16 and FP8 vec8) @@ -290,23 +324,27 @@ def _align_up(value: int, align: int) -> int: COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING = "row_major_streaming" COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" + COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" + COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" def _pick_compute_schedule_kind(): - # The FP4 col-band (quadrant) schedule reduces VGPR bank conflicts by - # splitting B loads into left/right halves and processing four quadrants - # (top-left, bottom-left, top-right, bottom-right). This distributes - # accumulator writes across different VGPR bank groups and overlaps - # B-right loading with quadrant-1 WMMA compute. - if not is_fp4: - return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - if wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0: + if b_streaming: + return COMPUTE_SCHEDULE_B_STREAMING + if wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8: return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - if n_accs < 8: - return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - return COMPUTE_SCHEDULE_FP4_COL_BAND + # Quadrant schedules split B into left/right halves and compute + # top-left, bottom-left, top-right, bottom-right. FP4 additionally + # changes accumulator layout for bank friendliness; FP8 keeps row-major + # accumulators and uses the split to increase LDS-load-to-WMMA distance. + if is_fp4: + return COMPUTE_SCHEDULE_FP4_COL_BAND + if data_format == "fp8": + return COMPUTE_SCHEDULE_FP8_QUADRANT + return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING compute_schedule_kind = _pick_compute_schedule_kind() use_fp4_bank_friendly_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND + use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT if use_fp4_bank_friendly_schedule: _bank_half_wm = wmma_m_rep // 2 @@ -327,6 +365,12 @@ def _pick_compute_schedule_kind(): for _wn in range(_bank_half_wn, wmma_n_rep): _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) + if use_fp8_quadrant_schedule: + _fp8_half_wm = wmma_m_rep // 2 + _fp8_half_wn = wmma_n_rep // 2 + _fp8_group_size = _fp8_half_wm * _fp8_half_wn + _fp8_b_scale_loads = (b_scale_load_rep + 3) // 4 + @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( arg_c: fx.Tensor, @@ -412,6 +456,44 @@ def make_desc_b(memref, k_base): atomic_barrier_enable=atomic_barrier_enable, ) + def make_desc_a_half(memref, k_base, m_half: int): + row_start = m_half * ab_split_a_rows + k_packed_off = k_base / arith.index(PACK_FACTOR_A) + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, + lds_memref=memref, + global_offset=(blk_m + arith.index(row_start), k_packed_off), + tensor_shape=(tile_m, packed_tile_k_a), + strides=(K_packed_a, 1), + tile_shape=(ab_split_a_rows, packed_tile_k_a), + elem_bytes=1, + pad_interval=packed_tile_k_a, + pad_amount=LDS_PAD_A_BYTES, + num_warps=1, + workgroup_mask=a_mcast_mask, + lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), + atomic_barrier_enable=atomic_barrier_enable, + ) + + def make_desc_b_half(memref, k_base, n_half: int): + group_start = n_half * ab_split_b_groups + k_packed_off = k_base / arith.index(PACK_FACTOR_B) + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b, + lds_memref=memref, + global_offset=(blk_n / arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)), + tensor_shape=(N // 16, K_packed_b * 16), + strides=(K_packed_b * 16, 1), + tile_shape=(ab_split_b_groups, packed_tile_k_b * 16), + elem_bytes=1, + pad_interval=0, + pad_amount=0, + num_warps=1, + workgroup_mask=b_mcast_mask, + lds_byte_offset=arith.index(group_start * packed_tile_k_b * 16), + atomic_barrier_enable=atomic_barrier_enable, + ) + def make_desc_as(memref, k_base): k_scale_off = k_base / arith.index(SCALE_BLOCK) outer_off = blk_m / arith.index(wmma_m_rep) @@ -617,29 +699,32 @@ def load_scale_slice_b128(lds_buffer, scale_base, full_reps, rep_start, rep_coun results.append(vecs[i // 4][i % 4]) return results + def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): + """Load both scale tensors and apply op_sel downsampling per format. + + FP4 BScale has no op_sel (scaleAType=0 fixed); only AScale halves. + FP8/A8W4 16x16 supports op_sel on both. + """ + a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + if const_expr(use_scale_opsel): + a = a_all[::2] + b = b_all if const_expr(is_fp4) else b_all[::2] + else: + a, b = a_all, b_all + return a, b + def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): - """Load B frags + all scales for one K-subtile.""" b_frags = [load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(wmma_n_rep)] - b_scales_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - a_scales_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - if const_expr(is_fp4): - # FP4 32x16: scaleAType=0 fixed (no op_sel on BScale) - b_scales = b_scales_all - if const_expr(use_scale_opsel): - a_scales = a_scales_all[::2] - else: - a_scales = a_scales_all - else: - # FP8/A8W4 16x16: both scales support op_sel - if const_expr(use_scale_opsel): - b_scales = b_scales_all[::2] - a_scales = a_scales_all[::2] - else: - b_scales = b_scales_all - a_scales = a_scales_all + a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) return b_frags, b_scales, a_scales - def _emit_wmma(accs, wm, wn, a_frag, b_frags, a_scales, b_scales): + def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): + a_frags = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] + a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) + return a_frags, a_scales, b_scales + + def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): """Emit one WMMA instruction (format-specific).""" idx = wm * wmma_n_rep + wn if const_expr(use_scale_opsel): @@ -653,7 +738,7 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frags, a_scales, b_scales): # 32x16 WMMA with A/B swap: SRC0=B, SRC1=A accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( T.vec(16, T.f32), - b_frags[wn], + b_frag, a_frag, accs[idx], b_scales[wn * 2], @@ -671,7 +756,7 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frags, a_scales, b_scales): b_opsel = 0 accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( T.vec(8, T.f32), - b_frags[wn], + b_frag, a_frag, accs[idx], b_scales[b_scale_idx], @@ -713,7 +798,7 @@ def _emit_rows(start_wm, a_frags): emit_filler() for wn_raw in range_constexpr(wmma_n_rep): wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw - _emit_wmma(accs, wm, wn, a_frags[frag_i], b_frags, a_scales, b_scales) + _emit_wmma(accs, wm, wn, a_frags[frag_i], b_frags[wn], a_scales, b_scales) a_frags_front = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(_front_wm)] @@ -746,6 +831,59 @@ def _emit_rows(start_wm, a_frags): return accs, next_result return accs + def _b_streaming_compute( + accs, + b_buf, + b_bases, + a_frags, + a_scales, + b_scales, + ks, + emit_filler=None, + next_info=None, + mid_compute_callback=None, + ): + """B-streaming counterpart to _a_streaming_compute (A held, B streamed).""" + next_result = None + _front_wn = (wmma_n_rep + 1) // 2 + _back_wn = wmma_n_rep - _front_wn + + def _emit_cols(start_wn, b_frags_chunk): + for frag_i in range_constexpr(len(b_frags_chunk)): + wn = start_wn + frag_i + if const_expr(wn == wmma_n_rep - 1 and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + for wm_raw in range_constexpr(wmma_m_rep): + wm = (wmma_m_rep - 1 - wm_raw) if (wn % 2 == 1) else wm_raw + _emit_wmma(accs, wm, wn, a_frags[wm], b_frags_chunk[frag_i], a_scales, b_scales) + + b_frags_front = [load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(_front_wn)] + _use_partial_drain = next_info is not None and _front_wn * wmma_m_rep >= 4 + + if const_expr(_use_partial_drain): + next_result = _load_a_and_scales(*next_info) + rocdl.s_wait_dscnt(_as_ds_loads) + else: + rocdl.s_wait_dscnt(0) + + _emit_cols(0, b_frags_front) + + if const_expr(mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + + if const_expr(_back_wn > 0): + b_frags_back = [load_b_frag(b_buf, b_bases, _front_wn + h, ks) for h in range_constexpr(_back_wn)] + rocdl.s_wait_dscnt(_as_ds_loads if _use_partial_drain else 0) + _emit_cols(_front_wn, b_frags_back) + + if const_expr(_use_partial_drain): + return accs, next_result + if const_expr(next_info is not None): + return accs, _load_a_and_scales(*next_info) + return accs + # ── Compute on one LDS buffer ── def compute_tile(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None): current_accs = list(accs_in) @@ -938,6 +1076,167 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ return current_accs + def compute_tile_fp8_quadrant( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + ): + current_accs = list(accs_in) + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) + _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn + _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads + + def _load_a_group(wm_base, wm_count, ks): + return [load_a_frag(a_buf, a_bases[wm_base + wm_local], ks) for wm_local in range_constexpr(wm_count)] + + def _load_b_half(wn_base, ks): + return [ + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_half_wn) + ] + + def _load_a_scales(ks): + a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + if const_expr(use_scale_opsel): + return a_scales[::2] + return a_scales + + def _load_b_scales(ks): + b_scales = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + if const_expr(use_scale_opsel): + return b_scales[::2] + return b_scales + + def _load_b_left_bundle(ks): + return _load_b_half(0, ks), _load_b_scales(ks) + + def _emit_group(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + if const_expr(emit_filler_now and emit_filler is not None): + rocdl.sched_barrier(0) + emit_filler() + for wm_local in range_constexpr(_fp8_half_wm): + global_wm = wm_base + wm_local + for wn_local in range_constexpr(_fp8_half_wn): + global_wn = wn_base + wn_local + _emit_wmma( + current_accs, + global_wm, + global_wn, + a_frags[wm_local], + b_frags[wn_local], + a_scales, + b_scales, + ) + + b_left_frags, b_scales = _load_b_left_bundle(0) + + for ks in range_constexpr(k_wmma_steps): + is_last_ks = ks == k_wmma_steps - 1 + a_scales = _load_a_scales(ks) + + a_top_frags = _load_a_group(0, _fp8_half_wm, ks) + a_bottom_frags = _load_a_group(_fp8_half_wm, _fp8_half_wm, ks) + + # Keep bottom A outstanding while the first quadrant consumes top A. + rocdl.s_wait_dscnt(_fp8_half_wm * DS_LOADS_PER_A_FRAG) + + _emit_group(0, 0, a_top_frags, b_left_frags, a_scales, b_scales) + b_right_frags = _load_b_half(_fp8_half_wn, ks) + + # Keep the newly issued right-half B loads outstanding while + # bottom A becomes ready for the second quadrant. + rocdl.s_wait_dscnt(_b_half_loads) + + _emit_group(_fp8_half_wm, 0, a_bottom_frags, b_left_frags, a_scales, b_scales) + + if const_expr(ks == 0 and mid_compute_callback is not None): + rocdl.sched_barrier(0) + mid_compute_callback() + + if const_expr(not is_last_ks): + next_left_frags, next_b_scales = _load_b_left_bundle(ks + 1) + # Current right-half B must be ready before Q2/Q3, while + # the next ks left-half bundle stays in flight. + rocdl.s_wait_dscnt(_b_left_bundle_loads) + else: + rocdl.s_wait_dscnt(0) + + _emit_group(0, _fp8_half_wn, a_top_frags, b_right_frags, a_scales, b_scales) + _emit_group( + _fp8_half_wm, + _fp8_half_wn, + a_bottom_frags, + b_right_frags, + a_scales, + b_scales, + emit_filler_now=is_last_ks, + ) + + if const_expr(not is_last_ks): + b_left_frags = next_left_frags + b_scales = next_b_scales + + return current_accs + + def compute_tile_b_streaming( + accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None + ): + """compute_tile counterpart with A held and B streamed.""" + current_accs = list(accs_in) + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) + load_args = (a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases) + + if const_expr(k_wmma_steps == 1): + a_frags, a_scales, b_scales = _load_a_and_scales(*load_args, 0) + return _b_streaming_compute( + current_accs, + b_buf, + b_bases, + a_frags, + a_scales, + b_scales, + 0, + emit_filler=emit_filler, + mid_compute_callback=mid_compute_callback, + ) + + prev_a, prev_as, prev_bs = _load_a_and_scales(*load_args, 0) + for ks in range_constexpr(k_wmma_steps - 1): + current_accs, (prev_a, prev_as, prev_bs) = _b_streaming_compute( + current_accs, + b_buf, + b_bases, + prev_a, + prev_as, + prev_bs, + ks, + next_info=load_args + (ks + 1,), + mid_compute_callback=mid_compute_callback if ks == 0 else None, + ) + return _b_streaming_compute( + current_accs, + b_buf, + b_bases, + prev_a, + prev_as, + prev_bs, + k_wmma_steps - 1, + emit_filler=emit_filler, + ) + def hot_loop_scheduler(): _half_wm = wmma_m_rep // 2 _half_wmma = _half_wm * wmma_n_rep @@ -977,7 +1276,38 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_mfma(_group_wmma) rocdl.sched_barrier(0) + def hot_loop_scheduler_fp8_quadrant(): + _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG + _a_scale_loads = (wmma_m_rep + 3) // 4 + _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn + _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads + _group_wmma = _fp8_group_size + + for _ks in range_constexpr(k_wmma_steps): + if const_expr(_ks == 0): + rocdl.sched_dsrd(_b_left_bundle_loads + _a_scale_loads + _a_all_loads) + else: + rocdl.sched_dsrd(_a_scale_loads + _a_all_loads) + rocdl.sched_mfma(_group_wmma) + rocdl.sched_dsrd(_b_half_loads) + rocdl.sched_mfma(_group_wmma) + if const_expr(_ks < k_wmma_steps - 1): + rocdl.sched_dsrd(_b_left_bundle_loads) + rocdl.sched_mfma(_group_wmma) + rocdl.sched_mfma(_group_wmma) + rocdl.sched_barrier(0) + def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None): + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): + return compute_tile_b_streaming( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=emit_filler, + mid_compute_callback=mid_compute_callback, + ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): return compute_tile_fp4_bank_friendly( accs_in, @@ -988,6 +1318,16 @@ def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=No emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, ) + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): + return compute_tile_fp8_quadrant( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=emit_filler, + mid_compute_callback=mid_compute_callback, + ) return compute_tile( accs_in, lds_a, @@ -998,9 +1338,35 @@ def compute_tile_scheduled(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=No mid_compute_callback=mid_compute_callback, ) + def hot_loop_scheduler_b_streaming(): + """hot_loop_scheduler counterpart for B-streaming.""" + _front_wn = (wmma_n_rep + 1) // 2 + _back_wn = wmma_n_rep - _front_wn + _a_loads_total = wmma_m_rep * DS_LOADS_PER_A_FRAG + _front_b_loads = _front_wn * _b_frag_loads_per_wn + _back_b_loads = _back_wn * _b_frag_loads_per_wn + _next_ks_loads = _a_loads_total + _scale_ds_loads + + for _ks in range_constexpr(k_wmma_steps): + if const_expr(_ks == 0): + rocdl.sched_dsrd(_next_ks_loads + _front_b_loads) + else: + rocdl.sched_dsrd(_front_b_loads) + rocdl.sched_mfma(_front_wn * wmma_m_rep) + if const_expr(_back_wn > 0): + rocdl.sched_dsrd(_back_b_loads) + rocdl.sched_mfma(_back_wn * wmma_m_rep) + if const_expr(_ks < k_wmma_steps - 1): + rocdl.sched_dsrd(_next_ks_loads) + rocdl.sched_barrier(0) + def hot_loop_scheduler_scheduled(): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): + hot_loop_scheduler_b_streaming() + elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): hot_loop_scheduler_fp4_bank_friendly() + elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): + hot_loop_scheduler_fp8_quadrant() else: hot_loop_scheduler() @@ -1204,6 +1570,21 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) + if const_expr(use_ab_split_scale_buffer_load): + stages_a0_lds_addr = [] + stages_b0_lds_addr = [] + stages_a1_lds_addr = [] + stages_b1_lds_addr = [] + for i in range_constexpr(num_buffers): + stages_a0_lds_addr.append(_dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 0), 1)) + stages_b0_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 0), 1)) + stages_a1_lds_addr.append(_dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 1), 1)) + stages_b1_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 1), 1)) + + desc_a0_init = make_desc_a_half(stages_a_mem[0], split_k_base, 0) + desc_b0_init = make_desc_b_half(stages_b_mem[0], split_k_base, 0) + desc_a1_init = make_desc_a_half(stages_a_mem[0], split_k_base, 1) + desc_b1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) @@ -1211,36 +1592,48 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): adv_bs_i32 = fx.Int32(tile_k // SCALE_BLOCK * b_scale_load_rep) pred_const = fx.Int32(1) - if const_expr(wave_specialized_tdm): - active_stage_lds_addr = [ - _select_wave_tdm_value( - stages_a_lds_addr[i], - stages_b_lds_addr[i], - stages_as_lds_addr[i], - stages_bs_lds_addr[i], + active_pred_const = arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)) + + def _select4(values): + return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) + + def _desc_lanes(descs, lane): + return [_dg0_lane(desc, lane) for desc in descs] + + def _select_active_tdm(stage_lds_addrs, descs, advs): + active_stages = [ + _select_wave_tdm_value( + stage_lds_addrs[0][i], + stage_lds_addrs[1][i], + stage_lds_addrs[2][i], + stage_lds_addrs[3][i], + ) + for i in range_constexpr(num_buffers) + ] + return ( + active_stages, + _select4(_desc_lanes(descs, 2)), + _select4(_desc_lanes(descs, 3)), + _select4([desc.dgroup1 for desc in descs]), + _select4(advs), ) - for i in range_constexpr(num_buffers) - ] - active_addr_lo = _select_wave_tdm_value( - _dg0_lane(desc_a_init, 2), - _dg0_lane(desc_b_init, 2), - _dg0_lane(desc_as_init, 2), - _dg0_lane(desc_bs_init, 2), - ) - active_addr_hi = _select_wave_tdm_value( - _dg0_lane(desc_a_init, 3), - _dg0_lane(desc_b_init, 3), - _dg0_lane(desc_as_init, 3), - _dg0_lane(desc_bs_init, 3), + + else: + active_pred_const = pred_const + + if const_expr(wave_specialized_tdm and not use_scale_buffer_load): + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), + (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), + (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), ) - active_dgroup1 = _select_wave_tdm_value( - desc_a_init.dgroup1, - desc_b_init.dgroup1, - desc_as_init.dgroup1, - desc_bs_init.dgroup1, + elif const_expr(use_ab_split_scale_buffer_load): + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_a0_lds_addr, stages_b0_lds_addr, stages_a1_lds_addr, stages_b1_lds_addr), + (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), + (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), ) - active_adv_i32 = _select_wave_tdm_value(adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) else: addr_lo_a = _dg0_lane(desc_a_init, 2) addr_hi_a = _dg0_lane(desc_a_init, 3) @@ -1256,40 +1649,149 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): dgroup1_as = desc_as_init.dgroup1 dgroup1_bs = desc_bs_init.dgroup1 + if const_expr(use_scale_buffer_load): + scale_a_base = buffer_ops.extract_base_index(arg_a_scale) + scale_b_base = buffer_ops.extract_base_index(arg_b_scale) + scale_async_offset = fx.Int32(0) + scale_async_aux = fx.Int32(0) + + def _dma_scale_tile_to_lds( + global_base, + lds_mem, + global_row_base, + global_col_base, + row_stride, + row_bytes: int, + total_bytes: int, + ): + from flydsl._mlir.dialects import memref as memref_dialect + from flydsl._mlir.dialects import rocdl as rocdl_dialect + + for batch in range_constexpr( + (total_bytes + block_threads * _scale_dma_bytes - 1) // (block_threads * _scale_dma_bytes) + ): + batch_byte = batch * block_threads * _scale_dma_bytes + copy_byte = arith.index(batch_byte) + tx * arith.index(_scale_dma_bytes) + if copy_byte < arith.index(total_bytes): + row = copy_byte / arith.index(row_bytes) + col = copy_byte % arith.index(row_bytes) + global_byte = (global_row_base + row) * arith.index(row_stride) + global_col_base + col + global_ptr = buffer_ops.create_llvm_ptr(global_base + global_byte, address_space=1) + lds_ptr = buffer_ops.create_llvm_ptr( + memref_dialect.extract_aligned_pointer_as_index(lds_mem) + copy_byte, + address_space=3, + ) + rocdl_dialect.global_load_async_to_lds_b128( + global_ptr, + lds_ptr, + scale_async_offset, + scale_async_aux, + ) + + def _issue_scale_buffer_loads(stage_idx, k_base): + k_scale_off = k_base / arith.index(SCALE_BLOCK) + _dma_scale_tile_to_lds( + scale_a_base, + stages_as_mem[stage_idx], + blk_m / arith.index(wmma_m_rep), + k_scale_off * arith.index(wmma_m_rep), + wmma_m_rep * K_scale, + interleaved_scale_cols_a, + tile_m * scale_k_per_tile, + ) + _dma_scale_tile_to_lds( + scale_b_base, + stages_bs_mem[stage_idx], + blk_n / arith.index(b_scale_load_rep), + k_scale_off * arith.index(b_scale_load_rep), + b_scale_load_rep * K_scale, + interleaved_scale_cols_b, + tile_n * scale_k_per_tile, + ) + + def _wait_scale_buffer_loads(): + if const_expr(use_scale_buffer_load): + rocdl.s_wait_asynccnt(0) + + def _pipeline_fence(outstanding=0): + _wait_scale_buffer_loads() + pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) + + def _pipeline_fence_signal(outstanding=0): + _wait_scale_buffer_loads() + pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) + + def _issue_ab_tdm(load_stage, addr_a, addr_b): + dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[load_stage], addr_a, addr_hi_a) + dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[load_stage], addr_b, addr_hi_b) + issue_tdm_loads( + tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), + tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), + wave_specialized=wave_specialized_tdm, + ) + + if const_expr(wave_specialized_tdm and (not use_scale_buffer_load or use_ab_split_scale_buffer_load)): + + def _issue_active_tdm(load_stage, addr_box, scale_k_box=None, k_prefetch=None): + dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) + addr_box[0] = addr_box[0] + active_adv_i32 + if scale_k_box is not None: + _issue_scale_buffer_loads(load_stage, scale_k_box[0]) + scale_k_box[0] = scale_k_box[0] + arith.index(tile_k) + if k_prefetch is not None: + _l2_prefetch(k_prefetch) + # Prologue - if const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm and not use_scale_buffer_load): for i in range_constexpr(pre_loaded): - dg0 = _pack_dg0(pred_const, active_stage_lds_addr[i], active_addr_lo, active_addr_hi) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) - active_addr_lo = active_addr_lo + active_adv_i32 + addr_box = [active_addr_lo] + _issue_active_tdm(i, addr_box) + active_addr_lo = addr_box[0] + elif const_expr(use_ab_split_scale_buffer_load): + for i in range_constexpr(pre_loaded): + addr_box = [active_addr_lo] + scale_k_box = [split_k_base + arith.index(i * tile_k)] + _issue_active_tdm(i, addr_box, scale_k_box=scale_k_box) + active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[i], addr_lo_a, addr_hi_a) dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[i], addr_lo_b, addr_hi_b) - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) - - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) + if const_expr(use_scale_buffer_load): + issue_tdm_loads( + tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), + tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), + wave_specialized=wave_specialized_tdm, + ) + _issue_scale_buffer_loads(i, split_k_base + arith.index(i * tile_k)) + else: + dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) + dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) + issue_tdm_loads( + tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), + tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), + tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), + tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), + wave_specialized=wave_specialized_tdm, + ) addr_lo_a = addr_lo_a + adv_a_i32 addr_lo_b = addr_lo_b + adv_b_i32 - addr_lo_as = addr_lo_as + adv_as_i32 - addr_lo_bs = addr_lo_bs + adv_bs_i32 + if const_expr(not use_scale_buffer_load): + addr_lo_as = addr_lo_as + adv_as_i32 + addr_lo_bs = addr_lo_bs + adv_bs_i32 + if const_expr(use_scale_buffer_load): + scale_next_k_base = split_k_base + arith.index(pre_loaded * tile_k) - pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2), use_cluster=use_cluster) + _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. # This overlaps TDM DMA with the remaining WMMA instructions, _fence_outstanding = TDM_LOADS_PER_STEP * (num_buffers - 2) if const_expr(loop_iters > 0): - if const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm and not use_scale_buffer_load): init_args = list(accs) + [active_addr_lo] for loop_iter, state in range(0, loop_iters, 1, init=init_args): @@ -1299,7 +1801,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers - pipeline_fence_signal(outstanding=_fence_outstanding, use_cluster=use_cluster) + _pipeline_fence_signal(outstanding=_fence_outstanding) pipeline_fence_wait(use_cluster=use_cluster) addr_box = [cur_addr_lo] @@ -1313,10 +1815,7 @@ def _mid_tdm_ws( + arith.index(buf_idx * tile_k) ), ): - dg0 = _pack_dg0(pred_const, active_stage_lds_addr[_ls], _ab[0], active_addr_hi) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) - _ab[0] = _ab[0] + active_adv_i32 - _l2_prefetch(_k_off) + _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) rocdl.sched_barrier(0) accs_in = compute_tile_scheduled( @@ -1334,6 +1833,108 @@ def _mid_tdm_ws( accs = list(results[:n_accs]) active_addr_lo = results[n_accs] + elif const_expr(use_ab_split_scale_buffer_load): + init_args = list(accs) + [active_addr_lo, scale_next_k_base] + + for loop_iter, state in range(0, loop_iters, 1, init=init_args): + accs_in = list(state[:n_accs]) + cur_addr_lo = state[n_accs] + cur_scale_k = state[n_accs + 1] + + for buf_idx in range_constexpr(num_buffers): + load_stage = (buf_idx + num_buffers - 1) % num_buffers + + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) + + addr_box = [cur_addr_lo] + scale_k_box = [cur_scale_k] + + def _mid_tdm_split_scale_dma( + _ls=load_stage, + _ab=addr_box, + _scale_k=scale_k_box, + _k_off=( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index(buf_idx * tile_k) + ), + ): + _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k, k_prefetch=_k_off) + + rocdl.sched_barrier(0) + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_split_scale_dma, + ) + cur_addr_lo = addr_box[0] + cur_scale_k = scale_k_box[0] + hot_loop_scheduler_scheduled() + + results = yield list(accs_in) + [cur_addr_lo, cur_scale_k] + + accs = list(results[:n_accs]) + active_addr_lo = results[n_accs] + scale_next_k_base = results[n_accs + 1] + elif const_expr(use_scale_buffer_load): + init_args = list(accs) + [addr_lo_a, addr_lo_b, scale_next_k_base] + + for loop_iter, state in range(0, loop_iters, 1, init=init_args): + accs_in = list(state[:n_accs]) + cur_lo_a = state[n_accs] + cur_lo_b = state[n_accs + 1] + cur_scale_k = state[n_accs + 2] + + for buf_idx in range_constexpr(num_buffers): + load_stage = (buf_idx + num_buffers - 1) % num_buffers + + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) + + addr_boxes = [[cur_lo_a], [cur_lo_b]] + scale_k_box = [cur_scale_k] + + def _mid_tdm_scale_dma( + _ls=load_stage, + _ab=addr_boxes, + _scale_k=scale_k_box, + _k_off=( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index(buf_idx * tile_k) + ), + ): + _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) + _ab[0][0] = _ab[0][0] + adv_a_i32 + _ab[1][0] = _ab[1][0] + adv_b_i32 + _issue_scale_buffer_loads(_ls, _scale_k[0]) + _scale_k[0] = _scale_k[0] + arith.index(tile_k) + _l2_prefetch(_k_off) + + rocdl.sched_barrier(0) + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_scale_dma, + ) + cur_lo_a = addr_boxes[0][0] + cur_lo_b = addr_boxes[1][0] + cur_scale_k = scale_k_box[0] + hot_loop_scheduler_scheduled() + + results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_scale_k] + + accs = list(results[:n_accs]) + addr_lo_a = results[n_accs] + addr_lo_b = results[n_accs + 1] + scale_next_k_base = results[n_accs + 2] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -1347,7 +1948,7 @@ def _mid_tdm_ws( for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers - pipeline_fence_signal(outstanding=_fence_outstanding, use_cluster=use_cluster) + _pipeline_fence_signal(outstanding=_fence_outstanding) pipeline_fence_wait(use_cluster=use_cluster) addr_boxes = [[cur_lo_a], [cur_lo_b], [cur_lo_as], [cur_lo_bs]] @@ -1403,7 +2004,7 @@ def _mid_tdm_nws( # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. if const_expr(loop_iters > 0): - pipeline_fence(outstanding=0, use_cluster=use_cluster) + _pipeline_fence(outstanding=0) elif const_expr(use_cluster): cluster.cluster_barrier() epi_addrs_box = [None] @@ -1411,7 +2012,7 @@ def _mid_tdm_nws( for _load_stage, _compute_stage, _outstanding in tail_plan: if const_expr(_outstanding == -1): if const_expr(_tail_had_load): - pipeline_fence(outstanding=0, use_cluster=use_cluster) + _pipeline_fence(outstanding=0) if const_expr(use_tdm_store): accs = compute_tile_scheduled( accs, @@ -1434,19 +2035,37 @@ def _emit_epi_addrs(): emit_filler=_emit_epi_addrs, ) else: - pipeline_fence_signal(outstanding=_outstanding, use_cluster=use_cluster) + _pipeline_fence_signal(outstanding=_outstanding) pipeline_fence_wait(use_cluster=use_cluster) _tail_mid_cb = None if const_expr(_load_stage is not None): _tail_had_load = True - if const_expr(wave_specialized_tdm): + if const_expr(use_ab_split_scale_buffer_load): + _tail_addr_box = [active_addr_lo] + _tail_scale_k = [scale_next_k_base] + + def _tail_mid_split_scale_dma(_ls=_load_stage, _ab=_tail_addr_box, _scale_k=_tail_scale_k): + _issue_active_tdm(_ls, _ab, scale_k_box=_scale_k) + + _tail_mid_cb = _tail_mid_split_scale_dma + elif const_expr(use_scale_buffer_load): + _tail_ab = [[addr_lo_a], [addr_lo_b]] + _tail_scale_k = [scale_next_k_base] + + def _tail_mid_scale_dma(_ls=_load_stage, _ab=_tail_ab, _scale_k=_tail_scale_k): + _issue_ab_tdm(_ls, _ab[0][0], _ab[1][0]) + _ab[0][0] = _ab[0][0] + adv_a_i32 + _ab[1][0] = _ab[1][0] + adv_b_i32 + _issue_scale_buffer_loads(_ls, _scale_k[0]) + _scale_k[0] = _scale_k[0] + arith.index(tile_k) + + _tail_mid_cb = _tail_mid_scale_dma + elif const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): - dg0 = _pack_dg0(pred_const, active_stage_lds_addr[_ls], _ab[0], active_addr_hi) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) - _ab[0] = _ab[0] + active_adv_i32 + _issue_active_tdm(_ls, _ab) _tail_mid_cb = _tail_mid_ws else: @@ -1482,7 +2101,14 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): ) if const_expr(_load_stage is not None): - if const_expr(wave_specialized_tdm): + if const_expr(use_ab_split_scale_buffer_load): + active_addr_lo = _tail_addr_box[0] + scale_next_k_base = _tail_scale_k[0] + elif const_expr(use_scale_buffer_load): + addr_lo_a = _tail_ab[0][0] + addr_lo_b = _tail_ab[1][0] + scale_next_k_base = _tail_scale_k[0] + elif const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] else: addr_lo_a = _tail_ab[0][0] @@ -1496,7 +2122,7 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): if const_expr(use_tdm_store): if const_expr(d_need_epilogue_fence): - pipeline_fence(outstanding=0, use_cluster=use_cluster) + _pipeline_fence(outstanding=0) rocdl.sched_barrier(0) epilogue_lds_stores(accs, d_lds_buffer, d_lane_base) rocdl.s_wait_dscnt(0) @@ -1533,6 +2159,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): use_scale_opsel, expert_sched_mode, atomic_barrier_enable, + b_streaming, + scale_load_path, ) @flyc.jit diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index 671d173d..8ab40b56 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -4,9 +4,4 @@ __version__ = "0.1.7" -# FFM simulator compatibility shim (no-op outside simulator sessions). -from ._compat import _maybe_preload_system_comgr # noqa: E402 - -_maybe_preload_system_comgr() - from .autotune import Config as Config, autotune as autotune # noqa: E402 diff --git a/python/flydsl/_compat.py b/python/flydsl/_compat.py deleted file mode 100644 index 1f4ac750..00000000 --- a/python/flydsl/_compat.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors - -"""Runtime compatibility shims loaded at import time. - -Kept separate from ``__init__.py`` so the workaround logic is isolated and -easy to find / disable. -""" - -import ctypes -import os - - -def _maybe_preload_system_comgr() -> None: - """Pre-load system ``libamd_comgr`` to avoid duplicate-option LLVM errors. - - The FFM simulator ships its own ``libamd_comgr`` that registers the same - LLVM command-line options as the system copy. If both are loaded the - process aborts with *"Option 'greedy' already exists!"*. Loading the - system copy first (with ``RTLD_GLOBAL``) makes the simulator copy a - harmless no-op. - - This function is a no-op outside FFM simulator sessions. - """ - disable = os.environ.get("FLYDSL_DISABLE_COMGR_PRELOAD", "").strip().lower() - if disable in {"1", "true", "yes", "on"}: - return - - model_path = os.environ.get("GFX1250_MODEL_PATH", "") - hsa_model_lib = os.environ.get("HSA_MODEL_LIB", "") - in_ffm_session = ("ffm-lite" in hsa_model_lib) or ("ffmlite" in model_path) - if not in_ffm_session: - return - - system_comgr = os.environ.get("FLYDSL_COMGR_PRELOAD_PATH", "/opt/rocm/lib/libamd_comgr.so.3") - sim_comgr = os.path.join(model_path, "rocm", "libamd_comgr.so.3") - if not (os.path.exists(system_comgr) and os.path.exists(sim_comgr)): - return - - mode = getattr(os, "RTLD_NOW", 0) | getattr(os, "RTLD_GLOBAL", 0) - try: - ctypes.CDLL(system_comgr, mode=mode) - except OSError: - # Keep import robust if the host ROCm stack differs. - pass diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 70cd4d15..3515faff 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -15,14 +15,13 @@ if _PYFLIR_SRC not in sys.path: sys.path.insert(0, _PYFLIR_SRC) -# workaround for simulator import pytest # noqa: E402 import torch # noqa: E402 -import flydsl # noqa: E402,F401 -- preload system comgr before torch/HIP loads LLVM - pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] +import flydsl.compiler as flyc # noqa: E402,I001 + from flydsl.runtime.device import get_rocm_arch # noqa: E402 from kernels.gemm_fp8fp4_gfx1250 import compile_mxscale_gemm # noqa: E402 from tests.kernels.utils import fp4_utils # noqa: E402 @@ -209,6 +208,8 @@ def _run_mxscale_gemm_test( waves_per_eu=None, expert_sched_mode=True, split_k=1, + b_streaming=False, + scale_load_path="tdm", return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -247,11 +248,12 @@ def _run_mxscale_gemm_test( fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" tdm_str = ", tdm_store" if use_tdm_store else ", buffer_store" + scale_load_str = "" if scale_load_path == "tdm" else f", scale_load={scale_load_path}" pad_str = _format_kernel_pad(M, N, K, padded_shape) print( f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" - f"{mcast_str}{tdm_str}, preshuffle, out={out_dtype}" + f"{mcast_str}{tdm_str}{scale_load_str}, preshuffle, out={out_dtype}" ) # Generate data @@ -321,13 +323,30 @@ def _run_mxscale_gemm_test( split_k=split_k, use_scale_opsel=use_scale_opsel, expert_sched_mode=expert_sched_mode, + b_streaming=b_streaming, + scale_load_path=scale_load_path, ) - launch_fn( - c_gpu.contiguous().view(-1), - a_gpu.contiguous().view(-1), - b_gpu.contiguous().view(-1), - as_gpu.contiguous().view(-1), - bs_gpu.contiguous().view(-1), + + # Pre-bind via flyc.compile so the launch goes through the CompiledFunction + # ctypes fast path (matches test_blockscale_preshuffle_gemm.py and any + # production caller that bench-times this kernel). The slow JitFunction + # path adds ~17us of inspect.Signature.bind + _make_cache_key per call, + # which would mask genuine kernel timing differences in the bench path. + # flyc.compile() launches the kernel once internally to trigger + # compilation, so no separate eager call is needed for correctness. + c_flat = c_gpu.contiguous().view(-1) + a_flat = a_gpu.contiguous().view(-1) + b_flat = b_gpu.contiguous().view(-1) + as_flat = as_gpu.contiguous().view(-1) + bs_flat = bs_gpu.contiguous().view(-1) + + flyc.compile( + launch_fn, + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, padded_m, padded_n, torch.cuda.current_stream(), @@ -491,8 +510,21 @@ def test_mxfp4_metadata_and_spill_regression(out_dtype): @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) +@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage"]) def test_mxfp8_gemm( - M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, use_tdm_store, out_dtype, use_scale_opsel + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers, + use_tdm_store, + out_dtype, + use_scale_opsel, + scale_load_path, ): _run_mxscale_gemm_test( "fp8", @@ -509,6 +541,7 @@ def test_mxfp8_gemm( out_dtype, l2_prefetch_distance=2, use_scale_opsel=use_scale_opsel, + scale_load_path=scale_load_path, ) @@ -572,6 +605,151 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): ) +@pytest.mark.parametrize( + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", + [ + ("fp4", 128, 512, 7168, 128, 128, 256, 2, 2), + ("fp8", 128, 256, 256, 128, 256, 128, 2, 4), + ("a8w4", 128, 256, 256, 128, 256, 128, 2, 4), + ], +) +def test_b_streaming_correctness(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): + _run_mxscale_gemm_test( + data_format, + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers=2, + use_tdm_store=True, + out_dtype="bf16", + l2_prefetch_distance=2, + b_streaming=True, + ) + + +@pytest.mark.parametrize( + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", + [ + ("fp4", 128, 256, 512, 128, 128, 256, 2, 2), + ("fp8", 128, 256, 256, 128, 256, 128, 2, 2), + ("a8w4", 128, 256, 256, 128, 256, 128, 2, 2), + ], +) +def test_b_streaming_with_wave_spec_tdm(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): + _run_mxscale_gemm_test( + data_format, + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers=2, + use_tdm_store=True, + out_dtype="bf16", + l2_prefetch_distance=2, + b_streaming=True, + wave_specialized_tdm=True, + ) + + +@pytest.mark.parametrize("scale_load_path", ["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"]) +@pytest.mark.parametrize("num_buffers", [2, 3]) +@pytest.mark.parametrize("use_tdm_store", [True, False]) +@pytest.mark.parametrize("use_scale_opsel", [False, True]) +def test_mxfp8_wave_spec_scale_load_paths(scale_load_path, num_buffers, use_tdm_store, use_scale_opsel): + _run_mxscale_gemm_test( + "fp8", + 128, + 256, + 384, + 128, + 256, + 128, + 2, + 2, + num_buffers=num_buffers, + use_tdm_store=use_tdm_store, + out_dtype="bf16", + l2_prefetch_distance=2, + wave_specialized_tdm=True, + use_scale_opsel=use_scale_opsel, + scale_load_path=scale_load_path, + ) + + +def test_mxfp8_ab_split_scale_load_allows_extra_waves(): + _run_mxscale_gemm_test( + "fp8", + 128, + 256, + 384, + 128, + 256, + 128, + 2, + 4, + num_buffers=3, + use_tdm_store=True, + out_dtype="bf16", + l2_prefetch_distance=2, + wave_specialized_tdm=True, + use_scale_opsel=True, + scale_load_path="buffer_lds_stage_ab_split", + ) + + +@pytest.mark.parametrize( + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", + [ + ("fp4", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), + ("fp8", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), + ], +) +def test_b_streaming_with_cluster_mcast( + data_format, + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + cluster_m, + cluster_n, +): + if str(get_rocm_arch()) != "gfx1250": + pytest.skip("requires gfx1250") + if "FFMLITE_TOPOLOGY" in os.environ or "AM_TOPOLOGY" in os.environ: + pytest.skip("cluster multicast not supported on simulator") + _run_mxscale_gemm_test( + data_format, + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers=2, + use_tdm_store=True, + out_dtype="bf16", + l2_prefetch_distance=2, + b_streaming=True, + cluster_m=cluster_m, + cluster_n=cluster_n, + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ @@ -610,6 +788,181 @@ def test_mxfp4_gemm_mcast( ) +@pytest.mark.parametrize( + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", + [ + ("fp8", 128, 256, 256, 128, 256, 128, 2, 2), + ("fp4", 128, 256, 256, 128, 256, 128, 2, 2), + ], + ids=["fp8-128x256x256", "fp4-128x256x256"], +) +def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): + """Verify that the gfx1250 MX-scale GEMM kernel works inside a hipGraph. + + Captures one launch, replays once, and checks the replay output is + bit-equivalent to an eager launch with the same inputs. Catches kernel + regressions that would break graph capture / replay (accidental host + syncs, allocator allocations on the kernel path, stream-event API misuse). + """ + arch = str(get_rocm_arch()) + if arch != "gfx1250": + pytest.skip(f"WMMA_SCALE requires gfx1250, got {arch}") + if "FFMLITE_TOPOLOGY" in os.environ or "AM_TOPOLOGY" in os.environ: + pytest.skip("hipGraph capture/replay not supported on simulator") + + is_fp4 = data_format == "fp4" + + # Build inputs (mirrors _run_mxscale_gemm_test, but no padding needed + # because we pick a clean shape). + torch.manual_seed(0) + if is_fp4: + a = fp4_utils.random_fp4_packed(M, K) + b = fp4_utils.random_fp4_packed(N, K) + else: + a = random_fp8_data(M, K) + b = random_fp8_data(N, K) + a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) + b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + + skt = tile_k // SCALE_BLOCK + warp_tile_m = tile_m // m_warp + warp_tile_n = tile_n // n_warp + a_scale_ps = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) + b_scale_ps = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + pack_b = 2 if is_fp4 else 1 + b_ps = fp4_utils.preshuffle_b_16x16(b, N, K // pack_b) + + a_gpu = a.cuda() + b_gpu = b_ps.cuda() + as_gpu = a_scale_ps.cuda() + bs_gpu = b_scale_ps.cuda() + c_gpu = torch.zeros(M, N, dtype=torch.bfloat16, device="cuda") + + launch_fn = compile_mxscale_gemm( + data_format=data_format, + M=M, + N=N, + K=K, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=2, + use_tdm_store=True, + out_dtype="bf16", + wave_specialized_tdm=False, + split_k=1, + ) + + c_flat = c_gpu.contiguous().view(-1) + a_flat = a_gpu.contiguous().view(-1) + b_flat = b_gpu.contiguous().view(-1) + as_flat = as_gpu.contiguous().view(-1) + bs_flat = bs_gpu.contiguous().view(-1) + compiled_exe = flyc.compile( + launch_fn, + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + M, + N, + torch.cuda.current_stream(), + ) + + # Resolve stream lazily inside the launch closure so graph capture sees + # the active capture stream rather than a stream bound before capture. + def launch(): + compiled_exe(c_flat, a_flat, b_flat, as_flat, bs_flat, M, N, torch.cuda.current_stream()) + + # ── Eager run (reference) ── + c_gpu.zero_() + launch() + torch.cuda.synchronize() + eager_result = c_gpu.clone() + assert eager_result.abs().max().item() > 0, "Eager run produced all zeros — kernel did not execute properly." + + # ── hipGraph capture ── + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + # Warmup on the capture stream so allocator state is stable + with torch.cuda.stream(s): + launch() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + c_gpu.zero_() + with torch.cuda.graph(g, stream=s): + launch() + torch.cuda.synchronize() + + # ── Replay ── + c_gpu.zero_() + g.replay() + torch.cuda.synchronize() + graph_result = c_gpu.clone() + + # ── Verify ── + assert graph_result.abs().max().item() > 0, "hipGraph replay produced all zeros — kernel was NOT captured." + # Same inputs + same kernel + same stream-order = bit-exact equality + assert torch.equal(eager_result, graph_result), ( + f"Eager vs hipGraph result mismatch: max abs diff = " + f"{(eager_result.float() - graph_result.float()).abs().max().item():.6f}" + ) + + +def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None): + """Per-launch timer that strips host launch overhead via hipGraph. + + How it works: + - Capture a single kernel launch into a hipGraph + - Replay it N times in one stream submission burst + - Per-launch time = total_burst_time / N + + NB: stream-ordered execution guarantees the N replays serialise — each + g.replay() is one submission to the stream and the next one cannot + start until the previous one finishes. + + NB: no L2 flush between replays. The whole point of the graph is to + measure the kernel in a hot, back-to-back launch scenario (which is + what production serving looks like). For cold-cache numbers use the + regular _bench_kernel_us with flush_l2=True. + """ + capture_stream = torch.cuda.Stream() + capture_stream.wait_stream(torch.cuda.current_stream()) + + # Warmup on the capture stream so the allocator / JIT cache is settled. + with torch.cuda.stream(capture_stream): + for _ in range(warmup): + if prep_fn is not None: + prep_fn() + run_fn() + torch.cuda.current_stream().wait_stream(capture_stream) + torch.cuda.synchronize() + + # Capture exactly one kernel launch into the graph. + g = torch.cuda.CUDAGraph() + if prep_fn is not None: + prep_fn() + with torch.cuda.graph(g, stream=capture_stream): + run_fn() + torch.cuda.synchronize() + + # Time iters replays as one batch. + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + start_ev.record() + for _ in range(iters): + g.replay() + end_ev.record() + torch.cuda.synchronize() + total_us = start_ev.elapsed_time(end_ev) * 1e3 + return total_us / iters + + def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): """Per-iteration CUDA events timer with L2 flush, IQR outlier removal, median. @@ -699,7 +1052,8 @@ def _run_benchmark(args): print(f" Tile: ({tile_m}, {tile_n}, {tile_k}), warps=({args.m_warp}x{args.n_warp})") print( f" Buffers={args.num_buffers}, out={args.out_dtype}, " - f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}" + f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}, " + f"scale_load={args.scale_load_path}" ) if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") @@ -768,20 +1122,40 @@ def _run_benchmark(args): use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, ) - stream = torch.cuda.current_stream() c_flat = c_gpu.view(-1) a_flat = a_gpu.view(-1) b_flat = b_gpu.view(-1) as_flat = as_gpu.view(-1) bs_flat = bs_gpu.view(-1) + # Pre-bind via flyc.compile so the bench loop calls go through the + # CompiledFunction ctypes fast path. The slow JitFunction path adds + # ~17us of inspect.Signature.bind + _make_cache_key per call, which + # would dominate per-launch latency for short kernels. + compiled_exe = flyc.compile( + launch_fn, + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + padded_m, + padded_n, + torch.cuda.current_stream(), + ) + def prep_kernel(): c_gpu.zero_() + # Resolve the stream lazily inside the closure so the graph-bench path + # captures on the active capture stream rather than the stream bound + # before capture. Same value on the eager path. def run_kernel(): - launch_fn( + compiled_exe( c_flat, a_flat, b_flat, @@ -789,7 +1163,7 @@ def run_kernel(): bs_flat, padded_m, padded_n, - stream, + torch.cuda.current_stream(), ) prep_kernel() @@ -798,10 +1172,20 @@ def run_kernel(): compile_ms = (time.perf_counter() - t0) * 1e3 print(f" Compile + first launch: {compile_ms:.0f} ms") - print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") - us = _bench_kernel_us( - run_kernel, warmup=args.warmup, iters=args.iters, flush_l2=not args.no_flush_l2, prep_fn=prep_kernel - ) + use_graph = getattr(args, "use_graph", False) + if use_graph: + print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph ({args.iters} replays)...") + # Graph mode prep: don't zero c_gpu inside the captured kernel + # (zero would be baked into the graph and runs every replay, but + # that's also fine — it would add a trivial memset per replay). + # We omit prep_fn here because the c_gpu state across replays + # doesn't matter for timing. + us = _bench_kernel_us_cudagraph(run_kernel, warmup=args.warmup, iters=args.iters) + else: + print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") + us = _bench_kernel_us( + run_kernel, warmup=args.warmup, iters=args.iters, flush_l2=not args.no_flush_l2, prep_fn=prep_kernel + ) logical_flops = 2.0 * M * N * K kernel_flops = 2.0 * padded_m * padded_n * padded_k @@ -870,9 +1254,9 @@ def run_kernel(): parser = argparse.ArgumentParser() parser.add_argument("--data-format", type=str, default="fp4", choices=["fp4", "fp8", "a8w4"]) - parser.add_argument("-M", type=int, default=8192) - parser.add_argument("-N", type=int, default=8192) - parser.add_argument("-K", type=int, default=8192) + parser.add_argument("-M", type=int, default=1024) + parser.add_argument("-N", type=int, default=1024) + parser.add_argument("-K", type=int, default=2048) parser.add_argument("--tile-m", type=int, default=128) parser.add_argument("--tile-n", type=int, default=256) parser.add_argument("--tile-k", type=int, default=256) @@ -889,7 +1273,14 @@ def run_kernel(): parser.add_argument("--wave-spec-tdm", action="store_true", default=False) parser.add_argument("--waves-per-eu", type=int, default=None) parser.add_argument("--use-scale-opsel", action="store_true", default=False) + parser.add_argument( + "--scale-load-path", + type=str, + default="tdm", + choices=["tdm", "buffer_lds_stage", "buffer_lds_stage_ab_split"], + ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) + parser.add_argument("--b-streaming", action="store_true", default=False) parser.add_argument( "--atomic-barrier-enable", action="store_true", @@ -903,6 +1294,15 @@ def run_kernel(): parser.add_argument("--warmup", type=int, default=5) parser.add_argument("--iters", type=int, default=20) parser.add_argument("--no-flush-l2", action="store_true", default=False) + parser.add_argument( + "--use-graph", + action="store_true", + default=False, + help="Time via hipGraph capture+replay to strip " + "host launch overhead from per-launch latency. " + "Implicitly disables L2 flush (graph replays " + "are back-to-back, hot-cache).", + ) args = parser.parse_args() if args.benchmark: @@ -931,4 +1331,6 @@ def run_kernel(): inst_prefetch=args.inst_prefetch, waves_per_eu=args.waves_per_eu, expert_sched_mode=args.expert_sched_mode, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, )