Support division between two ComplexF32 numbers#738
Support division between two ComplexF32 numbers#738albertomercurio wants to merge 1 commit intoJuliaGPU:mainfrom
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl
index 1b5265fc..8c2aa399 100644
--- a/src/device/intrinsics/math.jl
+++ b/src/device/intrinsics/math.jl
@@ -53,14 +53,14 @@ end
a, b = reim(z) # Avoid using widen(z) as in Base
if (isinf(c) | isinf(d))
if isfinite(z)
- return complex(zero(Float32)*sign(real(z))*sign(real(w)), -zero(Float32)*sign(imag(z))*sign(imag(w)))
+ return complex(zero(Float32) * sign(real(z)) * sign(real(w)), -zero(Float32) * sign(imag(z)) * sign(imag(w)))
end
- return Float32(NaN)+Float32(NaN)*im
+ return Float32(NaN) + Float32(NaN) * im
end
mag = inv(muladd(c, c, d^2))
- re_part = muladd(a, c, b*d)
- im_part = muladd(b, c, -a*d)
- return oftype(z, Complex(re_part*mag, im_part*mag))
+ re_part = muladd(a, c, b * d)
+ im_part = muladd(b, c, -a * d)
+ return oftype(z, Complex(re_part * mag, im_part * mag))
end
@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl
index 72fd4498..43138b68 100644
--- a/test/device/intrinsics/math.jl
+++ b/test/device/intrinsics/math.jl
@@ -181,21 +181,21 @@ end
@test Array(mtlout) == clamp.(in, minval, maxval)
end
- let
- N = 10
+ let
+ N = 10
- x = rand(ComplexF32, N)
- y = rand(ComplexF32, N)
+ x = rand(ComplexF32, N)
+ y = rand(ComplexF32, N)
- dx = MtlArray(x)
- dy = MtlArray(y)
+ dx = MtlArray(x)
+ dy = MtlArray(y)
- z = x ./ y
- dz = dx ./ dy
+ z = x ./ y
+ dz = dx ./ dy
- @test Array(dz) ≈ z
- end
+ @test Array(dz) ≈ z
+ end
let #pow
N = 4 |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #738 +/- ##
=======================================
Coverage 82.59% 82.59%
=======================================
Files 62 62
Lines 2862 2862
=======================================
Hits 2364 2364
Misses 498 498 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: 9fb073d | Previous: 043dbed | Ratio |
|---|---|---|---|
latency/precompile |
24953894292 ns |
25147544500 ns |
0.99 |
latency/ttfp |
2279042875 ns |
2280876000 ns |
1.00 |
latency/import |
1445356292 ns |
1448341375 ns |
1.00 |
integration/metaldevrt |
861395.5 ns |
856042 ns |
1.01 |
integration/byval/slices=1 |
1563562.5 ns |
1561437.5 ns |
1.00 |
integration/byval/slices=3 |
8507791 ns |
9985729 ns |
0.85 |
integration/byval/reference |
1542916 ns |
1550625 ns |
1.00 |
integration/byval/slices=2 |
2608479.5 ns |
2554125.5 ns |
1.02 |
kernel/indexing |
643687 ns |
621792 ns |
1.04 |
kernel/indexing_checked |
633792 ns |
630000 ns |
1.01 |
kernel/launch |
12458 ns |
11833 ns |
1.05 |
kernel/rand |
564354.5 ns |
569041 ns |
0.99 |
array/construct |
6417 ns |
6375 ns |
1.01 |
array/broadcast |
604750 ns |
594416 ns |
1.02 |
array/random/randn/Float32 |
1025708 ns |
1006834 ns |
1.02 |
array/random/randn!/Float32 |
748250 ns |
752083 ns |
0.99 |
array/random/rand!/Int64 |
551250 ns |
546541 ns |
1.01 |
array/random/rand!/Float32 |
590000 ns |
577979.5 ns |
1.02 |
array/random/rand/Int64 |
772125 ns |
773208.5 ns |
1.00 |
array/random/rand/Float32 |
584125 ns |
589791.5 ns |
0.99 |
array/accumulate/Int64/1d |
1254291.5 ns |
1262458 ns |
0.99 |
array/accumulate/Int64/dims=1 |
1837250.5 ns |
1837334 ns |
1.00 |
array/accumulate/Int64/dims=2 |
2171667 ns |
2166645.5 ns |
1.00 |
array/accumulate/Int64/dims=1L |
11428083 ns |
11676999.5 ns |
0.98 |
array/accumulate/Int64/dims=2L |
9713958 ns |
9763146 ns |
0.99 |
array/accumulate/Float32/1d |
1136812.5 ns |
1112833 ns |
1.02 |
array/accumulate/Float32/dims=1 |
1569645.5 ns |
1560562.5 ns |
1.01 |
array/accumulate/Float32/dims=2 |
1880500 ns |
1866625 ns |
1.01 |
array/accumulate/Float32/dims=1L |
9815334 ns |
9806292 ns |
1.00 |
array/accumulate/Float32/dims=2L |
7253979.5 ns |
7257291 ns |
1.00 |
array/reductions/reduce/Int64/1d |
1543896 ns |
1358250 ns |
1.14 |
array/reductions/reduce/Int64/dims=1 |
1104875 ns |
1089125 ns |
1.01 |
array/reductions/reduce/Int64/dims=2 |
1175250 ns |
1130958 ns |
1.04 |
array/reductions/reduce/Int64/dims=1L |
2004250 ns |
2002353.5 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
4239062.5 ns |
4220375 ns |
1.00 |
array/reductions/reduce/Float32/1d |
1047416 ns |
1028042 ns |
1.02 |
array/reductions/reduce/Float32/dims=1 |
830041 ns |
831916 ns |
1.00 |
array/reductions/reduce/Float32/dims=2 |
854917 ns |
743167 ns |
1.15 |
array/reductions/reduce/Float32/dims=1L |
1300271 ns |
1311125.5 ns |
0.99 |
array/reductions/reduce/Float32/dims=2L |
1796583.5 ns |
1800750 ns |
1.00 |
array/reductions/mapreduce/Int64/1d |
1387958 ns |
1538667 ns |
0.90 |
array/reductions/mapreduce/Int64/dims=1 |
1102250 ns |
1095375 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=2 |
1142458 ns |
1139729 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1L |
1986187.5 ns |
2011792 ns |
0.99 |
array/reductions/mapreduce/Int64/dims=2L |
3622459 ns |
3621583.5 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
1010416 ns |
1055750 ns |
0.96 |
array/reductions/mapreduce/Float32/dims=1 |
832729 ns |
819604 ns |
1.02 |
array/reductions/mapreduce/Float32/dims=2 |
858333 ns |
852417 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=1L |
1327000 ns |
1315208.5 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=2L |
1812771 ns |
1793125 ns |
1.01 |
array/private/copyto!/gpu_to_gpu |
644375 ns |
642208 ns |
1.00 |
array/private/copyto!/cpu_to_gpu |
784875 ns |
794167 ns |
0.99 |
array/private/copyto!/gpu_to_cpu |
803500 ns |
788500 ns |
1.02 |
array/private/iteration/findall/int |
1576084 ns |
1564479 ns |
1.01 |
array/private/iteration/findall/bool |
1403459 ns |
1408791.5 ns |
1.00 |
array/private/iteration/findfirst/int |
2100291.5 ns |
2072062 ns |
1.01 |
array/private/iteration/findfirst/bool |
2052292 ns |
2036375 ns |
1.01 |
array/private/iteration/scalar |
4051521 ns |
4806917 ns |
0.84 |
array/private/iteration/logical |
1816083 ns |
2579104 ns |
0.70 |
array/private/iteration/findmin/1d |
2518437.5 ns |
2506791 ns |
1.00 |
array/private/iteration/findmin/2d |
1786875 ns |
1788792 ns |
1.00 |
array/private/copy |
585667 ns |
576229 ns |
1.02 |
array/shared/copyto!/gpu_to_gpu |
84291 ns |
83125 ns |
1.01 |
array/shared/copyto!/cpu_to_gpu |
82250 ns |
82250 ns |
1 |
array/shared/copyto!/gpu_to_cpu |
83167 ns |
82375 ns |
1.01 |
array/shared/iteration/findall/int |
1577250 ns |
1574708 ns |
1.00 |
array/shared/iteration/findall/bool |
1425291 ns |
1415562.5 ns |
1.01 |
array/shared/iteration/findfirst/int |
1647792 ns |
1649917 ns |
1.00 |
array/shared/iteration/findfirst/bool |
1647541 ns |
1643167 ns |
1.00 |
array/shared/iteration/scalar |
210708 ns |
207208 ns |
1.02 |
array/shared/iteration/logical |
2461041 ns |
2487270.5 ns |
0.99 |
array/shared/iteration/findmin/1d |
2131792 ns |
2121917 ns |
1.00 |
array/shared/iteration/findmin/2d |
1796083 ns |
1791791 ns |
1.00 |
array/shared/copy |
246458 ns |
248833 ns |
0.99 |
array/permutedims/4d |
2391917 ns |
2395833 ns |
1.00 |
array/permutedims/2d |
1184354 ns |
1178750 ns |
1.00 |
array/permutedims/3d |
1688833 ns |
1686292 ns |
1.00 |
metal/synchronization/stream |
19125 ns |
19042 ns |
1.00 |
metal/synchronization/context |
20167 ns |
20042 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
|
@albertomercurio I asked on the Julia slack and it seems like if we cannot use a wider float for calculations, a different algorithm should be used for accurate results. @oscardssmith Is this implementation the one you were referring to on Slack? |
|
I was referring to https://github.com/JuliaLang/julia/blob/2cf16b10e40956671dde660de2e7914037e8b078/base/complex.jl#L390, but it might be the case that the performance hit is too large. |
Right. That seems like a lot of branching… What about the MLX implementation? Do you happen to know if it’s any worse than the generic Julia implementation? I’m thinking we use one of the two so we have something working until someone takes the time to benchmark the more accurate algorithm. |
|
the MLX version is the trivial one. It will suffer from overflow and underflow, but that's true of any algorithm that doesn't do some form of rescaling. |
|
Superceeded by #762 |
Fixes #736