Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,534 changes: 1,534 additions & 0 deletions kernels/dispatch_combine_intranode_kernel.py

Large diffs are not rendered by default.

638 changes: 638 additions & 0 deletions kernels/dispatch_combine_intranode_op.py

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions python/flydsl/compiler/ast_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,14 @@ def _visit_stmt_block(self, stmts):
return new_stmts

def visit_FunctionDef(self, node: ast.FunctionDef):
if getattr(node, _ASTREWRITE_MARKER, False):
# ``_ASTREWRITE_MARKER`` is set by ReplaceIfWithDispatch /
# InsertEmptyYieldForSCFFor on the synthetic then/else/body functions
# they generate. It records *which* transformer created the node so
# only that transformer skips re-visiting -- other passes still need
# to recurse into the synthetic function body (e.g. so a ``for`` loop
# generated inside an if-then gets lowered to scf.for_dispatch).
marker = getattr(node, _ASTREWRITE_MARKER, False)
if marker is True or marker == type(self).__name__:
Comment thread
xudoyuan marked this conversation as resolved.
return node

with self.symbol_scopes.function_scope():
Expand Down Expand Up @@ -797,7 +804,7 @@ def _state_return_node():
decorator_list=[],
type_params=[],
)
setattr(then_func, _ASTREWRITE_MARKER, True)
setattr(then_func, _ASTREWRITE_MARKER, type(self).__name__)
then_func = ast.copy_location(then_func, node)
then_func = ast.fix_missing_locations(then_func)

Expand Down Expand Up @@ -839,7 +846,7 @@ def _state_return_node():
decorator_list=[],
type_params=[],
)
setattr(else_func, _ASTREWRITE_MARKER, True)
setattr(else_func, _ASTREWRITE_MARKER, type(self).__name__)
else_func = ast.copy_location(else_func, node)
else_func = ast.fix_missing_locations(else_func)
dispatch_args.append(ast.Name(else_name, ctx=ast.Load()))
Expand All @@ -861,7 +868,7 @@ def _state_return_node():
decorator_list=[],
type_params=[],
)
setattr(else_func, _ASTREWRITE_MARKER, True)
setattr(else_func, _ASTREWRITE_MARKER, type(self).__name__)
else_func = ast.copy_location(else_func, node)
else_func = ast.fix_missing_locations(else_func)
dispatch_args.append(ast.Name(else_name, ctx=ast.Load()))
Expand Down
71 changes: 71 additions & 0 deletions python/flydsl/expr/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,74 @@ def cmpf(predicate, lhs, rhs, **kwargs):





@traced_op
def divui(lhs, rhs, **kwargs):
"""Unsigned integer divide accepting DSL types and Python int constants.

Generates ``arith.divui`` (efficient ``udiv`` on AMD GPU).

Args:
lhs: Dividend (ArithValue, ir.Value, or DSL Numeric).
rhs: Divisor (ArithValue, ir.Value, DSL Numeric, or Python int).
"""
lhs_v = _to_raw(lhs)
if isinstance(rhs, int):
rhs_v = _to_raw(constant(rhs, type=lhs_v.type))
else:
rhs_v = _to_raw(rhs)
return _mlir_arith.DivUIOp(lhs_v, rhs_v, **kwargs).result


@traced_op
def remui(lhs, rhs, **kwargs):
"""Unsigned integer remainder accepting DSL types and Python int constants.

Generates ``arith.remui`` (efficient ``urem`` on AMD GPU).

Args:
lhs: Dividend (ArithValue, ir.Value, or DSL Numeric).
rhs: Divisor (ArithValue, ir.Value, DSL Numeric, or Python int).
"""
lhs_v = _to_raw(lhs)
if isinstance(rhs, int):
rhs_v = _to_raw(constant(rhs, type=lhs_v.type))
else:
rhs_v = _to_raw(rhs)
return _mlir_arith.RemUIOp(lhs_v, rhs_v, **kwargs).result


def zext_i64(val):
"""Zero-extend integer value to i64, idempotent if already i64.

Returns ArithValue for use in arithmetic expressions.
"""
from .._mlir.extras import types as T
v = _to_raw(val)
i64 = T.i64()
if v.type == i64:
return v
return _mlir_arith.ExtUIOp(i64, v).result


@traced_op
def select_by_index(index_val, values):
"""Select one of *values* by integer *index_val* via chained ``arith.select``.

Equivalent to a compile-time switch: returns ``values[index_val]``.

Args:
index_val: Integer index (i32 ``ir.Value``).
values: List of ``ir.Value`` to select from.

Returns:
The selected ``ir.Value``.
"""
out = values[0]
for i in range(1, len(values)):
pred = _mlir_arith.CmpIOp(
_mlir_arith.CmpIPredicate.eq, index_val, constant(i, type=index_val.type)
).result
out = _mlir_arith.SelectOp(pred, values[i], out).result
return out
28 changes: 28 additions & 0 deletions python/flydsl/expr/rocdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,31 @@ def ds_bpermute(res, index, src, **kw):

def readfirstlane(res, src, **kw):
return _ods_readfirstlane(res=res, src=_to_ir(src), **kw)



def ballot_i64(cond, *, loc=None):
"""Warp ballot returning 64-bit lane mask, with auto i1 coercion."""
from ..._mlir.ir import IntegerType
from ..._mlir.dialects import llvm as _llvm, rocdl as _rocdl

pred = _to_ir(cond)
i1 = IntegerType.get_signless(1)
if pred.type != i1:
pred = _llvm.TruncOp(i1, pred).result
i64 = IntegerType.get_signless(64)
return _rocdl.BallotOp(i64, pred, loc=loc).result


def readlane(val, lane, *, loc=None):
"""Read a value from a specific warp lane, accepting Python int for *lane*."""
from ..._mlir.ir import IntegerType, IntegerAttr
from ..._mlir.dialects import rocdl as _rocdl, arith as _arith

src = _to_ir(val)
i32 = IntegerType.get_signless(32)
if isinstance(lane, int):
lane_v = _arith.ConstantOp(i32, IntegerAttr.get(i32, lane)).result
else:
lane_v = _to_ir(lane)
return _rocdl.ReadlaneOp(i32, src, lane_v, loc=loc).result
31 changes: 31 additions & 0 deletions python/flydsl/expr/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,34 @@ def bitcast(result_type, source, *, loc=None, ip=None):
loc=loc,
ip=ip,
).result



# Scalar <-> vector bitcast (requires llvm.BitcastOp).
# arith.bitcast and vector.BitCastOp do not support shape changes
# (e.g. i32 <-> vector<2xbf16>); llvm.BitcastOp is required.

def bitcast_i32_to_v2bf16(val, *, loc=None):
"""Bitcast i32 scalar to vector<2xbf16> (bit-identical reinterpretation).

Used to reinterpret a packed i32 load result as two bf16 elements.
"""
from . import arith as _arith_ext
from .._mlir.dialects import llvm as _llvm
from .._mlir.extras import types as _T

v2bf16 = _T.VectorType.get([2], _T.bf16())
return _llvm.BitcastOp(v2bf16, _arith_ext.unwrap(val, loc=loc), loc=loc).res


def bitcast_v2bf16_to_i32(val, *, loc=None):
"""Bitcast vector<2xbf16> to i32 (bit-identical reinterpretation).

Used to pack two bf16 accumulator results into an i32 for store.
"""
from . import arith as _arith_ext
from .._mlir.dialects import llvm as _llvm
from .._mlir.ir import IntegerType

i32 = IntegerType.get_signless(32)
return _llvm.BitcastOp(i32, _arith_ext.unwrap(val, loc=loc), loc=loc).res
Loading
Loading