Skip to content

Commit

Permalink
Merge pull request #764 from SciML/fgfghoptf
Browse files Browse the repository at this point in the history
Add `fg` and `fgh` fields to `OptimizationFunction`
  • Loading branch information
Vaibhavdixit02 authored Aug 14, 2024
2 parents a14ebec + 99546fc commit eb3758c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.48.1"
version = "2.49.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,6 @@ Returns a copy of the interpolation stripped of its function, to accommodate ser
If the interpolation object has no function, returns the interpolation object as is.
"""
strip_interpolation(id::AbstractDiffEqInterpolation) = id
strip_interpolation(id::HermiteInterpolation) = id
strip_interpolation(id::HermiteInterpolation) = id
strip_interpolation(id::LinearInterpolation) = id
strip_interpolation(id::ConstantInterpolation) = id
8 changes: 6 additions & 2 deletions src/problems/optimization_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ struct OptimizationProblem{iip, F, uType, P, LB, UB, I, LC, UC, S, K} <:
ucons::UC
sense::S
kwargs::K
@add_kwonly function OptimizationProblem{iip}(f::Union{OptimizationFunction{iip}, MultiObjectiveOptimizationFunction{iip}}, u0,
@add_kwonly function OptimizationProblem{iip}(
f::Union{OptimizationFunction{iip}, MultiObjectiveOptimizationFunction{iip}},
u0,
p = NullParameters();
lb = nothing, ub = nothing, int = nothing,
lcons = nothing, ucons = nothing,
Expand All @@ -119,7 +121,9 @@ struct OptimizationProblem{iip, F, uType, P, LB, UB, I, LC, UC, S, K} <:
end
end

function OptimizationProblem(f::Union{OptimizationFunction, MultiObjectiveOptimizationFunction}, args...; kwargs...)
function OptimizationProblem(
f::Union{OptimizationFunction, MultiObjectiveOptimizationFunction},
args...; kwargs...)
OptimizationProblem{isinplace(f)}(f, args...; kwargs...)
end
function OptimizationProblem(f, args...; kwargs...)
Expand Down
28 changes: 18 additions & 10 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1897,13 +1897,16 @@ For more details on this argument, see the ODEFunction documentation.
The fields of the OptimizationFunction type directly match the names of the inputs.
"""
struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
struct OptimizationFunction{
iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV} <:
AbstractOptimizationFunction{iip}
f::F
adtype::AD
grad::G
fg::FG
hess::H
fgh::FGH
hv::HV
cons::C
cons_j::CJ
Expand All @@ -1929,13 +1932,14 @@ end
$(TYPEDEF)
"""

struct MultiObjectiveOptimizationFunction{iip, AD, F, J, H, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
struct MultiObjectiveOptimizationFunction{
iip, AD, F, J, H, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV} <:
AbstractOptimizationFunction{iip}
f::F
adtype::AD
jac::J # Replacing grad with jac for the Jacobian
hess::Vector{H} # Hess will be a vector of type H
jac::J
hess::H
hv::HV
cons::C
cons_j::CJ
Expand Down Expand Up @@ -3809,7 +3813,7 @@ struct NoAD <: AbstractADType end
OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; kwargs...)

function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
grad = nothing, hess = nothing, hv = nothing,
grad = nothing, fg = nothing, hess = nothing, hv = nothing, fgh = nothing,
cons = nothing, cons_j = nothing, cons_jvp = nothing,
cons_vjp = nothing, cons_h = nothing,
hess_prototype = nothing,
Expand All @@ -3831,8 +3835,9 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
lag_hess_colorvec = nothing) where {iip}
isinplace(f, 2; has_two_dispatches = false, isoptimization = true)
sys = sys_or_symbolcache(sys, syms, paramsyms)
OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess),
typeof(hv),
OptimizationFunction{
iip, typeof(adtype), typeof(f), typeof(grad), typeof(fg), typeof(hess),
typeof(fgh), typeof(hv),
typeof(cons), typeof(cons_j), typeof(cons_jvp),
typeof(cons_vjp), typeof(cons_h),
typeof(hess_prototype),
Expand All @@ -3842,7 +3847,7 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
typeof(lag_hess_prototype), typeof(hess_colorvec),
typeof(cons_jac_colorvec), typeof(cons_hess_colorvec),
typeof(lag_hess_colorvec)
}(f, adtype, grad, hess,
}(f, adtype, grad, fg, hess, fgh,
hv, cons, cons_j, cons_jvp,
cons_vjp, cons_h,
hess_prototype, cons_jac_prototype,
Expand All @@ -3855,7 +3860,9 @@ end
(f::MultiObjectiveOptimizationFunction)(args...) = f.f(args...)

# Convenience constructor
MultiObjectiveOptimizationFunction(args...; kwargs...) = MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
function MultiObjectiveOptimizationFunction(args...; kwargs...)
MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
end

# Constructor with keyword arguments
function MultiObjectiveOptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
Expand All @@ -3881,7 +3888,8 @@ function MultiObjectiveOptimizationFunction{iip}(f, adtype::AbstractADType = NoA
lag_hess_colorvec = nothing) where {iip}
isinplace(f, 2; has_two_dispatches = false, isoptimization = true)
sys = sys_or_symbolcache(sys, syms, paramsyms)
MultiObjectiveOptimizationFunction{iip, typeof(adtype), typeof(f), typeof(jac), typeof(hess),
MultiObjectiveOptimizationFunction{
iip, typeof(adtype), typeof(f), typeof(jac), typeof(hess),
typeof(hv),
typeof(cons), typeof(cons_j), typeof(cons_jvp),
typeof(cons_vjp), typeof(cons_h),
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,3 @@ the arity of a function is computed with `numargs`
See also: `prepare_initial_state`.
"""
prepare_function(f) = f

0 comments on commit eb3758c

Please sign in to comment.