Commit eb1e915f authored by Alberto Ramos's avatar Alberto Ramos

Working version of hyperd-hessian with ForwardDiff backend

parent 0f915039
...@@ -16,6 +16,9 @@ import ForwardDiff, Statistics, FFTW, LinearAlgebra, QuadGK, BDIO, Printf ...@@ -16,6 +16,9 @@ import ForwardDiff, Statistics, FFTW, LinearAlgebra, QuadGK, BDIO, Printf
# Include data types # Include data types
include("ADerrorsTypes.jl") include("ADerrorsTypes.jl")
# hyperd for hessian
include("ADerrorsHyperd.jl")
# Include computation of autoCF # Include computation of autoCF
include("ADerrorsCF.jl") include("ADerrorsCF.jl")
......
...@@ -9,99 +9,91 @@ ...@@ -9,99 +9,91 @@
### created: Mon Jul 6 19:58:53 2020 ### created: Mon Jul 6 19:58:53 2020
### ###
const LG2 = 0.6931471805599453094172321214581765680755001343602552 for op in (:sin, :cos, :tan, :log, :exp, :sqrt, :sind, :cosd, :tand, :sinpi, :cospi, :sinh, :cosh, :tanh, :asin, :acos, :atan, :asind, :acosd, :atand, :sec, :csc, :cot, :secd, :cscd, :cotd, :asec, :acsc, :acot, :asecd, :acscd, :acotd, :sech, :csch, :coth, :asinh, :acosh, :atanh, :asech, :acsch, :acoth, :sinc, :cosc, :deg2rad, :rad2deg, :log2, :log10, :log1p, :exp2, :exp10, :expm1, :-)
const LG10 = 2.3025850929940456840179914546843642076011014886287729 @eval function Base.$op(h::hyperd)
fvec(x::Vector) = Base.$op(x[1])
function Base.sqrt(h::hyperd) v = Base.$op(h.v)
v = sqrt(h.v) d1 = ForwardDiff.derivative($op, h.v)
d1 = 0.5/v v2 = ForwardDiff.hessian(fvec, [h.v])
dd = -0.25/v^3 return hyperd(v, d1*h.d1, d1*h.d2, v2[1]*h.d1*h.d2 + d1*h.dd)
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) end
end end
function Base.log(h::hyperd) Base.:+(h1::hyperd, h2::hyperd) = hyperd(h1.v+h2.v, h1.d1+h2.d1, h1.d2+h2.d2, h1.dd+h2.dd)
v = log(h.v) Base.:+(h1::hyperd, h2::Number) = hyperd(h1.v+h2, h1.d1, h1.d2, h1.dd)
d1 = 1.0/h.v Base.:+(h1::Number, h2::hyperd) = hyperd(h1+h2.v, h2.d1, h2.d2, h2.dd)
dd = -d1/h.v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) Base.:-(h1::hyperd, h2::hyperd) = hyperd(h1.v-h2.v, h1.d1-h2.d1, h1.d2-h2.d2, h1.dd-h2.dd)
end Base.:-(h1::hyperd, h2::Number) = hyperd(h1.v-h2, h1.d1, h1.d2, h1.dd)
Base.:-(h1::Number, h2::hyperd) = hyperd(h1-h2.v, h2.d1, h2.d2, h2.dd)
function Base.log2(h::hyperd)
v = log2(h.v) Base.:*(h1::hyperd, h2::hyperd) = hyperd(h1.v*h2.v,
d1 = 1.0/(h.v*LG2) h1.v*h2.d1+h1.d1*h2.v,
dd = -d1/h.v h1.v*h2.d2+h1.d2*h2.v,
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) h1.v*h2.dd+h1.dd*h2.v+h1.d2*h2.d1+h1.d1*h2.d2)
end Base.:*(h1::hyperd, h2::Number) = hyperd(h1.v*h2, h2*h1.d1, h2*h1.d2, h2*h1.dd)
Base.:*(h1::Number, h2::hyperd) = hyperd(h1*h2.v, h1*h2.d1, h1*h2.d2, h1*h2.dd)
function Base.log10(h::hyperd)
v = log2(h.v) Base.:/(h1::hyperd, h2::hyperd) = hyperd(h1.v/h2.v,
d1 = 1.0/(h.v*LG10) h1.d1/h2.v - h2.d1*h1.v/h2.v^2,
dd = -d1/h.v h1.d2/h2.v - h2.d2*h1.v/h2.v^2,
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) h1.dd/h2.v - h2.d1*h1.d2/h2.v^2 -
h2.d2*h1.d1/h2.v^2 +
h1.v*(2.0*h2.d1*h2.d2/h2.v - h2.dd)/h2.v^2)
Base.:/(h1::hyperd, h2::Number) = hyperd(h1.v/h2, h1.d1/h2, h1.d2/h2, h1.dd/h2)
Base.:/(h1::Number, h2::hyperd) = hyperd(h1/h2.v, - h2.d1*h1.v/h2.v^2, - h2.d2*h1.v/h2.v^2,
h1*(2.0*h2.d1*h2.d2/h2.v - h2.dd)/h2.v^2)
Base.:^(h1::hyperd, h2::hyperd) = exp(h2*log(h1))
Base.:^(h1::hyperd, h2::Number) = exp(h2*log(h1))
function Base.:^(h1::hyperd, n::Integer)
v = h1.v^n
if (n == 0)
d1 = 0.0
dd = 0.0
elseif (n == 1)
d1 = 1.0
dd = 0.0
else
d1 = n*h1.v^(n-1)
dd = n*(n-1)*h1.v^(n-2)
end
return hyperd(v, d1*h1.d1, d1*h1.d2, dd*h1.d1*h1.d2 + d1*h1.dd)
end end
function Base.exp(h::hyperd) Base.:^(h1::Number, h2::hyperd) = exp(h2*log(h1))
v = exp(h.v)
d1 = v # Missing atan, hypot
dd = v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end Base.zero(::Type{hyperd}) = hyperd(0.0, 0.0, 0.0, 0.0)
Base.one(::Type{hyperd}) = hyperd(1.0, 0.0, 0.0, 0.0)
function Base.exp2(h::hyperd) Base.length(x::hyperd) = 1
v = exp2(h.v) Base.iterate(x::hyperd) = (x, nothing)
d1 = v * LG2 Base.iterate(x::hyperd, ::Nothing) = nothing
dd = d1 * LG2
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) function hyperd_hessian!(hess::Array{Float64, 2}, f::Function, x::Vector{Float64})
end
n = length(x)
function Base.exp10(h::hyperd) h = Vector{hyperd}(undef, n)
v = exp10(h.v) for i in 1:n
d1 = v * LG10 h[i] = hyperd(x[i], 0.0, 0.0, 0.0)
dd = d1 * LG10 end
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end for i in 1:n
h[i].d1 = 1.0
function Base.sin(h::hyperd) for j in i:n
v = sin(h.v) h[j].d2 = 1.0
d1 = cos(h.v) res = f(h)
dd = -v hess[i,j] = res.dd
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd) if (j > i)
hess[j,i] = hess[i,j]
end
h[j].d2 = 0.0
end
h[i].d1 = 0.0
end
return nothing
end end
function Base.cos(h::hyperd)
v = cos(h.v)
d1 = -sin(h.v)
dd = -v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.tan(h::hyperd)
v = tan(h.v)
d1 = 1.0 + v*v
dd = 2.0 * v*d1
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.sec(h::hyperd)
v = sec(h.v)
d1 = v*tan(h.v)
dd = v*(tan(h.v)^2 + v^2)
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.csc(h::hyperd)
v = csc(h.v)
d1 = -v * cot(h.v)
dd = v*(cot(h.v)^2 + v^2)
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.cot(h::hyperd)
v = cot(h.v)
d1 = -(1 + v^2)
dd = -2.0*v*d1
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
...@@ -17,7 +17,6 @@ for op in (:sin, :cos, :tan, :log, :exp, :sqrt, :sind, :cosd, :tand, :sinpi, :co ...@@ -17,7 +17,6 @@ for op in (:sin, :cos, :tan, :log, :exp, :sqrt, :sind, :cosd, :tand, :sinpi, :co
end end
end end
for op in (:+, :-, :*, :/, :^, :atan, :hypot) for op in (:+, :-, :*, :/, :^, :atan, :hypot)
@eval function Base.$op(a::uwreal, b::uwreal) @eval function Base.$op(a::uwreal, b::uwreal)
......
...@@ -67,7 +67,6 @@ function chiexp(chisq::Function, ...@@ -67,7 +67,6 @@ function chiexp(chisq::Function,
n = length(xp) # Number of fit parameters n = length(xp) # Number of fit parameters
m = length(data) # Number of data m = length(data) # Number of data
ccsq(x::Vector) = chisq(x[1:n], x[n+1:end])
xav = zeros(Float64, n+m) xav = zeros(Float64, n+m)
for i in 1:n for i in 1:n
...@@ -76,7 +75,10 @@ function chiexp(chisq::Function, ...@@ -76,7 +75,10 @@ function chiexp(chisq::Function,
for i in n+1:n+m for i in n+1:n+m
xav[i] = data[i-n].mean xav[i] = data[i-n].mean
end end
hess = ForwardDiff.hessian(ccsq, xav) ccsq(x::Vector) = chisq(view(x, 1:n), view(x, n+1:n+m))
hess = Array{Float64}(undef, n+m, n+m)
# @time ForwardDiff.hessian!(hess, ccsq, xav)
hyperd_hessian!(hess, ccsq, xav)
cse = 0.0 cse = 0.0
if (m-n > 0) if (m-n > 0)
...@@ -117,14 +119,10 @@ function fit_error(chisq::Function, ...@@ -117,14 +119,10 @@ function fit_error(chisq::Function,
xav[i] = data[i-n].mean xav[i] = data[i-n].mean
end end
function cls(x0) ccsq(x::Vector) = chisq(view(x, 1:n), view(x, n+1:n+m))
x1 = view(x0, 1:n)
x2 = view(x0, n+1:n+m)
return chisq(x1, x2)
end
hess = Array{Float64}(undef, n+m, n+m) hess = Array{Float64}(undef, n+m, n+m)
ForwardDiff.hessian!(hess, cls, xav) # @time ForwardDiff.hessian!(hess, ccsq, xav)
hyperd_hessian!(hess, ccsq, xav)
hinv = LinearAlgebra.pinv(hess[1:n,1:n]) hinv = LinearAlgebra.pinv(hess[1:n,1:n])
grad = - hinv[1:n,1:n] * hess[1:n,n+1:n+m] grad = - hinv[1:n,1:n] * hess[1:n,n+1:n+m]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment