Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 57 additions & 55 deletions src/RegisterMismatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ The major types and functions exported are:
RegisterMismatch

FFTW.set_num_threads(min(Sys.CPU_THREADS, 8))
set_FFTPROD([2,3])
set_FFTPROD([2, 3])

mutable struct NanCorrFFTs{T<:AbstractFloat,N,RCType<:RCpair{T,N}}
mutable struct NanCorrFFTs{T <: AbstractFloat, N, RCType <: RCpair{T, N}}
I0::RCType
I1::RCType
I2::RCType
Expand All @@ -79,13 +79,13 @@ Prepare for FFT-based mismatch computations over domains of size `aperture_width
mismatch up to shifts of size `maxshift`. The keyword arguments allow you to control the planning
process for the FFTs.
"""
mutable struct CMStorage{T<:AbstractFloat,N,RCType<:RCpair{T,N},FFT<:Function,IFFT<:Function}
mutable struct CMStorage{T <: AbstractFloat, N, RCType <: RCpair{T, N}, FFT <: Function, IFFT <: Function}
aperture_width::Vector{Float64}
maxshift::Vector{Int}
getindices::Vector{UnitRange{Int}} # indices for pulling padded data, in source-coordinates
padded::Array{T,N}
fixed::NanCorrFFTs{T,N,RCType}
moving::NanCorrFFTs{T,N,RCType}
padded::Array{T, N}
fixed::NanCorrFFTs{T, N, RCType}
moving::NanCorrFFTs{T, N, RCType}
buf1::RCType
buf2::RCType
# the next two store the result of calling plan_fft! and plan_ifft!
Expand All @@ -94,14 +94,14 @@ mutable struct CMStorage{T<:AbstractFloat,N,RCType<:RCpair{T,N},FFT<:Function,IF
shiftindices::Vector{Vector{Int}} # indices for performing fftshift & snipping from -maxshift:maxshift
end

function CMStorage{T,N}(::UndefInitializer, aperture_width::NTuple{N,<:Real}, maxshift::Dims{N}; flags=FFTW.ESTIMATE, timelimit=Inf, display=true) where {T,N}
blocksize = map(x->ceil(Int,x), aperture_width)
function CMStorage{T, N}(::UndefInitializer, aperture_width::NTuple{N, <:Real}, maxshift::Dims{N}; flags = FFTW.ESTIMATE, timelimit = Inf, display = true) where {T, N}
blocksize = map(x -> ceil(Int, x), aperture_width)
padsz = padsize(blocksize, maxshift)
padded = Array{T}(undef, padsz)
getindices = padranges(blocksize, maxshift)
maxshiftv = [maxshift...]
region = findall(maxshiftv .> 0)
fixed = NanCorrFFTs(RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region))
fixed = NanCorrFFTs(RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region))
moving = NanCorrFFTs(RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region), RCpair{T}(undef, padsz, region))
buf1 = RCpair{T}(undef, padsz, region)
buf2 = RCpair{T}(undef, padsz, region)
Expand All @@ -111,21 +111,21 @@ function CMStorage{T,N}(::UndefInitializer, aperture_width::NTuple{N,<:Real}, ma
flush(stdout)
tcalib = time()
end
fftfunc = plan_rfft!(fixed.I0, flags=flags, timelimit=timelimit/2)
ifftfunc = plan_irfft!(fixed.I0, flags=flags, timelimit=timelimit/2)
fftfunc = plan_rfft!(fixed.I0, flags = flags, timelimit = timelimit / 2)
ifftfunc = plan_irfft!(fixed.I0, flags = flags, timelimit = timelimit / 2)
if display && flags != FFTW.ESTIMATE
dt = time()-tcalib
dt = time() - tcalib
@printf("done (%.2f seconds)\n", dt)
end
shiftindices = Vector{Int}[ [size(padded,i).+(-maxshift[i]+1:0); 1:maxshift[i]+1] for i = 1:length(maxshift) ]
CMStorage{T,N,typeof(buf1),typeof(fftfunc),typeof(ifftfunc)}(Float64[aperture_width...], maxshiftv, getindices, padded, fixed, moving, buf1, buf2, fftfunc, ifftfunc, shiftindices)
shiftindices = Vector{Int}[ [size(padded, i) .+ ((-maxshift[i] + 1):0); 1:(maxshift[i] + 1)] for i in 1:length(maxshift) ]
return CMStorage{T, N, typeof(buf1), typeof(fftfunc), typeof(ifftfunc)}(Float64[aperture_width...], maxshiftv, getindices, padded, fixed, moving, buf1, buf2, fftfunc, ifftfunc, shiftindices)
end

CMStorage{T}(::UndefInitializer, aperture_width::NTuple{N,<:Real}, maxshift::Dims{N}; kwargs...) where {T<:Real,N} =
CMStorage{T,N}(undef, aperture_width, maxshift; kwargs...)
CMStorage{T}(::UndefInitializer, aperture_width::NTuple{N, <:Real}, maxshift::Dims{N}; kwargs...) where {T <: Real, N} =
CMStorage{T, N}(undef, aperture_width, maxshift; kwargs...)

eltype(cms::CMStorage{T,N}) where {T,N} = T
ndims(cms::CMStorage{T,N}) where {T,N} = N
eltype(cms::CMStorage{T, N}) where {T, N} = T
ndims(cms::CMStorage{T, N}) where {T, N} = N

"""
mm = mismatch([T], fixed, moving, maxshift; normalization=:intensity)
Expand All @@ -139,14 +139,14 @@ normalization scheme (`:intensity` or `:pixels`).
`fixed` and `moving` must have the same size; you can pad with
`NaN`s as needed. See `nanpad`.
"""
function mismatch(::Type{T}, fixed::AbstractArray, moving::AbstractArray, maxshift::DimsLike; normalization = :intensity) where T<:Real
function mismatch(::Type{T}, fixed::AbstractArray, moving::AbstractArray, maxshift::DimsLike; normalization = :intensity) where {T <: Real}
msz = 2 .* maxshift .+ 1
mm = MismatchArray(T, msz...)
cms = CMStorage{T}(undef, size(fixed), maxshift)
fillfixed!(cms, fixed)
erng = shiftrange.((cms.getindices...,), first.(axes(fixed)) .- 1) # expanded rng
mpad = PaddedView(convert(T, NaN), of_eltype(T, moving), erng)
mismatch!(mm, cms, mpad, normalization=normalization)
mismatch!(mm, cms, mpad, normalization = normalization)
return mm
end

Expand All @@ -170,28 +170,28 @@ function mismatch!(mm::MismatchArray, cms::CMStorage, moving::AbstractArray; nor
m0 = complex(cms.moving.I0)
m1 = complex(cms.moving.I1)
m2 = complex(cms.moving.I2)
tnum = complex(cms.buf1)
tnum = complex(cms.buf1)
tdenom = complex(cms.buf2)
if normalization == :intensity
@inbounds @maybe_threads for i in eachindex(tnum)
c = 2*conj(f1[i])*m1[i]
q = conj(f2[i])*m0[i] + conj(f0[i])*m2[i]
c = 2 * conj(f1[i]) * m1[i]
q = conj(f2[i]) * m0[i] + conj(f0[i]) * m2[i]
tdenom[i] = q
tnum[i] = q - c
end
elseif normalization == :pixels
@inbounds @maybe_threads for i in eachindex(tnum)
f0i, m0i = f0[i], m0[i]
tdenom[i] = conj(f0i)*m0i
tnum[i] = conj(f2[i])*m0i - 2*conj(f1[i])*m1[i] + conj(f0i)*m2[i]
tdenom[i] = conj(f0i) * m0i
tnum[i] = conj(f2[i]) * m0i - 2 * conj(f1[i]) * m1[i] + conj(f0i) * m2[i]
end
else
error("normalization $normalization not recognized")
end
cms.ifftfunc!(cms.buf1)
cms.ifftfunc!(cms.buf2)
copyto!(mm, (view(real(cms.buf1), cms.shiftindices...), view(real(cms.buf2), cms.shiftindices...)))
mm
return mm
end

"""
Expand Down Expand Up @@ -219,20 +219,22 @@ in a rectangular grid, you can use an `N`-dimensional array-of-tuples
(or array-of-vectors) or an `N+1`-dimensional array with the center
positions specified along the first dimension. See `aperture_grid`.
"""
function mismatch_apertures(::Type{T},
fixed::AbstractArray,
moving::AbstractArray,
aperture_centers::AbstractArray,
aperture_width::WidthLike,
maxshift::DimsLike;
normalization = :pixels,
flags = FFTW.MEASURE,
kwargs...) where T
function mismatch_apertures(
::Type{T},
fixed::AbstractArray,
moving::AbstractArray,
aperture_centers::AbstractArray,
aperture_width::WidthLike,
maxshift::DimsLike;
normalization = :pixels,
flags = FFTW.MEASURE,
kwargs...
) where {T}
nd = sdims(fixed)
(length(aperture_width) == nd && length(maxshift) == nd) || error("Dimensionality mismatch")
mms = allocate_mmarrays(T, aperture_centers, maxshift)
cms = CMStorage{T}(undef, aperture_width, maxshift; flags=flags, kwargs...)
mismatch_apertures!(mms, fixed, moving, aperture_centers, cms, normalization=normalization)
cms = CMStorage{T}(undef, aperture_width, maxshift; flags = flags, kwargs...)
return mismatch_apertures!(mms, fixed, moving, aperture_centers, cms, normalization = normalization)
end

"""
Expand All @@ -244,51 +246,51 @@ in `cms`, a `CMStorage` object. The results are stored in `mms`, an
Array-of-MismatchArrays which must have length equal to the number of
aperture centers.
"""
function mismatch_apertures!(mms, fixed, moving, aperture_centers, cms::CMStorage{T}; normalization=:pixels) where T
function mismatch_apertures!(mms, fixed, moving, aperture_centers, cms::CMStorage{T}; normalization = :pixels) where {T}
N = ndims(cms)
fillvalue = convert(T, NaN)
getinds = (cms.getindices...,)::NTuple{ndims(fixed),UnitRange{Int}}
getinds = (cms.getindices...,)::NTuple{ndims(fixed), UnitRange{Int}}
fixedT, movingT = of_eltype(T, fixed), of_eltype(T, moving)
for (mm,center) in zip(mms, each_point(aperture_centers))
for (mm, center) in zip(mms, each_point(aperture_centers))
rng = aperture_range(center, cms.aperture_width)
fsnip = PaddedView(fillvalue, fixedT, rng)
erng = shiftrange.(getinds, first.(rng) .- 1) # expanded rng
msnip = PaddedView(fillvalue, movingT, erng)
# Perform the calculation
fillfixed!(cms, fsnip)
mismatch!(mm, cms, msnip; normalization=normalization)
mismatch!(mm, cms, msnip; normalization = normalization)
end
mms
return mms
end

# Calculate the components needed to "nancorrelate"
function fftnan!(out::NanCorrFFTs{T}, A::AbstractArray{T}, fftfunc!::Function) where T<:Real
function fftnan!(out::NanCorrFFTs{T}, A::AbstractArray{T}, fftfunc!::Function) where {T <: Real}
I0 = real(out.I0)
I1 = real(out.I1)
I2 = real(out.I2)
_fftnan!(parent(I0), parent(I1), parent(I2), A)
fftfunc!(out.I0)
fftfunc!(out.I1)
fftfunc!(out.I2)
out
return out
end

function _fftnan!(I0, I1, I2, A::AbstractArray{T}) where T<:Real
@inbounds @maybe_threads for i in CartesianIndices(size(A))
function _fftnan!(I0, I1, I2, A::AbstractArray{T}) where {T <: Real}
return @inbounds @maybe_threads for i in CartesianIndices(size(A))
a = A[i]
f = !isnan(a)
I0[i] = f
af = f ? a : zero(T)
I1[i] = af
I2[i] = af*af
I2[i] = af * af
end
end

function fillfixed!(cms::CMStorage{T}, fixed::AbstractArray) where T
function fillfixed!(cms::CMStorage{T}, fixed::AbstractArray) where {T}
fill!(cms.padded, NaN)
pinds = CartesianIndices(ntuple(d->(1:size(fixed,d)).+cms.maxshift[d], ndims(fixed)))
pinds = CartesianIndices(ntuple(d -> (1:size(fixed, d)) .+ cms.maxshift[d], ndims(fixed)))
copyto!(cms.padded, pinds, fixed, CartesianIndices(fixed))
fftnan!(cms.fixed, cms.padded, cms.fftfunc!)
return fftnan!(cms.fixed, cms.padded, cms.fftfunc!)
end

#### Utilities
Expand All @@ -298,23 +300,23 @@ function sumsq_finite(A)
s = 0.0
for a in A
if isfinite(a)
s += a*a
s += a * a
end
end
if s == 0
error("No finite values available")
end
s
return s
end

### Deprecations

function CMStorage{T}(::UndefInitializer, aperture_width::WidthLike, maxshift::WidthLike; kwargs...) where {T<:Real}
function CMStorage{T}(::UndefInitializer, aperture_width::WidthLike, maxshift::WidthLike; kwargs...) where {T <: Real}
Base.depwarn("CMStorage with aperture_width::$(typeof(aperture_width)) and maxshift::$(typeof(maxshift)) is deprecated, use tuples instead", :CMStorage)
(N = length(aperture_width)) == length(maxshift) || error("Dimensionality mismatch")
return CMStorage{T,N}(undef, (aperture_width...,), (maxshift...,); kwargs...)
return CMStorage{T, N}(undef, (aperture_width...,), (maxshift...,); kwargs...)
end

@deprecate CMStorage(::Type{T}, aperture_width, maxshift; kwargs...) where T CMStorage{T}(undef, aperture_width, maxshift; kwargs...)
@deprecate CMStorage(::Type{T}, aperture_width, maxshift; kwargs...) where {T} CMStorage{T}(undef, aperture_width, maxshift; kwargs...)

end
Loading
Loading