Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions src/CenterIndexedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,86 @@ A `CenterIndexedArray` is one for which the array center has indexes
CenterIndexedArray(A) "converts" `A` into a CenterIndexedArray. All
the sizes of `A` must be odd.
"""
struct CenterIndexedArray{T,N,A<:AbstractArray} <: AbstractArray{T,N}
struct CenterIndexedArray{T, N, A <: AbstractArray} <: AbstractArray{T, N}
data::A
halfsize::NTuple{N,Int}
halfsize::NTuple{N, Int}

function CenterIndexedArray{T,N,A}(data::A) where {T,N,A<:AbstractArray}
new{T,N,A}(data, _halfsize(data))
function CenterIndexedArray{T, N, A}(data::A) where {T, N, A <: AbstractArray}
return new{T, N, A}(data, _halfsize(data))
end
end

CenterIndexedArray(A::AbstractArray{T,N}) where {T,N} = CenterIndexedArray{T,N,typeof(A)}(A)
CenterIndexedArray{T,N}(::UndefInitializer, sz::Vararg{Integer,N}) where {T,N} =
CenterIndexedArray(Array{T,N}(undef, sz...))
CenterIndexedArray{T,N}(::UndefInitializer, sz::NTuple{N,Integer}) where {T,N} =
CenterIndexedArray(Array{T,N}(undef, sz))
CenterIndexedArray{T}(::UndefInitializer, sz::Vararg{Integer,N}) where {T,N} =
CenterIndexedArray{T,N}(undef, sz...)
CenterIndexedArray{T}(::UndefInitializer, sz::NTuple{N,Integer}) where {T,N} =
CenterIndexedArray{T,N}(undef, sz)
CenterIndexedArray(A::AbstractArray{T, N}) where {T, N} = CenterIndexedArray{T, N, typeof(A)}(A)
CenterIndexedArray{T, N}(::UndefInitializer, sz::Vararg{Integer, N}) where {T, N} =
CenterIndexedArray(Array{T, N}(undef, sz...))
CenterIndexedArray{T, N}(::UndefInitializer, sz::NTuple{N, Integer}) where {T, N} =
CenterIndexedArray(Array{T, N}(undef, sz))
CenterIndexedArray{T}(::UndefInitializer, sz::Vararg{Integer, N}) where {T, N} =
CenterIndexedArray{T, N}(undef, sz...)
CenterIndexedArray{T}(::UndefInitializer, sz::NTuple{N, Integer}) where {T, N} =
CenterIndexedArray{T, N}(undef, sz)

# This is the AbstractArray default, but do this just to be sure
Base.IndexStyle(::Type{A}) where {A<:CenterIndexedArray} = IndexCartesian()
Base.IndexStyle(::Type{A}) where {A <: CenterIndexedArray} = IndexCartesian()

Base.size(A::CenterIndexedArray) = size(A.data)
Base.axes(A::CenterIndexedArray) = map(SymRange, A.halfsize)

const SymAx = Union{SymRange, Base.Slice{SymRange}}
Base.axes(r::Base.Slice{SymRange}) = (r.indices,)

function Base.similar(A::CenterIndexedArray, ::Type{T}, inds::Tuple{SymAx,Vararg{SymAx}}) where T
function Base.similar(A::CenterIndexedArray, ::Type{T}, inds::Tuple{SymAx, Vararg{SymAx}}) where {T}
data = Array{T}(undef, map(length, inds))
CenterIndexedArray(data)
return CenterIndexedArray(data)
end
function Base.similar(::Type{T}, inds::Tuple{SymAx, Vararg{SymAx}}) where T<:AbstractArray
function Base.similar(::Type{T}, inds::Tuple{SymAx, Vararg{SymAx}}) where {T <: AbstractArray}
data = Array{eltype(T)}(undef, map(length, inds))
CenterIndexedArray(data)
return CenterIndexedArray(data)
end

# This is incomplete: ideally we wouldn't need SymAx in the first slot
# as long as there was at least one SymAx.
function Base.similar(A::CenterIndexedArray, ::Type{T}, inds::Tuple{SymAx,Vararg{Union{Int,<:IdentityUnitRange,SymAx}}}) where T
function Base.similar(A::CenterIndexedArray, ::Type{T}, inds::Tuple{SymAx, Vararg{Union{Int, <:IdentityUnitRange, SymAx}}}) where {T}
torange(n) = isa(n, Int) ? Base.OneTo(n) : n
return OffsetArray{T}(undef, map(torange, inds))
end


function _halfsize(A::AbstractArray)
all(isodd, size(A)) || error("Must have all-odd sizes")
map(n->n>>UInt(1), size(A))
return map(n -> n >> UInt(1), size(A))
end

@inline function Base.getindex(A::CenterIndexedArray{T,N}, i::Vararg{Int,N}) where {T,N}
@inline function Base.getindex(A::CenterIndexedArray{T, N}, i::Vararg{Int, N}) where {T, N}
@boundscheck checkbounds(A, i...)
@inbounds val = A.data[map(offset, A.halfsize, i)...]
val
return val
end

Base.@propagate_inbounds Base.getindex(A::CenterIndexedArray{T,N,I}, i::Vararg{Int,N}) where {T,N,I<:AbstractInterpolation} =
Base.@propagate_inbounds Base.getindex(A::CenterIndexedArray{T, N, I}, i::Vararg{Int, N}) where {T, N, I <: AbstractInterpolation} =
_getindex(A, i...)
Base.@propagate_inbounds Base.getindex(A::CenterIndexedArray{T,N,I}, i::Vararg{Number,N}) where {T,N,I<:AbstractInterpolation} =
Base.@propagate_inbounds Base.getindex(A::CenterIndexedArray{T, N, I}, i::Vararg{Number, N}) where {T, N, I <: AbstractInterpolation} =
_getindex(A, i...)

@inline function _getindex(A::CenterIndexedArray{T,N,I}, i::Vararg{Number,N}) where {T,N,I<:AbstractInterpolation}
@inline function _getindex(A::CenterIndexedArray{T, N, I}, i::Vararg{Number, N}) where {T, N, I <: AbstractInterpolation}
@boundscheck checkbounds(A, i...)
@inbounds val = A.data(map(offset, A.halfsize, i)...)
val
return val
end
Base.throw_boundserror(A::CenterIndexedArray, I) = (Base.@_noinline_meta; throw(BoundsError(A, I)))

offset(off, i) = off+i+1
offset(off, i) = off + i + 1

@inline function Base.setindex!(A::CenterIndexedArray{T,N}, v, i::Vararg{Int,N}) where {T,N}
@inline function Base.setindex!(A::CenterIndexedArray{T, N}, v, i::Vararg{Int, N}) where {T, N}
@boundscheck checkbounds(A, i...)
@inbounds A.data[map(offset, A.halfsize, i)...] = v
v
return v
end


Base.BroadcastStyle(::Type{<:CenterIndexedArray}) = Broadcast.ArrayStyle{CenterIndexedArray}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{CenterIndexedArray}}, ::Type{ElType}) where ElType
similar(Array{ElType}, axes(bc))
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{CenterIndexedArray}}, ::Type{ElType}) where {ElType}
return similar(Array{ElType}, axes(bc))
end

Base.parent(A::CenterIndexedArray) = A.data
Expand All @@ -102,7 +102,7 @@ function Base.showarg(io::IO, A::CenterIndexedArray, toplevel)
print(io, "CenterIndexedArray(")
Base.showarg(io, parent(A), false)
print(io, ')')
toplevel && print(io, " with eltype ", eltype(A))
return toplevel && print(io, " with eltype ", eltype(A))
end

include("deprecated.jl")
Expand Down
10 changes: 5 additions & 5 deletions src/symrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ SymRange(n::Integer) = SymRange(Int(n))

function SymRange(r::AbstractUnitRange)
first(r) == -last(r) || error("cannot convert $r to a SymRange")
SymRange(last(r))
return SymRange(last(r))
end

Base.first(r::SymRange) = -r.n
Expand All @@ -19,12 +19,12 @@ Base.axes(r::SymRange) = (r,)

function iterate(r::SymRange)
r.n == 0 && return nothing
first(r), first(r)
return first(r), first(r)
end

function iterate(r::SymRange, s)
s == last(r) && return nothing
copy(s+1), s+1
return copy(s + 1), s + 1
end

@inline function Base.getindex(v::SymRange, i::Int)
Expand All @@ -36,7 +36,7 @@ Base.intersect(r::SymRange, s::SymRange) = SymRange(min(last(r), last(s)))

@inline function Base.getindex(r::SymRange, s::SymRange)
@boundscheck checkbounds(r, s)
s
return s
end

@inline function Base.getindex(r::SymRange, s::AbstractUnitRange{<:Integer})
Expand All @@ -46,7 +46,7 @@ end

# TODO: should we be worried about the mismatch in axes?
# And should `convert(SymRange, r)` fail if axes(r) isn't the same as the result?
Base.promote_rule(::Type{SymRange}, ::Type{UR}) where {UR<:AbstractUnitRange} =
Base.promote_rule(::Type{SymRange}, ::Type{UR}) where {UR <: AbstractUnitRange} =
UR
Base.promote_rule(::Type{UnitRange{T2}}, ::Type{SymRange}) where {T2} =
UnitRange{promote_type(T2, Int)}
Expand Down
68 changes: 34 additions & 34 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
@testset "SymRange" begin
r = SymRange(3)
@test first(r) == -3
@test last(r) == 3
@test last(r) == 3
@test axes(r) == (r,)
@test axes(r, 1) == r
@test size(r) == (7,)
Expand Down Expand Up @@ -49,15 +49,15 @@ end
end

@testset "Uninitialized" begin
@test isa(CenterIndexedArray{Float32,2}(undef, 3, 5), CenterIndexedArray)
@test isa(CenterIndexedArray{Float32,2}(undef, (3, 5)), CenterIndexedArray)
@test isa(CenterIndexedArray{Float32, 2}(undef, 3, 5), CenterIndexedArray)
@test isa(CenterIndexedArray{Float32, 2}(undef, (3, 5)), CenterIndexedArray)
@test isa(CenterIndexedArray{Float32}(undef, 3, 5), CenterIndexedArray)
@test isa(CenterIndexedArray{Float32}(undef, (3, 5)), CenterIndexedArray)
@test_throws ErrorException CenterIndexedArray{Float32}(undef, 4, 5)
end

@testset "Construction & traits" begin
dat = rand(3,5)
dat = rand(3, 5)
A = CenterIndexedArray(dat)
@test size(A) == size(dat)
@test axes(A) === (SymRange(1), SymRange(2))
Expand All @@ -67,29 +67,29 @@ end
end

@testset "Indexing & iteration" begin
dat = rand(3,5)
dat = rand(3, 5)
A = CenterIndexedArray(dat)
@test A[0,0] == dat[2,3]
@test A[0, 0] == dat[2, 3]
k = 0
for j = -2:2, i = -1:1
@test @inferred(A[i,j]) == dat[k+=1]
for j in -2:2, i in -1:1
@test @inferred(A[i, j]) == dat[k += 1]
end
@test_throws BoundsError A[3,5]
@test @inferred(A[:,:]) == A
@test @inferred(A[:,SymRange(1)]) == CenterIndexedArray(dat[:,2:4])
@test @inferred(A[SymRange(1),:]) == A
@test @inferred(A[:,-2:0]) == OffsetArray(dat[:,1:3], -1:1, 1:3) # axes-of-the-axes
@test @inferred(A[:,IdentityUnitRange(-2:0)]) == OffsetArray(dat[:,1:3], -1:1, -2:0)
@test_throws BoundsError A[3, 5]
@test @inferred(A[:, :]) == A
@test @inferred(A[:, SymRange(1)]) == CenterIndexedArray(dat[:, 2:4])
@test @inferred(A[SymRange(1), :]) == A
@test @inferred(A[:, -2:0]) == OffsetArray(dat[:, 1:3], -1:1, 1:3) # axes-of-the-axes
@test @inferred(A[:, IdentityUnitRange(-2:0)]) == OffsetArray(dat[:, 1:3], -1:1, -2:0)
k = 0
for j = -2:2, i = -1:1
A[i,j] = (k+=1)
for j in -2:2, i in -1:1
A[i, j] = (k += 1)
end
@test dat == reshape(1:15, 3, 5)
@test_throws BoundsError A[3,5] = 15
@test_throws BoundsError A[3, 5] = 15

rand!(dat)
iall = (-1:1).*ones(Int, 5)'
jall = ones(Int,3).*(-2:2)'
iall = (-1:1) .* ones(Int, 5)'
jall = ones(Int, 3) .* (-2:2)'
k = 0
for I in eachindex(A)
k += 1
Expand All @@ -112,7 +112,7 @@ end
end

@testset "Operations" begin
dat = rand(3,5)
dat = rand(3, 5)
A = CenterIndexedArray(dat)
# Standard julia operations
B = copy(A)
Expand All @@ -125,8 +125,8 @@ end

@test minimum(A) == minimum(dat)
@test maximum(A) == maximum(dat)
@test minimum(A,dims=1) == CenterIndexedArray(minimum(dat,dims=1))
@test maximum(A,dims=2) == CenterIndexedArray(maximum(dat,dims=2))
@test minimum(A, dims = 1) == CenterIndexedArray(minimum(dat, dims = 1))
@test maximum(A, dims = 2) == CenterIndexedArray(maximum(dat, dims = 2))

amin, iamin = findmin(A)
dmin, idmin = findmin(dat)
Expand All @@ -141,14 +141,14 @@ end
@test amax == dat[idmax]

fill!(A, 2)
@test all(x->x==2, A)
@test all(x -> x == 2, A)

ii, jj = begin
II = findall(!iszero, A)
(getindex.(II, 1), getindex.(II, 2))
end
iall = (-1:1).*ones(Int, 5)'
jall = ones(Int,3).*(-2:2)'
iall = (-1:1) .* ones(Int, 5)'
jall = ones(Int, 3) .* (-2:2)'
@test vec(ii) == vec(iall)
@test vec(jj) == vec(jall)

Expand All @@ -157,20 +157,20 @@ end
# @test cat(1, A, dat) == cat(1, dat, dat)
# @test cat(2, A, dat) == cat(2, dat, dat)

@test permutedims(A, (2,1)) == CenterIndexedArray(permutedims(dat, (2,1)))
@test permutedims(A, (2, 1)) == CenterIndexedArray(permutedims(dat, (2, 1)))
# @test ipermutedims(A, (2,1)) == CenterIndexedArray(ipermutedims(dat, (2,1)))

@test cumsum(A, dims=1) == CenterIndexedArray(cumsum(dat, dims=1))
@test cumsum(A, dims=2) == CenterIndexedArray(cumsum(dat, dims=2))
@test cumsum(A, dims = 1) == CenterIndexedArray(cumsum(dat, dims = 1))
@test cumsum(A, dims = 2) == CenterIndexedArray(cumsum(dat, dims = 2))

@test mapslices(v->sort(v), A, dims=1) == CenterIndexedArray(mapslices(v->sort(v), dat, dims=1))
@test mapslices(v->sort(v), A, dims=2) == CenterIndexedArray(mapslices(v->sort(v), dat, dims=2))
@test mapslices(v -> sort(v), A, dims = 1) == CenterIndexedArray(mapslices(v -> sort(v), dat, dims = 1))
@test mapslices(v -> sort(v), A, dims = 2) == CenterIndexedArray(mapslices(v -> sort(v), dat, dims = 2))

@test reverse(A, dims=1) == CenterIndexedArray(reverse(dat, dims=1))
@test reverse(A, dims=2) == CenterIndexedArray(reverse(dat, dims=2))
@test reverse(A, dims = 1) == CenterIndexedArray(reverse(dat, dims = 1))
@test reverse(A, dims = 2) == CenterIndexedArray(reverse(dat, dims = 2))

@test A .+ 1 == CenterIndexedArray(dat .+ 1)
@test 2*A == CenterIndexedArray(2*dat)
@test A+A == CenterIndexedArray(dat+dat)
@test 2 * A == CenterIndexedArray(2 * dat)
@test A + A == CenterIndexedArray(dat + dat)
@test isa(A .+ 1, CenterIndexedArray)
end
Loading