// (c) Microsoft Corporation 2005-2007.  

#light

module Microsoft.FSharp.Math.Primitives.FFT 

    open Microsoft.FSharp.Core
    open Microsoft.FSharp.Core.LanguagePrimitives.IntrinsicOperators
    open Microsoft.FSharp.Core.Operators
    open Microsoft.FSharp.Collections
    open Microsoft.FSharp.Primitives.Basics

    //-------------------------------------------------------------------------
    // misc
    //-----------------------------------------------------------------------

    #if SELFTEST
    let check b = if not b then failwith "assertion failwith"
    #endif

    let rec pow32 x n =
        if   n=0       then 1 
        elif n % 2 = 0 then pow32 (x*x) (n / 2)
        else                x* pow32 (x*x) (n / 2)

    let leastBounding2Power b =
        let rec findBounding2Power b tp i = if b<=tp then tp,i else findBounding2Power b (tp*2) (i+1) in
        findBounding2Power b 1 0

    //-------------------------------------------------------------------------
    // p = 2^k.m + 1 prime and w primitive 2^k root of 1 mod p
    //-----------------------------------------------------------------------

    // Given p = 2^k.m + 1 prime and w a primitive 2^k root of unity (mod p).
    // Required to define arithmetic ops for Fp = field modulo p.
    // The following are possible choices for p.
       
    //                            p,  k,  m,       g,  w 
    //  let p,k,m,g,w =         97L,  4,  6,       5,   8           // p is  7 bit 
    //  let p,k,m,g,w =        769L,  8,  3,       7,   7           // p is 10 bit 
    //  let p,k,m,g,w =       7681L,  8, 30,      13, 198           // p is 13 bit 
    //  let p,k,m,g,w =      12289L, 10, 12,      11,  49           // p is 14 bit 
    //  let p,k,m,g,w =  167772161L, 25,  5,  557092, 39162105      // p is 28 bit 
    //  let p,k,m,g,w =  469762049L, 26,  7, 1226571, 288772249     // p is 29 bit 
    

    let p,k,m,g,w = 2013265921L, 27, 15,      31, 440564289     // p is 31 bit 
    let primeP = p  

    let max_bits_inside_p = 30    


    //-------------------------------------------------------------------------
    // Fp = finite field mod p - rep is uint32
    //-----------------------------------------------------------------------


    type fp = uint32
    // operations in Fp (finite field size p) 
    module Fp = 
        //module I = UInt32
        let p   = 2013265921ul : fp
        let p64 = 2013265921UL : uint64
        let to_int   (x:fp) : int = int32 x
        let of_int   (x:int) : fp = uint32 x

        let mzero : fp = 0ul
        let mone  : fp = 1ul
        let mtwo  : fp = 2ul
        let inline madd (x:fp) (y:fp) : fp = (x + y) % p
        let inline msub (x:fp) (y:fp) : fp = (x + p - y) % p
        let inline mmul (x:fp) (y:fp) : fp = uint32 ((uint64 x * uint64 y) % p64)

        let rec mpow x n =
            if n=0       then mone
            elif n % 2=0 then mpow (mmul x x) (n / 2)
            else              mmul x (mpow (mmul x x) (n / 2))
                
        let rec mpowL x n =
            if   n = 0L      then mone
            elif n % 2L = 0L then mpowL (mmul x x) (n / 2L)
            else                  mmul x (mpowL (mmul x x) (n / 2L))
                
        // Have the w is primitive 2^kth root of 1 in Zp           
        let m2PowNthRoot n =
            // Find x s.t. x is (2^n)th root of unity.
            //
            //   pow w (pow 2 k) = 1 primitively.
            // = pow w (pow 2 ((k-n)+n))
            // = pow w (pow 2 (k-n) * pow 2 n)
            // = pow (pow w (pow 2 (k-n))) (pow 2 n)
            //
            // Take wn = pow (pow w (pow 2 (k-n)))
             
            mpow (uint32 w) (pow32 2 (k-n))
            
        let minv x = mpowL x (primeP - 2L)


    //-------------------------------------------------------------------------
    // FFT - in place low garbage
    //-----------------------------------------------------------------------

    open Fp
    let rec computeFFT lambda mu n w u res offset =
        // Given n a 2-power,
        //       w an nth root of 1 in Fp, and
        //       lambda, mu and u(x) defining
        //       poly(lambda,mu,x) = sum(i<n) u(lambda.i + mu).x^i
        //
        //       Note, "lambda.i + mu" for i=0...(n-1) defines the coefficients of the u(x) odd/even sub polys.
        // 
        // Compute res.[offset+j] = poly(lambda,mu,w^j)
        // ---
        // poly(lambda,mu,x) = sum(i<n/2) u.[lambda.2i + mu] * x^2i  + x.sum(i<n/2) u.[lambda.(2i+1) + mu] * x^2i
        //                   = poly(2.lambda,mu,x^2)                 + x.poly(2.lambda,lambda+mu,x^2)
        // ---
        // Recursively call s.t.
        // For j<n/2,
        //   res.[offset+j    ] = poly(2.lambda,mu       ,(w^2)^j)
        //   res.[offset+j+n/2] = poly(2.lambda,lambda+mu,(w^2)^j)
        // For j<n/2,
        //   even = res.[offset+j]
        //   odd  = res.[offset+j+n/2]
        //   res.[offset+j]     = even + w^j * odd
        //   res.[offset+j+n/2] = even - w^j * odd
         
        if n=1 then
            res.[offset] <- u.[mu]
        else
            let halfN       = n/2 
            let ww          = mmul w w 
            let offsetHalfN = offset + halfN 
            let uevenFT     = computeFFT (lambda*2) mu            halfN ww u res offset      
            let uoddFT      = computeFFT (lambda*2) (lambda + mu) halfN ww u res offsetHalfN 
            let mutable wj  = mone 
            for j = 0 to halfN-1 do
                let even = res.[offset+j]      
                let odd  = res.[offsetHalfN+j] 
                res.[offset+j]      <- madd even (mmul wj odd);
                res.[offsetHalfN+j] <- msub even (mmul wj odd);
                wj <- mmul w wj

    let computFftInPlace n w u =
        // Given n a power of 2,
        //       w a primitive nth root of unity in Fp,
        //       u(x) = sum(i<n) u.[i] * x^i
        // Compute res.(j) = u(w^j) for j<n.           
        let lambda = 1 
        let mu     = 0 
        let res    = Array.create n mzero 
        let offset = 0 
        computeFFT lambda mu n w u res offset;
        res

    let computeInverseFftInPlace n w uT =
        let bigKInv = minv (uint32 n) 
        Array.map
          (mmul bigKInv)
          (computFftInPlace n (minv w) uT)

    //-------------------------------------------------------------------------
    // FFT - polynomial product
    //-----------------------------------------------------------------------

    let maxTwoPower   = 29
    let twoPowerTable = Array.init (maxTwoPower-1) (fun i -> pow32 2 i)

    let computeFftPaddedPolynomialProduct bigK k u v =
        // REQUIRES: bigK = 2^k
        // REQUIRES: Array lengths of u and v = bigK.
        // REQUIRES: degree(uv) <= bigK-1
        // ---
        // Given u,v polynomials.
        // Computes the product polynomial by FFT.
        // For correctness,
        //   require the result coeff to be in range [0,p-1], for p defining Fp above.
         
    #if SELFTEST
        check ( k <= maxTwoPower );
        check ( bigK = twoPowerTable.[k] );
        check ( Array.length u = bigK );
        check ( Array.length v = bigK );
    #endif
        // Find 2^k primitive root of 1 
        let w      = m2PowNthRoot k 
        // FFT 
        let n  = bigK 
        let uT = computFftInPlace n w u 
        let vT = computFftInPlace n w v 
        // Evaluate 
        let rT = Array.init n (fun i -> mmul uT.[i] vT.[i]) 
        // INV FFT 
        let r  = computeInverseFftInPlace n w rT 
        r

    let padTo n u =
        let uBound = Array.length u 
        Array.init n (fun i -> if i<uBound then Fp.of_int u.[i] else Fp.mzero)

    let computeFftPolynomialProduct degu u degv v =
        // u,v polynomials.
        // Compute the product polynomial by FFT.
        // For correctness,
        //   require the result coeff to be in range [0,p-1], for p defining Fp above.
         
        let deguv  = degu + degv 
        let bound  = deguv + 1   
        let bigK,k = leastBounding2Power bound 
        let w      = m2PowNthRoot k 
        // PAD 
        let u      = padTo bigK u 
        let v      = padTo bigK v 
        // FFT 
        let n  = bigK 
        let uT = computFftInPlace n w u 
        let vT = computFftInPlace n w v 
        // Evaluate 
        let rT = Array.init n (fun i -> mmul uT.[i] vT.[i]) 
        // INV FFT 
        let r  = computeInverseFftInPlace n w rT 
        Array.map Fp.to_int r


    //-------------------------------------------------------------------------
    // fp exports
    //-----------------------------------------------------------------------

    open Fp
    let mzero = mzero
    let mone  = mone
    let max_fp             = msub Fp.p mone
    let fp_of_int x        = Fp.of_int x
    let int_of_fp x        = Fp.to_int x
    let max_bits_inside_fp = max_bits_inside_p

    //-------------------------------------------------------------------------
    // FFT - reference implementation
    //-----------------------------------------------------------------------
        
    #if SELFTEST
    open Fp
    let rec computeFftReference n w u =
        // Given n a 2-power,
        //       w an nth root of 1 in Fp, and
        //       u(x) = sum(i<n) u(i).x^i
        // Compute res.[j] = u(w^j)
        // ---
        // u(x) = sum(i<n/2) u.[2i] * x^i  +  x . sum(i<n/2) u.[2i+1] * x^i
        //      = ueven(x)                 +  x . uodd(x)
        // ---
        // u(w^j)         = ueven(w^2j) + w^j . uodd(w^2j)
        // u(w^(halfN+j)) = ueven(w^2j) - w^j . uodd(w^2j)
        //)
        if n=1 then
          [| u.[0];
          |]
        else
            let ueven   = Array.init (n/2) (fun i -> u.[2*i])   
            let uodd    = Array.init (n/2) (fun i -> u.[2*i+1]) 
            let uevenFT = computeFftReference (n/2) (mmul w w) ueven 
            let uoddFT  = computeFftReference (n/2) (mmul w w) uodd  
              Array.init n
                (fun j ->
                   if j < n/2 then
                     madd
                       (uevenFT.[j])
                       (mmul
                          (mpow w j)
                          (uoddFT.[j]))
                   else
                     let j = j - (n/2) 
                     msub
                         (uevenFT.[j])
                         (mmul
                            (mpow w j)
                            (uoddFT.[j])))
    #endif

