From aa3c6b6374073a81fe9901c84aaf4ff4ffe85a1b Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 14:27:02 +0200 Subject: [PATCH 01/11] Add background and perturbations initialization problems to CosmologyProblem --- src/solve.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 5d2aa058..95a9bcd8 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -16,12 +16,15 @@ import NonlinearSolve.BracketingNonlinearSolve: AbstractBracketingAlgorithm background(sys) = transform((sys, _) -> filter_system(isbackground, sys), sys) perturbations(sys) = transform((sys, _) -> filter_system(isperturbation, sys), sys) -struct CosmologyProblem{Tbg <: ODEProblem, Tpt <: Union{ODEProblem, Nothing}} +struct CosmologyProblem{Tbg <: ODEProblem, Tpt <: Union{ODEProblem, Nothing}, Tbginit <: NonlinearProblem, Tptinit <: Union{NonlinearProblem, Nothing}} M::System bg::Tbg pt::Tpt + bginit::Tbginit + ptinit::Tptinit + pars::Vector{Symbolics.SymbolicT} shoot::Dict conditions::Vector{Equation} @@ -153,6 +156,9 @@ function CosmologyProblem( bg = debug_system(bg) end + # Background initialization problem + bginit = InitializationProblem(bg, first(ivspan), parsk) + # Set up callback for today # TODO: specify callbacks symbolically? iv = ModelingToolkit.get_iv(M) if Symbol(iv) == :τ @@ -192,8 +198,10 @@ function CosmologyProblem( ) bg = ODEProblem(bg, parsk, ivspan; fully_determined, callback, jac, bgopts..., kwargs...) # never sparse because small # TODO: hangs with jac = true, sparse = true; try without tearing state as in pt? + bg = remake(bg; u0 = fill(NaN, length(bg.u0)), build_initializeprob = Val{false}) else bg = nothing + bginit = nothing end if pt @@ -211,17 +219,23 @@ function CosmologyProblem( if debug pt = debug_system(pt) end + + # Perturbations initialization problem + ptinit = InitializationProblem(pt, first(ivspan), parsk) + ts = ModelingToolkit.get_tearing_state(pt) @set! pt.tearing_state = nothing # additional pass in mtkcompile_spline modifies variable ordering and leads to an incorrect Jacobian; reset tearing state to nothing to trigger "manual" computation of the Jacobian pt = ODEProblem(pt, parsk, ivspan; fully_determined, jac, sparse, ptopts..., kwargs...) @set! pt.f.sys.tearing_state = ts # restore + pt = remake(pt; u0 = fill(NaN, length(pt.u0)), build_initializeprob = Val{false}) else + ptinit = nothing pt = nothing end pars = [unwrap(par) for (par, val) in pars] shoot_conditions = convert(Vector{Equation}, shoot_conditions) - return CosmologyProblem(M, bg, pt, pars, shoot_pars, shoot_conditions) + return CosmologyProblem(M, bg, pt, bginit, ptinit, pars, shoot_pars, shoot_conditions) end """ From d841b0ad5daf9521b2d5c35d2ee080b198f8c522 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 14:34:14 +0200 Subject: [PATCH 02/11] Get initial conditions from backgrond/perturbations initialization problems --- docs/src/benchmarks.md | 2 +- src/solve.jl | 43 +++++++++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/docs/src/benchmarks.md b/docs/src/benchmarks.md index 1d07dfe4..c7d4b55d 100644 --- a/docs/src/benchmarks.md +++ b/docs/src/benchmarks.md @@ -73,7 +73,7 @@ The points on each curve correspond to a sequence of tolerances. # TODO: test different nlsolve # hide # TODO: add AdaptiveRadau/RadauIIA5 when they support sparse J: https://github.com/SciML/OrdinaryDiffEq.jl/issues/2892 # hide ptalgs = [algtype(linsolve = KLUFactorization()) for algtype in [TRBDF2, KenCarp4, KenCarp47, KenCarp5, Kvaerno5, Rodas4P, Rodas5P, Rodas6P, QNDF, FBDF]] -ptprobgen = SymBoltz.setuppt(prob.pt, bgsol) +ptprobgen = SymBoltz.setuppt(prob.pt, prob.ptinit, bgsol) setups = [Dict(:alg => alg) for alg in ptalgs] refalg = Rodas5P(linsolve = KLUFactorization()) abstols = 1 ./ 10 .^ (5:9) diff --git a/src/solve.jl b/src/solve.jl index 95a9bcd8..09d52a2f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -379,10 +379,12 @@ function solve( shootopts = (alg = shootalg(prob), abstol = 1e-5), thread = true, verbose = false, kwargs... ) + u0 = initbg(prob.bg, prob.bginit) # TODO: do inside solvebg? + bgprob = remake(prob.bg; u0) if !isempty(prob.shoot) - bgsol = solvebg(prob.bg, prob.shoot, prob.conditions; shootopts, verbose, bgopts..., bgextraopts..., kwargs...) + bgsol = solvebg(bgprob, prob.shoot, prob.conditions; shootopts, verbose, bgopts..., bgextraopts..., kwargs...) else - bgsol = solvebg(prob.bg; verbose, bgopts..., bgextraopts..., kwargs...) + bgsol = solvebg(bgprob; verbose, bgopts..., bgextraopts..., kwargs...) end if isnothing(ks) || isempty(ks) @@ -390,7 +392,7 @@ function solve( ptsol = nothing else ks = k_dimensionless.(ks, Ref(bgsol)) - ptsol = solvept(prob.pt, bgsol, ks, ptivini; thread, verbose, ptopts..., ptextraopts..., kwargs...) + ptsol = solvept(prob.pt, prob.ptinit, bgsol, ks, ptivini; thread, verbose, ptopts..., ptextraopts..., kwargs...) end return CosmologySolution(prob, bgsol, ks, ptsol) @@ -413,6 +415,24 @@ function warning_failed_solution(sol::ODESolution, name = "ODE"; verbose = false return msg end +function init(init::NonlinearProblem, unknowns; kwargs...) + sol = solve(init; kwargs...) + return sol[unknowns] +end +function initbg(bgprob::ODEProblem, bginit::NonlinearProblem) + # TODO: handle shooting ICs + return init(bginit, unknowns(bgprob.f.sys)) +end +function initpt(ptprob::ODEProblem, ptinit::NonlinearProblem, k, bgsol::ODESolution) + p = ptinit.p + SciMLStructures.replace!(Tunable(), p, canonicalize(Tunable(), parameter_values(bgsol))[1]) # copy background parameters + @set! p.nonnumeric = ([spline(bgsol)],) # add background spline parameter + p[ModelingToolkit.parameter_index(ptinit, :k)] = k # set k value + + ptinit = remake(ptinit; p) + return init(ptinit, unknowns(ptprob.f.sys)) +end + """ solvebg(bgprob::ODEProblem[, vars, conditions]; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, build_initializeprob = Val{false}, kwargs...) @@ -494,15 +514,14 @@ function solvebg(bgprob::ODEProblem, vars, conditions; alg = bgalg(bgprob), relt return solvebg(bgprob; alg, reltol, abstol, kwargs...) end -function setuppt(ptprob::ODEProblem, bgsol::ODESolution, ptivini::Function) +function setuppt(ptprob::ODEProblem, ptinit::NonlinearProblem, bgsol::ODESolution, ptivini::Function) ivspanbg = (bgsol.t[begin], bgsol.t[end]) - bgspline = spline(bgsol) newp = ptprob.p # has abstractly typed nonnumeric spline parameter SciMLStructures.replace!(Tunable(), newp, canonicalize(Tunable(), parameter_values(bgsol))[1]) # copy parameters from background solution to perturbations problem (e.g. τ0 and κ0) hasspline = !isempty(ptprob.p.nonnumeric) if hasspline - @set! newp.nonnumeric = ([bgspline],) # reset field to make MTKParameters' nonnumeric spline parameter concrete + @set! newp.nonnumeric = ([spline(bgsol)],) # reset field to make MTKParameters' nonnumeric spline parameter concrete end kset! = ModelingToolkit.setp(ptprob, k) @@ -511,14 +530,12 @@ function setuppt(ptprob::ODEProblem, bgsol::ODESolution, ptivini::Function) kset!(p, k) ivi = clamp(ptivini(k), ivspanbg[begin], ivspanbg[end]) # clamp to background timespan ivspan = (ivi, ivspanbg[end]) - newptprob = remake(ptprob; u0 = ptprob.u0, p = p, tspan = ivspan, build_initializeprob = true) # solve for u0 # TODO: separate function? - if hasspline # need to do this to get solving with splined background type-stable - newptprob = remake(newptprob; u0 = newptprob.u0, p = p, build_initializeprob = false) # remake again with build_initializeprob = false makes following solve type-stable; https://github.com/SciML/ModelingToolkit.jl/issues/3715 - end + u0 = initpt(ptprob, ptinit, k, bgsol) + newptprob = remake(ptprob; u0, p, tspan = ivspan, build_initializeprob = false) # solve for u0 # TODO: separate function? return newptprob end end -setuppt(ptprob::ODEProblem, bgsol::ODESolution, ptivini::Number = -Inf) = setuppt(ptprob, bgsol, k -> ptivini) +setuppt(ptprob::ODEProblem, ptinit::NonlinearProblem, bgsol::ODESolution, ptivini::Number = -Inf) = setuppt(ptprob, ptinit, bgsol, k -> ptivini) """ solvept(ptprob::ODEProblem, bgsol::ODESolution, ks::AbstractArray, ptivini = -Inf; alg = ptalg(ptprob), reltol = 1e-5, abstol = 1e-5, output_func = (sol, i) -> sol, thread = true, verbose = false, kwargs...) @@ -528,7 +545,7 @@ If `thread` and Julia is running with multiple threads, the solution of independ `ptivini` is a number or a function of ``k`` that sets the initial time of integration for each perturbation mode, but is always clamped to the background timespan. The return value is a vector with one `ODESolution` per wavenumber, or its mapping through `output_func` if a custom transformation is passed. """ -function solvept(ptprob::ODEProblem, bgsol::ODESolution, ks::AbstractArray, ptivini = -Inf; alg = ptalg(ptprob), reltol = 1e-5, abstol = 1e-5, output_func = (sol, i) -> sol, thread = true, verbose = false, kwargs...) +function solvept(ptprob::ODEProblem, ptinit::NonlinearProblem, bgsol::ODESolution, ks::AbstractArray, ptivini = -Inf; alg = ptalg(ptprob), reltol = 1e-5, abstol = 1e-5, output_func = (sol, i) -> sol, thread = true, verbose = false, kwargs...) check_solve_args(ptprob, alg) !issorted(ks) && throw(error("ks = $ks are not sorted in ascending order")) @@ -541,7 +558,7 @@ function solvept(ptprob::ODEProblem, bgsol::ODESolution, ks::AbstractArray, ptiv end # TODO: can I exploit that the structure of the perturbation ODEs is ẏ = J * y with "constant" J? - ptprobgen = setuppt(ptprob, bgsol, ptivini) + ptprobgen = setuppt(ptprob, ptinit, bgsol, ptivini) function output_func_warn(sol, i) if !successful_retcode(sol) From 7c10a263362bbaf7cf467fa0634d0ab33c2506fe Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 15:13:03 +0200 Subject: [PATCH 03/11] Get initial conditions inside solvebg and solvept --- src/observables/fourier.jl | 8 ++-- src/solve.jl | 80 +++++++++++++++++++++----------------- test/runtests.jl | 14 +++---- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/observables/fourier.jl b/src/observables/fourier.jl index 7a46db72..815f1814 100644 --- a/src/observables/fourier.jl +++ b/src/observables/fourier.jl @@ -246,7 +246,7 @@ Compute and evaluate source functions ``S(τ,k)`` with symbolic expressions `Ss` The options `bgopts` and `ptopts` are passed to the background and perturbation solves. """ function source_grid(prob::CosmologyProblem, Ss::AbstractArray, τs, ks; bgopts = (), ptopts = (), thread = true, verbose = false) - bgsol = solvebg(prob.bg; bgopts..., verbose) + bgsol = solvebg(prob.bg, prob.bginit; bgopts..., verbose) getSs = map(S -> getsym(prob.pt, S), Ss) Ss = similar(bgsol, length(Ss), length(τs), length(ks)) minimum(τs) ≥ bgsol.t[begin] && maximum(τs) ≤ bgsol.t[end] || error("input τs and computed background solution have different timespans") @@ -256,7 +256,7 @@ function source_grid(prob::CosmologyProblem, Ss::AbstractArray, τs, ks; bgopts end return nothing end - solvept(prob.pt, bgsol, ks; output_func, saveat = τs, ptopts..., thread, verbose) + solvept(prob.pt, prob.ptinit, bgsol, ks; output_func, saveat = τs, ptopts..., thread, verbose) return Ss end @@ -289,7 +289,7 @@ function source_grid_adaptive(prob::CosmologyProblem, Ss::AbstractVector, τs, k ptsaveopts = (saveat = τs,) end - ptprobgen = setuppt(prob.pt, bgsol) + ptprobgen = setuppt(prob.pt, prob.ptinit, bgsol) getSs = map(S -> getsym(prob.pt, S), Ss) function sourcek!(k, ik, Ss) @@ -379,7 +379,7 @@ end # Dispatch without background solution function source_grid_adaptive(prob::CosmologyProblem, Ss::AbstractVector, τs, ks; bgopts = (), kwargs...) - bgsol = solvebg(prob.bg; bgopts...) + bgsol = solvebg(prob.bg, prob.bginit; bgopts...) return source_grid_adaptive(prob, Ss, τs, ks, bgsol; kwargs...) end diff --git a/src/solve.jl b/src/solve.jl index 09d52a2f..bdc49ad9 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -379,12 +379,10 @@ function solve( shootopts = (alg = shootalg(prob), abstol = 1e-5), thread = true, verbose = false, kwargs... ) - u0 = initbg(prob.bg, prob.bginit) # TODO: do inside solvebg? - bgprob = remake(prob.bg; u0) if !isempty(prob.shoot) - bgsol = solvebg(bgprob, prob.shoot, prob.conditions; shootopts, verbose, bgopts..., bgextraopts..., kwargs...) + bgsol = solvebg(prob.bg, prob.bginit, prob.shoot, prob.conditions; shootopts, verbose, bgopts..., bgextraopts..., kwargs...) else - bgsol = solvebg(bgprob; verbose, bgopts..., bgextraopts..., kwargs...) + bgsol = solvebg(prob.bg, prob.bginit; verbose, bgopts..., bgextraopts..., kwargs...) end if isnothing(ks) || isempty(ks) @@ -415,37 +413,27 @@ function warning_failed_solution(sol::ODESolution, name = "ODE"; verbose = false return msg end -function init(init::NonlinearProblem, unknowns; kwargs...) - sol = solve(init; kwargs...) - return sol[unknowns] -end -function initbg(bgprob::ODEProblem, bginit::NonlinearProblem) - # TODO: handle shooting ICs - return init(bginit, unknowns(bgprob.f.sys)) -end -function initpt(ptprob::ODEProblem, ptinit::NonlinearProblem, k, bgsol::ODESolution) - p = ptinit.p - SciMLStructures.replace!(Tunable(), p, canonicalize(Tunable(), parameter_values(bgsol))[1]) # copy background parameters - @set! p.nonnumeric = ([spline(bgsol)],) # add background spline parameter - p[ModelingToolkit.parameter_index(ptinit, :k)] = k # set k value - - ptinit = remake(ptinit; p) - return init(ptinit, unknowns(ptprob.f.sys)) -end - """ solvebg(bgprob::ODEProblem[, vars, conditions]; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, build_initializeprob = Val{false}, kwargs...) Solve the background cosmology problem `bgprob`. If the background requires shooting, `vars` is a dictionary with variables to shoot for and their initial guesses, and `conditions` is and an array of equations that should hold at the final integration time (usually today). """ -function solvebg(bgprob::ODEProblem; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, verbose = false, kwargs...) +function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, verbose = false, kwargs...) check_solve_args(bgprob, alg) + + # 1) Get and set initial conditions + initsol = solve(bginit) + u0 = initsol[unknowns(bgprob.f.sys)] + bgprob = remake(bgprob; u0) + + # 2) Solve ODE bgsol = solve(bgprob, alg; verbose, reltol, abstol, kwargs...) if !successful_retcode(bgsol) @warn warning_failed_solution(bgsol, "Background"; verbose) end + # 3) Post-process τrecidx = ModelingToolkit.parameter_index(bgprob, :τrec) if !isnothing(τrecidx) bgsol.ps[τrecidx] = bgsol[:τ][argmax(bgsol[bgprob.f.sys.b.v])] @@ -515,23 +503,43 @@ function solvebg(bgprob::ODEProblem, vars, conditions; alg = bgalg(bgprob), relt end function setuppt(ptprob::ODEProblem, ptinit::NonlinearProblem, bgsol::ODESolution, ptivini::Function) - ivspanbg = (bgsol.t[begin], bgsol.t[end]) - newp = ptprob.p # has abstractly typed nonnumeric spline parameter - SciMLStructures.replace!(Tunable(), newp, canonicalize(Tunable(), parameter_values(bgsol))[1]) # copy parameters from background solution to perturbations problem (e.g. τ0 and κ0) + tspanbg = (bgsol.t[begin], bgsol.t[end]) + + probp = ptprob.p # has abstractly typed nonnumeric spline parameter + initp = ptinit.p - hasspline = !isempty(ptprob.p.nonnumeric) - if hasspline - @set! newp.nonnumeric = ([spline(bgsol)],) # reset field to make MTKParameters' nonnumeric spline parameter concrete + # copy parameters from background solution to perturbations problem (e.g. τ0 and κ0) + bgp = parameter_values(bgsol)[1] # i.e. .tunable + size(bgp) != size(probp.tunable) && error("Incompatible size of tunable parameters in background and perturbations") + SciMLStructures.replace!(Tunable(), probp, canonicalize(Tunable(), bgp)[1]) + + if !isempty(ptprob.p.nonnumeric) # do we spline the background? + bgspline = spline(bgsol) + @set! probp.nonnumeric = ([bgspline],) # reset field to make MTKParameters' nonnumeric spline parameter concrete + @set! initp.nonnumeric = ([bgspline],) end - kset! = ModelingToolkit.setp(ptprob, k) + # Getters and setters + ksetprob! = ModelingToolkit.setp(ptprob, :k) + ksetinit! = ModelingToolkit.setp(ptinit, :k) + getu0 = ModelingToolkit.getsym(ptinit, unknowns(ptprob.f.sys)) + return k -> begin - p = copy(newp) # newp specializes on spline types, while ptprob0.p does not; see https://github.com/SciML/ModelingToolkit.jl/issues/3715 - kset!(p, k) - ivi = clamp(ptivini(k), ivspanbg[begin], ivspanbg[end]) # clamp to background timespan - ivspan = (ivi, ivspanbg[end]) - u0 = initpt(ptprob, ptinit, k, bgsol) - newptprob = remake(ptprob; u0, p, tspan = ivspan, build_initializeprob = false) # solve for u0 # TODO: separate function? + tini = clamp(ptivini(k), tspanbg[begin], tspanbg[end]) # clamp to background + tend = tspanbg[end] + tspan = (tini, tend) + + # 1) Get initial conditions + p = copy(initp) + ksetinit!(p, k) + newptinit = remake(ptinit; p, tspan) + ptinitsol = solve(newptinit) + u0 = getu0(ptinitsol) + + # 2) Set up ODE with initial conditions + p = copy(probp) # newp specializes on spline types, while ptprob0.p does not; see https://github.com/SciML/ModelingToolkit.jl/issues/3715 + ksetprob!(p, k) + newptprob = remake(ptprob; u0, p, tspan) return newptprob end end diff --git a/test/runtests.jl b/test/runtests.jl index 4943a278..9aab3dc3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,12 +22,12 @@ prob_sparse = prob @testset "Solve failure warnings" begin Ωc0 = prob.bg.ps[M.c.Ω₀] prob.bg.ps[M.c.Ω₀] = NaN # bad - bgsol = @test_warn "Background solution failed" solvebg(prob.bg) + bgsol = @test_warn "Background solution failed" solvebg(prob.bg, prob.bginit) prob.bg.ps[M.c.Ω₀] = Ωc0 # restore good - bgsol = @test_nowarn solvebg(prob.bg) + bgsol = @test_nowarn solvebg(prob.bg, prob.bginit) - @test_warn "Perturbation (mode k = NaN) solution failed" ptsol = solvept(prob.pt, bgsol, [NaN]; thread = false) - @test_nowarn ptsol = solvept(prob.pt, bgsol, [1.0]; thread = false) + @test_warn "Perturbation (mode k = NaN) solution failed" ptsol = solvept(prob.pt, prob.ptinit, bgsol, [NaN]; thread = false) + @test_nowarn ptsol = solvept(prob.pt, prob.ptinit, bgsol, [1.0]; thread = false) end @testset "Solution accessing" begin @@ -390,11 +390,11 @@ end end @testset "Dedicated background/perturbation solvers" begin - bgsol = solvebg(prob.bg) # TODO: @inferred + bgsol = solvebg(prob.bg, prob.bginit) # TODO: @inferred @test bgsol isa SymBoltz.ODESolution ks = 1.0:1.0:10.0 - ptsol = solvept(prob.pt, bgsol, ks) # TODO: @inferred + ptsol = solvept(prob.pt, prob.ptinit, bgsol, ks) # TODO: @inferred @test ptsol isa Vector{<:SymBoltz.ODESolution} # custom output_func for e.g. source function @@ -403,7 +403,7 @@ end τ0 = bgsol.t[end] τs = range(τi, τ0, length = 768) ks = range(1.0, 1000.0, length = 1000) - Ss = solvept(prob.pt, bgsol, ks; saveat = τs, output_func = (ptsol, _) -> getS(ptsol)) + Ss = solvept(prob.pt, prob.ptinit, bgsol, ks; saveat = τs, output_func = (ptsol, _) -> getS(ptsol)) Ss = stack(Ss) @test size(Ss) == (length(τs), length(ks)) end From 007456e9d9a5cc649cdbfdbeb54aef140e4d5b64 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 17:56:46 +0200 Subject: [PATCH 04/11] Fix remake to work with both ODE and initialization problems --- src/solve.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index bdc49ad9..5d0a2852 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -253,18 +253,16 @@ function remake( bg = true, pt = true, shoot = true, kwargs... ) - vars, pars = split_vars_pars(prob.M, pars) - vars = isempty(vars) ? missing : vars - pars = isempty(pars) ? missing : pars - bg = bg && !isnothing(prob.bg) ? remake(prob.bg; u0 = vars, p = pars, build_initializeprob = Val{!isnothing(prob.bg.f.initialization_data)}, kwargs...) : nothing - if !ismissing(vars) - remove_background_initial_conditions!(vars) # must filter ICs in remake, too - end - pt = pt && !isnothing(prob.pt) ? remake(prob.pt; u0 = vars, p = pars, build_initializeprob = Val{!isnothing(prob.pt.f.initialization_data)}, kwargs...) : nothing + u0, p = split_vars_pars(prob.M, pars) + bg = isnothing(prob.bg) ? nothing : remake(prob.bg; p, kwargs...) + bginit = isnothing(prob.bginit) ? nothing : remake(prob.bginit; p, kwargs...) + pt = isnothing(prob.pt) ? nothing : remake(prob.pt; p, kwargs...) + ptinit = isnothing(prob.ptinit) ? nothing : remake(prob.ptinit; p, kwargs...) shoot_pars = shoot ? prob.shoot : Dict() shoot_conditions = shoot ? prob.conditions : [] - return CosmologyProblem(prob.M, bg, pt, prob.pars, shoot_pars, shoot_conditions) + return CosmologyProblem(prob.M, bg, pt, bginit, ptinit, prob.pars, shoot_pars, shoot_conditions) end +remake(prob::CosmologyProblem, pars::AbstractArray; kwargs...) = remake(prob, Dict(pars); kwargs...) """ parameter_updater(prob::CosmologyProblem, idxs; kwargs...) From 56f8453e8cda4e5662125f7a38794873549da134 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 20:13:43 +0200 Subject: [PATCH 05/11] Fix parameter_updater to work with both ODE and initialization problems --- src/solve.jl | 44 ++++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 5d0a2852..8e4131cb 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -272,32 +272,40 @@ The returned function is called with numerical values (in the same order as `idx """ function parameter_updater(prob::CosmologyProblem, idxs; kwargs...) # define a closure based on https://docs.sciml.ai/ModelingToolkit/dev/examples/remake/#replace-and-remake - # TODO: remove M, etc. for efficiency? - - @unpack bg, pt = prob + bgprob = prob.bg + ptprob = prob.pt + bginit = prob.bginit + ptinit = prob.ptinit - bgsetsym = SymbolicIndexingInterface.setsym_oop(bg, idxs) # TODO: define setsym(::CosmologyProblem)? - bgdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(bg))[1])) + bgprobsetsym = SymbolicIndexingInterface.setsym_oop(bgprob, idxs) + bginitsetsym = SymbolicIndexingInterface.setsym_oop(bginit, idxs) + bgprobdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(bgprob))[1])) + bginitdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(bginit))[1])) - if !isnothing(pt) - ptsetsym = setsym_oop(pt, idxs) - ptdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(pt))[1])) + if !isnothing(ptprob) && !isnothing(ptinit) + ptprobsetsym = SymbolicIndexingInterface.setsym_oop(ptprob, idxs) + ptinitsetsym = SymbolicIndexingInterface.setsym_oop(ptinit, idxs) + ptprobdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(ptprob))[1])) + ptinitdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(ptinit))[1])) end function updater(p) - # Update background problem - newu0, newp = bgsetsym(bg, p) # set new parameters - bg_new = remake(bg; u0 = newu0, p = newp, kwargs...) # create updated problem (don't overwrite old) - - # Update perturbation problem - if isnothing(pt) - pt_new = pt + bgprobu0, bgprobp = bgprobsetsym(bgprob, p) + bginitu0, bginitp = bginitsetsym(bginit, p) + newbgprob = remake(bgprob; u0 = bgprobu0, p = bgprobp, kwargs...) + newbginit = remake(bginit; u0 = bginitu0, p = bginitp, kwargs...) + + if !isnothing(ptprob) && !isnothing(ptinit) + ptprobu0, ptprobp = ptprobsetsym(ptprob, p) + ptinitu0, ptinitp = ptinitsetsym(ptinit, p) + newptprob = remake(ptprob; u0 = ptprobu0, p = ptprobp, kwargs...) + newptinit = remake(ptinit; u0 = ptinitu0, p = ptinitp, kwargs...) else - newu0, newp = ptsetsym(pt, p) - pt_new = remake(pt; u0 = newu0, p = newp, kwargs...) # create updated problem (don't overwrite old) + newptprob = nothing + newptinit = nothing end - return CosmologyProblem(prob.M, bg_new, pt_new, prob.pars, prob.shoot, prob.conditions) + return CosmologyProblem(prob.M, newbgprob, newptprob, newbginit, newptinit, prob.pars, prob.shoot, prob.conditions) end function updater(p::Dict) p = [p[var] for var in idxs] From 68df8d0ba15d6857150632729398444f3b205ea0 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 20:24:14 +0200 Subject: [PATCH 06/11] Pass original CosmologyProblem to parameter_updater (avoid closure capture) --- docs/src/automatic_differentiation.md | 2 +- docs/src/comparison.md | 4 ++-- docs/src/forecasting.md | 2 +- docs/src/plot.md | 4 ++-- docs/src/solve.md | 2 +- src/solve.jl | 25 +++++++++++-------------- test/runtests.jl | 14 +++++++------- 7 files changed, 25 insertions(+), 28 deletions(-) diff --git a/docs/src/automatic_differentiation.md b/docs/src/automatic_differentiation.md index 8ba5048a..6d0fa959 100644 --- a/docs/src/automatic_differentiation.md +++ b/docs/src/automatic_differentiation.md @@ -19,7 +19,7 @@ pars = [M.γ.T₀, M.c.Ω₀, M.b.Ω₀, M.ν.Neff, M.g.h, M.b.YHe, M.h.m_eV, M. prob0 = CosmologyProblem(M, Dict(pars .=> NaN); sparse = false) # dense Jacobian faster for AD probgen = parameter_updater(prob0, pars) -P(k, θ) = spectrum_matter(probgen(θ), k; ptopts = (reltol = 1e-3, abstol = 1e-3)) +P(k, θ) = spectrum_matter(probgen(prob0, θ), k; ptopts = (reltol = 1e-3, abstol = 1e-3)) ``` It is now easy to evaluate the power spectrum: ```@example ad diff --git a/docs/src/comparison.md b/docs/src/comparison.md index 2f303523..ace2cc6a 100644 --- a/docs/src/comparison.md +++ b/docs/src/comparison.md @@ -354,7 +354,7 @@ function P_class(k, pars) return P end function P_symboltz(k, pars) - prob′ = parameter_updater(prob, collect(keys(pars)))(collect(values(pars))) # TODO: move outside; common for Pk and Cl + prob′ = parameter_updater(prob, collect(keys(pars)))(prob, collect(values(pars))) # TODO: move outside; common for Pk and Cl P = spectrum_matter(prob′, k / u"Mpc") / u"Mpc^3" return P end @@ -406,7 +406,7 @@ function Dl_class(modes, l, pars) return stack(Dl) end function Dl_symboltz(modes, jl, pars; kwargs...) - prob′ = parameter_updater(prob, collect(keys(pars)))(collect(values(pars))) + prob′ = parameter_updater(prob, collect(keys(pars)))(prob, collect(values(pars))) return spectrum_cmb(modes, prob′, jl; normalization = :Dl, kwargs...) end diff --git a/docs/src/forecasting.md b/docs/src/forecasting.md index 543ab0b9..30d51af4 100644 --- a/docs/src/forecasting.md +++ b/docs/src/forecasting.md @@ -33,7 +33,7 @@ Since $Cₗ$ is an expensive but smooth function of $l$, we make one function fo probgen = parameter_updater(prob0, pars_varying) jl = SphericalBesselCache(40:20:1000) ls = 40:1:1000 -Cl(θ) = spectrum_cmb(:TT, probgen(θ), jl, ls) +Cl(θ) = spectrum_cmb(:TT, probgen(prob0, θ), jl, ls) ``` We can now compute $Cₗ$ and the cosmic variance uncertainties ```math diff --git a/docs/src/plot.md b/docs/src/plot.md index 65b3f82f..abe94c56 100644 --- a/docs/src/plot.md +++ b/docs/src/plot.md @@ -98,7 +98,7 @@ function plot_interactive(prob::CosmologyProblem, xvar::SymBoltz.Num, yvar::SymB end probgen = parameter_updater(prob, [par for (par, _) in obspars]) function xyfunc(θ) - prob = probgen(θ) + prob = probgen(prob, θ) sol = solve(prob) τ = τs(sol) xs = sol(xvar, τ) @@ -135,7 +135,7 @@ obspars = [ ] probgen = parameter_updater(prob, [par for (par, _) in obspars]) function xyfunc(θ) - prob = probgen(θ) + prob = probgen(prob, θ) lgks = unique([-4:0.5:-3; -3:0.2:-2; -2:0.05:0]) # as few points as possible ks = 10 .^ lgks / u"Mpc" Ps = spectrum_matter(prob, ks; ptopts = (alg = SymBoltz.TRBDF2(), reltol = 1e-4, abstol = 1e-4)) diff --git a/docs/src/solve.md b/docs/src/solve.md index 29fa26a0..8ce967d1 100644 --- a/docs/src/solve.md +++ b/docs/src/solve.md @@ -29,7 +29,7 @@ To do so, use the function `parameter_updater` that returns a function that quic ```@example sol probmaker = parameter_updater(prob, [M.g.h, M.c.Ω₀]) # fast factory function -prob = probmaker([0.70, 0.27]) # create updated problem +prob = probmaker(prob, [0.70, 0.27]) # create updated problem ``` ```@docs diff --git a/src/solve.jl b/src/solve.jl index 8e4131cb..3e5739f6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -289,25 +289,22 @@ function parameter_updater(prob::CosmologyProblem, idxs; kwargs...) ptinitdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(ptinit))[1])) end - function updater(p) - bgprobu0, bgprobp = bgprobsetsym(bgprob, p) - bginitu0, bginitp = bginitsetsym(bginit, p) - newbgprob = remake(bgprob; u0 = bgprobu0, p = bgprobp, kwargs...) - newbginit = remake(bginit; u0 = bginitu0, p = bginitp, kwargs...) + function updater(prob::CosmologyProblem, newp) + bgprob, bginit = prob.bg, prob.bginit + if !isnothing(bgprob) && !isnothing(bginit) + u0, p = bgprobsetsym(bgprob, newp); bgprob = remake(bgprob; u0, p, kwargs...) + u0, p = bginitsetsym(bginit, newp); bginit = remake(bginit; u0, p, kwargs...) + end + ptprob, ptinit = prob.pt, prob.ptinit if !isnothing(ptprob) && !isnothing(ptinit) - ptprobu0, ptprobp = ptprobsetsym(ptprob, p) - ptinitu0, ptinitp = ptinitsetsym(ptinit, p) - newptprob = remake(ptprob; u0 = ptprobu0, p = ptprobp, kwargs...) - newptinit = remake(ptinit; u0 = ptinitu0, p = ptinitp, kwargs...) - else - newptprob = nothing - newptinit = nothing + u0, p = ptprobsetsym(ptprob, newp); ptprob = remake(ptprob; u0, p, kwargs...) + u0, p = ptinitsetsym(ptinit, newp); ptinit = remake(ptinit; u0, p, kwargs...) end - return CosmologyProblem(prob.M, newbgprob, newptprob, newbginit, newptinit, prob.pars, prob.shoot, prob.conditions) + return CosmologyProblem(prob.M, bgprob, ptprob, bginit, ptinit, prob.pars, prob.shoot, prob.conditions) end - function updater(p::Dict) + function updater(prob::CosmologyProblem, p::Dict) p = [p[var] for var in idxs] return updater(p) end diff --git a/test/runtests.jl b/test/runtests.jl index 9aab3dc3..885cb7fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -297,7 +297,7 @@ end probgen = parameter_updater(prob, diffpars) function logP(logθ) θ = exp.(logθ) - prob′ = probgen(θ) + prob′ = probgen(prob, θ) P = spectrum_matter(prob′, k) return log.(P / u"Mpc^3") end @@ -327,7 +327,7 @@ end probgen = parameter_updater(prob, diffpars) function logDlTT(logθ) θ = exp.(logθ) - prob′ = probgen(θ) + prob′ = probgen(prob, θ) DlTT = spectrum_cmb(:TT, prob′, jl; normalization = :Dl) return log.(DlTT) end @@ -359,7 +359,7 @@ end @test all(isnan.(getter(prob0))) probgen = parameter_updater(prob0, M.γ.T₀) - prob1 = probgen(2.73) + prob1 = probgen(prob0, 2.73) vals = getter(prob1) @test vals[1] == 2.73 @test isfinite(vals[2]) @@ -370,7 +370,7 @@ end @testset "Parameter updater and remake" begin probgen = parameter_updater(prob, [M.c.Ω₀]) - newprob = probgen([0.3]) + newprob = probgen(prob, [0.3]) @test newprob.bg.ps[M.c.Ω₀] == newprob.pt.ps[M.c.Ω₀] == 0.3 @test newprob.bg.ps[M.γ.Ω₀ + M.ν.Ω₀ + M.h.Ω₀ + M.b.Ω₀ + M.c.Ω₀ + M.Λ.Ω₀] == newprob.pt.ps[M.γ.Ω₀ + M.ν.Ω₀ + M.h.Ω₀ + M.b.Ω₀ + M.c.Ω₀ + M.Λ.Ω₀] ≈ 1.0 @@ -379,7 +379,7 @@ end @test all(map(SymBoltz.successful_retcode, sol.pts)) function Pk(Ωc0) - newprob = probgen([Ωc0]) + newprob = probgen(prob, [Ωc0]) return spectrum_matter(newprob, ks) end isnonzero(x) = isfinite(x) && !iszero(x) @@ -442,7 +442,7 @@ end diffpars = [M.g.h, M.c.Ω₀, M.b.Ω₀, M.γ.T₀, M.ν.Neff, M.h.m_eV, M.b.YHe, M.I.ln_As1e10, M.I.ns] probgen = parameter_updater(prob, diffpars) getτ0 = SymBoltz.getsym(prob, M.τ0) - τ0(θ) = getτ0(solve(probgen(θ))) + τ0(θ) = getτ0(solve(probgen(prob, θ))) θ0 = [pars[par] for par in diffpars] dτ0_ad = ForwardDiff.gradient(τ0, θ0) dτ0_fd = FiniteDiff.finite_difference_gradient(τ0, θ0) @@ -479,7 +479,7 @@ function stability(M::System, ks, vary::Dict, nsamples; verbose = false, kwargs. println("Solving for wavenumbers ", ks) end for sample in eachcol(samples) - prob = probgen(sample) + prob = probgen(prob0, sample) sol = solve(prob, ks; kwargs...) if issuccess(sol) nsuccess += 1 From d9803254decb50f67a45ca15c3cf716a451172a6 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 20:55:57 +0200 Subject: [PATCH 07/11] Pass keyword arguments to the returned parameter updater function only --- src/solve.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 3e5739f6..bc90f307 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -265,12 +265,12 @@ end remake(prob::CosmologyProblem, pars::AbstractArray; kwargs...) = remake(prob, Dict(pars); kwargs...) """ - parameter_updater(prob::CosmologyProblem, idxs; kwargs...) + parameter_updater(prob::CosmologyProblem, idxs) Create and return a function that updates the symbolic parameters `idxs` of the cosmological problem `prob`. The returned function is called with numerical values (in the same order as `idxs`) and returns a new problem with the updated parameters. """ -function parameter_updater(prob::CosmologyProblem, idxs; kwargs...) +function parameter_updater(prob::CosmologyProblem, idxs) # define a closure based on https://docs.sciml.ai/ModelingToolkit/dev/examples/remake/#replace-and-remake bgprob = prob.bg ptprob = prob.pt @@ -289,7 +289,7 @@ function parameter_updater(prob::CosmologyProblem, idxs; kwargs...) ptinitdiffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(ptinit))[1])) end - function updater(prob::CosmologyProblem, newp) + function updater(prob::CosmologyProblem, newp; kwargs...) bgprob, bginit = prob.bg, prob.bginit if !isnothing(bgprob) && !isnothing(bginit) u0, p = bgprobsetsym(bgprob, newp); bgprob = remake(bgprob; u0, p, kwargs...) From a8427f8512418db13dc453ed4e3ef1f926eae238 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 28 Apr 2026 21:35:22 +0200 Subject: [PATCH 08/11] Handle shooting with parameters --- src/solve.jl | 48 +++++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index bc90f307..a5efc5ca 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -157,7 +157,8 @@ function CosmologyProblem( end # Background initialization problem - bginit = InitializationProblem(bg, first(ivspan), parsk) + parsk_numeric = Dict(par => issymbolic(val) ? NaN : val for (par, val) in parsk) # TODO: prevent adding e.g. ΦBD ~ 1-1/(1+ωBD) as equation if it is a shooting guess + bginit = InitializationProblem(bg, first(ivspan), parsk_numeric; fully_determined) # Set up callback for today # TODO: specify callbacks symbolically? iv = ModelingToolkit.get_iv(M) @@ -417,19 +418,14 @@ function warning_failed_solution(sol::ODESolution, name = "ODE"; verbose = false end """ - solvebg(bgprob::ODEProblem[, vars, conditions]; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, build_initializeprob = Val{false}, kwargs...) + solvebg(bgprob::ODEProblem[, vars, conditions]; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, kwargs...) Solve the background cosmology problem `bgprob`. If the background requires shooting, `vars` is a dictionary with variables to shoot for and their initial guesses, and `conditions` is and an array of equations that should hold at the final integration time (usually today). """ -function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, verbose = false, kwargs...) +function solvebg(bgprob::ODEProblem; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, verbose = false, kwargs...) check_solve_args(bgprob, alg) - # 1) Get and set initial conditions - initsol = solve(bginit) - u0 = initsol[unknowns(bgprob.f.sys)] - bgprob = remake(bgprob; u0) - # 2) Solve ODE bgsol = solve(bgprob, alg; verbose, reltol, abstol, kwargs...) if !successful_retcode(bgsol) @@ -444,8 +440,14 @@ function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem; alg = bgalg(bgpro return bgsol end +function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem; kwargs...) + initsol = solve(bginit) + u0 = initsol[unknowns(bgprob.f.sys)] + bgprob = remake(bgprob; u0) + return solvebg(bgprob; kwargs...) +end # TODO: more generic shooting method that can do anything (e.g. S8) -function solvebg(bgprob::ODEProblem, vars, conditions; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, build_initializeprob = Val{false}, kwargs...) +function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem, vars, conditions; alg = bgalg(bgprob), reltol = 1e-7, abstol = 1e-7, shootopts = (alg = shootalg(), reltol = 1e-3), verbose = false, kwargs...) length(vars) == length(conditions) || error("Different number of shooting parameters and conditions") guess = collect(values(vars)) @@ -461,19 +463,23 @@ function solvebg(bgprob::ODEProblem, vars, conditions; alg = bgalg(bgprob), relt vars = only(vars) conditions = only(conditions) end - setvars = SymbolicIndexingInterface.setsym_oop(bgprob, vars) # efficient setter + vars = (var -> ModelingToolkit.is_parameter(bginit, var) ? var : Initial(var)).(vars) # turn time-dependent variables into Initial(...) (map fails for scalar case) + setinitvars = SymbolicIndexingInterface.setsym_oop(bginit, vars) # efficient setter + setprobvars = SymbolicIndexingInterface.setsym_oop(bgprob, vars) + getics = SymbolicIndexingInterface.getsym(bginit, unknowns(bgprob.f.sys)) getfuns = getsym(bgprob, conditions) # efficient getter - function f(vals, (oldbgprob, setvars, getfuns, build_initializeprob, verbose, varstrs, constrs)) - # slow but "safe" - #u0, p = SymBoltz.split_vars_pars(oldbgprob.f.sys, Dict(keys(vars) .=> vals)) - #newbgprob = remake(oldbgprob; u0, p) + function f(vals, (oldbgprob, oldbginit, setinitvars, setprobvars, getfuns, verbose, varstrs, constrs)) + u0, p = setinitvars(oldbginit, vals) + bginit = remake(oldbginit; u0, p) - # fast but "unsafe" - newu0, newp = setvars(oldbgprob, vals) - newbgprob = remake(oldbgprob; u0 = newu0, p = newp, build_initializeprob) + # TODO: use DiffCache by refactoring parameter_updater into one function for bg and another for pt + initsol = solve(bginit) + u0 = getics(initsol) + _, p = setprobvars(oldbgprob, vals) + bgprob = remake(oldbgprob; u0, p) - bgsol = solvebg(newbgprob; alg, reltol, abstol, kwargs..., save_everystep = false, save_start = false, save_end = true, verbose) + bgsol = solvebg(bgprob; alg, reltol, abstol, kwargs..., save_everystep = false, save_start = false, save_end = true, maxiters = 2000, verbose) !successful_retcode(bgsol) && error("Shooting failed when solving background with $(varvalstr(varstrs, vals)). Run with `verbose = true` for more output. Change the initial shooting guesses.") conditions = only(getfuns(bgsol)) # get final values verbose && !(eltype(vals) <: ForwardDiff.Dual) && println("Shooting: ", varvalstr(varstrs, vals), " -> ", varvalstr(constrs, conditions)) @@ -493,15 +499,15 @@ function solvebg(bgprob::ODEProblem, vars, conditions; alg = bgalg(bgprob), relt NonlinearProblemT = NonlinearProblem end end - prob = NonlinearProblemT(f, guess, (bgprob, setvars, getfuns, build_initializeprob, verbose, varstrs, constrs)) + prob = NonlinearProblemT(f, guess, (bgprob, bginit, setinitvars, setprobvars, getfuns, verbose, varstrs, constrs)) sol = solve(prob; shootopts...) if !successful_retcode(sol) error("Shooting failed to converge. Last result was $(varvalstr(varstrs, sol.u)). Run with `verbose = true` for more output. Change the initial shooting guesses.") end - u0, p = setvars(bgprob, sol.u) - bgprob = remake(bgprob; u0, p, build_initializeprob) + u0, p = setprobvars(bgprob, sol.u) + bgprob = remake(bgprob; u0, p) return solvebg(bgprob; alg, reltol, abstol, kwargs...) end From d38ccf37e53caa600b2fbfa215e3956a8286a49e Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Thu, 30 Apr 2026 21:25:40 +0200 Subject: [PATCH 09/11] Add setupbg function; fix benchmarks --- docs/src/benchmarks.md | 5 +++-- src/solve.jl | 10 +++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/src/benchmarks.md b/docs/src/benchmarks.md index c7d4b55d..addc32af 100644 --- a/docs/src/benchmarks.md +++ b/docs/src/benchmarks.md @@ -31,13 +31,14 @@ The points on each curve correspond to a sequence of tolerances. using DiffEqDevTools refalg = Rodas5P(linsolve = RFLUFactorization()) -bgsol = solve(prob.bg, refalg; abstol = 1e-12, reltol = 1e-12) # reference solution (results are similar compared to Rodas4/4P/5P/FBDF) +bgprob = SymBoltz.setupbg(prob.bg, prob.bginit) +bgsol = solve(bgprob, refalg; abstol = 1e-12, reltol = 1e-12) # reference solution (results are similar compared to Rodas4/4P/5P/FBDF) abstols = 1 ./ 10 .^ (7:11) reltols = 1 ./ 10 .^ (7:11) bgalgs = [Rodas4(), Rodas5(), Rodas4P(), Rodas5P(), Rodas6P(), FBDF(), QNDF()] # FBDF/QNDF unstable for some tolerances setups = [Dict(:alg => alg) for alg in bgalgs] -wp = WorkPrecisionSet(prob.bg, abstols, reltols, setups; appxsol = bgsol, save_everystep = false, error_estimate = :l2) +wp = WorkPrecisionSet(bgprob, abstols, reltols, setups; appxsol = bgsol, save_everystep = false, error_estimate = :l2) plot(wp; title = "Reference: $(SymBoltz.algname(refalg))", size = (800, 400), margin = 5*Plots.mm) ``` diff --git a/src/solve.jl b/src/solve.jl index a5efc5ca..da6a3df8 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -441,9 +441,7 @@ function solvebg(bgprob::ODEProblem; alg = bgalg(bgprob), reltol = 1e-7, abstol return bgsol end function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem; kwargs...) - initsol = solve(bginit) - u0 = initsol[unknowns(bgprob.f.sys)] - bgprob = remake(bgprob; u0) + bgprob = setupbg(bgprob, bginit) return solvebg(bgprob; kwargs...) end # TODO: more generic shooting method that can do anything (e.g. S8) @@ -511,6 +509,12 @@ function solvebg(bgprob::ODEProblem, bginit::NonlinearProblem, vars, conditions; return solvebg(bgprob; alg, reltol, abstol, kwargs...) end +function setupbg(bgprob::ODEProblem, bginit::NonlinearProblem) + initsol = solve(bginit) + u0 = initsol[unknowns(bgprob.f.sys)] + return remake(bgprob; u0) +end + function setuppt(ptprob::ODEProblem, ptinit::NonlinearProblem, bgsol::ODESolution, ptivini::Function) tspanbg = (bgsol.t[begin], bgsol.t[end]) From d701755b3d5269f0a1c86185a0c9fd3c9dcb2793 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Fri, 1 May 2026 11:50:17 +0200 Subject: [PATCH 10/11] Fix distance_luminosity_function --- src/observables/distances.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/observables/distances.jl b/src/observables/distances.jl index 3d455413..ae290525 100644 --- a/src/observables/distances.jl +++ b/src/observables/distances.jl @@ -25,16 +25,15 @@ function distance_luminosity_function(M::System, pars_fixed, pars_varying, zs; b pars = merge(pars_fixed, Dict(pars_varying .=> NaN)) as = @. 1 / (zs + 1) prob = CosmologyProblem(M, pars; pt = false, ivspan = (minimum(as), 1.0)) - probgen = parameter_updater(prob, pars_varying; build_initializeprob = Val{false}) + probgen = parameter_updater(prob, pars_varying) geta = getsym(prob, M.g.a) getτ = getsym(prob, M.τ) geth = getsym(prob, M.g.h) getΩk0 = getsym(prob, M.K.Ω₀) - return p -> begin - prob = probgen(p) - sol = solve(prob; bgopts, saveat = as, save_end = true) + return (p) -> begin + sol = solve(probgen(prob, p); bgopts, saveat = as, save_end = true) a = geta(sol) τ = getτ(sol) h = geth(sol) From 49be0baaf35fa22da46ecb2a93b8bafa303764f9 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Fri, 1 May 2026 15:22:28 +0200 Subject: [PATCH 11/11] Fix plot_interactive --- docs/src/plot.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/src/plot.md b/docs/src/plot.md index abe94c56..54ba47b0 100644 --- a/docs/src/plot.md +++ b/docs/src/plot.md @@ -98,8 +98,7 @@ function plot_interactive(prob::CosmologyProblem, xvar::SymBoltz.Num, yvar::SymB end probgen = parameter_updater(prob, [par for (par, _) in obspars]) function xyfunc(θ) - prob = probgen(prob, θ) - sol = solve(prob) + sol = solve(probgen(prob, θ)) τ = τs(sol) xs = sol(xvar, τ) ys = sol(yvar, τ) @@ -135,10 +134,9 @@ obspars = [ ] probgen = parameter_updater(prob, [par for (par, _) in obspars]) function xyfunc(θ) - prob = probgen(prob, θ) lgks = unique([-4:0.5:-3; -3:0.2:-2; -2:0.05:0]) # as few points as possible ks = 10 .^ lgks / u"Mpc" - Ps = spectrum_matter(prob, ks; ptopts = (alg = SymBoltz.TRBDF2(), reltol = 1e-4, abstol = 1e-4)) + Ps = spectrum_matter(probgen(prob, θ), ks; ptopts = (alg = SymBoltz.TRBDF2(), reltol = 1e-4, abstol = 1e-4)) lgPs = log10.(Ps/u"Mpc^3") # smoothen with spline and sample more densely