Skip to content
42 changes: 13 additions & 29 deletions examples/flopscomp.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Metal, GPUArrays, LinearAlgebra, Printf#, AppleAccelerate
using Metal, GPUArrays, LinearAlgebra, Printf, ScopedValues#, AppleAccelerate

testing = (@isdefined TESTING) && TESTING

Expand Down Expand Up @@ -54,42 +54,26 @@ function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)

return n_batch*2*Float64(n)^3 / minimum(t)
end
function gpuarrpeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
function gpuarrpeakflops(; n_batch::Integer=1,
kwargs...)
n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
end
@with (Metal.matmul_alg => :GPUArrays) defaultpeakflops(; n_batch, kwargs...)
end
function defaultpeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
_peakflops(n, n_batch, inT, outT, ntrials; verify) do c, a, b
LinearAlgebra.generic_matmatmul!(c, 'N', 'N', a, b, 1, 0)
end
end
function mpspeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
_peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
function mpspeakflops(; kwargs...)
@with (Metal.matmul_alg => :MPS) defaultpeakflops(; kwargs...)
end
function graphpeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
_peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
function graphpeakflops(; kwargs...)
@with (Metal.matmul_alg => :MPSGraph) defaultpeakflops(; kwargs...)
end
function anepeakflops(; kwargs...)
# VERY HACKY
Expand Down Expand Up @@ -139,9 +123,9 @@ DEFAULT_FS = [
(mpspeakflops, "MPS"),
(graphpeakflops, "MPSGraph"),
(defaultpeakflops, "Default"),
# (anepeakflops, "MPSGraph (ANE)"),
# (gpuarrpeakflops, "GPUArrays"),
(gpuarrpeakflops, "GPUArrays"),
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
# (anepeakflops, "MPSGraph (ANE)"), # Run last to prevent different line colours
]

function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, verbose=true)
Expand All @@ -152,7 +136,7 @@ function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, ver
return res
end

function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, fileext="svg", plt_title=PLOT_TITLE)
Fs = get.(Fs, 2, "You shouldn't be reading this")
ylim_upper = 9e12
resplts = []
Expand Down Expand Up @@ -184,7 +168,7 @@ function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_ti
bottommargin=15pt,
size=(2000,1200))
if !isnothing(outpath)
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$fileext"))
end
return finalplot
end
Expand Down
121 changes: 98 additions & 23 deletions lib/mpsgraphs/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,63 +30,138 @@ else
end


@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
alpha::Number, beta::Number,
transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
graph = MPSGraph()
#=
MPSGraph caching infrastructure.

placeA = placeholderTensor(graph, size(a), Tab)
placeB = placeholderTensor(graph, size(b), Tab)
placeC = placeholderTensor(graph, size(c), Tc)
The overhead of creating an MPSGraph dominates matmul time for small-medium matrices.
By caching graphs keyed by their structural parameters (shapes, types, flags), we
achieve significant speedup for repeated operations with the same configuration.

feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
placeA => MPSGraphTensorData(a),
placeB => MPSGraphTensorData(b),
placeC => MPSGraphTensorData(c)
The cache key includes all parameters that affect graph structure:
- Input/output shapes and element types
- Transpose flags
- Alpha/beta values (baked into graph as constants)
=#

# Cache key for matmul graphs - includes all structural parameters
struct MatmulGraphKey
size_a::Tuple{Vararg{Int}}
size_b::Tuple{Vararg{Int}}
size_c::Tuple{Vararg{Int}}
eltype_ab::DataType
eltype_c::DataType
ndims_a::Int
ndims_b::Int
alpha::Float64
beta::Float64
transpose_a::Bool
transpose_b::Bool
end
# Build graph key from matmul parameters
function MatmulGraphKey(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc},
alpha::Number, beta::Number,
transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb}
MatmulGraphKey(
size(a), size(b), size(c),
Tab, Tc,
Na, Nb,
Float64(alpha), Float64(beta),
transpose_a, transpose_b
)
end

# Cached graph with all tensors needed for execution
struct CachedMatmulGraph
graph::MPSGraph
place_c::MPSGraphTensor
place_a::MPSGraphTensor
place_b::MPSGraphTensor
result::MPSGraphTensor
end
# Build a new matmul graph (called only on cache miss)
function CachedMatmulGraph(key::MatmulGraphKey)
graph = MPSGraph()

placeA = placeholderTensor(graph, key.size_a, key.eltype_ab)
placeB = placeholderTensor(graph, key.size_b, key.eltype_ab)
placeC = placeholderTensor(graph, key.size_c, key.eltype_c)

# cast to output eltype if input type is an integer type
castT = Tab <: Integer ? Tc : Tab
castT = key.eltype_ab <: Integer ? key.eltype_c : key.eltype_ab
castA = castTensor(graph, placeA, castT, "castA")
castB = castTensor(graph, placeB, castT, "castB")

transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB
transA = key.transpose_a ? transposeTensor(graph, castA, key.ndims_a - 2, key.ndims_a - 1, "transpose_a") : castA
transB = key.transpose_b ? transposeTensor(graph, castB, key.ndims_b - 2, key.ndims_b - 1, "transpose_b") : castB

nBatchA = Na == 2 ? 1 : size(transA)[1]
nBatchB = Nb == 2 ? 1 : size(transB)[1]
nBatchA = key.ndims_a == 2 ? 1 : key.size_a[1]
nBatchB = key.ndims_b == 2 ? 1 : key.size_b[1]

# for batched matmul between different sized tensors
broadcastA, broadcastB = if nBatchA == nBatchB
transA, transB
elseif Na == 1
elseif key.ndims_a == 1
broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB
elseif Nb == 1
elseif key.ndims_b == 1
transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...]))
else
transA, transB
end

matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA)

afteralpha = let alphatensor = constantWithScalar(graph, alpha, castT)
afteralpha = let alphatensor = constantWithScalar(graph, key.alpha, castT)
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
end

afterbeta = let betatensor = constantWithScalar(graph, beta, castT)
afterbeta = let betatensor = constantWithScalar(graph, key.beta, castT)
castplaceC = castTensor(graph, placeC, castT, "castplaceC")
betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC)
additionWithPrimaryTensor(graph, afteralpha, betaC)
end

castC = castTensor(graph, afterbeta, Tc, "castC")
castC = castTensor(graph, afterbeta, key.eltype_c, "castC")

CachedMatmulGraph(graph, placeC, placeA, placeB, castC)
end

# Get or create cached graph
function _get_cached_graph!(graph_cache_lock, graph_cache, key::MatmulGraphKey)
# Fast path: check cache without lock (safe for reads)
cached = get(graph_cache, key, nothing)
if cached !== nothing
return cached
end

# Slow path: acquire lock and build graph
@lock graph_cache_lock get!(graph_cache, key) do
CachedMatmulGraph(key)
end
end

# Thread-safe graph cache with lock
const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}()
const _matmul_graph_cache_lock = ReentrantLock()
@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
alpha::Number, beta::Number,
transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
# Get or create cached graph
key = MatmulGraphKey(a, b, c, alpha, beta, transpose_a, transpose_b)
cached = _get_cached_graph!(_matmul_graph_cache_lock, _matmul_graph_cache, key)

# Build feed and result dictionaries with current data
feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
cached.place_a => MPSGraphTensorData(a),
cached.place_b => MPSGraphTensorData(b),
cached.place_c => MPSGraphTensorData(c)
)

resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
castC => feeds[placeC]
cached.result => feeds[cached.place_c]
)

cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
commit!(cmdbuf)
wait_completed(cmdbuf)

Expand Down
11 changes: 1 addition & 10 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@ end
C.offset == 0
end

# Assumes support for MPS matrix multiplication has been verified elsewhere
@inline function should_use_MPS(A, _, C)
rows = size(C,1)
cols = size(C,2)
# TODO: matvecmul different?
(eltype(A) <: Integer && rows <= 2000 && cols <= 2000 ) ||
eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9)
end

# Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays
const matmul_alg = ScopedValue(:auto)
matmul_alg_error(alg, inT, outT, vec) = error("Matrix-$(vec ? "Vector" : "Matrix") multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.")
Expand Down Expand Up @@ -63,7 +54,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
# If possible, dispatch to MPSGraphs, then performance shaders
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C))
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported)
mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), false)
graph_matmul!(C, A, B, alpha, beta, transA, transB)
elseif alg === :MPS || (alg === :auto && mps_supported)
Expand Down