Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.1'
- 'min'
- '1'
# - 'nightly'
os:
Expand All @@ -23,7 +22,7 @@ jobs:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RegisterDriver"
uuid = "935ac36e-2656-11e9-1e3b-cbaa636797af"
authors = ["Tim Holy <tim.holy@gmail.com>"]
version = "0.2.3"
version = "0.2.4"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -22,9 +22,9 @@ ImageCore = "0.8.1, 0.9, 0.10"
ImageMetadata = "0.9"
JLD = "0.9, 0.10, 0.11, 0.12, 0.13"
RegisterCore = "0.2"
RegisterWorkerShell = "0.2"
RegisterWorkerShell = "0.2, 1.0"
StaticArrays = "0.11, 0.12, 1"
julia = "1"
julia = "1.10"

[extras]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Expand Down
200 changes: 97 additions & 103 deletions src/RegisterDriver.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module RegisterDriver

using ImageCore, ImageMetadata, JLD, HDF5, StaticArrays, Formatting, SharedArrays, Distributed
using RegisterCore
using RegisterWorkerShell
using RegisterCore, RegisterWorkerShell
using Base.Threads

if isdefined(HDF5, :BitsType)
const BitsType = HDF5.BitsType
Expand All @@ -16,7 +16,7 @@ if !isdefined(HDF5, :create_group)
const create_group = g_create
end

export driver, mm_package_loader
export driver, mm_package_loader, threadids

"""
`driver(outfile, algorithm, img, mon)` performs registration of the
Expand Down Expand Up @@ -50,98 +50,91 @@ worker has been written to look for such settings:

which will save `extra` only if `:extra` is a key in `mon`.
"""
function driver(outfile::AbstractString, algorithm::Vector, img, mon::Vector)
nworkers = length(algorithm)
length(mon) == nworkers || error("Number of monitors must equal number of workers")
use_workerprocs = nworkers > 1 || workerpid(algorithm[1]) != myid()
rralgorithm = Array{RemoteChannel}(undef, nworkers)
if use_workerprocs
# Push the algorithm objects to the worker processes. This elminates
# per-iteration serialization penalties, and ensures that any
# initalization state is retained.
for i = 1:nworkers
alg = algorithm[i]
rralgorithm[i] = put!(RemoteChannel(workerpid(alg)), alg)
end
# Perform any needed worker initialization
@sync for i = 1:nworkers
p = workerpid(algorithm[i])
@async remotecall_fetch(init!, p, rralgorithm[i])
end
else
init!(algorithm[1])
end
try
n = nimages(img)
fs = FormatSpec("0$(ndigits(n))d") # group names of unpackable objects
jldopen(outfile, "w") do file
dsets = Dict{Symbol,Any}()
firstsave = SharedArray{Bool}(1)
firstsave[1] = true
have_unpackable = SharedArray{Bool}(1)
have_unpackable[1] = false
# Run the jobs
nextidx = 0
getnextidx() = nextidx += 1
writing_mutex = RemoteChannel()
@sync begin
for i = 1:nworkers
alg = algorithm[i]
@async begin
while (idx = getnextidx()) <= n
if use_workerprocs
remotecall_fetch(println, workerpid(alg), "Worker ", workerpid(alg), " is working on ", idx)
# See https://github.com/JuliaLang/julia/issues/22139
tmp = remotecall_fetch(worker, workerpid(alg), rralgorithm[i], img, idx, mon[i])
copy_all_but_shared!(mon[i], tmp)
else
println("Working on ", idx)
mon[1] = worker(algorithm[1], img, idx, mon[1])
end
# Save the results
put!(writing_mutex, true) # grab the lock
try
local g
if firstsave[]
firstsave[] = false
have_unpackable[] = initialize_jld!(dsets, file, mon[i], fs, n)
end
if fetch(have_unpackable[])
g = file[string("stack", fmt(fs, idx))]
end
for (k,v) in mon[i]
if isa(v, Number)
dsets[k][idx] = v
continue
elseif isa(v, Array) || isa(v, SharedArray)
vw = nicehdf5(v)
if eltype(vw) <: BitsType
colons = [Colon() for i = 1:ndims(vw)]
dsets[k][colons..., idx] = vw
continue
end
end
g[string(k)] = v
end
finally
take!(writing_mutex) # release the lock
end
function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector)
nalgs = length(algorithms)
nummon = length(mon)
nummon == nalgs || error("Number of monitors must equal number of workers")
usethreads = nummon > 2
numthreads = nthreads()
tpool = map(alg->alg.workertid, algorithms)
aindices = usethreads ? Dict(map((alg,aidx)->(alg.workertid=>aidx), algorithms, 1:length(algorithms))...) :
Dict(threadid()=>1)
n = nimages(img)
fs = FormatSpec("0$(ndigits(n))d")

println("Initializing algorithm")
init!(algorithms[1])

println("Working on algorithm and saving the result")
jldopen(outfile, "w") do file
dsets = Dict{Symbol,Any}()
firstsave = Ref(true)
have_unpackable = Ref(false)

# Channel for passing results from threads to writer
results_ch = Channel{Tuple{Int,Dict}}(32)

# Writer task (runs on main thread)
writer_task = @async begin
for (movidx, monres) in results_ch

# Initialize datasets on first save
if firstsave[]
firstsave[] = false
have_unpackable[] = initialize_jld!(dsets, file, monres, fs, n)
end

g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing

# Write all values into the file
for (k,v) in monres
if isa(v, Number)
dsets[k][movidx] = v
elseif isa(v, Array) || isa(v, SharedArray)
vw = nicehdf5(v)
if eltype(vw) <: BitsType
colons = [Colon() for _=1:ndims(vw)]
dsets[k][colons..., movidx] = vw
else
g[string(k)] = v
end
else
g[string(k)] = v
end
end
yield()
end
end
finally
# Perform any needed worker cleanup
if use_workerprocs
@sync for i = 1:nworkers
p = workerpid(algorithm[i])
@async remotecall_fetch(close!, p, rralgorithm[i])

if usethreads
# writer_task shares the first thread, making static scheduling inefficient
@threads :dynamic for movidx in 1:n
tid = threadid()
if tid in tpool
println("thread $tid processing $movidx")
tmp = worker(algorithms[aindices[tid]], img, movidx, mon[aindices[tid]])
put!(results_ch, (movidx, deepcopy(tmp)))
end
yield()
end
else
close!(algorithm[1])
for movidx in 1:n
println("processing $movidx")
tmp = worker(algorithms[1], img, movidx, mon[1])
put!(results_ch, (movidx, tmp))
yield()
end
end

# Close channel and wait for writer to finish
close(results_ch)
wait(writer_task)
end

println("Closing algorithm")
close!(algorithms[1])

return nothing
end

driver(outfile::AbstractString, algorithm::AbstractWorker, img, mon::Dict) = driver(outfile, [algorithm], img, [mon])
Expand Down Expand Up @@ -212,24 +205,25 @@ function copy_all_but_shared!(dest, src)
dest
end

mm_package_loader(algorithm::AbstractWorker) = mm_package_loader([algorithm])
function mm_package_loader(algorithms::Vector)
nworkers = length(algorithms)
use_workerprocs = nworkers > 1 || workerpid(algorithms[1]) != myid()
rrdev = Array{RemoteChannel}(undef, nworkers)
if use_workerprocs
for i = 1:nworkers
dev = algorithms[i].dev
rrdev[i] = put!(RemoteChannel(workerpid(algorithms[i])), dev)
end
@sync for i = 1:nworkers
p = workerpid(algorithms[i])
@async remotecall_fetch(load_mm_package, p, rrdev[i])
end
else
load_mm_package(algorithms[1].dev)
end
mm_package_loader(algorithms::Vector{W}) where {W<:AbstractWorker} = mm_package_loader(algorithms[1])
function mm_package_loader(algorithm::AbstractWorker)
load_mm_package(algorithm.dev)
nothing
end

function threadids()
nt = nthreads()
ch = Channel{Int}(nt*1001)

@threads for i in 1:nt
put!(ch, threadid())
end
@sync for i in 1:(nt*1000)
Threads.@spawn put!(ch, threadid())
end

close(ch)
tids = unique(collect(ch))
sort(tids)
end
end # module
18 changes: 9 additions & 9 deletions test/WorkerDummy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,28 @@ abstract type Alg <: AbstractWorker end
mutable struct Alg1{A<:AbstractArray} <: Alg
fixed::A
λ::Float64
workerpid::Int
workertid::Int
end
function Alg1(fixed, λ; pid=1)
Alg1(maybe_sharedarray(fixed, pid), λ, pid)
function Alg1(fixed, λ; tid=1)
Alg1(fixed, λ, tid)
end

mutable struct Alg2{A<:AbstractArray,V<:AbstractVector,M<:AbstractMatrix} <: Alg
fixed::A
tform::V
u0::M
workerpid::Int
workertid::Int
end
function Alg2(fixed, ::Type{T}, sz; pid=1) where T
Alg2(maybe_sharedarray(fixed, pid), maybe_sharedarray(T, (12,), pid), maybe_sharedarray(T, sz, pid), pid)
function Alg2(fixed, ::Type{T}, sz; tid=1) where T
Alg2(fixed, Vector{T}(undef,12), Matrix{T}(undef, sz), tid)
end

mutable struct Alg3 <: Alg
string::String
workerpid::Int
workertid::Int
end
function Alg3(s::String; pid=1)
Alg3(s, pid)
function Alg3(s::String; tid=1)
Alg3(s, tid)
end

# Here are the "registration algorithms"
Expand Down
36 changes: 19 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@ using Test, Distributed, SharedArrays
using ImageCore, JLD
using RegisterDriver, RegisterWorkerShell
using AxisArrays: AxisArray
using Base.Threads

driverprocs = addprocs(2)
push!(LOAD_PATH, pwd())
@sync for p in driverprocs
@spawnat p push!(LOAD_PATH, pwd())
end
using WorkerDummy

workdir = tempname()
mkdir(workdir)

img = AxisArray(SharedArray{Float32}((100,100,7)), :y, :x, :time)

# Single-process tests
# Single-process tests : mon::Dict
# Simple operation & passing back scalars
alg = Alg1(rand(3,3), 3.2)
mon = monitor(alg, (:λ,))
Expand All @@ -40,6 +37,7 @@ rm(fn)
# one per stack.
alg = Alg3("Hello")
mon = monitor(alg, (:string,))
@show typeof(mon)
mon[:extra] = ""
fn = joinpath(workdir, "file3.jld")
driver(fn, alg, img, mon)
Expand All @@ -50,18 +48,22 @@ jldopen(fn) do file
end
rm(fn)

# Multi-process
nw = length(driverprocs)
alg = Vector{Any}(undef, nw)
mon = Vector{Any}(undef, nw)
for i = 1:nw
alg[i] = Alg2(rand(100,100), Float32, (3,3), pid=driverprocs[i])
mon[i] = monitor(alg[i], (:tform,:u0,:workerpid))
# Multi-thread : mon::Vector{Dict}
tids = threadids()
nt = length(tids)
alg = Vector{Any}(undef, nt)
mon = Vector{Any}(undef, nt)
for i = 1:nt
alg[i] = Alg2(rand(100,100), Float32, (3,3), tid=tids[i])
mon[i] = monitor(alg[i], (:tform,:u0,:workertid))
end
fn = joinpath(workdir, "file4.jld")
driver(fn, alg, img, mon)
wpid = JLD.load(fn, "workerpid")
indx = unique(indexin(wpid, driverprocs))
@test length(indx) == length(driverprocs) && all(indx .> 0)

rmprocs(driverprocs, waitfor=1.0)
tform = JLD.load(fn, "tform")
u0 = JLD.load(fn, "u0")
@test tform[:,4] == collect(range(1, stop=12, length=12).+4)
@test u0[:,:,2] == fill(-2,(3,3))
tid = JLD.load(fn, "workertid")
indx = unique(indexin(tid, tids))
@test length(indx) == length(tids) && all(indx .> 0)
rm(fn)
Loading