Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Multi-threading of CommonSolve functions #17

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9a027c0
initial revamp specialization and add executors
lxvm Jan 2, 2025
077c5a7
Merge branch 'main' into thread
lxvm Jan 3, 2025
0463610
add CommonSolveFourier executor
lxvm Jan 3, 2025
4e7f77a
add executor docs and missing variable
lxvm Jan 3, 2025
5ef5d2f
refactor of nestedquad with reversed initialization
lxvm Jan 16, 2025
81e17d4
remove unused keywords
lxvm Jan 17, 2025
05ea9db
composable refactor
lxvm Jan 30, 2025
6f67698
Breaking: replace (up!, post) with solve!
lxvm Feb 11, 2025
6a82beb
multithreading fix
lxvm Feb 11, 2025
031f049
pass interface tests
lxvm Feb 13, 2025
f1cde42
some fourier fixes, not yet autosymptr
lxvm Feb 14, 2025
b036773
bump polyhedra compat
lxvm Feb 14, 2025
d9a2ab4
quick fix for autosymptr
lxvm Feb 15, 2025
ffe4c9c
refactor algorithms to init integrand consistently
lxvm Feb 16, 2025
33eb572
update auxquadgk to refactor
lxvm Feb 16, 2025
b1a975a
refactor quadgk_integrand and fix fourier
lxvm Feb 16, 2025
6137eb1
remove nthreads keyword from autosymptr algorithms
lxvm Feb 16, 2025
1baa966
fix how tolerances are passed from AutoBZProblem to IntegralProblem a…
lxvm Feb 16, 2025
8653bd4
parallelize ptr no syms
lxvm Feb 17, 2025
4508bd9
restore previous symptr fourier evaluation
lxvm Feb 17, 2025
96b2152
ci version fix
lxvm Feb 17, 2025
1f494b6
docs updates
lxvm Feb 17, 2025
bea0b8a
dos fixes
lxvm Feb 17, 2025
3ea50fc
update version in docs CI
lxvm Feb 17, 2025
1a44480
fix in fourier nested quad
lxvm Feb 17, 2025
98fe48c
important autobzproblem fixes
lxvm Feb 17, 2025
30ed98b
fix unitful ptr bug
lxvm Feb 18, 2025
5073026
consolidate autoptr
lxvm Mar 12, 2025
d6e664d
create logger alg
lxvm Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
some fourier fixes, not yet autosymptr
lxvm committed Feb 14, 2025
commit f1cde42940e1aa8810361921ca381f8f0f45ec9b
255 changes: 102 additions & 153 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -40,10 +40,10 @@ struct FourierIntegralFunction{F,S,P} <: AbstractFourierIntegralFunction
end
FourierIntegralFunction(f, s, p=nothing; alias=false) = FourierIntegralFunction(f, s, p, alias)

function get_prototype(f::FourierIntegralFunction, x, ws, p)
function _get_prototype(f::FourierIntegralFunction, x, ws, p)
f.prototype === nothing ? f.f(x, ws(x), p) : f.prototype
end
get_prototype(f::FourierIntegralFunction, x, p) = get_prototype(f, x, f.s, p)
get_prototype(f::FourierIntegralFunction, x, p) = _get_prototype(f, x, f.s, p)

function get_fourierworkspace(f::AbstractFourierIntegralFunction)
f.s isa FourierWorkspace ? f.s : FourierSeriesEvaluators.workspace_allocate(f.alias ? f.s : deepcopy(f.s), FourierSeriesEvaluators.period(f.s))
@@ -79,15 +79,15 @@ end
function do_solve!(solver, f::CommonSolveFourierIntegralFunction, x, s, p)
return f.solve!(solver, x, s, p)
end
function get_prototype(f::CommonSolveFourierIntegralFunction, x, ws, p)
function _get_prototype(f::CommonSolveFourierIntegralFunction, x, ws, p, _solver=nothing)
if isnothing(f.prototype)
solver = init(f.prob, f.alg; f.kwargs...)
solver = isnothing(_solver) ? init(f.prob, f.alg; f.kwargs...) : _solver
do_solve!(solver, f, x, ws(x), p)
else
f.prototype
end
end
get_prototype(f::CommonSolveFourierIntegralFunction, x, p) = get_prototype(f, x, f.s, p)
get_prototype(f::CommonSolveFourierIntegralFunction, x, p, _solver=nothing) = _get_prototype(f, x, f.s, p, _solver)

function init_specialized_fourierintegrand(solver, f, dom, p; x=get_prototype(dom), ws=f.s, s = ws(x), prototype=f.prototype)
proto = prototype === nothing ? do_solve!(solver, f, x, s, p) : prototype
@@ -109,32 +109,32 @@ end

# TODO implement CommonSolveFourierInplaceIntegrand CommonSolveFourierInplaceBatchIntegrand

# similar to workspace_allocate, but more type-stable because of loop unrolling and vector types
function workspace_allocate_vec(s::AbstractFourierSeries{N}, x::NTuple{N,Any}, len::NTuple{N,Integer}=ntuple(one,Val(N))) where {N}
# Only the top-level workspace has an AbstractFourierSeries in the series field
# In the lower level workspaces the series field has a cache that can be contract!-ed
# into a series
dim = Val(N)
if N == 1
c = FourierSeriesEvaluators.allocate(s, x[N], dim)
ws = Vector{typeof(c)}(undef, len[N])
ws[1] = c
for n in 2:len[N]
ws[n] = FourierSeriesEvaluators.allocate(s, x[N], dim)
end
else
c = FourierSeriesEvaluators.allocate(s, x[N], dim)
t = FourierSeriesEvaluators.contract!(c, s, x[N], dim)
c_ = FourierWorkspace(c, FourierSeriesEvaluators.workspace_allocate(t, x[1:N-1], len[1:N-1]).cache)
ws = Vector{typeof(c_)}(undef, len[N])
ws[1] = c_
for n in 2:len[N]
_c = FourierSeriesEvaluators.allocate(s, x[N], dim)
_t = FourierSeriesEvaluators.contract!(_c, s, x[N], dim)
ws[n] = FourierWorkspace(_c, FourierSeriesEvaluators.workspace_allocate(_t, x[1:N-1], len[1:N-1]).cache)
end
function fourier_to_standard(func::FourierIntegralFunction, x0, p)
(; f, s, prototype, alias) = func
prob = FourierEvaluationProblem(s, x0, alias)
alg = FourierEvaluationAlgorithm()
_solve! = (solver, x, p) -> begin
solver.x = x
sol = solve!(solver)
return f(x, sol, p)
end
CommonSolveIntegralFunction(_solve!, prob, alg, prototype)
end
function fourier_to_standard(func::CommonSolveFourierIntegralFunction, x0, p)
(; prob, alg, s, kwargs, prototype, specialize, executor, alias) = func
fourierprob = FourierEvaluationProblem(s, x0, alias)
fourieralg = FourierEvaluationAlgorithm()
fullprob = ComposedCommonSolveProblem((; x=x0, p=p), fourierprob, prob) do (; x, p), fouriersolver, probsolver
fouriersolver.x = x
fx = solve!(fouriersolver)
return func.solve!(probsolver, x, fx, p)
end
fullalg = ComposedCommonSolveAlgorithm(fourieralg, alg)
_solve! = (solver, x, p) -> begin
solver.input = (; x, p)
return solve!(solver)
end
return FourierWorkspace(s, ws)
CommonSolveIntegralFunction(_solve!, fullprob, fullalg, prototype, specialize, executor; kwargs...)
end

struct FourierValue{X,S}
@@ -144,70 +144,48 @@ end
@inline AutoSymPTR.mymul(w, x::FourierValue) = FourierValue(AutoSymPTR.mymul(w, x.x), x.s)
@inline AutoSymPTR.mymul(::AutoSymPTR.One, x::FourierValue) = x

function init_cacheval(f::FourierIntegralFunction, dom, p, alg::QuadGKJL; kws...)
segs = PuncturedInterval(dom)
ws = get_fourierworkspace(f)
prototype = get_prototype(f, get_prototype(segs), ws, p)
return init_segbuf(prototype, segs, alg), ws
end
function init_cacheval(f::CommonSolveFourierIntegralFunction, dom, p, alg::QuadGKJL; kws...)
segs = PuncturedInterval(dom)
x = get_prototype(segs)
ws = get_fourierworkspace(f)
cache, integrand, prototype = _init_commonsolvefourierfunction(f, dom, p; x, ws)
return init_segbuf(prototype, segs, alg), ws, cache, integrand
end
function call_quadgk(f::FourierIntegralFunction, p, u, usegs, cacheval; kws...)
segbuf, ws = cacheval
quadgk(x -> (ux =
u*x; f.f(ux, ws(ux), p)), usegs...; kws..., segbuf)
function init_cacheval(f::AbstractFourierIntegralFunction, dom, p, alg::QuadGKJL; kws...)
g = fourier_to_standard(f, get_prototype(dom), p)
return g, init_cacheval(g, dom, p, alg; kws...)
end
function call_quadgk(f::CommonSolveFourierIntegralFunction, p, u, usegs, cacheval; kws...)
segbuf, ws, _, integrand = cacheval
quadgk(x -> (ux = u*x; integrand(ux, ws(ux), p)), usegs...; kws..., segbuf)
function call_quadgk(f::AbstractFourierIntegralFunction, p, u, usegs, (g, cacheval); kws...)
return call_quadgk(g, p, u, usegs, cacheval; kws...)
end

function init_cacheval(f::FourierIntegralFunction, dom, p, ::HCubatureJL; kws...)
# TODO utilize hcubature_buffer
ws = get_fourierworkspace(f)
return ws
end
init_cacheval(f::CommonSolveFourierIntegralFunction, dom, p, alg::HCubatureJL; kws...) = init_cacheval_cs(f.executor, f, dom, p, alg; kws...)
function init_cacheval_cs(::SerialExecutor, f::CommonSolveFourierIntegralFunction, dom, p, alg::HCubatureJL; kws...)
ws = get_fourierworkspace(f)
solver, integrand, = init_commonsolvefunction(f, dom, p)
return (; ws, solver, integrand)
end
function init_cacheval_cs(exec::ThreadedExecutor, f::CommonSolveFourierIntegralFunction, dom, p, alg::HCubatureJL; kws...)
throw(ArgumentError("HCubatureJL does not support threaded execution because it does not support batched integrands"))
end
function hcubature_integrand(f::FourierIntegralFunction, p, a, b, ws)
x -> f.f(x, ws(x), p)
function init_cacheval(f::AbstractFourierIntegralFunction, dom, p, alg::HCubatureJL; kws...)
g = fourier_to_standard(f, get_prototype(dom), p)
return g, init_cacheval(g, dom, p, alg; kws...)
end
function hcubature_integrand(f::CommonSolveFourierIntegralFunction, p, a, b, cacheval)
integrand = cacheval.integrand
ws = cacheval.ws
x -> integrand(x, ws(x), p)
function hcubature_integrand(f::AbstractFourierIntegralFunction, p, a, b, (g, cacheval))
return hcubature_integrand(g, p, a, b, cacheval)
end

function init_autosymptr_cache(f::FourierIntegralFunction, dom, p, bufsize; kws...)
ws = get_fourierworkspace(f)
return (; buffer=nothing, ws)
end
function init_autosymptr_cache(f::CommonSolveFourierIntegralFunction, dom, p, bufsize; kws...)
ws = get_fourierworkspace(f)
cache, integrand, = _init_commonsolvefourierfunction(f, dom, p; ws)
return (; buffer=nothing, ws, cache, integrand)
end
function autosymptr_integrand(f::FourierIntegralFunction, p, segs, cacheval)
ws = cacheval.ws
x -> x isa FourierValue ? f.f(x.x, x.s, p) : f.f(x, ws(x), p)

function init_autosymptr_cache(f::AbstractFourierIntegralFunction, dom, p, bufsize; kws...)
g = fourier_to_standard(f, get_prototype(dom), p)
return (; g, init_autosymptr_cache(g, dom, p, bufsize; kws...)...)
end
function autosymptr_integrand(f::CommonSolveFourierIntegralFunction, p, segs, cacheval)
integrand = cacheval.integrand
ws = cacheval.ws
return x -> x isa FourierValue ? integrand(x.x, x.s, p) : integrand(x, ws(x), p)
# function init_autosymptr_cache(f::FourierIntegralFunction, dom, p, bufsize; kws...)
# ws = get_fourierworkspace(f)
# return (; buffer=nothing, ws)
# end
# function init_autosymptr_cache(f::CommonSolveFourierIntegralFunction, dom, p, bufsize; kws...)
# ws = get_fourierworkspace(f)
# cache, integrand, = _init_commonsolvefourierfunction(f, dom, p; ws)
# return (; buffer=nothing, ws, cache, integrand)
# end
function autosymptr_integrand(f::AbstractFourierIntegralFunction, p, segs, cacheval)
return autosymptr_integrand(cacheval.g, p, segs, cacheval)
end
# function autosymptr_integrand(f::FourierIntegralFunction, p, segs, cacheval)
# ws = cacheval.ws
# x -> x isa FourierValue ? f.f(x.x, x.s, p) : f.f(x, ws(x), p)
# end
# function autosymptr_integrand(f::CommonSolveFourierIntegralFunction, p, segs, cacheval)
# integrand = cacheval.integrand
# ws = cacheval.ws
# return x -> x isa FourierValue ? integrand(x.x, x.s, p) : integrand(x, ws(x), p)
# end


function init_cacheval(f::AbstractFourierIntegralFunction, dom, p, alg::AuxQuadGKJL; kws...)
@@ -454,6 +432,15 @@ function init_cacheval(f::AbstractFourierIntegralFunction, dom, p, alg::AutoSymP
end


function insert_counter(f::FourierIntegralFunction, x, p, channel)
prob = CounterProblem(; channel)
alg = SingleCount()
proto = get_prototype(f, x, p)
CommonSolveFourierIntegralFunction(prob, alg, f.s, proto) do solver, x, s, p
step!(solver)
return f.f(x, s, p)
end
end
function insert_counter(f::CommonSolveFourierIntegralFunction, x, p, channel)
input = (; x, p, s=f.s(x))
prob = ComposedCommonSolveProblem(input, CounterProblem(; channel), f.prob) do (; x, s, p), countersolver, probsolver
@@ -494,14 +481,6 @@ function solve!(solver::FourierEliminationSolver)
return FourierSeriesEvaluators.contract!(solver.cacheval, solver.f, solver.x, solver.dim)
end

function nested_innerintegralfunction(f::FourierIntegralFunction, x0, p)
proto = get_prototype(f, x0, p)
func = IntegralFunction(proto) do x, (; p, state, fouriercache, fourierseries)
f.f(SVector(promote(x, state...)), FourierSeriesEvaluators.evaluate!(fouriercache, fourierseries, x), p)
end
return func
end

struct FourierEvaluationProblem{F<:AbstractFourierSeries,X,K}
f::F
x::X
@@ -519,68 +498,38 @@ mutable struct FourierEvaluationSolver{F,X,A,C,K}
end
function init(prob::FourierEvaluationProblem, alg::FourierEvaluationAlgorithm; kws...)
kwargs = (; prob.kwargs..., kws...)
cacheval = FourierSeriesEvaluators.allocate(prob.f, prob.x, Val(1))#prob.dim)
return FourierEvaluationSolver(prob.f, prob.x, alg, cacheval, kwargs)
end
function solve!(solver::FourierEvaluationSolver)
return FourierSeriesEvaluators.evaluate!(solver.cacheval, solver.f, solver.x)
end

function fourier_to_standard(func::FourierIntegralFunction, x0, p)
(; f, s, prototype, alias) = func
prob = FourierEvaluationProblem(s, x0, alias)
alg = FourierEvaluationAlgorithm()
_solve! = (solver, x, p) -> begin
solver.x = x
sol = solve!(solver)
sol = FourierSeriesEvaluators.evaluate!(solver.cacheval, solver.f, x)
return f(x, sol, p)
end
CommonSolveIntegralFunction(_solve!, prob, alg, prototype)
cacheval = init_fourierevalcache(prob.f |> (prob.alias ? identity : deepcopy), Tuple(prob.x))
return FourierEvaluationSolver(prob.f, prob.x, alg, cacheval, kwargs)
end
function fourier_to_standard(func::CommonSolveFourierIntegralFunction, x0, p)
(; prob, alg, s, kwargs, prototype, specialize, executor, alias) = func
fourierprob = FourierEvaluationProblem(s, x0, alias)
fourieralg = FourierEvaluationAlgorithm()
fullprob = ComposedCommonSolveProblem((; x=x0, p=p), fourierprob, prob) do ((; x, p), fouriersolver, probsolver)
fouriersolver.x = x
fx = solve!(fouriersolver)
return func.solve!(probsolver, x, fx, p)
end
fullalg = ComposedCommonSolveAlgorithm(fourieralg, alg)
_solve! = (solver, x, p) -> begin
solver.input = (; x, p)
return solve!(solver)
function init_fourierevalcache(f::AbstractFourierSeries, x::Tuple)
nd = ndims(f)
vd = Val(nd)
xd = x[nd]
if nd == 1
return (FourierSeriesEvaluators.allocate(f, xd, vd),)
else
solver = init(FourierEliminationProblem(f, xd, vd), FourierEliminationAlgorithm())
return (init_fourierevalcache(solve!(solver), x[1:nd-1])..., solver)
end
CommonSolveIntegralFunction(_solve!, fullprob, fullalg, prototype, specialize, executor; kwargs...)
end

struct ComposedCommonSolveProblem{P,S,I,K}
problems::P
solve!::S
input::I
kwargs::K
ComposedCommonSolveProblem(solve!, input, probs...; kws...) = new{typeof(probs),typeof(solve!),typeof(input),typeof(kws)}(probs, solve!, input, kws)
end

struct ComposedCommonSolveAlgorithm{A}
algorithms::A
ComposedCommonSolveAlgorithm(algs...) = new{typeof(algs)}(algs)
end

mutable struct ComposedCommonSolveSolver{S,SS,I,K}
solvers::S
solve!::SS
input::I
kwargs::K
end
function init(prob::ComposedCommonSolveProblem, alg::ComposedCommonSolveAlgorithm; kws...)
kwargs = (; prob.kwargs..., kws...)
solvers = map(init, prob.problems, alg.algorithms)
return ComposedCommonSolveSolver(solvers, prob.solve!, prob.input, kwargs)
end
function solve!(solver::ComposedCommonSolveSolver)
return solver.solve!(solver.input, solver.solvers...; solver.kwargs...)
function solve!(solver::FourierEvaluationSolver)
return solve_fourierevalcache!(solver.cacheval, solver.f, Tuple(solver.x))
end
function solve_fourierevalcache!(cacheval, f::AbstractFourierSeries, x::Tuple)
nd = ndims(f)
vd = Val(nd)
xd = x[nd]
cache = cacheval[nd]
if nd == 1
return FourierSeriesEvaluators.evaluate!(cache, f, xd)
else
cache.f = f
cache.x = xd
return solve_fourierevalcache!(cacheval[1:nd-1], solve!(cache), x[1:nd-1])
end
end

function init_cacheval(f::AbstractFourierIntegralFunction, dom, p, alg::NestedQuad; kws...)
@@ -671,12 +620,12 @@ function nested_innerfourierintegralfunction(f::FourierIntegralFunction, x0, ser
alg = FourierEvaluationAlgorithm()

_f = f.f
func = CommonSolveIntegralFunction(prob, alg, proto) do solver, x, p
func = CommonSolveIntegralFunction(prob, alg, proto) do solver, x, (; series, p, state)
# solver.x = x
# solver.f = p.series
# sol = solve!(solver)
## using out-of-place semantics can be faster
sol = FourierSeriesEvaluators.evaluate!(solver.cacheval, p.series, x)
sol = solve_fourierevalcache!(solver.cacheval, series, Tuple(x))
return _f(SVector(promote(x, state...)), sol, p)
end
return func
@@ -710,4 +659,4 @@ end
function nested_innerfourierintegralfunction_cs(exec::ThreadedExecutor, f, x0, series, x1, p)
_f = nested_innerfourierintegralfunction_cs(SerialExecutor(), f, x0, series, x1, p)
return nested_integralfunction_cs(exec, _f, x1, p)
end
end
32 changes: 31 additions & 1 deletion src/interfaces.jl
Original file line number Diff line number Diff line change
@@ -158,8 +158,9 @@ end
Base.@nospecializeinfer function do_solve_nsp!(@nospecialize(solver), f::CommonSolveIntegralFunction, x, p)
return do_solve!(solver, f, x, p)
end
function get_prototype(f::CommonSolveIntegralFunction, x, p, solver=init(f.prob, f.alg; f.kwargs...))
function get_prototype(f::CommonSolveIntegralFunction, x, p, _solver=nothing)
if isnothing(f.prototype)
solver = isnothing(_solver) ? init(f.prob, f.alg; f.kwargs...) : _solver
do_solve!(solver, f, x, p)
else
f.prototype
@@ -332,3 +333,32 @@ struct IntegralSolution{T,S}
retcode::ReturnCode
stats::S
end


struct ComposedCommonSolveProblem{P,S,I,K}
problems::P
solve!::S
input::I
kwargs::K
ComposedCommonSolveProblem(solve!, input, probs...; kws...) = new{typeof(probs),typeof(solve!),typeof(input),typeof(kws)}(probs, solve!, input, kws)
end

struct ComposedCommonSolveAlgorithm{A}
algorithms::A
ComposedCommonSolveAlgorithm(algs...) = new{typeof(algs)}(algs)
end

mutable struct ComposedCommonSolveSolver{S,SS,I,K}
solvers::S
solve!::SS
input::I
kwargs::K
end
function init(prob::ComposedCommonSolveProblem, alg::ComposedCommonSolveAlgorithm; kws...)
kwargs = (; prob.kwargs..., kws...)
solvers = map(init, prob.problems, alg.algorithms)
return ComposedCommonSolveSolver(solvers, prob.solve!, prob.input, kwargs)
end
function solve!(solver::ComposedCommonSolveSolver)
return solver.solve!(solver.input, solver.solvers...; solver.kwargs...)
end
64 changes: 8 additions & 56 deletions test/fourier.jl
Original file line number Diff line number Diff line change
@@ -26,13 +26,16 @@ using AutoBZCore: PuncturedInterval, HyperCube, segments, endpoints
b = 1
p = 0.0
t = 1.0
update! = (cache, x, s, p) -> cache.p = (x, s, p)
postsolve = (sol, x, s, p) -> sol.value
_solve! = (solver, x, s, p) -> begin
solver.p = (x, s, p)
sol = solve!(solver)
return sol.value
end
s = FourierSeries([1, 0, 1]/2; period=t, offset=-2)
f = (x, (y, s, p)) -> x * s + p + y
subprob = IntegralProblem(f, (a, b), ((a+b)/2, s((a+b)/2), 1.0))
abstol = 1e-5
prob = IntegralProblem(CommonSolveFourierIntegralFunction(subprob, QuadGKJL(), update!, postsolve, s), (a, b), p; abstol)
prob = IntegralProblem(CommonSolveFourierIntegralFunction(_solve!, subprob, QuadGKJL(), s), (a, b), p; abstol)
for alg in (QuadGKJL(), HCubatureJL(), QuadratureFunction(), AuxQuadGKJL())
cache = init(prob, alg)
for p in [3.0, 4.0]
@@ -45,7 +48,7 @@ using AutoBZCore: PuncturedInterval, HyperCube, segments, endpoints
f = (x, (y, s, p)) -> x * s + p
subprob = IntegralProblem(f, (a, b), ([(a+b)/2], s((a+b)/2), 1.0))
abstol = 1e-5
prob = IntegralProblem(CommonSolveFourierIntegralFunction(subprob, QuadGKJL(), update!, postsolve, s), AutoBZCore.Basis(t*I(1)), p; abstol)
prob = IntegralProblem(CommonSolveFourierIntegralFunction(_solve!, subprob, QuadGKJL(), s), AutoBZCore.Basis(t*I(1)), p; abstol)
for alg in (MonkhorstPack(), AutoSymPTRJL(),)
cache = init(prob, alg)
for p in [3.0, 4.0]
@@ -104,7 +107,7 @@ using AutoBZCore: PuncturedInterval, HyperCube, segments, endpoints
# EvalCounter
@testset "evalcounter" for prob in (
IntegralProblem(FourierIntegralFunction((x, s, p) -> x * s + p, FourierSeries([1, 0, 1]/2; period=1.0, offset=-2)), (0.0, 1.0), 0.0; abstol=1e-3),
IntegralProblem(CommonSolveFourierIntegralFunction(IntegralProblem((x, (y, s, p)) -> x * s + p + y, (0.0, 1.0), (0.5, 0.0, 1.0)), QuadGKJL(), (cache, x, s, p) -> (cache.p = (x, s, p)), (sol, x, s, p) -> sol.value, FourierSeries([1, 0, 1]/2; period=1.0, offset=-2)), (0.0, 1.0), 1.0),
IntegralProblem(CommonSolveFourierIntegralFunction((solver, x, s, p) -> (solver.p = (x, s, p); solve!(solver).value), IntegralProblem((x, (y, s, p)) -> x * s + p + y, (0.0, 1.0), (0.5, 0.0, 1.0)), QuadGKJL(), FourierSeries([1, 0, 1]/2; period=1.0, offset=-2)), (0.0, 1.0), 1.0),
)
# constant integrand should always use the same number of evaluations as the
# base quadrature rule
@@ -117,54 +120,3 @@ using AutoBZCore: PuncturedInterval, HyperCube, segments, endpoints
end
end
end

#=
@testset "FourierIntegrand" begin
for dims in 1:3
s = FourierSeries(integer_lattice(dims), period=1)
# AutoBZ interface user function: f(x, args...; kwargs...) where args & kwargs
# stored in MixedParameters
# a FourierIntegrand should expect a FourierValue in the first argument
# a FourierIntegrand is just a wrapper around an integrand
f(x::FourierValue, a; b) = a*x.s*x.x .+ b
# IntegralSolver will accept args & kwargs for a FourierIntegrand
prob = IntegralProblem(FourierIntegrand(f, s, 1.3, b=4.2), zeros(dims), ones(dims))
u = IntegralSolver(prob, HCubatureJL())()
v = IntegralSolver(FourierIntegrand(f, s), zeros(dims), ones(dims), HCubatureJL())(1.3, b=4.2)
w = IntegralSolver(FourierIntegrand(f, s, b=4.2), zeros(dims), ones(dims), HCubatureJL())(1.3)
@test u == v == w
# tests for the nested integrand
nouter = 3
ws = FourierSeriesEvaluators.workspace_allocate(s, FourierSeriesEvaluators.period(s), ntuple(n -> n == dims ? nouter : 1,dims))
p = ParameterIntegrand(f, 1.3, b=4.2)
nest = NestedBatchIntegrand(ntuple(n -> deepcopy(p), nouter), SVector{dims,ComplexF64})
for (alg, dom) in (
(HCubatureJL(), HyperCube(zeros(dims), ones(dims))),
(NestedQuad(AuxQuadGKJL()), CubicLimits(zeros(dims), ones(dims))),
(MonkhorstPack(), Basis(one(SMatrix{dims,dims}))),
)
prob1 = IntegralProblem(FourierIntegrand(p, s), dom)
prob2 = IntegralProblem(FourierIntegrand(p, ws, nest), dom)
@test solve(prob1, alg).u ≈ solve(prob2, alg).u
end
end
end
@testset "algorithms" begin
f(x::FourierValue, a; b) = a*x.s+b
for dims in 1:3
vol = (2pi)^dims
A = I(dims)
s = FourierSeries(integer_lattice(dims), period=1)
for bz in (load_bz(FBZ(), A), load_bz(InversionSymIBZ(), A))
integrand = FourierIntegrand(f, s, 1.3, b=1.0)
prob = IntegralProblem(integrand, bz)
for alg in (IAI(), PTR(), AutoPTR(), TAI()), counter in (false, true)
new_alg = counter ? EvalCounter(alg) : alg
solver = IntegralSolver(prob, new_alg, reltol=0, abstol=1e-6)
@test solver() ≈ vol atol=1e-6
end
end
end
end
=#