Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
*.jl.*.cov
*.jl.mem
deps/deps.jl
Manifest.toml
Manifest.toml
Manifest-v*.toml
56 changes: 28 additions & 28 deletions src/CachedInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ A single `CachedInterpolation` represents a "movable"
i_2)`. An array of these objects thus implements an array-of-arrays
interface. Create them with `cachedinterpolators`.
"""
mutable struct CachedInterpolation{T,N,M,O,K} <: AbstractInterpolation{T,N,BSpline{Quadratic{InPlace}}}
mutable struct CachedInterpolation{T, N, M, O, K} <: AbstractInterpolation{T, N, BSpline{Quadratic{InPlace}}}
# Note: M = N+K
coefs::Array{T,M} # tiled array of 3x3x... buffers
parent::Array{T,M} # the overall array (`P` in the documentation above)
center::NTuple{N,Int} # rounded (y_1, y_2) of prev. eval for this tile
coefs::Array{T, M} # tiled array of 3x3x... buffers
parent::Array{T, M} # the overall array (`P` in the documentation above)
center::NTuple{N, Int} # rounded (y_1, y_2) of prev. eval for this tile
tileindex::CartesianIndex{K}
end

const splitN = Base.IteratorsMD.split
Base.size(itp::CachedInterpolation{T,N}) where {T,N} = splitN(size(itp.parent), Val(N))[1]
Base.size(itp::CachedInterpolation{T,N}, d) where {T,N} = d <= N ? size(itp.parent, d) : 1
Base.size(itp::CachedInterpolation{T, N}) where {T, N} = splitN(size(itp.parent), Val(N))[1]
Base.size(itp::CachedInterpolation{T, N}, d) where {T, N} = d <= N ? size(itp.parent, d) : 1

Base.axes(itp::CachedInterpolation{T,N,M,O}) where {T,N,M,O} =
map((x,o)->x.-o, splitN(axes(itp.parent), Val(N))[1], O)
Base.axes(itp::CachedInterpolation{T,N,M,O}, d) where {T,N,M,O} =
Base.axes(itp::CachedInterpolation{T, N, M, O}) where {T, N, M, O} =
map((x, o) -> x .- o, splitN(axes(itp.parent), Val(N))[1], O)
Base.axes(itp::CachedInterpolation{T, N, M, O}, d) where {T, N, M, O} =
d <= N ? axes(itp.parent, d) .- O[d] : Base.OneTo(1)

"""
Expand All @@ -86,34 +86,34 @@ coordinates within a tile. For example, one can mimic a
`CenterIndexedArray` when `size(parent, d)` is odd for the first `N`
dimensions and you supply `origin = (div(size(parent,1)+1, 2), ...)`.
"""
function cachedinterpolators(parent::Array{T,M}, N, origin=ntuple(d->0,N)) where {T,M}
function cachedinterpolators(parent::Array{T, M}, N, origin = ntuple(d -> 0, N)) where {T, M}
0 <= N <= M || error("N must be between 0 and $M")
length(origin) == N || throw(DimensionMismatch("length(origin) = $(length(origin)) is inconsistent with $N interpolating dimensions"))
sz3 = ntuple(d->d<=N ? 3 : size(parent,d), Val(M))
sz3 = ntuple(d -> d <= N ? 3 : size(parent, d), Val(M))
buffer = Array{eltype(parent)}(undef, sz3)
sztiles = size(parent)[N+1:end] # the tiling dimensions of parent
sztiles = size(parent)[(N + 1):end] # the tiling dimensions of parent
# use an impossible initial value (post-offset by origin) to
# ensure the first access will result in a cache miss
center = ntuple(d->-1, N)
cachedinterpolators(buffer, parent, origin, center, sztiles)
center = ntuple(d -> -1, N)
return cachedinterpolators(buffer, parent, origin, center, sztiles)
end

# function-barriered to circumvent type-instability in sztiles
@noinline function cachedinterpolators(buffer::Array{T,M}, parent::Array{T,M}, origin::NTuple{N,Int}, center::NTuple{N,Int}, sztiles::NTuple{K,Int}) where {T,N,M,K}
itps = Array{CachedInterpolation{T,N,M,origin,K}}(undef, sztiles)
@noinline function cachedinterpolators(buffer::Array{T, M}, parent::Array{T, M}, origin::NTuple{N, Int}, center::NTuple{N, Int}, sztiles::NTuple{K, Int}) where {T, N, M, K}
itps = Array{CachedInterpolation{T, N, M, origin, K}}(undef, sztiles)
for tileindex in CartesianIndices(sztiles)
itps[tileindex] = CachedInterpolation{T,N,M,origin,K}(buffer, parent, center, tileindex)
itps[tileindex] = CachedInterpolation{T, N, M, origin, K}(buffer, parent, center, tileindex)
end
itps
return itps
end

@inline function (itp::CachedInterpolation{T,N,M,O,K})(xs::Vararg{Number,N}) where {T,N,M,O,K}
@inline function (itp::CachedInterpolation{T, N, M, O, K})(xs::Vararg{Number, N}) where {T, N, M, O, K}
coefs, parent, center, tileindex = itp.coefs, itp.parent, itp.center, itp.tileindex
ixs = round.(Int, xs)
fxs = xs .- ixs .+ 2
newcenter = ixs .+ O
sz3 = ntuple(d->3, Val(N))
itpinfo = (ntuple(d->BSpline(Quadratic(InPlace(OnCell()))), Val(N))..., ntuple(d->NoInterp(), Val(K))...)
sz3 = ntuple(d -> 3, Val(N))
itpinfo = (ntuple(d -> BSpline(Quadratic(InPlace(OnCell()))), Val(N))..., ntuple(d -> NoInterp(), Val(K))...)
if newcenter != center
# Copy the relevant portion from parent into buffer
offset = CartesianIndex(newcenter .- 2)
Expand All @@ -127,7 +127,7 @@ end
return icoefs[wis...]
end

struct CoefsWrapper{N,A}
struct CoefsWrapper{N, A}
coefs::A
end

Expand All @@ -138,27 +138,27 @@ Base.size(itp::CoefsWrapper{N}, d) where {N} = d <= N ? size(itp.coefs, d) : 1
# update the cache. This is equivalent to the assumption you've called
# getindex for the current `(x_1, x_2, ...)` location before calling
# gradient!. If this is not true, you'll get wrong answers.
@inline function Interpolations.gradient(itp::CachedInterpolation{T,N,M,O,K}, ys::Vararg{Number,N}) where {T,N,M,O,K}
@inline function Interpolations.gradient(itp::CachedInterpolation{T, N, M, O, K}, ys::Vararg{Number, N}) where {T, N, M, O, K}
coefs, tileindex = itp.coefs, itp.tileindex
xs = ys .- round.(Int, ys) .+ 2
itpinfo = (ntuple(d->BSpline(Quadratic(InPlace(OnCell()))), Val(N))..., ntuple(d->NoInterp(), Val(K))...)
itpinfo = (ntuple(d -> BSpline(Quadratic(InPlace(OnCell()))), Val(N))..., ntuple(d -> NoInterp(), Val(K))...)
wis = weightedindexes((value_weights, gradient_weights), itpinfo, axes(coefs), (xs..., Tuple(tileindex)...))
icoefs = InterpGetindex(coefs)
return SVector(map(inds->icoefs[inds...], wis))
return SVector(map(inds -> icoefs[inds...], wis))
end

@inline function Interpolations.gradient!(g::AbstractVector, itp::CachedInterpolation{T,N,M,O,K}, ys::Vararg{Number,N}) where {T,N,M,O,K}
@inline function Interpolations.gradient!(g::AbstractVector, itp::CachedInterpolation{T, N, M, O, K}, ys::Vararg{Number, N}) where {T, N, M, O, K}
gs = Interpolations.gradient(itp, ys...)
return copyto!(g, gs)
end

### Potential deprecations

# if AbstractInterpolation <: AbstractArray goes away, this can be deprecated
getindex(itp::CachedInterpolation{T,N,M,O,K}, xs::Vararg{Int,N}) where {T,N,M,O,K} = itp(xs...)
getindex(itp::CachedInterpolation{T, N, M, O, K}, xs::Vararg{Int, N}) where {T, N, M, O, K} = itp(xs...)

### Deprecations

@deprecate getindex(itp::CachedInterpolation{T,N,M,O,K}, xs::Vararg{Number,N}) where {T,N,M,O,K} itp(xs...)
@deprecate getindex(itp::CachedInterpolation{T, N, M, O, K}, xs::Vararg{Number, N}) where {T, N, M, O, K} itp(xs...)

end # module
66 changes: 34 additions & 32 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
import CachedInterpolations
using Interpolations, Test

A = reshape([0;1;0], (3,1))
C = CachedInterpolations.cachedinterpolators(A, 1)
@test C[1](2.2) ≈ 3/4-0.2^2
@test C[1](1.7) ≈ 3/4-0.3^2
@testset "CachedInterpolations" begin
A = reshape([0; 1; 0], (3, 1))
C = CachedInterpolations.cachedinterpolators(A, 1)
@test C[1](2.2) ≈ 3 / 4 - 0.2^2
@test C[1](1.7) ≈ 3 / 4 - 0.3^2

A = rand(7,7,2,2,3)
Q = BSpline(Quadratic(InPlace(OnCell())))
# Note the next line will modify A, and the modified A will be used for C.
# This is what we want to happen.
Ai = interpolate!(A, (Q, Q, NoInterp(), NoInterp(), NoInterp()))
C = CachedInterpolations.cachedinterpolators(A, 2)
@test size(C) == (2,2,3)
c = C[1,1,1]
@test size(c) == (7,7)
@test size(c,1) == 7
@test size(c,2) == 7
@test size(c,3) == 1
@test @inferred(C[1,2,2](3.2, 4.8)) == Ai(3.2,4.8,1,2,2)
@test C[1,2,2](3.2,4.9) == Ai(3.2,4.9,1,2,2)
@test C[1,2,2](3.2,3.8) == Ai(3.2,3.8,1,2,2)
A = rand(7, 7, 2, 2, 3)
Q = BSpline(Quadratic(InPlace(OnCell())))
# Note the next line will modify A, and the modified A will be used for C.
# This is what we want to happen.
Ai = interpolate!(A, (Q, Q, NoInterp(), NoInterp(), NoInterp()))
C = CachedInterpolations.cachedinterpolators(A, 2)
@test size(C) == (2, 2, 3)
c = C[1, 1, 1]
@test size(c) == (7, 7)
@test size(c, 1) == 7
@test size(c, 2) == 7
@test size(c, 3) == 1
@test @inferred(C[1, 2, 2](3.2, 4.8)) == Ai(3.2, 4.8, 1, 2, 2)
@test C[1, 2, 2](3.2, 4.9) == Ai(3.2, 4.9, 1, 2, 2)
@test C[1, 2, 2](3.2, 3.8) == Ai(3.2, 3.8, 1, 2, 2)

@test Interpolations.gradient(C[1,2,2], 3.2, 3.8) === Interpolations.gradient(Ai, 3.2, 3.8, 1, 2, 2)
@test Interpolations.gradient(C[1, 2, 2], 3.2, 3.8) === Interpolations.gradient(Ai, 3.2, 3.8, 1, 2, 2)

# With origin
C = CachedInterpolations.cachedinterpolators(A, 2, (4,4))
@test C[1,2,2](-0.8,0.8) ≈ Ai(3.2,4.8,1,2,2)
@test C[1,2,2](-0.8,0.9) ≈ Ai(3.2,4.9,1,2,2)
@test C[1,2,2](-0.8,-0.2) ≈ Ai(3.2,3.8,1,2,2)
@test Interpolations.gradient(C[1,2,2], -0.8, -0.2) ≈ Interpolations.gradient(Ai, 3.2, 3.8, 1, 2, 2)
# With origin
C = CachedInterpolations.cachedinterpolators(A, 2, (4, 4))
@test C[1, 2, 2](-0.8, 0.8) ≈ Ai(3.2, 4.8, 1, 2, 2)
@test C[1, 2, 2](-0.8, 0.9) ≈ Ai(3.2, 4.9, 1, 2, 2)
@test C[1, 2, 2](-0.8, -0.2) ≈ Ai(3.2, 3.8, 1, 2, 2)
@test Interpolations.gradient(C[1, 2, 2], -0.8, -0.2) ≈ Interpolations.gradient(Ai, 3.2, 3.8, 1, 2, 2)

# Check for Float32 with Float64 indexes, since that's the
# default mismatch case
A = rand(Float32,7,7,2,2,3)
Ai = interpolate!(A, (Q, Q, NoInterp(), NoInterp(), NoInterp()))
C = CachedInterpolations.cachedinterpolators(A, 2, (4,4))
@test @inferred(C[1,2,2](-0.8, 0.8)) ≈ Ai(3.2,4.8,1,2,2)
# Check for Float32 with Float64 indexes, since that's the
# default mismatch case
A = rand(Float32, 7, 7, 2, 2, 3)
Ai = interpolate!(A, (Q, Q, NoInterp(), NoInterp(), NoInterp()))
C = CachedInterpolations.cachedinterpolators(A, 2, (4, 4))
@test @inferred(C[1, 2, 2](-0.8, 0.8)) ≈ Ai(3.2, 4.8, 1, 2, 2)
end
Loading