diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 54daa5ca..2438a068 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -920,39 +920,21 @@ def scf_ifexp_dispatch(cond, then_fn, else_fn): if not isinstance(cond_i1, ir.Value): raise TypeError(f"dynamic ifexp condition must lower to ir.Value, got {type(cond_i1).__name__}") - sandbox = scf.ExecuteRegionOp(result=[]) - sandbox.region.blocks.append() - with ir.InsertionPoint(sandbox.region.blocks[0]): - probe_then = then_fn() - probe_then_raw = _unwrap_value(probe_then) - probe_else = else_fn() - probe_else_raw = _unwrap_value(probe_else) - if not isinstance(probe_then_raw, ir.Value): - raise TypeError( - f"dynamic ifexp then-branch must produce an MLIR Value, " f"got {type(probe_then_raw).__name__}" - ) - if not isinstance(probe_else_raw, ir.Value): - raise TypeError( - f"dynamic ifexp else-branch must produce an MLIR Value, " f"got {type(probe_else_raw).__name__}" - ) - if probe_then_raw.type != probe_else_raw.type: - raise TypeError( - f"dynamic ifexp type mismatch: " - f"then-branch produces {probe_then_raw.type}, " - f"else-branch produces {probe_else_raw.type}" - ) - yield_type = probe_then_raw.type - - op = scf.IfOp(cond_i1, [yield_type], has_else=True, loc=ir.Location.unknown()) - with ir.InsertionPoint(op.regions[0].blocks[0]): - scf.YieldOp([_unwrap_value(then_fn())]) - if len(op.regions[1].blocks) == 0: - op.regions[1].blocks.append() - with ir.InsertionPoint(op.regions[1].blocks[0]): - scf.YieldOp([_unwrap_value(else_fn())]) - - sandbox.operation.erase() - return _wrap_like(op.results[0], probe_then) + then_val = then_fn() + else_val = else_fn() + then_raw = _unwrap_value(then_val) + else_raw = _unwrap_value(else_val) + if not isinstance(then_raw, ir.Value): + raise TypeError(f"dynamic ifexp then-branch must produce an MLIR Value, got {type(then_raw).__name__}") + if not isinstance(else_raw, ir.Value): + raise TypeError(f"dynamic ifexp else-branch must produce an MLIR Value, got {type(else_raw).__name__}") + if then_raw.type != else_raw.type: + raise TypeError( + f"dynamic ifexp type mismatch: then-branch produces {then_raw.type}, else-branch produces {else_raw.type}" + ) + + result = arith.SelectOp(cond_i1, then_raw, else_raw).result + return _wrap_like(result, then_val) @ASTRewriter.register diff --git a/tests/unit/test_ifexp_dispatch.py b/tests/unit/test_ifexp_dispatch.py index db5dda08..1218ef32 100644 --- a/tests/unit/test_ifexp_dispatch.py +++ b/tests/unit/test_ifexp_dispatch.py @@ -53,7 +53,7 @@ def test_ifexp_static_false_no_scf_if(): assert "scf.if" not in str(module) -def test_ifexp_dynamic_builds_scf_if(): +def test_ifexp_dynamic_builds_arith_select(): with Context(), Location.unknown(): module = Module.create() i1 = IntegerType.get_signless(1) @@ -73,8 +73,8 @@ def test_ifexp_dynamic_builds_scf_if(): assert module.operation.verify() ir_text = str(module) - assert "scf.if" in ir_text - assert "-> (i32)" in ir_text + assert "arith.select" in ir_text + assert "scf.if" not in ir_text def test_ifexp_dynamic_type_mismatch_raises(): @@ -122,7 +122,8 @@ def test_ifexp_nested_condition(): assert module.operation.verify() ir_text = str(module) - assert ir_text.count("scf.if") == 2 + assert ir_text.count("arith.select") == 2 + assert "scf.if" not in ir_text def test_ifexp_dynamic_float32(): @@ -145,8 +146,8 @@ def test_ifexp_dynamic_float32(): assert module.operation.verify() ir_text = str(module) - assert "scf.if" in ir_text - assert "-> (f32)" in ir_text + assert "arith.select" in ir_text + assert "scf.if" not in ir_text def test_ifexp_dynamic_float16(): @@ -169,5 +170,5 @@ def test_ifexp_dynamic_float16(): assert module.operation.verify() ir_text = str(module) - assert "scf.if" in ir_text - assert "-> (f16)" in ir_text + assert "arith.select" in ir_text + assert "scf.if" not in ir_text