diff --git a/examples/flopscomp.jl b/examples/flopscomp.jl index 3fda1292f..cb87da3b6 100644 --- a/examples/flopscomp.jl +++ b/examples/flopscomp.jl @@ -1,4 +1,4 @@ -using Metal, GPUArrays, LinearAlgebra, Printf#, AppleAccelerate +using Metal, GPUArrays, LinearAlgebra, Printf, ScopedValues#, AppleAccelerate testing = (@isdefined TESTING) && TESTING @@ -54,16 +54,10 @@ function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true) return n_batch*2*Float64(n)^3 / minimum(t) end -function gpuarrpeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) +function gpuarrpeakflops(; n_batch::Integer=1, + kwargs...) n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1" - _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b - GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0) - end + @with (Metal.matmul_alg => :GPUArrays) defaultpeakflops(; n_batch, kwargs...) end function defaultpeakflops(; n::Integer=4096, n_batch::Integer=1, @@ -71,25 +65,15 @@ function defaultpeakflops(; n::Integer=4096, outT::DataType=inT, ntrials::Integer=3, verify=true) - _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b + _peakflops(n, n_batch, inT, outT, ntrials; verify) do c, a, b LinearAlgebra.generic_matmatmul!(c, 'N', 'N', a, b, 1, 0) end end -function mpspeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) - _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify) +function mpspeakflops(; kwargs...) + @with (Metal.matmul_alg => :MPS) defaultpeakflops(; kwargs...) end -function graphpeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) - _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify) +function graphpeakflops(; kwargs...) + @with (Metal.matmul_alg => :MPSGraph) defaultpeakflops(; kwargs...) end function anepeakflops(; kwargs...) # VERY HACKY @@ -139,9 +123,9 @@ DEFAULT_FS = [ (mpspeakflops, "MPS"), (graphpeakflops, "MPSGraph"), (defaultpeakflops, "Default"), - # (anepeakflops, "MPSGraph (ANE)"), - # (gpuarrpeakflops, "GPUArrays"), + (gpuarrpeakflops, "GPUArrays"), # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance + # (anepeakflops, "MPSGraph (ANE)"), # Run last to prevent different line colours ] function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, verbose=true) @@ -152,7 +136,7 @@ function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, ver return res end -function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE) +function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, fileext="svg", plt_title=PLOT_TITLE) Fs = get.(Fs, 2, "You shouldn't be reading this") ylim_upper = 9e12 resplts = [] @@ -184,7 +168,7 @@ function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_ti bottommargin=15pt, size=(2000,1200)) if !isnothing(outpath) - savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype")) + savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$fileext")) end return finalplot end diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 2bbc1d2e1..3fb30b45a 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -30,38 +30,79 @@ else end -@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, - alpha::Number, beta::Number, - transpose_a, transpose_b) where {Tc, Tab, Na, Nb} - graph = MPSGraph() +#= +MPSGraph caching infrastructure. - placeA = placeholderTensor(graph, size(a), Tab) - placeB = placeholderTensor(graph, size(b), Tab) - placeC = placeholderTensor(graph, size(c), Tc) +The overhead of creating an MPSGraph dominates matmul time for small-medium matrices. +By caching graphs keyed by their structural parameters (shapes, types, flags), we +achieve significant speedup for repeated operations with the same configuration. - feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( - placeA => MPSGraphTensorData(a), - placeB => MPSGraphTensorData(b), - placeC => MPSGraphTensorData(c) +The cache key includes all parameters that affect graph structure: +- Input/output shapes and element types +- Transpose flags +- Alpha/beta values (baked into graph as constants) +=# + +# Cache key for matmul graphs - includes all structural parameters +struct MatmulGraphKey + size_a::Tuple{Vararg{Int}} + size_b::Tuple{Vararg{Int}} + size_c::Tuple{Vararg{Int}} + eltype_ab::DataType + eltype_c::DataType + ndims_a::Int + ndims_b::Int + alpha::Float64 + beta::Float64 + transpose_a::Bool + transpose_b::Bool +end +# Build graph key from matmul parameters +function MatmulGraphKey(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc}, + alpha::Number, beta::Number, + transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb} + MatmulGraphKey( + size(a), size(b), size(c), + Tab, Tc, + Na, Nb, + Float64(alpha), Float64(beta), + transpose_a, transpose_b ) +end + +# Cached graph with all tensors needed for execution +struct CachedMatmulGraph + graph::MPSGraph + place_c::MPSGraphTensor + place_a::MPSGraphTensor + place_b::MPSGraphTensor + result::MPSGraphTensor +end +# Build a new matmul graph (called only on cache miss) +function CachedMatmulGraph(key::MatmulGraphKey) + graph = MPSGraph() + + placeA = placeholderTensor(graph, key.size_a, key.eltype_ab) + placeB = placeholderTensor(graph, key.size_b, key.eltype_ab) + placeC = placeholderTensor(graph, key.size_c, key.eltype_c) # cast to output eltype if input type is an integer type - castT = Tab <: Integer ? Tc : Tab + castT = key.eltype_ab <: Integer ? key.eltype_c : key.eltype_ab castA = castTensor(graph, placeA, castT, "castA") castB = castTensor(graph, placeB, castT, "castB") - transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA - transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB + transA = key.transpose_a ? transposeTensor(graph, castA, key.ndims_a - 2, key.ndims_a - 1, "transpose_a") : castA + transB = key.transpose_b ? transposeTensor(graph, castB, key.ndims_b - 2, key.ndims_b - 1, "transpose_b") : castB - nBatchA = Na == 2 ? 1 : size(transA)[1] - nBatchB = Nb == 2 ? 1 : size(transB)[1] + nBatchA = key.ndims_a == 2 ? 1 : key.size_a[1] + nBatchB = key.ndims_b == 2 ? 1 : key.size_b[1] # for batched matmul between different sized tensors broadcastA, broadcastB = if nBatchA == nBatchB transA, transB - elseif Na == 1 + elseif key.ndims_a == 1 broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB - elseif Nb == 1 + elseif key.ndims_b == 1 transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) else transA, transB @@ -69,24 +110,58 @@ end matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA) - afteralpha = let alphatensor = constantWithScalar(graph, alpha, castT) + afteralpha = let alphatensor = constantWithScalar(graph, key.alpha, castT) multiplicationWithPrimaryTensor(graph, alphatensor, matmul) end - afterbeta = let betatensor = constantWithScalar(graph, beta, castT) + afterbeta = let betatensor = constantWithScalar(graph, key.beta, castT) castplaceC = castTensor(graph, placeC, castT, "castplaceC") betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC) additionWithPrimaryTensor(graph, afteralpha, betaC) end - castC = castTensor(graph, afterbeta, Tc, "castC") + castC = castTensor(graph, afterbeta, key.eltype_c, "castC") + + CachedMatmulGraph(graph, placeC, placeA, placeB, castC) +end + +# Get or create cached graph +function _get_cached_graph!(graph_cache_lock, graph_cache, key::MatmulGraphKey) + # Fast path: check cache without lock (safe for reads) + cached = get(graph_cache, key, nothing) + if cached !== nothing + return cached + end + + # Slow path: acquire lock and build graph + @lock graph_cache_lock get!(graph_cache, key) do + CachedMatmulGraph(key) + end +end + +# Thread-safe graph cache with lock +const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() +const _matmul_graph_cache_lock = ReentrantLock() +@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, + alpha::Number, beta::Number, + transpose_a, transpose_b) where {Tc, Tab, Na, Nb} + # Get or create cached graph + key = MatmulGraphKey(a, b, c, alpha, beta, transpose_a, transpose_b) + cached = _get_cached_graph!(_matmul_graph_cache_lock, _matmul_graph_cache, key) + + # Build feed and result dictionaries with current data + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.place_a => MPSGraphTensorData(a), + cached.place_b => MPSGraphTensorData(b), + cached.place_c => MPSGraphTensorData(c) + ) resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( - castC => feeds[placeC] + cached.result => feeds[cached.place_c] ) cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) - encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) + encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) commit!(cmdbuf) wait_completed(cmdbuf) diff --git a/src/linalg.jl b/src/linalg.jl index 017b21772..e232aa9ae 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -21,15 +21,6 @@ end C.offset == 0 end -# Assumes support for MPS matrix multiplication has been verified elsewhere -@inline function should_use_MPS(A, _, C) - rows = size(C,1) - cols = size(C,2) - # TODO: matvecmul different? - (eltype(A) <: Integer && rows <= 2000 && cols <= 2000 ) || - eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9) -end - # Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays const matmul_alg = ScopedValue(:auto) matmul_alg_error(alg, inT, outT, vec) = error("Matrix-$(vec ? "Vector" : "Matrix") multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.") @@ -63,7 +54,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) # If possible, dispatch to MPSGraphs, then performance shaders - if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C)) + if alg === :MPSGraph || (alg === :auto && mpsgraph_supported) mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), false) graph_matmul!(C, A, B, alpha, beta, transA, transB) elseif alg === :MPS || (alg === :auto && mps_supported)