Skip to content

@Const can lead to wrong results on CPU in Julia 1.11/1.12 #652

@maximilian-gelbrecht

Description

@maximilian-gelbrecht

MWE below @vchuravy

Works in 1.10, but not in 1.11/1.12

"""
Minimal working example to demonstrate the difference between:
1. Using `array[I] +=` in a loop (BROKEN)
2. Using a local accumulator then `array[I] =` (WORKS)

This appears to be a bug in KernelAbstractions.jl or me misunderstanding how it handles
compound assignment operators on array elements within loops.
"""

using KernelAbstractions
using Test

# Version 1: Using += directly on array element (BROKEN)
@kernel function accumulate_broken!(output, input, @Const(n))
    I = @index(Global, Cartesian)
    i = I[1]
    j = I[2]
    
    # This should accumulate input[i, k] for k in j:n
    # But it produces WRONG results!
    for k in j:n
        output[I] += input[i, k]
    end
end

# Version 2: Using local accumulator (WORKS)
@kernel function accumulate_fixed!(output, input, @Const(n))
    I = @index(Global, Cartesian)
    i = I[1]
    j = I[2]
    
    # Use local accumulator
    sum_val = zero(eltype(output))
    for k in j:n
        sum_val += input[i, k]
    end
    output[I] = sum_val
end

# Test function
function test_kernel_bug()
    println("="^70)
    println("Minimal Kernel Bug Demonstration")
    println("="^70)
    
    # Setup test data
    n = 8
    m = 5
    
    input = Float32[i + k for i in 1:m, k in 1:n]
    output_broken = zeros(Float32, m, n)
    output_fixed = zeros(Float32, m, n)
    output_cpu = zeros(Float32, m, n)
    
    println("\nInput matrix ($(m)×$(n)):")
    display(input)
    println("\n")
    
    # CPU reference implementation
    println("Computing CPU reference...")
    for i in 1:m
        for j in 1:n
            for k in j:n
                output_cpu[i, j] += input[i, k]
            end
        end
    end
    
    # Version 1: Broken (using +=)
    println("Running BROKEN kernel (using output[I] += ...)...")
    backend = CPU()
    kernel_broken! = accumulate_broken!(backend)
    kernel_broken!(output_broken, input, n, ndrange=size(output_broken))
    KernelAbstractions.synchronize(backend)
    
    # Version 2: Fixed (using local accumulator)
    println("Running FIXED kernel (using local accumulator)...")
    kernel_fixed! = accumulate_fixed!(backend)
    kernel_fixed!(output_fixed, input, n, ndrange=size(output_fixed))
    KernelAbstractions.synchronize(backend)
    
    # Compare results
    println("\n" * "="^70)
    println("RESULTS")
    println("="^70)
    
    println("\nCPU Reference:")
    display(output_cpu)
    println("\n")
    
    println("\nBROKEN Kernel (output[I] +=):")
    display(output_broken)
    println("\n")
    
    println("\nFIXED Kernel (local accumulator):")
    display(output_fixed)
    println("\n")
    
    # Check if broken version is actually broken
    is_broken_wrong = !(output_broken  output_cpu)
    is_fixed_correct = output_fixed  output_cpu
    
    return (is_broken_wrong, is_fixed_correct)
end

# Run the test
if abspath(PROGRAM_FILE) == @__FILE__
    test_kernel_bug()
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions