#TODO: apply_rw with gaps
function apply_rw(data::Array{Float64}, W::Matrix{Float64})
    nc = size(data, 1)
    W1 = W[1, 1:nc]
    W2 = W[2, 1:nc]

    data_r = data .* W1 .* W2 / mean(W1 .* W2)
    return data_r
end

function apply_rw(data::Vector{<:Array{Float64}}, W::Vector{Matrix{Float64}})
    if length(W) != length(data)
        error("Lenghts must match")
    end
    nc = size.(data, 1)

    rw1 = [W[k][1, 1:nc[k]] for k=1:length(W)]
    rw2 = [W[k][2, 1:nc[k]] for k=1:length(W)]
    rw1_cat = rw1[1]
    rw2_cat = rw2[1]

    for k = 2:length(W)
        rw1_cat = cat(rw1_cat, rw1[k], dims=1)
        rw2_cat = cat(rw2_cat, rw2[k], dims=1)
    end
    
    rw_mean = mean(rw1_cat .* rw2_cat)
    data_r = [data[k] .* rw1[k].* rw2[k] / rw_mean for k=1:length(data)]
    return data_r
end

function check_corr_der(obs::Corr, derm::Vector{Corr})   
    g1 = Vector{String}(undef, 0)
    g2 = Vector{String}(undef, 0)

    for d in derm
        aux = [d.gamma[1], d.gamma[2]]
        push!(g1, aux[1][1:end-3])
        push!(g2, aux[2][1:end-3])
    end

    h = copy(derm)
    push!(h, obs)
    
    
    if any(getfield.(h, :y0) .!= getfield(h[1], :y0))
        return false
    end
    for s in [:kappa, :mu]
        for k = 1:2
            if any(getindex.(getfield.(h, s), k) .!= getindex(getfield(h[1], s), k))
                return false
            end
        end
    end
    #gamma check
    if any(g1 .!= obs.gamma[1]) || any(g2 .!= obs.gamma[2])
        return false
    end

    return true
end

@doc raw"""
    corr_obs(cdata::CData; real::Bool=true, rw::Union{Array{Float64, 2}, Nothing}=nothing, L::Int64=1)

    corr_obs(cdata::Array{CData, 1}; real::Bool=true, rw::Union{Array{Array{Float64, 2}, 1}, Nothing}=nothing, L::Int64=1)

Creates a `Corr` struct with the given `CData` struct `cdata` (`read_mesons`) for a single replica.
An array of `CData` can be passed as argument for multiple replicas.

The flag `real` select the real or imaginary part of the correlator.
If `rw` is specified, the method applies reweighting. `rw` is passed as a matrix of Float64 (`read_ms1`)
The correlator can be normalized with the volume factor if `L` is fixed.

```@example
#Single replica
data = read_mesons(path, "G5", "G5")
rw = read_ms1(path_rw)
corr_pp = corr_obs.(data)
corr_pp_r = corr_obs.(data, rw=rw)

#Two replicas
data = read_mesons([path_r1, path_r2], "G5", "G5")
rw1 = read_ms1(path_rw1)
rw2 = read_ms1(path_rw2)

corr_pp = corr_obs.(data)
corr_pp_r = corr_obs.(data, rw=[rw1, rw2])
```
"""
function corr_obs(cdata::CData; real::Bool=true, rw::Union{Array{Float64, 2}, Nothing}=nothing, L::Int64=1)

    real ? data = cdata.re_data ./ L^3 : data = cdata.im_data ./ L^3
    data_r = isnothing(rw) ? data : apply_rw(data, rw)

    nt = size(data)[2]
    obs = Vector{uwreal}(undef, nt)
    [obs[x0] = uwreal(data_r[:, x0], cdata.id) for x0 = 1:nt]
    return Corr(obs, cdata)
end

#function corr_obs for R != 1
#TODO: vcfg with gaps
function corr_obs(cdata::Array{CData, 1}; real::Bool=true, rw::Union{Array{Array{Float64, 2}, 1}, Nothing}=nothing, L::Int64=1)
    nr = length(cdata)
    id = getfield.(cdata, :id)
    vcfg = getfield.(cdata, :vcfg)
    replica = Int64.(maximum.(vcfg))
    
    if !all(id .== id[1])
        error("IDs are not equal")
    end

    real ? data = getfield.(cdata, :re_data) ./ L^3 : data = getfield.(cdata, :im_data) ./ L^3
    data_r = isnothing(rw) ? data : apply_rw(data, rw)

    tmp = data_r[1]
    [tmp = cat(tmp, data_r[k], dims=1) for k = 2:nr]

    nt = size(data[1])[2]
    obs = Vector{uwreal}(undef, nt)

    [obs[x0] = uwreal(tmp[:, x0], id[1], replica) for x0 = 1:nt]

    return Corr(obs, cdata)
end    
@doc raw"""
    corr_sym(corrL::Corr, corrR::Corr, parity::Int64=1)

Computes the symmetrized correlator using the left correlador `corrL` and the right correlator `corrR`. The source position
of `corrR` must be `T - 1 - y0`, where `y0` is the source position of `corrL`. 

```@example
pp_sym = corr_sym(ppL, ppR, +1)
a0p_sym = corr_sym(a0pL, a0pR, -1)
```
"""
function corr_sym(corrL::Corr, corrR::Corr, parity::Int64=1)
    T = length(corrL.obs)
    sym = [:kappa, :mu, :gamma]
    if corrL.y0 != T - 1 - corrR.y0
        error("Corr: Parameter mismatch")
    end
    for s in sym
        if getfield(corrL, s) != getfield(corrR, s)
            error("Corr: Parameter mismatch")
        end
    end
    if abs(parity) != 1
        error("incorrect value of parity (+- 1)")
    end

    res = (corrL.obs[1:end] + parity * corrR.obs[end:-1:1]) / 2
    return Corr(res, corrL.kappa, corrL.mu, corrL.gamma, corrL.y0)
end
#TODO: VECTORIZE, uwreal?
@doc raw"""
    md_sea(a::uwreal, md::Vector{Matrix{Float64}}, ws::ADerrors.wspace=ADerrors.wsg)

Computes the derivative of an observable A with respect to the sea quark masses.

``\frac{d <A>}{dm(sea)} = \sum_i \frac{\partial <A>}{\partial <O_i>}  \frac{d <O_i>}{d m(sea)}``


``\frac{d <O_i>}{dm(sea)} = <O_i> <\frac{\partial S}{\partial m}> - <O_i \frac{\partial S}{\partial m}> 
= - <(O_i - <O_i>) (\frac{\partial S}{\partial m} - <\frac{\partial S}{\partial m}>)>``

where ``O_i`` are primary observables 

`md` is a vector that contains the derivative of the action ``S`` with respect
to the sea quark masses for each replica. `md[irep][irw, icfg]`

`md_sea` returns a tuple of uwreal observables ``(dA/dm_l, dA/dm_s)|_{sea}``, 
where ``m_l`` and ``m_s`` are the light and strange quark masses.

```@example
#Single replica
data = read_mesons(path, "G5", "G5")
md = read_md(path_md)
rw = read_ms1(path_rw)

corr_pp = corr_obs.(data, rw=rw)
m = meff(corr_pp[1], plat)
m_mdl, m_mds = md_sea(m, [md], ADerrors.wsg)
m_shifted = m + 2 * dml * m_mdl + dms * m_mds

#Two replicas
data = read_mesons([path_r1, path_r2], "G5", "G5")
md1 = read_md(path_md1)
md2 = read_md(path_md2)

corr_pp = corr_obs.(data)
m = meff(corr_pp[1], plat)
m_mdl, m_mds = md_sea(m, [md1, md2], ADerrors.wsg)
m_shifted = m + 2 * dml * m_mdl + dms * m_mds
```
"""
function md_sea(a::uwreal, md::Vector{Matrix{Float64}}, ws::ADerrors.wspace=ADerrors.wsg)
    nid = neid(a)
    p = findall(t-> t==1, a.prop)

    if nid != 1
        error("Error: neid > 1")
    end

    id = ws.map_nob[p]
    if !all(id .== id[1])
        error("ids do not match")
    end
    id = ws.id2str[id[1]]
    
    ivrep = getfield.(ws.fluc[p], :ivrep)
    ivrep1 = fill(ivrep[1], length(ivrep))
    if !all(ivrep .== ivrep1)
        error("ivreps do not match")
    end
    ivrep = ivrep[1]

    if length(md) != length(ivrep)
        error("Nr obs != Nr md")
    end

    #md_aux as a Matrix + Automatic truncation
    md_aux = md[1][:, 1:ivrep[1]]
    for k = 2:length(md)
        md_aux = cat(md_aux, md[k][:, 1:ivrep[k]], dims=2)
    end

    fluc_obs = getfield.(ws.fluc[p], :delta)
    fluc_md = md_aux .- mean(md_aux, dims=2)
    uwerr(a)
    fluc_obs = mchist(a, id)

    nrw = size(fluc_md, 1)
    if nrw == 1
        der1 = uwreal(-fluc_md[1, :] .* fluc_obs, id, ivrep)
        return (der1, der1)
    elseif nrw == 2
        der1 = uwreal(-fluc_md[1, :] .* fluc_obs, id, ivrep)
        der2 = uwreal(-fluc_md[2, :] .* fluc_obs, id, ivrep)
        return (der1, der2)
    else
        return nothing
    end
    
end

@doc raw"""
    md_val(a::uwreal, obs::Corr, derm::Vector{Corr})

Computes the derivative of an observable A with respect to the valence quark masses.

``\frac{d <A>}{dm(val)} = \sum_i \frac{\partial <A>}{\partial <O_i>}  \frac{d <O_i>}{d m(val)}``

``\frac{d <O_i>}{dm(val)} = <\frac{\partial O_i}{\partial m(val)}>``

where ``O_i`` are primary observables 

`md` is a vector that contains the derivative of the action ``S`` with respect
to the sea quark masses for each replica. `md[irep][irw, icfg]`

`md_val` returns a tuple of `uwreal` observables ``(dA/dm_1, dA/dm_2)|_{val}``, 
where ``m_1`` and ``m_2`` are the correlator masses.

```@example
data = read_mesons(path, "G5", "G5", legacy=true)
data_d1 = read_mesons(path, "G5_d1", "G5_d1", legacy=true)
data_d2 = read_mesons(path, "G5_d2", "G5_d2", legacy=true)

rw = read_ms1(path_rw)

corr_pp = corr_obs.(data, rw=rw)
corr_pp_d1 = corr_obs.(data_d1, rw=rw)
corr_pp_d2 = corr_obs.(data_d2, rw=rw)
derm = [[corr_pp_d1[k], corr_pp_d2[k]] for k = 1:length(pp_d1)]

m = meff(corr_pp[1], plat)
m_md1, m_md2 = md_val(m, corr_pp[1], derm[1])
m_shifted = m + 2 * dm1 * m_md1 + dm2 * m_md2
```
"""
function md_val(a::uwreal, obs::Corr, derm::Vector{Corr})
    nid = neid(a)
    if nid != 1
        error("Error: neid > 1")
    end
    if length(derm) != 2
        error("Error: length derm != 2")
    end
    if !check_corr_der(obs, derm)
        error("Corr parameters does not match")
    end

    corr = getfield(obs, :obs)

    der = [derivative(a, corr[k]) for k = 1:length(corr)]
    derm1, derm2 = derm
    return (sum(der .* derm1.obs), sum(der .* derm2.obs))
end

function plat_av(obs::Vector{uwreal}, plat::Vector{Int64}, wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing)
    isnothing(wpm) ? uwerr.(obs) : [uwerr(obs_aux, wpm) for obs_aux in obs]
    w = 1 ./ err.(obs)[plat[1]:plat[2]].^2
    av = sum(w .* obs[plat[1]:plat[2]]) / sum(w)
    return av 
end

function lin_fit(x::Vector{<:Real}, v::Vector{Float64}, e::Vector{Float64})
    sig2 = e .* e
    S = sum(1 ./ sig2)
    Sx = sum(x ./ sig2)
    Sy = sum(v ./ sig2)
    Sxy = sum(v .* x ./ sig2)
    Sxx = sum(x .* x ./sig2)
    delta = S * Sxx - Sx*Sx
    par = [Sxx*Sy-Sx*Sxy, S*Sxy-Sx*Sy] ./delta
    #C = [[Sxx/delta, -Sx/delta], [-Sx/delta,  S/delta]]
    return par
end
@doc raw"""
    lin_fit(x::Vector{<:Real}, y::Vector{uwreal})

Computes a linear fit of uwreal data points y. This method return uwreal fit parameters and chisqexpected.

```@example
fitp, csqexp = lin_fit(phi2, m2)
m2_phys = fitp[1] + fitp[2] * phi2_phys
```
"""
function lin_fit(x::Vector{<:Real}, y::Vector{uwreal}; wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing)
    isnothing(wpm) ? uwerr.(y) : [uwerr(yaux, wpm) for yaux in y]
    par = lin_fit(x, value.(y), err.(y))
    chisq(p, d) = sum((d .- p[1] .- p[2].*x).^2 ./ err.(y) .^2)
    (fitp, csqexp) = fit_error(chisq, par, y)
    for i = 1:length(fitp)
        isnothing(wpm) ? uwerr(fitp[i]) : uwerr(fitp[i], wpm)
        print("\n Fit parameter: ", i, ": ")
        details(fitp[i])
    end
    println("Chisq / chiexp: ", chisq(par, y), " / ", csqexp, " (dof: ", length(x)-length(par),")")
    return (fitp, csqexp)
end

@doc raw"""
    x_lin_fit(par::Vector{uwreal}, y::Union{uwreal, Float64})

Computes the results of a linear interpolation/extrapolation in the x axis
"""
x_lin_fit(par::Vector{uwreal}, y::Union{uwreal, Float64}) = (y - par[1]) / par[2]
@doc raw"""
    y_lin_fit(par::Vector{uwreal}, y::Union{uwreal, Float64})

Computes the results of a linear interpolation/extrapolation in the y axis
"""
y_lin_fit(par::Vector{uwreal}, x::Union{uwreal, Float64}) = par[1] + par[2] * x

@doc raw"""
    fit_routine(model::Function, xdata::Array{<:Real}, ydata::Array{uwreal}, param::Int64=3; wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing)

    fit_routine(model::Function, xdata::Array{uwreal}, ydata::Array{uwreal}, param::Int64=3; wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing, covar::Bool=false)

Given a model function with a number param of parameters and an array of `uwreal`,
this function fit ydata with the given `model` and print fit information
The method return an array `upar` with the best fit parameters with their errors.
The flag `wpm` is an optional array of Float64 of lenght 4. The first three paramenters specify the criteria to determine
the summation windows:

- `vp[1]`: The autocorrelation function is summed up to ``t = round(vp[1])``.

- `vp[2]`: The sumation window is determined using U. Wolff poposal with ``S_\tau = wpm[2]``

- `vp[3]`: The autocorrelation function ``\Gamma(t)`` is summed up a point where its error ``\delta\Gamma(t)`` is a factor `vp[3]` times larger than the signal.

An additional fourth parameter `vp[4]`, tells ADerrors to add a tail to the error with ``\tau_{exp} = wpm[4]``.
Negative values of `wpm[1:4]` are ignored and only one component of `wpm[1:3]` needs to be positive.
If the flag `covar`is set to true, `fit_routine` takes into account covariances between x and y for each data point.
```@example
@. model(x,p) = p[1] + p[2] * exp(-(p[3]-p[1])*x)
@. model2(x,p) = p[1] + p[2] * x[:, 1] + (p[3] + p[4] * x[:, 1]) * x[:, 2]
fit_routine(model, xdata, ydata, param=3)
fit_routine(model, xdata, ydata, param=3, covar=true)
```
"""
function fit_routine(model::Function, xdata::Array{<:Real}, ydata::Array{uwreal}, param::Int64=3; wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing)
    isnothing(wpm) ? uwerr.(ydata) : [uwerr(yaux, wpm) for yaux in ydata]
    
    yval = value.(ydata)
    yer = err.(ydata)
    
    # Generate chi2 + solver
    chisq = gen_chisq(model, xdata, yer)
    fit = curve_fit(model, xdata, yval, 1.0 ./ yer.^2, fill(0.5, param))
    (upar, chi_exp) = isnothing(wpm) ? fit_error(chisq, coef(fit), ydata) : fit_error(chisq, coef(fit), ydata, wpm)
    #Info
    for i = 1:length(upar)
        isnothing(wpm) ? uwerr(upar[i]) : uwerr(upar[i], wpm)
        print("\n Fit parameter: ", i, ": ")
        details(upar[i])
    end
    println("Chisq / chiexp: ", chisq(coef(fit), ydata), " / ", chi_exp, " (dof: ", length(yval) - param,")")
    return upar
end

function fit_routine(model::Function, xdata::Array{uwreal}, ydata::Array{uwreal}, param::Int64=3; 
    wpm::Union{Dict{Int64,Vector{Float64}},Dict{String,Vector{Float64}}, Nothing}=nothing, covar::Bool=false)
    
    Nalpha = size(xdata, 2) # number of x-variables
    Ndata = size(xdata, 1) # number of datapoints
    if isnothing(wpm)
        uwerr.(ydata)
        uwerr.(xdata)
    else
        [uwerr(yaux, wpm) for yaux in ydata] 
        [uwerr(xaux, wpm) for xaux in xdata]
    end

    yval = value.(ydata)
    yer = err.(ydata)
    xval = value.(xdata)
    xer = err.(xdata)

    dat = Vector{Float64}(undef, Ndata * (Nalpha+1))
    ddat = Vector{Float64}(undef, Ndata * (Nalpha+1))
    data = Vector{uwreal}(undef, Ndata * (Nalpha+1)) 

    for i = 1:Nalpha
        dat[(i-1)*Ndata+1:i*Ndata] = xval[:, i]
        ddat[(i-1)*Ndata+1:i*Ndata] = xer[:, i]
        data[(i-1)*Ndata+1:i*Ndata] = xdata[:, i]
    end
    dat[Nalpha*Ndata+1:end] = yval
    ddat[Nalpha*Ndata+1:end] = yer
    data[Nalpha*Ndata+1:end] = ydata

    # Guess
    fit = curve_fit(model, xval, yval, 1.0 ./ yer.^2, fill(0.5, param))

    # Generate chi2 + solver
    if covar
        aux = Vector{Vector{uwreal}}(undef, Ndata)
        for k = 1:Ndata
            aux[k] = Vector{uwreal}(undef, Nalpha+1)
            for i = 1:Nalpha
                aux[k][i] = xdata[k, i]
            end
            aux[k][Nalpha+1] = ydata[k]
        end

        C = isnothing(wpm) ? [ADerrors.cov(aux[k]) for k = 1:Ndata] : [ADerrors.cov(aux[k], wpm) for k = 1:Ndata]
        chisq_full_cov(p, d) = get_chi2_cov(model, d, C, p, Nalpha)
        min_fun_cov(t) = chisq_full_cov(t, dat)
        sol = optimize(min_fun_cov, vcat(fit.param, dat[1:Nalpha*Ndata]), LevenbergMarquardt())
        
        (upar, chi2_exp) = isnothing(wpm) ? fit_error(chisq_full_cov, sol.minimizer, data) : fit_error(chisq_full_cov, sol.minimizer, data, wpm)
        println("Chisq / chiexp: ", min_fun_cov(sol.minimizer), " / ", chi2_exp, " (dof: ", length(ydata) - param,")")
    else
        chisq_full(p, d) = get_chi2(model, d, ddat, p, Nalpha)
        min_fun(t) = chisq_full(t, dat)
        sol = optimize(min_fun, vcat(fit.param, dat[1:Nalpha*Ndata]), LevenbergMarquardt())
        
        (upar, chi2_exp) = isnothing(wpm) ? fit_error(chisq_full, sol.minimizer, data) : fit_error(chisq_full, sol.minimizer, data, wpm)
        println("Chisq / chiexp: ", min_fun(sol.minimizer), " / ", chi2_exp, " (dof: ", length(ydata) - param,")")

    end

    #### chisq_full, min_fun out of conditional ->
    #### COMPILER WARNING ** incremental compilation may be fatally broken for this module **

    # Info
    for i = 1:length(upar)
        isnothing(wpm) ? uwerr(upar[i]) : uwerr(upar[i], wpm)
        print("\n Fit parameter: ", i, ": ")
        details(upar[i])
    end
    return upar
    
end

function gen_chisq(f::Function, x::Array{<:Real}, err::Vector{Float64}) #constrained
    chisq(par, dat) = sum((dat .- f(x,par)).^2 ./err.^2)
    return chisq
end

function get_chi2(f::Function, data, ddata, par, Nalpha) #full
    chi2 = 0.0

    Ndata = div(length(data), Nalpha+1)
    Npar = length(par) - Ndata * Nalpha
    p = par[1:Npar]

    for k = 1:Ndata
        xx = [par[Npar + k + (i-1)*Ndata] for i = 1:Nalpha]
        Cinv = zeros(Nalpha+1, Nalpha+1)
        [Cinv[i, i] = 1 / ddata[k + (i-1)*Ndata]^2 for i = 1:Nalpha+1]

        xx = [par[Npar + k + (i-1)*Ndata] for i = 1:Nalpha]
        delta = [data[k + (i-1)*Ndata] - xx[i] for i = 1:Nalpha]
        yy = f(xx', p)
        push!(delta, data[k + Nalpha*Ndata] - yy[1])

        chi2 += delta' * Cinv * delta
    end
    return chi2
end

function get_chi2_cov(f::Function, data, C, par, Nalpha) # full + cov
    chi2 = 0.0

    Ndata = div(length(data), Nalpha+1)
    Npar = length(par) - Ndata * Nalpha
    p = par[1:Npar]

    for k = 1:Ndata
        if det(C[k]) / prod(diag(C[k])) > 1e-6
            Cinv = inv(C[k])
        else
            Cinv = zeros(Nalpha+1, Nalpha+1)
            [Cinv[i, i] = 1 / C[k][i, i] for i = 1:Nalpha+1]
        end

        xx = [par[Npar + k + (i-1)*Ndata] for i = 1:Nalpha]
        delta = [data[k + (i-1)*Ndata] - xx[i] for i = 1:Nalpha]
        yy = f(xx', p)
        push!(delta, data[k + Nalpha*Ndata] - yy[1])
        
        chi2 += delta' * Cinv * delta 
    end
    return chi2
end