cc @wsmoses Nearest neighbor interpolation seems off
ENV["XLA_FLAGS"] = get(ENV, "XLA_FLAGS", "") * " --xla_force_host_platform_device_count=4"
using Random, Reactant
using Reactant: Sharding, InterpolateArray, InterpolationType
mesh = Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y))
sharding = Sharding.NamedSharding(mesh, ("x", "y", nothing))
Random.seed!(42)
ρ_src = 1.0f0 .+ rand(Float32, 32, 16, 8)
Nx, Ny, Nz = 120, 120, 8
@info "Source" size=size(ρ_src) range=extrema(ρ_src)
# Reactant
r = Array(InterpolateArray(ρ_src, (Nx, Ny, Nz), sharding, InterpolationType.Nearest))
# CPU nearest-neighbor (matches _nearest_neighbor_data_copy! kernel)
Nx_s, Ny_s, _ = size(ρ_src)
Rx, Ry = Nx / Nx_s, Ny / Ny_s
v = zeros(Float32, Nx, Ny, Nz)
for k in 1:Nz, j in 1:Ny, i in 1:Nx
v[i, j, k] = ρ_src[ceil(Int, i / Rx), ceil(Int, j / Ry), k]
end
@info "Reactant" size=size(r) range=extrema(r)
@info "CPU" size=size(v) range=extrema(v)
@info "Diff" max_abs=maximum(abs.(r .- v)) n_mismatch=count(r .!= v) n_total=length(r)
mid_y, mid_z = Ny ÷ 2, Nz ÷ 2
println("\n— Reactant [:, $mid_y, $mid_z] —")
println(r[:, mid_y, mid_z])
println("\n— CPU [:, $mid_y, $mid_z] —")
println(v[:, mid_y, mid_z])
┌ Info: Source
│ size = (32, 16, 8)
└ range = (1.0000849f0, 1.999742f0)
┌ Info: Reactant
│ size = (120, 120, 8)
└ range = (1.0000849f0, 1.999742f0)
┌ Info: CPU
│ size = (120, 120, 8)
└ range = (1.0000849f0, 1.999742f0)
┌ Info: Diff
│ max_abs = 0.99734044f0
│ n_mismatch = 44544
└ n_total = 115200
cc @wsmoses Nearest neighbor interpolation seems off