#TODO include rw, different plateaux depending on obs, print chi2, compute t0, compute mpi
using juobs, ADerrors, DelimitedFiles, PyPlot, LaTeXStrings
const path = "/home/javier/Lattice/charm/production_2"
const path_plat = "/home/javier/Lattice/juobs/analysis/plat.txt"
const path_plot = "/home/javier/Lattice/juobs/analysis/plots"

const ensembles = ["H400", "N200", "N203", "N300", "J303"]
const deg = [true, false, false, true, false]
const L = [32, 48, 48, 48, 64]
const beta = [3.46 , 3.55, 3.55, 3.70, 3.70]
const R = ["H400r001", ["N200r000", "N200r001"], ["N203r000", "N203r001"], "N300r002", "J303r003"]
include("/home/javier/Lattice/juobs/constants/juobs_const.jl")
include("/home/javier/Lattice/juobs/analysis/functions.jl")
const phi2 = 8 .* t0.(beta) .* [0.16345, 0.09222, 0.11224, 0.10630, 0.06514].^2 #8t0 m_pi^2


m_lh = Vector{Vector{uwreal}}(undef, length(ensembles))
m_sh = Vector{Vector{uwreal}}(undef, length(ensembles))
m_lh_star = Vector{Vector{uwreal}}(undef, length(ensembles))
m_sh_star = Vector{Vector{uwreal}}(undef, length(ensembles))

f_lh = Vector{Vector{uwreal}}(undef, length(ensembles))
f_sh = Vector{Vector{uwreal}}(undef, length(ensembles))

mu_pp = Vector{Vector{Vector{Float64}}}(undef, length(ensembles))
mu_aa = Vector{Vector{Vector{Float64}}}(undef, length(ensembles))

###########################
###### COMPUTATION ########
###########################
for iens = 1:length(ensembles)
    pp = read_dat(R[iens], "G5", "G5")
    aa1 = read_dat(R[iens], "G1G5", "G1G5")
    pp_obs = corr_obs.(pp, L=L[iens])
    aa1_obs = corr_obs.(aa1, L=L[iens])
    mu_pp[iens] = getfield.(pp_obs, :mu)
    mu_aa[iens] = getfield.(aa1_obs, :mu)

    m = comp_meff(pp_obs, deg[iens], ensembles[iens])
    m_star = comp_meff(aa1_obs, deg[iens], ensembles[iens])
    f = comp_f(pp_obs, m, deg[iens], ensembles[iens])

    m_lh[iens] = get_lh(mu_pp[iens], m, deg[iens])
    deg[iens] ? m_sh[iens] = m_lh[iens] : m_sh[iens] = get_sh(mu_pp[iens], m, deg[iens])

    m_lh_star[iens] = get_lh(mu_aa[iens], m_star, deg[iens])
    deg[iens] ? m_sh_star[iens] = m_lh_star[iens] : m_sh_star[iens] = get_sh(mu_aa[iens], m_star, deg[iens])

    f_lh[iens] = get_lh(mu_pp[iens], f, deg[iens])
    deg[iens] ? f_sh[iens] = f_lh[iens] : f_sh[iens] = get_sh(mu_pp[iens], f, deg[iens])

end

mm = get_mu.(mu_pp, deg)
mul_pp = getindex.(mm, 1)
mus_pp = getindex.(mm, 2)
muh_pp = getindex.(mm, 3)

mm = get_mu.(mu_aa, deg)
mul_aa = getindex.(mm, 1)
mus_aa = getindex.(mm, 2)
muh_aa = getindex.(mm, 3)

m_lh_match = Vector{uwreal}(undef, length(ensembles))
m_sh_match = Vector{uwreal}(undef, length(ensembles))
m_lh_v_match = Vector{uwreal}(undef, length(ensembles))
m_sh_v_match = Vector{uwreal}(undef, length(ensembles))

f_lh_match = Vector{uwreal}(undef, length(ensembles))
f_sh_match = Vector{uwreal}(undef, length(ensembles))

muh_target = Vector{uwreal}(undef, length(ensembles))

###########################
###### MATCHING ###########
###########################
for iens = 1:length(ensembles)
    target = a(beta[iens]) * (2*M[1] + 6*M[2] + M[3] + 3*M[4]) / (12*hc)
    if !deg[iens] 
        muh_target[iens] = match_muc(muh_pp[iens], m_lh[iens], m_sh[iens], m_lh_star[iens], m_sh_star[iens], target)
    else
        muh_target[iens] = match_muc(muh_pp[iens], m_lh[iens], m_lh_star[iens], target)
    end
        
    uwerr(muh_target[iens])

    #Interpolate m_lh m_lh_star, m_sh, m_sh_tar
    par, chi2exp = lin_fit(muh_pp[iens], m_lh[iens])
    m_lh_match[iens] = y_lin_fit(par, muh_target[iens])


    par, chi2exp = lin_fit(muh_aa[iens], m_lh_star[iens])
    m_lh_v_match[iens] = y_lin_fit(par, muh_target[iens])
    
    uwerr.(m_lh[iens])
    uwerr.(m_lh_star[iens])
    uwerr(m_lh_match[iens])
    uwerr(m_lh_v_match[iens])
    if !deg[iens]
        par, chi2exp = lin_fit(muh_pp[iens], m_sh[iens])
        m_sh_match[iens] = y_lin_fit(par, muh_target[iens])
        
        par, chi2exp = lin_fit(muh_aa[iens], m_sh_star[iens])
        m_sh_v_match[iens] = y_lin_fit(par, muh_target[iens])

        uwerr.(m_sh[iens])
        uwerr.(m_sh_star[iens])
        uwerr(m_sh_match[iens])
        uwerr(m_sh_v_match[iens])
    else
        m_sh_match[iens] = m_lh_match[iens]
        m_sh_v_match[iens] = m_lh_v_match[iens]
    end
    #Interpolate f_lh, f_sh
    par, chi2exp = lin_fit(muh_pp[iens], f_lh[iens])
    f_lh_match[iens] = y_lin_fit(par, muh_target[iens])
    uwerr.(f_lh[iens])
    uwerr(f_lh_match[iens])
    if !deg[iens]
        par, chi2exp = lin_fit(muh_pp[iens], f_sh[iens])
        f_sh_match[iens] = y_lin_fit(par, muh_target[iens])
        uwerr.(f_sh[iens])
        uwerr(f_sh_match[iens])
    else
        f_sh_match[iens] = f_lh_match[iens]
    end
end
###########################
###### PLOTS ##############
###########################
for iens = 1:length(ensembles)
    #m_lh m_lh_star
    figure()
    title(ensembles[iens])
    xlabel(L"$a\mu$")
    ylabel(L"$aM$")
    errorbar(muh_pp[iens], value.(m_lh[iens]), err.(m_lh[iens]), fmt="x")
    errorbar(value(muh_target[iens]), value(m_lh_match[iens]), err(m_lh_match[iens]), err(muh_target[iens]), fmt="x")
    
    errorbar(muh_aa[iens], value.(m_lh_star[iens]), err.(m_lh_star[iens]), fmt="x")
    errorbar(value(muh_target[iens]), value(m_lh_v_match[iens]), err(m_lh_v_match[iens]), err(muh_target[iens]), fmt="x")  
    legend([L"$m_D$", L"$m_D (\mathrm{int})$", L"$m_{D^*}$", L"$m_{D^*} (\mathrm{int})$"])  
    display(gcf())
    t = string(ensembles[iens], "_mD.pdf")
    savefig(joinpath(path_plot, t))
    close()
    
    #m_sh m_sh_star
    if !deg[iens]
        figure()
        title(ensembles[iens])
        xlabel(L"$a\mu$")
        ylabel(L"$aM$")
        errorbar(muh_pp[iens], value.(m_sh[iens]), err.(m_sh[iens]), fmt="x")
        errorbar(value(muh_target[iens]), value(m_sh_match[iens]), err(m_sh_match[iens]), err(muh_target[iens]), fmt="x")

        errorbar(muh_aa[iens], value.(m_sh_star[iens]), err.(m_sh_star[iens]), fmt="x")
        errorbar(value(muh_target[iens]), value(m_sh_v_match[iens]), err(m_sh_v_match[iens]), err(muh_target[iens]), fmt="x")   
        legend([L"$m_{D_s}$", L"$m_{D_s} (\mathrm{int})$", L"$m_{D^*_s}$", L"$m_{D^*_s} (\mathrm{int})$"])   
        display(gcf())
        t = string(ensembles[iens], "_mDs.pdf")
        savefig(joinpath(path_plot, t))
        close()
    end
    #f_lh f_sh
    figure()
    title(ensembles[iens])
    xlabel(L"$a\mu$")
    ylabel(L"$af$")
    errorbar(muh_pp[iens], value.(f_lh[iens]), err.(f_lh[iens]), fmt="x")
    errorbar(value(muh_target[iens]), value(f_lh_match[iens]), err(f_lh_match[iens]), err(muh_target[iens]), fmt="x")
    l = [L"$f_D$", L"$f_D (\mathrm{int})$"]
    if !deg[iens]
        errorbar(muh_pp[iens], value.(f_sh[iens]), err.(f_sh[iens]), fmt="x")
        errorbar(value(muh_target[iens]), value(f_sh_match[iens]), err(f_sh_match[iens]), err(muh_target[iens]), fmt="x")
        push!(l, L"$f_{D_s}$")
        push!(l, L"$f_{D_s} (\mathrm{int})$")
    end
    legend(l)
    display(gcf())
    t = string(ensembles[iens], "_fD.pdf")
    savefig(joinpath(path_plot, t))
    close()

end
###########################
###### RESULTS ###########
###########################
#Quark mass
muc = zm_tm.(beta) .* muh_target .* sqrt.(8 * t0.(beta))
uwerr.(muc)
#Meson masses
m_D = m_lh_match .* sqrt.(8 * t0.(beta))
m_Ds = m_sh_match .* sqrt.(8 * t0.(beta))
m_D_star = m_lh_v_match .* sqrt.(8 * t0.(beta))
m_Ds_star = m_sh_v_match .* sqrt.(8 * t0.(beta))
uwerr.(m_D)
uwerr.(m_Ds)
uwerr.(m_D_star)
uwerr.(m_Ds_star)
#Decay const
f_D = f_lh_match .* sqrt.(8 * t0.(beta))
f_Ds = f_sh_match .* sqrt.(8 * t0.(beta))
uwerr.(f_D)
uwerr.(f_Ds)
for iens=1:length(ensembles)
    println("(", ensembles[iens], ")", L"$Z^{tm}_M \mu_c \sqrt{8t_0} = $", muc[iens])
    println("(", ensembles[iens], ")", L"$M_D \sqrt{8t_0}= $", m_D[iens])
    println("(", ensembles[iens], ")", L"$M_Ds \sqrt{8t_0}= $", m_Ds[iens])
    println("(", ensembles[iens], ")", L"$M_D* \sqrt{8t_0}= $", m_D_star[iens])
    println("(", ensembles[iens], ")", L"$M_Ds* \sqrt{8t_0}= $", m_Ds_star[iens])
    println("(", ensembles[iens], ")", L"$f_D \sqrt{8t_0}= $", f_D[iens])
    println("(", ensembles[iens], ")", L"$f_Ds \sqrt{8t_0}= $", f_Ds[iens])
end

###########################
###### FITS ###############
###########################
x = [1 ./(8 .* t0.(beta)) phi2] #x1 = a^2 / (8t0), x2 = 8t0 mpi^2
uwerr.(x)
phi2_ph = (t0_ph[1] * 139.57039 / hc)^2
uwerr(phi2_ph)
#f(a2/8t0,phi2) = p[1]+ p[2](a2/8t0) + (p[3] + p[4](a2/8t0)) * phi2 
@. model(x, p) = (p[1] + p[2] * x[:, 1]) + (p[3] + p[4] * x[:, 1]) * x[:, 2] #linear fits
obs = [muc, m_D, m_Ds, m_D_star, m_Ds_star, f_D, f_Ds] #sqrt(8t0) obs
ttl_obs = ["muc", "m_D", "m_Ds", "m_D_star", "m_Ds_star", "f_D", "f_Ds"]
ylbl = [L"$Z^{tm}_M \mu_c \sqrt{8t_0}$", L"$M_D \sqrt{8t_0}$", L"$M_{D_s} \sqrt{8t_0}$",
L"$M_{D^*} \sqrt{8t_0}$", L"$M_{D^*_s} \sqrt{8t_0}$", L"$f_D \sqrt{8t_0}$", L"$f_{D_s} \sqrt{8t_0}$"]
xlbl = L"$\phi_2$"
obs_t0 = Vector{uwreal}(undef, length(obs)) #sqrt(8t0)obs @ CL & phi2_phys
for k = 1:length(obs)
    println("OBS ",ttl_obs[k])
    par = fit_routine(model, value.(x), obs[k], 4)
    obs_t0[k] = par[1] + par[3] * phi2_ph
    uwerr(obs_t0[k])
    figure()
    for b in unique(beta) #select point with same beta
        nn = findall(x-> x == b, beta)
        lgnd = string(L"$\beta = $", b)
        errorbar(value.(x[nn,2]), value.(obs[k][nn]), err.(obs[k][nn]), err.(x[nn,2]), fmt="x", label=lgnd)
    end
    lgnd=L"$\mathrm{CL}$"
    errorbar(value(phi2_ph), value(obs_t0[k]), err(obs_t0[k]), err(phi2_ph), fmt="x", zorder=2, label=lgnd)
    axvline(value(phi2_ph), ls="--", color="black", zorder=1, lw=0.6, label="")
    xlabel(xlbl)
    ylabel(ylbl[k])

    legend()
    display(gcf())
    t = string("fit_", ttl_obs[k], ".pdf")
    savefig(joinpath(path_plot, t))
    close()

end

obs_ph = Vector{uwreal}(undef, length(obs))
for k = 1:length(obs)
    println(ylbl[k], " = ", obs_t0[k])
    #phys
    obs_ph[k] = obs_t0[k] * hc / t0_ph[1]
    uwerr(obs_ph[k])
    println(ttl_obs[k], "(MeV) = ", obs_ph[k])
end