Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/RegisterDriver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
nummon == nalgs || error("Number of monitors must equal number of workers")
usethreads = nummon > 2
numthreads = nthreads()
tpool = map(alg->alg.workertid, algorithms)
aindices = usethreads ? Dict(map((alg,aidx)->(alg.workertid=>aidx), algorithms, 1:length(algorithms))...) :
Dict(threadid()=>1)
tpool = map(alg -> alg.workertid, algorithms)
aindices = usethreads ? Dict(map((alg, aidx) -> (alg.workertid => aidx), algorithms, 1:length(algorithms))...) :
Dict(threadid() => 1)
n = nimages(img)
fs = FormatSpec("0$(ndigits(n))d")

Expand All @@ -67,12 +67,12 @@ function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)

println("Working on algorithm and saving the result")
jldopen(outfile, "w") do file
dsets = Dict{Symbol,Any}()
dsets = Dict{Symbol, Any}()
firstsave = Ref(true)
have_unpackable = Ref(false)

# Channel for passing results from threads to writer
results_ch = Channel{Tuple{Int,Dict}}(32)
results_ch = Channel{Tuple{Int, Dict}}(32)

# Writer task (runs on main thread)
writer_task = @async begin
Expand All @@ -87,13 +87,13 @@ function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing

# Write all values into the file
for (k,v) in monres
for (k, v) in monres
if isa(v, Number)
dsets[k][movidx] = v
elseif isa(v, Array) || isa(v, SharedArray)
vw = nicehdf5(v)
if eltype(vw) <: BitsType
colons = [Colon() for _=1:ndims(vw)]
colons = [Colon() for _ in 1:ndims(vw)]
dsets[k][colons..., movidx] = vw
else
g[string(k)] = v
Expand Down Expand Up @@ -148,15 +148,15 @@ function driver(algorithm::AbstractWorker, img, mon::Dict)
init!(algorithm)
worker(algorithm, img, 1, mon)
close!(algorithm)
mon
return mon
end

# Initialize the datasets in the output JLD file.
# We wait to do this until we get back one valid `mon` object,
# to get the sizes of any returned arrays.
function initialize_jld!(dsets, file, mon, fs, n)
have_unpackable = false
for (k,v) in mon
for (k, v) in mon
kstr = string(k)
if isa(v, Number)
write(file, kstr, Vector{typeof(v)}(undef, n))
Expand All @@ -178,19 +178,19 @@ function initialize_jld!(dsets, file, mon, fs, n)
end
end
if have_unpackable
for i = 1:n
for i in 1:n
create_group(file, string("stack", fmt(fs, i)))
end
end
have_unpackable
return have_unpackable
end

function nicehdf5(v::Union{Array{T},SharedArray{T}}) where T<:StaticArray
nicehdf5(reshape(reinterpret(eltype(T), vec(sdata(v))), (size(eltype(v))..., size(v)...)))
function nicehdf5(v::Union{Array{T}, SharedArray{T}}) where {T <: StaticArray}
return nicehdf5(reshape(reinterpret(eltype(T), vec(sdata(v))), (size(eltype(v))..., size(v)...)))
end

function nicehdf5(v::Union{Array{T},SharedArray{T}}) where T<:NumDenom
nicehdf5(reshape(reinterpret(eltype(T), vec(sdata(v))), (2, size(v)...)))
function nicehdf5(v::Union{Array{T}, SharedArray{T}}) where {T <: NumDenom}
return nicehdf5(reshape(reinterpret(eltype(T), vec(sdata(v))), (2, size(v)...)))
end

nicehdf5(v::SharedArray) = sdata(v)
Expand All @@ -202,28 +202,28 @@ function copy_all_but_shared!(dest, src)
dest[k] = v
end
end
dest
return dest
end

mm_package_loader(algorithms::Vector{W}) where {W<:AbstractWorker} = mm_package_loader(algorithms[1])
mm_package_loader(algorithms::Vector{W}) where {W <: AbstractWorker} = mm_package_loader(algorithms[1])
function mm_package_loader(algorithm::AbstractWorker)
load_mm_package(algorithm.dev)
nothing
return nothing
end

function threadids()
nt = nthreads()
ch = Channel{Int}(nt*1001)
ch = Channel{Int}(nt * 1001)

@threads for i in 1:nt
put!(ch, threadid())
end
@sync for i in 1:(nt*1000)
@sync for i in 1:(nt * 1000)
Threads.@spawn put!(ch, threadid())
end

close(ch)
tids = unique(collect(ch))
sort(tids)
return sort(tids)
end
end # module
24 changes: 12 additions & 12 deletions test/WorkerDummy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,53 @@ export Alg1, Alg2, Alg3
# with the driver process
abstract type Alg <: AbstractWorker end

mutable struct Alg1{A<:AbstractArray} <: Alg
mutable struct Alg1{A <: AbstractArray} <: Alg
fixed::A
λ::Float64
workertid::Int
end
function Alg1(fixed, λ; tid=1)
Alg1(fixed, λ, tid)
function Alg1(fixed, λ; tid = 1)
return Alg1(fixed, λ, tid)
end

mutable struct Alg2{A<:AbstractArray,V<:AbstractVector,M<:AbstractMatrix} <: Alg
mutable struct Alg2{A <: AbstractArray, V <: AbstractVector, M <: AbstractMatrix} <: Alg
fixed::A
tform::V
u0::M
workertid::Int
end
function Alg2(fixed, ::Type{T}, sz; tid=1) where T
Alg2(fixed, Vector{T}(undef,12), Matrix{T}(undef, sz), tid)
function Alg2(fixed, ::Type{T}, sz; tid = 1) where {T}
return Alg2(fixed, Vector{T}(undef, 12), Matrix{T}(undef, sz), tid)
end

mutable struct Alg3 <: Alg
string::String
workertid::Int
end
function Alg3(s::String; tid=1)
Alg3(s, tid)
function Alg3(s::String; tid = 1)
return Alg3(s, tid)
end

# Here are the "registration algorithms"
function worker(algorithm::Alg1, moving, tindex, mon)
algorithm.λ = tindex
monitor!(mon, algorithm) # just dump output
return monitor!(mon, algorithm) # just dump output
end

function worker(algorithm::Alg2, moving, tindex, mon)
# Do stuff to set tform
tform = range(1, stop=12, length=12).+tindex
tform = range(1, stop = 12, length = 12) .+ tindex
monitor!(mon, :tform, tform)
# Do more computations...
monitor!(mon, :u0, zeros(size(algorithm.u0)).-tindex)
return monitor!(mon, :u0, zeros(size(algorithm.u0)) .- tindex)
end

function worker(algorithm::Alg3, moving, tindex, mon)
monitor!(mon, algorithm)
if haskey(mon, :extra)
mon[:extra] = "world"
end
mon
return mon
end

end # module
110 changes: 56 additions & 54 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,63 +7,65 @@ using Base.Threads
push!(LOAD_PATH, pwd())
using WorkerDummy

workdir = tempname()
mkdir(workdir)
@testset "RegisterDriver" begin
workdir = tempname()
mkdir(workdir)

img = AxisArray(SharedArray{Float32}((100,100,7)), :y, :x, :time)
img = AxisArray(SharedArray{Float32}((100, 100, 7)), :y, :x, :time)

# Single-process tests : mon::Dict
# Simple operation & passing back scalars
alg = Alg1(rand(3,3), 3.2)
mon = monitor(alg, (:λ,))
fn = joinpath(workdir, "file1.jld")
driver(fn, alg, img, mon)
λ = JLD.load(fn, "λ")
@test λ == Float64[1,2,3,4,5,6,7]
rm(fn)
# Single-process tests : mon::Dict
# Simple operation & passing back scalars
alg = Alg1(rand(3, 3), 3.2)
mon = monitor(alg, (:λ,))
fn = joinpath(workdir, "file1.jld")
driver(fn, alg, img, mon)
λ = JLD.load(fn, "λ")
@test λ == Float64[1, 2, 3, 4, 5, 6, 7]
rm(fn)

# Passing back arrays
alg = Alg2(rand(100,100), Float32, (3,3))
mon = monitor(alg, (:tform,:u0))
fn = joinpath(workdir, "file2.jld")
driver(fn, alg, img, mon)
tform = JLD.load(fn, "tform")
u0 = JLD.load(fn, "u0")
@test tform[:,4] == collect(range(1, stop=12, length=12).+4)
@test u0[:,:,2] == fill(-2,(3,3))
rm(fn)
# Passing back arrays
alg = Alg2(rand(100, 100), Float32, (3, 3))
mon = monitor(alg, (:tform, :u0))
fn = joinpath(workdir, "file2.jld")
driver(fn, alg, img, mon)
tform = JLD.load(fn, "tform")
u0 = JLD.load(fn, "u0")
@test tform[:, 4] == collect(range(1, stop = 12, length = 12) .+ 4)
@test u0[:, :, 2] == fill(-2, (3, 3))
rm(fn)

# Passing back strings. Anything not "packable" ends up in a group,
# one per stack.
alg = Alg3("Hello")
mon = monitor(alg, (:string,))
@show typeof(mon)
mon[:extra] = ""
fn = joinpath(workdir, "file3.jld")
driver(fn, alg, img, mon)
jldopen(fn) do file
g = file["stack5"]
@test read(g, "string") == "Hello"
@test read(g, "extra") == "world"
end
rm(fn)
# Passing back strings. Anything not "packable" ends up in a group,
# one per stack.
alg = Alg3("Hello")
mon = monitor(alg, (:string,))
@show typeof(mon)
mon[:extra] = ""
fn = joinpath(workdir, "file3.jld")
driver(fn, alg, img, mon)
jldopen(fn) do file
g = file["stack5"]
@test read(g, "string") == "Hello"
@test read(g, "extra") == "world"
end
rm(fn)

# Multi-thread : mon::Vector{Dict}
tids = threadids()
nt = length(tids)
alg = Vector{Any}(undef, nt)
mon = Vector{Any}(undef, nt)
for i = 1:nt
alg[i] = Alg2(rand(100,100), Float32, (3,3), tid=tids[i])
mon[i] = monitor(alg[i], (:tform,:u0,:workertid))
# Multi-thread : mon::Vector{Dict}
tids = threadids()
nt = length(tids)
alg = Vector{Any}(undef, nt)
mon = Vector{Any}(undef, nt)
for i in 1:nt
alg[i] = Alg2(rand(100, 100), Float32, (3, 3), tid = tids[i])
mon[i] = monitor(alg[i], (:tform, :u0, :workertid))
end
fn = joinpath(workdir, "file4.jld")
driver(fn, alg, img, mon)
tform = JLD.load(fn, "tform")
u0 = JLD.load(fn, "u0")
@test tform[:, 4] == collect(range(1, stop = 12, length = 12) .+ 4)
@test u0[:, :, 2] == fill(-2, (3, 3))
tid = JLD.load(fn, "workertid")
indx = unique(indexin(tid, tids))
@test length(indx) == length(tids) && all(indx .> 0)
rm(fn)
end
fn = joinpath(workdir, "file4.jld")
driver(fn, alg, img, mon)
tform = JLD.load(fn, "tform")
u0 = JLD.load(fn, "u0")
@test tform[:,4] == collect(range(1, stop=12, length=12).+4)
@test u0[:,:,2] == fill(-2,(3,3))
tid = JLD.load(fn, "workertid")
indx = unique(indexin(tid, tids))
@test length(indx) == length(tids) && all(indx .> 0)
rm(fn)
Loading