GPU extensions (Metal, oneAPI): unify Float64 → Float32 conversion#280
Open
sshin23 wants to merge 2 commits into
Open
GPU extensions (Metal, oneAPI): unify Float64 → Float32 conversion#280sshin23 wants to merge 2 commits into
sshin23 wants to merge 2 commits into
Conversation
The Metal and oneAPI backend extensions both need to recursively convert
Float64 host data to Float32 before upload, because
- Metal cannot compile Float64 IR at all,
- oneAPI Arc-class devices reject Float64 allocations outright (Iris Xe /
Data Center GPU Max accept them but the user usually wants fp32 anyway).
Two things were wrong before this change:
1. The Metal extension's `replace_float_64` only handled scalars,
Tuples, and NamedTuples — its catch-all `replace_float_64(x) = x`
silently let plain structs through. So Float64 fields inside structs
like ExaPowerIO's `BusData`/`GenData`/`BranchData` reached the
device, and the JIT either emitted invalid Float64 IR (best case) or
crashed Apple's Metal compiler with an XPC interrupt (typical case for
second-order AD kernels on OPF).
2. The oneAPI extension only defined `convert_array` behind
`if pkgversion(oneAPI) < v\"2.6\"`. On current oneAPI (≥ v2.6) the
branch was skipped, falling back to `adapt(backend, v)` which
preserved Float64 element type. Every fp32 OPF/LV/COPS run on Arc A770
bombed at host-to-device transfer with "Float64 is not supported on
this device".
This commit:
- Hoists `replace_float_64` into `src/templates.jl` (single
implementation, both extensions reuse it).
- Implements the struct recursion with pure multiple dispatch (no
try/catch, no `@generated`). Specialized methods cover Float64,
Tuple, NamedTuple, AbstractArray{Float64}, and AbstractArray; the
generic struct path uses `Val(fieldcount(T))` to keep `ntuple`
type-stable and `ConstructionBase.constructorof(T)` to rebuild the
struct with Float32-typed fields. Adds ConstructionBase as a dep.
- In the Metal extension: removes the bespoke (incomplete)
`replace_float_64` and adds the `MtlArray` identity overload to
`convert_array`. Both extensions now read
`convert_array(v::DeviceArray, ::Backend) = v` and
`convert_array(v, ::Backend) = Backend.Array(ExaModels.replace_float_64.(v))`.
- In the oneAPI extension: drops the `< v2.6` guard from the
`convert_array` definition so the dispatch is unconditional. Keeps
the older `sort!`/`findall` shims inside the version guard since
those remain version-specific.
- Adds a `replace_float_64` test block to UtilsTest covering: scalar
leaf, mixed Tuple/NamedTuple, Vector{Float64} eltype check, mixed-Any
array, parametric struct rebuild, and nested struct with array fields.
Validated manually:
- Apple M1 (Metal, local): OPF case118 polar with T=Float32 — all five
NLPModels callbacks succeed (obj, cons!, grad!, jac_coord!,
hess_coord!). LV.rosenrock with T=Float32 — same.
- Apple M2 Pro (Metal, pro): 231/240-case OPF+LV+COPS benchmark with
T=Float32; remaining 9 failures are rocket/glider Metal-compiler
XPC crashes on hess_coord! and dirichlet/henon/lane_emden which use
a separate PDEProblem struct in ExaModelsPower that needs its own
parameterization on T (out of scope here).
- Intel Arc A770 (oneAPI, shin-compute-002): was 0/240 before this
change; smoke test on case118 polar now PASS. Full 240-case run is
in progress as of this commit.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Please run: julia --project=@runic -e 'using Pkg; Pkg.add("Runic")'
julia --project=@runic -e "using Runic; exit(Runic.main(ARGS))" -- --fix <files>(or Note: the full diff is omitted because it can exceed GitHub Actions input limits. |
The convert_array dispatch normalizes host data to Float32 before upload, so the default float type for an oneAPI backend should match. Without this, OracleTest's _atol(backend) helper still picks the Float64 tolerance (1e-10), which is tighter than Float32 machine epsilon (~1e-7), causing 15 false test failures on the oneapi CI runner. This mirrors the existing Metal extension which already declares default_T(::MetalBackend) = Float32.
Contributor
Benchmark Results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two related bugs in the GPU extensions:
replace_float_64's catch-allreplace_float_64(x) = xsilently let plain structs through, soBusData{Float64}arrays reached the device. Result: Float64 IR emitted, Apple's compiler often crashed with an XPC interrupt on second-order AD kernels.convert_array(v, ::oneAPIBackend) = oneArray(v)was gated byif pkgversion(oneAPI) < v"2.6". On ≥ 2.6 the branch was skipped and the defaultadapt(backend, v)preserved Float64, killing all fp32 work on Intel Arc A770 withFloat64 is not supported on this device.Fix
replace_float_64tosrc/templates.jl(single implementation, both extensions reuse it).Float64,Tuple,NamedTuple,AbstractArray{Float64},AbstractArray; generic struct path usesVal(fieldcount(T))for type-stablentuple+ConstructionBase.constructorof(T)for the rebuild. No@generated, no try/catch. AddsConstructionBaseas a dep.convert_array. oneAPI's version guard now wraps only thesort!/findallshims.Testing
New
replace_float_64block intest/UtilsTest/UtilsTest.jlcovers: scalar Float64, scalar passthrough (Float32/Int/Symbol), Tuple/NamedTuple recursion,Vector{Float64}eltype check, mixed-Any vector, parametric struct rebuild, nested struct with array and tuple fields.Stats
6 files, +84/−9.
🤖 Generated with Claude Code