###
### "THE BEER-WARE LICENSE":
### Alberto Ramos and Carlos Pena wrote this file. As long as you retain this  
### notice you can do whatever you want with this stuff. If we meet some 
### day, and you think this stuff is worth it, you can buy us a beer in 
### return. <alberto.ramos@cern.ch> <carlos.pena@uam.es>
###
### file:    Dirac.jl
### created: Thu Nov 18 17:20:24 2021
###                               


module Dirac

using CUDA, TimerOutputs
using ..Space
using ..Groups
using ..Fields
using ..YM
using ..Spinors

struct DiracParam{T}
    rep
    m0::T
    csw::T
    th::NTuple{4,Complex{T}}
    ct::T
end


struct DiracWorkspace{T}
    sr
    sp
    sAp
    st
    
    function DiracWorkspace(::Type{G}, ::Type{T}, lp::SpaceParm{4,6,B,D}) where {G,T <: AbstractFloat, B,D}

        sr  = scalar_field(Spinor{4,G}, lp)
        sp  = scalar_field(Spinor{4,G}, lp)
        sAp = scalar_field(Spinor{4,G}, lp)
        st  = scalar_field(Spinor{4,G}, lp)
        return new{T}(sr,sp,sAp,st)
    end
end
export DiracWorkspace, DiracParam

function Dw!(so, U, si, dpar::DiracParam, lp::SpaceParm{4,6,B,D}) where {B,D}

    if B == BC_SF_AFWB || B == BC_SF_ORBI
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(si,si,1.0,lp)
        end
        @timeit "Dw" begin
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_Dw!(so, U, si, dpar.m0, dpar.th, lp)
            end
        end
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(so,si,dpar.ct,lp)
        end 

    else
        @timeit "Dw" begin
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_Dw!(so, U, si, dpar.m0, dpar.th, lp)
            end
            end
    end 
    
    return nothing
end

function DwdagDw!(so, U, si, dpar::DiracParam, st, lp::SpaceParm{4,6,B,D}) where {B,D}

    

    if B == BC_SF_AFWB || B == BC_SF_ORBI
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(si,si,1.0,lp)
        end
        @timeit "DwdagDw" begin
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(st, U, si, dpar.m0, dpar.th, lp)
            end

            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(st,si,dpar.ct,lp)
            end

            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(so, U, st, dpar.m0, dpar.th, lp)
            end

            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(so,st,dpar.ct,lp)
            end
            end

    else
        @timeit "DwdagDw" begin
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(st, U, si, dpar.m0, dpar.th, lp)
            end
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(so, U, st, dpar.m0, dpar.th, lp)
            end
        end
    end 
    
    return nothing
end


function g5Dw!(so, U, si, dpar, lp::SpaceParm{4,6,B,D}) where {B,D}

    if B == BC_SF_AFWB || B == BC_SF_ORBI
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(si,si,1.0,lp)
        end
        @timeit "Dw" begin
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(so, U, si, dpar.m0, dpar.th , lp)
            end
        end
        CUDA.@sync begin
            CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_sfbndfix!(so,si,dpar.ct,lp)
        end 
        
    else
        @timeit "Dw" begin
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_g5Dw!(so, U, si, dpar.m0, dpar.th , lp)
                end
            end
    end 
    
    return nothing
end

function krnl_sfbndfix!(so,si,ct,lp::SpaceParm)
    b=Int64(CUDA.threadIdx().x)
    r=Int64(CUDA.blockIdx().x)

    if (point_time((b,r),lp) == 1)
    so[b,r] = 0.0*so[b,r]

    elseif (point_time((b,r),lp) == 2) || (point_time((b,r),lp) == lp.iL[4])
    so[b,r] += (ct-1.0)si[b,r]
    end
    return nothing
end


function krnl_Dw!(so, U, si, m0, th, lp::SpaceParm{4,6,B,D}) where {B,D}

    b = Int64(CUDA.threadIdx().x);  r = Int64(CUDA.blockIdx().x)

    # For SF:
    #  - cttilde affects mass term at x0 = a, T-a
    #  - Spinor can be periodic as long as 0 at x_0=0
    @inbounds begin 
        so[b,r] = (4+m0)*si[b,r]
        
            bu1, ru1 = up((b,r), 1, lp)
            bd1, rd1 = dw((b,r), 1, lp)
            bu2, ru2 = up((b,r), 2, lp)
            bd2, rd2 = dw((b,r), 2, lp)
            bu3, ru3 = up((b,r), 3, lp)
            bd3, rd3 = dw((b,r), 3, lp)
            bu4, ru4 = up((b,r), 4, lp)
            bd4, rd4 = dw((b,r), 4, lp)
            

        so[b,r] -= 0.5*( th[1]*gpmul(Pgamma{1,-1},U[b,1,r],si[bu1,ru1]) +conj(th[1])*gdagpmul(Pgamma{1,+1},U[bd1,1,rd1],si[bd1,rd1]) +
                     th[2]*gpmul(Pgamma{2,-1},U[b,2,r],si[bu2,ru2]) +conj(th[2])*gdagpmul(Pgamma{2,+1},U[bd2,2,rd2],si[bd2,rd2]) +
                     th[3]*gpmul(Pgamma{3,-1},U[b,3,r],si[bu3,ru3]) +conj(th[3])*gdagpmul(Pgamma{3,+1},U[bd3,3,rd3],si[bd3,rd3]) +
                     th[4]*gpmul(Pgamma{4,-1},U[b,4,r],si[bu4,ru4]) +conj(th[4])*gdagpmul(Pgamma{4,+1},U[bd4,4,rd4],si[bd4,rd4])  )
    
    end

    return nothing
end

function krnl_g5Dw!(so, U, si, m0, th, lp::SpaceParm{4,6,B,D}) where {B,D}

    b = Int64(CUDA.threadIdx().x);  r = Int64(CUDA.blockIdx().x)

    @inbounds begin 
        so[b,r] = (4+m0)*si[b,r]

	    bu1, ru1 = up((b,r), 1, lp)
            bd1, rd1 = dw((b,r), 1, lp)
            bu2, ru2 = up((b,r), 2, lp)
            bd2, rd2 = dw((b,r), 2, lp)
            bu3, ru3 = up((b,r), 3, lp)
            bd3, rd3 = dw((b,r), 3, lp)
            bu4, ru4 = up((b,r), 4, lp)
            bd4, rd4 = dw((b,r), 4, lp)


        so[b,r] -= 0.5*( th[1]*gpmul(Pgamma{1,-1},U[b,1,r],si[bu1,ru1]) +conj(th[1])*gdagpmul(Pgamma{1,+1},U[bd1,1,rd1],si[bd1,rd1]) +
                     th[2]*gpmul(Pgamma{2,-1},U[b,2,r],si[bu2,ru2]) +conj(th[2])*gdagpmul(Pgamma{2,+1},U[bd2,2,rd2],si[bd2,rd2]) +
                     th[3]*gpmul(Pgamma{3,-1},U[b,3,r],si[bu3,ru3]) +conj(th[3])*gdagpmul(Pgamma{3,+1},U[bd3,3,rd3],si[bd3,rd3]) +
                     th[4]*gpmul(Pgamma{4,-1},U[b,4,r],si[bu4,ru4]) +conj(th[4])*gdagpmul(Pgamma{4,+1},U[bd4,4,rd4],si[bd4,rd4])  )


        so[b,r] = dmul(Gamma{5}, so[b,r])
    end
        
    return nothing
end



###############################   HMC for fermions   ###################################



function pfrandomize!(f,lp::SpaceParm,dpar::DiracParam,t::Int64=0) #DiracParam to apply dirac operator later?

    if dpar.rep == SU3fund && lp.ndim == 4
        @timeit "Randomize pseudofermion field" begin
            p = ntuple(i->CUDA.randn(Float64, lp.bsz, 3, lp.rsz,2),4) # complex generation not suported for Julia 1.5.4
            CUDA.@sync begin
                CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_assign_pf!(f,p,lp,t)
            end
        end
    end

    return nothing

end

function krnl_assign_pf!(f::AbstractArray{T}, p ,lp::SpaceParm, t::Int64) where {T} #only valid for SU3fund for now. Check performance and maybe change it for the tuple gen in the krnl

    @inbounds begin
        b = Int64(CUDA.threadIdx().x)
        r = Int64(CUDA.blockIdx().x)

            if t == 0
            f[b,r] = Spinor(map(x->SU3fund(x[b,1,r,1] + im* x[b,1,r,2],
                                           x[b,2,r,1] + im* x[b,2,r,2],
                                           x[b,3,r,1] + im* x[b,3,r,2]),p))
            elseif point_time((b,r),lp) == t
            f[b,r] = Spinor(map(x->SU3fund(x[b,1,r,1] + im* x[b,1,r,2],
                                           x[b,2,r,1] + im* x[b,2,r,2],
                                           x[b,3,r,1] + im* x[b,3,r,2]),p))
            end
            
    end

    return nothing

end

export Dw!, DwdagDw!, g5Dw!, pfrandomize!

end