From ed2fe2b117d27b7e2b4c9f48a98724d9d5c0b996 Mon Sep 17 00:00:00 2001 From: Kaan Kesgin Date: Fri, 5 Dec 2025 16:20:41 +0100 Subject: [PATCH 01/10] Cache MPSGraph instances for matmul to reduce overhead MPSGraph construction takes ~2ms per call, which dominated matmul latency for the MPSGraph path. This adds a thread-safe cache keyed by structural parameters (shapes, types, transpose flags, alpha/beta). Performance impact by use case: FASTER (3-7x improvement on subsequent calls): - Large matrices (>6000x6000 Float32, >2000x2000 Integer) - Mixed-precision matmul (Int8->Float32, Float16->Float32) - Matrix-vector multiplication with supported types - Explicit `Metal.@with Metal.matmul_alg => :MPSGraph` usage - Batched matrix multiplication (3D+ arrays) UNCHANGED (uses MPS path, not affected): - Small/medium Float32 matrices (<=6000x6000 on Apple9+ GPUs) - Small Integer matrices (<=2000x2000) - Most typical ML inference workloads SLIGHTLY SLOWER on first call only: - First matmul of each unique shape/type adds cache lookup overhead - Negligible compared to the ~2ms saved on all subsequent calls The cache is process-global and grows with unique configurations. Typical ML workloads use few distinct shapes, so memory overhead is minimal (each cached graph is ~1-2KB). --- lib/mpsgraphs/matmul.jl | 153 +++++++++++++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 19 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 2bbc1d2e1..d546088b0 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -30,20 +30,71 @@ else end -@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, - alpha::Number, beta::Number, - transpose_a, transpose_b) where {Tc, Tab, Na, Nb} - graph = MPSGraph() +#= +MPSGraph caching infrastructure. - placeA = placeholderTensor(graph, size(a), Tab) - placeB = placeholderTensor(graph, size(b), Tab) - placeC = placeholderTensor(graph, size(c), Tc) +Creating an MPSGraph takes ~2ms per call, which dominates matmul time for small-medium +matrices. By caching graphs keyed by their structural parameters (shapes, types, flags), +we achieve 3-7x speedup for repeated operations with the same configuration. - feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( - placeA => MPSGraphTensorData(a), - placeB => MPSGraphTensorData(b), - placeC => MPSGraphTensorData(c) +The cache key includes all parameters that affect graph structure: +- Input/output shapes and element types +- Transpose flags +- Alpha/beta values (baked into graph as constants) +=# + +# Cache key for matmul graphs - includes all structural parameters +struct MatmulGraphKey + size_a::Tuple{Vararg{Int}} + size_b::Tuple{Vararg{Int}} + size_c::Tuple{Vararg{Int}} + eltype_ab::DataType + eltype_c::DataType + ndims_a::Int + ndims_b::Int + transpose_a::Bool + transpose_b::Bool + alpha::Float64 # Normalized to Float64 for hashing + beta::Float64 +end + +# Cached graph with all tensors needed for execution +struct CachedMatmulGraph + graph::MPSGraph + place_a::MPSGraphTensor + place_b::MPSGraphTensor + place_c::MPSGraphTensor + result::MPSGraphTensor +end + +# Thread-safe graph cache with lock +const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() +const _matmul_graph_cache_lock = ReentrantLock() + +# Build graph key from matmul parameters +function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc}, + alpha::Number, beta::Number, + transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb} + MatmulGraphKey( + size(a), size(b), size(c), + Tab, Tc, + Na, Nb, + transpose_a, transpose_b, + Float64(alpha), Float64(beta) ) +end + +# Build a new matmul graph (called only on cache miss) +function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, + Tab::DataType, Tc::DataType, + Na::Int, Nb::Int, + transpose_a::Bool, transpose_b::Bool, + alpha::Number, beta::Number) + graph = MPSGraph() + + placeA = placeholderTensor(graph, size_a, Tab) + placeB = placeholderTensor(graph, size_b, Tab) + placeC = placeholderTensor(graph, size_c, Tc) # cast to output eltype if input type is an integer type castT = Tab <: Integer ? Tc : Tab @@ -53,16 +104,34 @@ end transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB - nBatchA = Na == 2 ? 1 : size(transA)[1] - nBatchB = Nb == 2 ? 1 : size(transB)[1] + # Compute batch sizes for broadcasting + # For transposed tensors, we need to compute the shape after transpose + function get_batch_size(tensor, ndims, transposed) + if ndims == 2 + return 1 + else + # For N-dimensional arrays, batch is first dimension + # The placeholder has the original shape, transpose swaps last two dims + return size_a[1] # Batch dimension doesn't change with transpose + end + end + + nBatchA = Na == 2 ? 1 : size_a[1] + nBatchB = Nb == 2 ? 1 : size_b[1] # for batched matmul between different sized tensors broadcastA, broadcastB = if nBatchA == nBatchB transA, transB - elseif Na == 1 - broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB - elseif Nb == 1 - transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) + elseif nBatchA == 1 + # Need to broadcast A to match B's batch size + # After transpose, shape is (batch, rows, cols) or (rows, cols) + trans_shape_a = transpose_a ? (size_a[1:end-2]..., size_a[end], size_a[end-1]) : size_a + new_shape = (nBatchB, trans_shape_a[max(1,end-1):end]...) + broadcastTensor(graph, transA, convert(MPSShape, collect(new_shape))), transB + elseif nBatchB == 1 + trans_shape_b = transpose_b ? (size_b[1:end-2]..., size_b[end], size_b[end-1]) : size_b + new_shape = (nBatchA, trans_shape_b[max(1,end-1):end]...) + transA, broadcastTensor(graph, transB, convert(MPSShape, collect(new_shape))) else transA, transB end @@ -81,12 +150,58 @@ end castC = castTensor(graph, afterbeta, Tc, "castC") + CachedMatmulGraph(graph, placeA, placeB, placeC, castC) +end + +# Get or create cached graph +function _get_cached_graph(key::MatmulGraphKey) + # Fast path: check cache without lock (safe for reads) + cached = get(_matmul_graph_cache, key, nothing) + if cached !== nothing + return cached + end + + # Slow path: acquire lock and build graph + lock(_matmul_graph_cache_lock) do + # Double-check after acquiring lock + cached = get(_matmul_graph_cache, key, nothing) + if cached !== nothing + return cached + end + + # Build new graph + cached = _build_matmul_graph( + key.size_a, key.size_b, key.size_c, + key.eltype_ab, key.eltype_c, + key.ndims_a, key.ndims_b, + key.transpose_a, key.transpose_b, + key.alpha, key.beta + ) + _matmul_graph_cache[key] = cached + return cached + end +end + +@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, + alpha::Number, beta::Number, + transpose_a, transpose_b) where {Tc, Tab, Na, Nb} + # Get or create cached graph + key = _make_matmul_key(a, b, c, alpha, beta, transpose_a, transpose_b) + cached = _get_cached_graph(key) + + # Build feed and result dictionaries with current data + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.place_a => MPSGraphTensorData(a), + cached.place_b => MPSGraphTensorData(b), + cached.place_c => MPSGraphTensorData(c) + ) + resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( - castC => feeds[placeC] + cached.result => MPSGraphTensorData(c) ) cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) - encode!(cmdbuf, graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) + encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) commit!(cmdbuf) wait_completed(cmdbuf) From de58ed86f4884e4477e2078e236693968224bca3 Mon Sep 17 00:00:00 2001 From: Kaan Kesgin Date: Mon, 8 Dec 2025 08:25:27 +0100 Subject: [PATCH 02/10] Address review feedback: cleanup and consistency improvements - Reorder struct fields for consistency (alpha/beta before transpose, place_c before place_a) - Remove dead code (unused get_batch_size helper function) - Revert inadvertent change to broadcast logic (Na==1 vs nBatchA==1) - Update speedup claim in comment to be less specific --- lib/mpsgraphs/matmul.jl | 40 +++++++++++----------------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index d546088b0..55720066e 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -35,7 +35,7 @@ MPSGraph caching infrastructure. Creating an MPSGraph takes ~2ms per call, which dominates matmul time for small-medium matrices. By caching graphs keyed by their structural parameters (shapes, types, flags), -we achieve 3-7x speedup for repeated operations with the same configuration. +we achieve significant speedup for repeated operations with the same configuration. The cache key includes all parameters that affect graph structure: - Input/output shapes and element types @@ -52,18 +52,18 @@ struct MatmulGraphKey eltype_c::DataType ndims_a::Int ndims_b::Int - transpose_a::Bool - transpose_b::Bool alpha::Float64 # Normalized to Float64 for hashing beta::Float64 + transpose_a::Bool + transpose_b::Bool end # Cached graph with all tensors needed for execution struct CachedMatmulGraph graph::MPSGraph + place_c::MPSGraphTensor place_a::MPSGraphTensor place_b::MPSGraphTensor - place_c::MPSGraphTensor result::MPSGraphTensor end @@ -79,8 +79,8 @@ function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArra size(a), size(b), size(c), Tab, Tc, Na, Nb, - transpose_a, transpose_b, - Float64(alpha), Float64(beta) + Float64(alpha), Float64(beta), + transpose_a, transpose_b ) end @@ -104,34 +104,16 @@ function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB - # Compute batch sizes for broadcasting - # For transposed tensors, we need to compute the shape after transpose - function get_batch_size(tensor, ndims, transposed) - if ndims == 2 - return 1 - else - # For N-dimensional arrays, batch is first dimension - # The placeholder has the original shape, transpose swaps last two dims - return size_a[1] # Batch dimension doesn't change with transpose - end - end - nBatchA = Na == 2 ? 1 : size_a[1] nBatchB = Nb == 2 ? 1 : size_b[1] # for batched matmul between different sized tensors broadcastA, broadcastB = if nBatchA == nBatchB transA, transB - elseif nBatchA == 1 - # Need to broadcast A to match B's batch size - # After transpose, shape is (batch, rows, cols) or (rows, cols) - trans_shape_a = transpose_a ? (size_a[1:end-2]..., size_a[end], size_a[end-1]) : size_a - new_shape = (nBatchB, trans_shape_a[max(1,end-1):end]...) - broadcastTensor(graph, transA, convert(MPSShape, collect(new_shape))), transB - elseif nBatchB == 1 - trans_shape_b = transpose_b ? (size_b[1:end-2]..., size_b[end], size_b[end-1]) : size_b - new_shape = (nBatchA, trans_shape_b[max(1,end-1):end]...) - transA, broadcastTensor(graph, transB, convert(MPSShape, collect(new_shape))) + elseif Na == 1 + broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB + elseif Nb == 1 + transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) else transA, transB end @@ -150,7 +132,7 @@ function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, castC = castTensor(graph, afterbeta, Tc, "castC") - CachedMatmulGraph(graph, placeA, placeB, placeC, castC) + CachedMatmulGraph(graph, placeC, placeA, placeB, castC) end # Get or create cached graph From 946ac0fa5f77eb8d44c8bfdc82ff330c44616038 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:18:05 -0400 Subject: [PATCH 03/10] Simplify lock --- lib/mpsgraphs/matmul.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 55720066e..aae471cb2 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -144,23 +144,14 @@ function _get_cached_graph(key::MatmulGraphKey) end # Slow path: acquire lock and build graph - lock(_matmul_graph_cache_lock) do - # Double-check after acquiring lock - cached = get(_matmul_graph_cache, key, nothing) - if cached !== nothing - return cached - end - - # Build new graph - cached = _build_matmul_graph( + @lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do + _build_matmul_graph( key.size_a, key.size_b, key.size_c, key.eltype_ab, key.eltype_c, key.ndims_a, key.ndims_b, key.transpose_a, key.transpose_b, key.alpha, key.beta ) - _matmul_graph_cache[key] = cached - return cached end end From c4b8fed07aa253024ccbde667e00872b679c7576 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:20:53 -0400 Subject: [PATCH 04/10] Reuse MPSGraphTensorData for c and result --- lib/mpsgraphs/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index aae471cb2..003be3e1c 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -170,7 +170,7 @@ end ) resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( - cached.result => MPSGraphTensorData(c) + cached.result => feeds[cached.place_c] ) cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) From 60b7aebfa1916bcf3e74df7bb8e4a09213646a55 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:37:55 -0400 Subject: [PATCH 05/10] Change `_build_matmul_graph` arguments --- lib/mpsgraphs/matmul.jl | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 003be3e1c..04587ce12 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -85,34 +85,30 @@ function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArra end # Build a new matmul graph (called only on cache miss) -function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, - Tab::DataType, Tc::DataType, - Na::Int, Nb::Int, - transpose_a::Bool, transpose_b::Bool, - alpha::Number, beta::Number) +function _build_matmul_graph(key::MatmulGraphKey) graph = MPSGraph() - placeA = placeholderTensor(graph, size_a, Tab) - placeB = placeholderTensor(graph, size_b, Tab) - placeC = placeholderTensor(graph, size_c, Tc) + placeA = placeholderTensor(graph, key.size_a, key.eltype_ab) + placeB = placeholderTensor(graph, key.size_b, key.eltype_ab) + placeC = placeholderTensor(graph, key.size_c, key.eltype_c) # cast to output eltype if input type is an integer type - castT = Tab <: Integer ? Tc : Tab + castT = key.eltype_ab <: Integer ? key.eltype_c : key.eltype_ab castA = castTensor(graph, placeA, castT, "castA") castB = castTensor(graph, placeB, castT, "castB") - transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA - transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB + transA = key.transpose_a ? transposeTensor(graph, castA, key.ndims_a - 2, key.ndims_a - 1, "transpose_a") : castA + transB = key.transpose_b ? transposeTensor(graph, castB, key.ndims_b - 2, key.ndims_b - 1, "transpose_b") : castB - nBatchA = Na == 2 ? 1 : size_a[1] - nBatchB = Nb == 2 ? 1 : size_b[1] + nBatchA = key.ndims_a == 2 ? 1 : key.size_a[1] + nBatchB = key.ndims_b == 2 ? 1 : key.size_b[1] # for batched matmul between different sized tensors broadcastA, broadcastB = if nBatchA == nBatchB transA, transB - elseif Na == 1 + elseif key.ndims_a == 1 broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB - elseif Nb == 1 + elseif key.ndims_b == 1 transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) else transA, transB @@ -120,17 +116,17 @@ function _build_matmul_graph(size_a::Tuple, size_b::Tuple, size_c::Tuple, matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA) - afteralpha = let alphatensor = constantWithScalar(graph, alpha, castT) + afteralpha = let alphatensor = constantWithScalar(graph, key.alpha, castT) multiplicationWithPrimaryTensor(graph, alphatensor, matmul) end - afterbeta = let betatensor = constantWithScalar(graph, beta, castT) + afterbeta = let betatensor = constantWithScalar(graph, key.beta, castT) castplaceC = castTensor(graph, placeC, castT, "castplaceC") betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC) additionWithPrimaryTensor(graph, afteralpha, betaC) end - castC = castTensor(graph, afterbeta, Tc, "castC") + castC = castTensor(graph, afterbeta, key.eltype_c, "castC") CachedMatmulGraph(graph, placeC, placeA, placeB, castC) end @@ -145,13 +141,7 @@ function _get_cached_graph(key::MatmulGraphKey) # Slow path: acquire lock and build graph @lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do - _build_matmul_graph( - key.size_a, key.size_b, key.size_c, - key.eltype_ab, key.eltype_c, - key.ndims_a, key.ndims_b, - key.transpose_a, key.transpose_b, - key.alpha, key.beta - ) + _build_matmul_graph(key) end end From b269dba3e69caad4eab831cf22f933cb6bfa9db5 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:42:56 -0400 Subject: [PATCH 06/10] Constructors --- lib/mpsgraphs/matmul.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 04587ce12..6bb9cfc31 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -57,6 +57,18 @@ struct MatmulGraphKey transpose_a::Bool transpose_b::Bool end +# Build graph key from matmul parameters +function MatmulGraphKey(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc}, + alpha::Number, beta::Number, + transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb} + MatmulGraphKey( + size(a), size(b), size(c), + Tab, Tc, + Na, Nb, + Float64(alpha), Float64(beta), + transpose_a, transpose_b + ) +end # Cached graph with all tensors needed for execution struct CachedMatmulGraph @@ -71,21 +83,9 @@ end const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() const _matmul_graph_cache_lock = ReentrantLock() -# Build graph key from matmul parameters -function _make_matmul_key(a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, c::MtlArray{Tc}, - alpha::Number, beta::Number, - transpose_a::Bool, transpose_b::Bool) where {Tc, Tab, Na, Nb} - MatmulGraphKey( - size(a), size(b), size(c), - Tab, Tc, - Na, Nb, - Float64(alpha), Float64(beta), - transpose_a, transpose_b - ) -end # Build a new matmul graph (called only on cache miss) -function _build_matmul_graph(key::MatmulGraphKey) +function CachedMatmulGraph(key::MatmulGraphKey) graph = MPSGraph() placeA = placeholderTensor(graph, key.size_a, key.eltype_ab) @@ -141,7 +141,7 @@ function _get_cached_graph(key::MatmulGraphKey) # Slow path: acquire lock and build graph @lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do - _build_matmul_graph(key) + CachedMatmulGraph(key) end end @@ -149,7 +149,7 @@ end alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb} # Get or create cached graph - key = _make_matmul_key(a, b, c, alpha, beta, transpose_a, transpose_b) + key = MatmulGraphKey(a, b, c, alpha, beta, transpose_a, transpose_b) cached = _get_cached_graph(key) # Build feed and result dictionaries with current data From 6baaff0525a165b30ebe294177839478e07abe1a Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:57:29 -0400 Subject: [PATCH 07/10] Use less global state --- lib/mpsgraphs/matmul.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 6bb9cfc31..0230e4b30 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -52,7 +52,7 @@ struct MatmulGraphKey eltype_c::DataType ndims_a::Int ndims_b::Int - alpha::Float64 # Normalized to Float64 for hashing + alpha::Float64 beta::Float64 transpose_a::Bool transpose_b::Bool @@ -78,12 +78,6 @@ struct CachedMatmulGraph place_b::MPSGraphTensor result::MPSGraphTensor end - -# Thread-safe graph cache with lock -const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() -const _matmul_graph_cache_lock = ReentrantLock() - - # Build a new matmul graph (called only on cache miss) function CachedMatmulGraph(key::MatmulGraphKey) graph = MPSGraph() @@ -132,25 +126,28 @@ function CachedMatmulGraph(key::MatmulGraphKey) end # Get or create cached graph -function _get_cached_graph(key::MatmulGraphKey) +function _get_cached_graph!(graph_cache_lock, graph_cache, key::MatmulGraphKey) # Fast path: check cache without lock (safe for reads) - cached = get(_matmul_graph_cache, key, nothing) + cached = get(graph_cache, key, nothing) if cached !== nothing return cached end # Slow path: acquire lock and build graph - @lock _matmul_graph_cache_lock get!(_matmul_graph_cache, key) do + @lock graph_cache_lock get!(graph_cache, key) do CachedMatmulGraph(key) end end +# Thread-safe graph cache with lock +const _matmul_graph_cache = Dict{MatmulGraphKey, CachedMatmulGraph}() +const _matmul_graph_cache_lock = ReentrantLock() @autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb} # Get or create cached graph key = MatmulGraphKey(a, b, c, alpha, beta, transpose_a, transpose_b) - cached = _get_cached_graph(key) + cached = _get_cached_graph!(_matmul_graph_cache_lock, _matmul_graph_cache, key) # Build feed and result dictionaries with current data feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( From 3adc518586800578e96d81a74dde8c0a3c59feb8 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 16:37:43 -0400 Subject: [PATCH 08/10] Fixup flopscomp.jl script --- examples/flopscomp.jl | 42 +++++++++++++----------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/examples/flopscomp.jl b/examples/flopscomp.jl index 3fda1292f..cb87da3b6 100644 --- a/examples/flopscomp.jl +++ b/examples/flopscomp.jl @@ -1,4 +1,4 @@ -using Metal, GPUArrays, LinearAlgebra, Printf#, AppleAccelerate +using Metal, GPUArrays, LinearAlgebra, Printf, ScopedValues#, AppleAccelerate testing = (@isdefined TESTING) && TESTING @@ -54,16 +54,10 @@ function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true) return n_batch*2*Float64(n)^3 / minimum(t) end -function gpuarrpeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) +function gpuarrpeakflops(; n_batch::Integer=1, + kwargs...) n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1" - _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b - GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0) - end + @with (Metal.matmul_alg => :GPUArrays) defaultpeakflops(; n_batch, kwargs...) end function defaultpeakflops(; n::Integer=4096, n_batch::Integer=1, @@ -71,25 +65,15 @@ function defaultpeakflops(; n::Integer=4096, outT::DataType=inT, ntrials::Integer=3, verify=true) - _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b + _peakflops(n, n_batch, inT, outT, ntrials; verify) do c, a, b LinearAlgebra.generic_matmatmul!(c, 'N', 'N', a, b, 1, 0) end end -function mpspeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) - _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify) +function mpspeakflops(; kwargs...) + @with (Metal.matmul_alg => :MPS) defaultpeakflops(; kwargs...) end -function graphpeakflops(; n::Integer=4096, - n_batch::Integer=1, - inT::DataType=Float32, - outT::DataType=inT, - ntrials::Integer=3, - verify=true) - _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify) +function graphpeakflops(; kwargs...) + @with (Metal.matmul_alg => :MPSGraph) defaultpeakflops(; kwargs...) end function anepeakflops(; kwargs...) # VERY HACKY @@ -139,9 +123,9 @@ DEFAULT_FS = [ (mpspeakflops, "MPS"), (graphpeakflops, "MPSGraph"), (defaultpeakflops, "Default"), - # (anepeakflops, "MPSGraph (ANE)"), - # (gpuarrpeakflops, "GPUArrays"), + (gpuarrpeakflops, "GPUArrays"), # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance + # (anepeakflops, "MPSGraph (ANE)"), # Run last to prevent different line colours ] function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, verbose=true) @@ -152,7 +136,7 @@ function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5, ver return res end -function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE) +function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, fileext="svg", plt_title=PLOT_TITLE) Fs = get.(Fs, 2, "You shouldn't be reading this") ylim_upper = 9e12 resplts = [] @@ -184,7 +168,7 @@ function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_ti bottommargin=15pt, size=(2000,1200)) if !isnothing(outpath) - savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype")) + savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$fileext")) end return finalplot end From b461c8f20f8fc7457ca79a6b422a1df80b35546e Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 17:54:37 -0400 Subject: [PATCH 09/10] Always use MPSGraphs matmul by default --- src/linalg.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 017b21772..e232aa9ae 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -21,15 +21,6 @@ end C.offset == 0 end -# Assumes support for MPS matrix multiplication has been verified elsewhere -@inline function should_use_MPS(A, _, C) - rows = size(C,1) - cols = size(C,2) - # TODO: matvecmul different? - (eltype(A) <: Integer && rows <= 2000 && cols <= 2000 ) || - eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9) -end - # Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays const matmul_alg = ScopedValue(:auto) matmul_alg_error(alg, inT, outT, vec) = error("Matrix-$(vec ? "Vector" : "Matrix") multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.") @@ -63,7 +54,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) # If possible, dispatch to MPSGraphs, then performance shaders - if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C)) + if alg === :MPSGraph || (alg === :auto && mpsgraph_supported) mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), false) graph_matmul!(C, A, B, alpha, beta, transA, transB) elseif alg === :MPS || (alg === :auto && mps_supported) From ed2615caa1331bf3210c1d515e4550a31e5fec03 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 18:15:48 -0400 Subject: [PATCH 10/10] Tweak comment --- lib/mpsgraphs/matmul.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl index 0230e4b30..3fb30b45a 100644 --- a/lib/mpsgraphs/matmul.jl +++ b/lib/mpsgraphs/matmul.jl @@ -33,9 +33,9 @@ end #= MPSGraph caching infrastructure. -Creating an MPSGraph takes ~2ms per call, which dominates matmul time for small-medium -matrices. By caching graphs keyed by their structural parameters (shapes, types, flags), -we achieve significant speedup for repeated operations with the same configuration. +The overhead of creating an MPSGraph dominates matmul time for small-medium matrices. +By caching graphs keyed by their structural parameters (shapes, types, flags), we +achieve significant speedup for repeated operations with the same configuration. The cache key includes all parameters that affect graph structure: - Input/output shapes and element types