diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 65a8e0e..a5c6a36 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,8 +13,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' - - '1.1' + - 'min' - '1' # - 'nightly' os: @@ -23,7 +22,7 @@ jobs: - x64 steps: - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} diff --git a/Project.toml b/Project.toml index 0008d2a..f1eb4ff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RegisterDriver" uuid = "935ac36e-2656-11e9-1e3b-cbaa636797af" authors = ["Tim Holy "] -version = "0.2.3" +version = "0.2.4" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -22,9 +22,9 @@ ImageCore = "0.8.1, 0.9, 0.10" ImageMetadata = "0.9" JLD = "0.9, 0.10, 0.11, 0.12, 0.13" RegisterCore = "0.2" -RegisterWorkerShell = "0.2" +RegisterWorkerShell = "0.2, 1.0" StaticArrays = "0.11, 0.12, 1" -julia = "1" +julia = "1.10" [extras] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" diff --git a/src/RegisterDriver.jl b/src/RegisterDriver.jl index fe47616..0a62b86 100644 --- a/src/RegisterDriver.jl +++ b/src/RegisterDriver.jl @@ -1,8 +1,8 @@ module RegisterDriver using ImageCore, ImageMetadata, JLD, HDF5, StaticArrays, Formatting, SharedArrays, Distributed -using RegisterCore -using RegisterWorkerShell +using RegisterCore, RegisterWorkerShell +using Base.Threads if isdefined(HDF5, :BitsType) const BitsType = HDF5.BitsType @@ -16,7 +16,7 @@ if !isdefined(HDF5, :create_group) const create_group = g_create end -export driver, mm_package_loader +export driver, mm_package_loader, threadids """ `driver(outfile, algorithm, img, mon)` performs registration of the @@ -50,98 +50,91 @@ worker has been written to look for such settings: which will save `extra` only if `:extra` is a key in `mon`. """ -function driver(outfile::AbstractString, algorithm::Vector, img, mon::Vector) - nworkers = length(algorithm) - length(mon) == nworkers || error("Number of monitors must equal number of workers") - use_workerprocs = nworkers > 1 || workerpid(algorithm[1]) != myid() - rralgorithm = Array{RemoteChannel}(undef, nworkers) - if use_workerprocs - # Push the algorithm objects to the worker processes. This elminates - # per-iteration serialization penalties, and ensures that any - # initalization state is retained. - for i = 1:nworkers - alg = algorithm[i] - rralgorithm[i] = put!(RemoteChannel(workerpid(alg)), alg) - end - # Perform any needed worker initialization - @sync for i = 1:nworkers - p = workerpid(algorithm[i]) - @async remotecall_fetch(init!, p, rralgorithm[i]) - end - else - init!(algorithm[1]) - end - try - n = nimages(img) - fs = FormatSpec("0$(ndigits(n))d") # group names of unpackable objects - jldopen(outfile, "w") do file - dsets = Dict{Symbol,Any}() - firstsave = SharedArray{Bool}(1) - firstsave[1] = true - have_unpackable = SharedArray{Bool}(1) - have_unpackable[1] = false - # Run the jobs - nextidx = 0 - getnextidx() = nextidx += 1 - writing_mutex = RemoteChannel() - @sync begin - for i = 1:nworkers - alg = algorithm[i] - @async begin - while (idx = getnextidx()) <= n - if use_workerprocs - remotecall_fetch(println, workerpid(alg), "Worker ", workerpid(alg), " is working on ", idx) - # See https://github.com/JuliaLang/julia/issues/22139 - tmp = remotecall_fetch(worker, workerpid(alg), rralgorithm[i], img, idx, mon[i]) - copy_all_but_shared!(mon[i], tmp) - else - println("Working on ", idx) - mon[1] = worker(algorithm[1], img, idx, mon[1]) - end - # Save the results - put!(writing_mutex, true) # grab the lock - try - local g - if firstsave[] - firstsave[] = false - have_unpackable[] = initialize_jld!(dsets, file, mon[i], fs, n) - end - if fetch(have_unpackable[]) - g = file[string("stack", fmt(fs, idx))] - end - for (k,v) in mon[i] - if isa(v, Number) - dsets[k][idx] = v - continue - elseif isa(v, Array) || isa(v, SharedArray) - vw = nicehdf5(v) - if eltype(vw) <: BitsType - colons = [Colon() for i = 1:ndims(vw)] - dsets[k][colons..., idx] = vw - continue - end - end - g[string(k)] = v - end - finally - take!(writing_mutex) # release the lock - end +function driver(outfile::AbstractString, algorithms::Vector, img, mon::Vector) + nalgs = length(algorithms) + nummon = length(mon) + nummon == nalgs || error("Number of monitors must equal number of workers") + usethreads = nummon > 2 + numthreads = nthreads() + tpool = map(alg->alg.workertid, algorithms) + aindices = usethreads ? Dict(map((alg,aidx)->(alg.workertid=>aidx), algorithms, 1:length(algorithms))...) : + Dict(threadid()=>1) + n = nimages(img) + fs = FormatSpec("0$(ndigits(n))d") + + println("Initializing algorithm") + init!(algorithms[1]) + + println("Working on algorithm and saving the result") + jldopen(outfile, "w") do file + dsets = Dict{Symbol,Any}() + firstsave = Ref(true) + have_unpackable = Ref(false) + + # Channel for passing results from threads to writer + results_ch = Channel{Tuple{Int,Dict}}(32) + + # Writer task (runs on main thread) + writer_task = @async begin + for (movidx, monres) in results_ch + + # Initialize datasets on first save + if firstsave[] + firstsave[] = false + have_unpackable[] = initialize_jld!(dsets, file, monres, fs, n) + end + + g = have_unpackable[] ? file[string("stack", fmt(fs, movidx))] : nothing + + # Write all values into the file + for (k,v) in monres + if isa(v, Number) + dsets[k][movidx] = v + elseif isa(v, Array) || isa(v, SharedArray) + vw = nicehdf5(v) + if eltype(vw) <: BitsType + colons = [Colon() for _=1:ndims(vw)] + dsets[k][colons..., movidx] = vw + else + g[string(k)] = v end + else + g[string(k)] = v end end + yield() end end - finally - # Perform any needed worker cleanup - if use_workerprocs - @sync for i = 1:nworkers - p = workerpid(algorithm[i]) - @async remotecall_fetch(close!, p, rralgorithm[i]) + + if usethreads + # writer_task shares the first thread, making static scheduling inefficient + @threads :dynamic for movidx in 1:n + tid = threadid() + if tid in tpool + println("thread $tid processing $movidx") + tmp = worker(algorithms[aindices[tid]], img, movidx, mon[aindices[tid]]) + put!(results_ch, (movidx, deepcopy(tmp))) + end + yield() end else - close!(algorithm[1]) + for movidx in 1:n + println("processing $movidx") + tmp = worker(algorithms[1], img, movidx, mon[1]) + put!(results_ch, (movidx, tmp)) + yield() + end end + + # Close channel and wait for writer to finish + close(results_ch) + wait(writer_task) end + + println("Closing algorithm") + close!(algorithms[1]) + + return nothing end driver(outfile::AbstractString, algorithm::AbstractWorker, img, mon::Dict) = driver(outfile, [algorithm], img, [mon]) @@ -212,24 +205,25 @@ function copy_all_but_shared!(dest, src) dest end -mm_package_loader(algorithm::AbstractWorker) = mm_package_loader([algorithm]) -function mm_package_loader(algorithms::Vector) - nworkers = length(algorithms) - use_workerprocs = nworkers > 1 || workerpid(algorithms[1]) != myid() - rrdev = Array{RemoteChannel}(undef, nworkers) - if use_workerprocs - for i = 1:nworkers - dev = algorithms[i].dev - rrdev[i] = put!(RemoteChannel(workerpid(algorithms[i])), dev) - end - @sync for i = 1:nworkers - p = workerpid(algorithms[i]) - @async remotecall_fetch(load_mm_package, p, rrdev[i]) - end - else - load_mm_package(algorithms[1].dev) - end +mm_package_loader(algorithms::Vector{W}) where {W<:AbstractWorker} = mm_package_loader(algorithms[1]) +function mm_package_loader(algorithm::AbstractWorker) + load_mm_package(algorithm.dev) nothing end +function threadids() + nt = nthreads() + ch = Channel{Int}(nt*1001) + + @threads for i in 1:nt + put!(ch, threadid()) + end + @sync for i in 1:(nt*1000) + Threads.@spawn put!(ch, threadid()) + end + + close(ch) + tids = unique(collect(ch)) + sort(tids) +end end # module diff --git a/test/WorkerDummy.jl b/test/WorkerDummy.jl index cfa55b7..d4b6044 100644 --- a/test/WorkerDummy.jl +++ b/test/WorkerDummy.jl @@ -14,28 +14,28 @@ abstract type Alg <: AbstractWorker end mutable struct Alg1{A<:AbstractArray} <: Alg fixed::A λ::Float64 - workerpid::Int + workertid::Int end -function Alg1(fixed, λ; pid=1) - Alg1(maybe_sharedarray(fixed, pid), λ, pid) +function Alg1(fixed, λ; tid=1) + Alg1(fixed, λ, tid) end mutable struct Alg2{A<:AbstractArray,V<:AbstractVector,M<:AbstractMatrix} <: Alg fixed::A tform::V u0::M - workerpid::Int + workertid::Int end -function Alg2(fixed, ::Type{T}, sz; pid=1) where T - Alg2(maybe_sharedarray(fixed, pid), maybe_sharedarray(T, (12,), pid), maybe_sharedarray(T, sz, pid), pid) +function Alg2(fixed, ::Type{T}, sz; tid=1) where T + Alg2(fixed, Vector{T}(undef,12), Matrix{T}(undef, sz), tid) end mutable struct Alg3 <: Alg string::String - workerpid::Int + workertid::Int end -function Alg3(s::String; pid=1) - Alg3(s, pid) +function Alg3(s::String; tid=1) + Alg3(s, tid) end # Here are the "registration algorithms" diff --git a/test/runtests.jl b/test/runtests.jl index f0bcb0a..0f9afe6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,9 @@ using Test, Distributed, SharedArrays using ImageCore, JLD using RegisterDriver, RegisterWorkerShell using AxisArrays: AxisArray +using Base.Threads -driverprocs = addprocs(2) push!(LOAD_PATH, pwd()) -@sync for p in driverprocs - @spawnat p push!(LOAD_PATH, pwd()) -end using WorkerDummy workdir = tempname() @@ -15,7 +12,7 @@ mkdir(workdir) img = AxisArray(SharedArray{Float32}((100,100,7)), :y, :x, :time) -# Single-process tests +# Single-process tests : mon::Dict # Simple operation & passing back scalars alg = Alg1(rand(3,3), 3.2) mon = monitor(alg, (:λ,)) @@ -40,6 +37,7 @@ rm(fn) # one per stack. alg = Alg3("Hello") mon = monitor(alg, (:string,)) +@show typeof(mon) mon[:extra] = "" fn = joinpath(workdir, "file3.jld") driver(fn, alg, img, mon) @@ -50,18 +48,22 @@ jldopen(fn) do file end rm(fn) -# Multi-process -nw = length(driverprocs) -alg = Vector{Any}(undef, nw) -mon = Vector{Any}(undef, nw) -for i = 1:nw - alg[i] = Alg2(rand(100,100), Float32, (3,3), pid=driverprocs[i]) - mon[i] = monitor(alg[i], (:tform,:u0,:workerpid)) +# Multi-thread : mon::Vector{Dict} +tids = threadids() +nt = length(tids) +alg = Vector{Any}(undef, nt) +mon = Vector{Any}(undef, nt) +for i = 1:nt + alg[i] = Alg2(rand(100,100), Float32, (3,3), tid=tids[i]) + mon[i] = monitor(alg[i], (:tform,:u0,:workertid)) end fn = joinpath(workdir, "file4.jld") driver(fn, alg, img, mon) -wpid = JLD.load(fn, "workerpid") -indx = unique(indexin(wpid, driverprocs)) -@test length(indx) == length(driverprocs) && all(indx .> 0) - -rmprocs(driverprocs, waitfor=1.0) +tform = JLD.load(fn, "tform") +u0 = JLD.load(fn, "u0") +@test tform[:,4] == collect(range(1, stop=12, length=12).+4) +@test u0[:,:,2] == fill(-2,(3,3)) +tid = JLD.load(fn, "workertid") +indx = unique(indexin(tid, tids)) +@test length(indx) == length(tids) && all(indx .> 0) +rm(fn)