Skip to content

ConvertMulConvert creates a correctess error on atmosphere init #2805

@wsmoses

Description

@wsmoses

@Pangoraw @dkytezab

I have investigated the IR and the optimization passes as requested, using the StableHLO interpreter and enzymexlamlir-opt to trace the precision failure. Here is what I found:

  1. Baseline with Interpreter
    I ran the interpreter on test_interpret.mlir (which performs the multiplication in f64 and then converts to f32). The interpreter output for the first few elements (after index 0) is:

Index 1: 3.21428561
Index 2: 6.42857122
Index 3: 9.64285755
I verified these values using Python's struct and numpy:

The f64 calculation $3 \times (45.0/14.0)$ truncated to f32 yields exactly 9.642857551574707 (hex 411a4925).
This confirms that the interpreter is behaving correctly according to the IR, which specifies f64 operations.
2. Isolating the Failure in Optimization Passes
I ran enzymexlamlir-opt on test_interpret.mlir with a restricted set of patterns to see when the result changes.

Case A: Strict constant folding (without convert_mul_convert) I used the following patterns: iota_simplify, convert_simplify, add_simplify, mul_simplify, add_const_prop, mul_const_prop, and broadcast_in_dim_simplify. The resulting constant tensor had the value 25491A41 for index 3, which is 9.64285755. This matches the interpreter.

Case B: Enabling convert_mul_convert I added the convert_mul_convert pattern to the above list. The resulting constant tensor for index 3 changed to 24491A41, which is 9.6428566. This value corresponds to doing the multiplication in f32 directly: $\text{np.float32}(45.0/14.0) \times 3.0 = 9.6428566$.

Conclusion
The pattern convert_mul_convert is responsible for the precision discrepancy. It transforms the graph to perform the multiplication in f32 instead of f64, likely to save computation/memory, but at the cost of precision.

If constant folding happens before this pass (e.g., when operands are statically known early), it folds in f64 and produces 9.64285755. If it happens after or if the operations are executed at runtime in f32 (due to this pass), it produces 9.6428566. This order-of-operations or static-vs-dynamic discrepancy likely causes the grid-size-dependent errors.

How would you like to proceed? Should we disable convert_mul_convert or restrict it to maintain f64 precision when specified?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions