Commit eb1e915f authored by Alberto Ramos's avatar Alberto Ramos

Working version of hyperd-hessian with ForwardDiff backend

parent 0f915039
......@@ -16,6 +16,9 @@ import ForwardDiff, Statistics, FFTW, LinearAlgebra, QuadGK, BDIO, Printf
# Include data types
include("ADerrorsTypes.jl")
# hyperd for hessian
include("ADerrorsHyperd.jl")
# Include computation of autoCF
include("ADerrorsCF.jl")
......
......@@ -9,99 +9,91 @@
### created: Mon Jul 6 19:58:53 2020
###
const LG2 = 0.6931471805599453094172321214581765680755001343602552
const LG10 = 2.3025850929940456840179914546843642076011014886287729
function Base.sqrt(h::hyperd)
v = sqrt(h.v)
d1 = 0.5/v
dd = -0.25/v^3
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.log(h::hyperd)
v = log(h.v)
d1 = 1.0/h.v
dd = -d1/h.v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.log2(h::hyperd)
v = log2(h.v)
d1 = 1.0/(h.v*LG2)
dd = -d1/h.v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.log10(h::hyperd)
v = log2(h.v)
d1 = 1.0/(h.v*LG10)
dd = -d1/h.v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.exp(h::hyperd)
v = exp(h.v)
d1 = v
dd = v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.exp2(h::hyperd)
v = exp2(h.v)
d1 = v * LG2
dd = d1 * LG2
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.exp10(h::hyperd)
v = exp10(h.v)
d1 = v * LG10
dd = d1 * LG10
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.sin(h::hyperd)
v = sin(h.v)
d1 = cos(h.v)
dd = -v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.cos(h::hyperd)
v = cos(h.v)
d1 = -sin(h.v)
dd = -v
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.tan(h::hyperd)
v = tan(h.v)
d1 = 1.0 + v*v
dd = 2.0 * v*d1
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
end
function Base.sec(h::hyperd)
v = sec(h.v)
d1 = v*tan(h.v)
dd = v*(tan(h.v)^2 + v^2)
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
for op in (:sin, :cos, :tan, :log, :exp, :sqrt, :sind, :cosd, :tand, :sinpi, :cospi, :sinh, :cosh, :tanh, :asin, :acos, :atan, :asind, :acosd, :atand, :sec, :csc, :cot, :secd, :cscd, :cotd, :asec, :acsc, :acot, :asecd, :acscd, :acotd, :sech, :csch, :coth, :asinh, :acosh, :atanh, :asech, :acsch, :acoth, :sinc, :cosc, :deg2rad, :rad2deg, :log2, :log10, :log1p, :exp2, :exp10, :expm1, :-)
@eval function Base.$op(h::hyperd)
fvec(x::Vector) = Base.$op(x[1])
v = Base.$op(h.v)
d1 = ForwardDiff.derivative($op, h.v)
v2 = ForwardDiff.hessian(fvec, [h.v])
return hyperd(v, d1*h.d1, d1*h.d2, v2[1]*h.d1*h.d2 + d1*h.dd)
end
end
function Base.csc(h::hyperd)
v = csc(h.v)
d1 = -v * cot(h.v)
dd = v*(cot(h.v)^2 + v^2)
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
Base.:+(h1::hyperd, h2::hyperd) = hyperd(h1.v+h2.v, h1.d1+h2.d1, h1.d2+h2.d2, h1.dd+h2.dd)
Base.:+(h1::hyperd, h2::Number) = hyperd(h1.v+h2, h1.d1, h1.d2, h1.dd)
Base.:+(h1::Number, h2::hyperd) = hyperd(h1+h2.v, h2.d1, h2.d2, h2.dd)
Base.:-(h1::hyperd, h2::hyperd) = hyperd(h1.v-h2.v, h1.d1-h2.d1, h1.d2-h2.d2, h1.dd-h2.dd)
Base.:-(h1::hyperd, h2::Number) = hyperd(h1.v-h2, h1.d1, h1.d2, h1.dd)
Base.:-(h1::Number, h2::hyperd) = hyperd(h1-h2.v, h2.d1, h2.d2, h2.dd)
Base.:*(h1::hyperd, h2::hyperd) = hyperd(h1.v*h2.v,
h1.v*h2.d1+h1.d1*h2.v,
h1.v*h2.d2+h1.d2*h2.v,
h1.v*h2.dd+h1.dd*h2.v+h1.d2*h2.d1+h1.d1*h2.d2)
Base.:*(h1::hyperd, h2::Number) = hyperd(h1.v*h2, h2*h1.d1, h2*h1.d2, h2*h1.dd)
Base.:*(h1::Number, h2::hyperd) = hyperd(h1*h2.v, h1*h2.d1, h1*h2.d2, h1*h2.dd)
Base.:/(h1::hyperd, h2::hyperd) = hyperd(h1.v/h2.v,
h1.d1/h2.v - h2.d1*h1.v/h2.v^2,
h1.d2/h2.v - h2.d2*h1.v/h2.v^2,
h1.dd/h2.v - h2.d1*h1.d2/h2.v^2 -
h2.d2*h1.d1/h2.v^2 +
h1.v*(2.0*h2.d1*h2.d2/h2.v - h2.dd)/h2.v^2)
Base.:/(h1::hyperd, h2::Number) = hyperd(h1.v/h2, h1.d1/h2, h1.d2/h2, h1.dd/h2)
Base.:/(h1::Number, h2::hyperd) = hyperd(h1/h2.v, - h2.d1*h1.v/h2.v^2, - h2.d2*h1.v/h2.v^2,
h1*(2.0*h2.d1*h2.d2/h2.v - h2.dd)/h2.v^2)
Base.:^(h1::hyperd, h2::hyperd) = exp(h2*log(h1))
Base.:^(h1::hyperd, h2::Number) = exp(h2*log(h1))
function Base.:^(h1::hyperd, n::Integer)
v = h1.v^n
if (n == 0)
d1 = 0.0
dd = 0.0
elseif (n == 1)
d1 = 1.0
dd = 0.0
else
d1 = n*h1.v^(n-1)
dd = n*(n-1)*h1.v^(n-2)
end
return hyperd(v, d1*h1.d1, d1*h1.d2, dd*h1.d1*h1.d2 + d1*h1.dd)
end
function Base.cot(h::hyperd)
v = cot(h.v)
d1 = -(1 + v^2)
dd = -2.0*v*d1
return hyperd(v, d1*h.d1, d1*h.d2, dd*h.d1*h.d2 + d1*h.dd)
Base.:^(h1::Number, h2::hyperd) = exp(h2*log(h1))
# Missing atan, hypot
Base.zero(::Type{hyperd}) = hyperd(0.0, 0.0, 0.0, 0.0)
Base.one(::Type{hyperd}) = hyperd(1.0, 0.0, 0.0, 0.0)
Base.length(x::hyperd) = 1
Base.iterate(x::hyperd) = (x, nothing)
Base.iterate(x::hyperd, ::Nothing) = nothing
function hyperd_hessian!(hess::Array{Float64, 2}, f::Function, x::Vector{Float64})
n = length(x)
h = Vector{hyperd}(undef, n)
for i in 1:n
h[i] = hyperd(x[i], 0.0, 0.0, 0.0)
end
for i in 1:n
h[i].d1 = 1.0
for j in i:n
h[j].d2 = 1.0
res = f(h)
hess[i,j] = res.dd
if (j > i)
hess[j,i] = hess[i,j]
end
h[j].d2 = 0.0
end
h[i].d1 = 0.0
end
return nothing
end
......@@ -17,7 +17,6 @@ for op in (:sin, :cos, :tan, :log, :exp, :sqrt, :sind, :cosd, :tand, :sinpi, :co
end
end
for op in (:+, :-, :*, :/, :^, :atan, :hypot)
@eval function Base.$op(a::uwreal, b::uwreal)
......
......@@ -67,7 +67,6 @@ function chiexp(chisq::Function,
n = length(xp) # Number of fit parameters
m = length(data) # Number of data
ccsq(x::Vector) = chisq(x[1:n], x[n+1:end])
xav = zeros(Float64, n+m)
for i in 1:n
......@@ -76,7 +75,10 @@ function chiexp(chisq::Function,
for i in n+1:n+m
xav[i] = data[i-n].mean
end
hess = ForwardDiff.hessian(ccsq, xav)
ccsq(x::Vector) = chisq(view(x, 1:n), view(x, n+1:n+m))
hess = Array{Float64}(undef, n+m, n+m)
# @time ForwardDiff.hessian!(hess, ccsq, xav)
hyperd_hessian!(hess, ccsq, xav)
cse = 0.0
if (m-n > 0)
......@@ -117,14 +119,10 @@ function fit_error(chisq::Function,
xav[i] = data[i-n].mean
end
function cls(x0)
x1 = view(x0, 1:n)
x2 = view(x0, n+1:n+m)
return chisq(x1, x2)
end
ccsq(x::Vector) = chisq(view(x, 1:n), view(x, n+1:n+m))
hess = Array{Float64}(undef, n+m, n+m)
ForwardDiff.hessian!(hess, cls, xav)
# @time ForwardDiff.hessian!(hess, ccsq, xav)
hyperd_hessian!(hess, ccsq, xav)
hinv = LinearAlgebra.pinv(hess[1:n,1:n])
grad = - hinv[1:n,1:n] * hess[1:n,n+1:n+m]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment