/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
! Content:
!       Example of using fftwf_plan_dft_r2c_plan_2d/fftwf_plan_dft_c2r_plan_2d 
!       function on a (GPU) device using the OpenMP target (offload) interface
!
!****************************************************************************/

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <float.h>
#include "fftw/fftw3.h"
#include "fftw/offload/fftw3_omp_offload.h"

static void init_c(fftwf_complex *x, int N1, int N2, int H1, int H2);
static int verify_r(float *x, int N1, int N2, int H1, int H2);
static void init_r(float *x, int N1, int N2, int H1, int H2);
static int verify_c(fftwf_complex *x, int N1, int N2, int H1, int H2);

int main(void)
{
    /* Sizes of 2D transform */
    int N1 = 7;
    int N2 = 7;

    /* Arbitrary harmonic used to verify FFT */
    int H1 = N1/2;
    int H2 = -2;

    /* FFTW plan handles */
    fftwf_plan forward_plan = 0;
    fftwf_plan backward_plan = 0;

    /* Pointers to input and output data */
    float *x       = NULL;
    fftwf_complex *y = NULL;

    /* Execution status */
    int statusf = 0, statusb = 0, status = 0;

    const int devNum = 0;

    printf("Example sp_plan_dft_real_2d_async\n");
    printf("2D real-to-complex out-of-place transform\n");
    printf("Configuration parameters:\n");
    printf(" N = {%d, %d}\n", N1, N2);
    printf(" H = {%d, %d}\n", H1, H2);

    printf("Allocate array for input data\n");
    x  = fftwf_malloc(sizeof(float)*N1*N2);
    y  = fftwf_malloc(sizeof(fftwf_complex)*N1*(N2/2+1));
    if (0 == x || 0 == y) goto failed;
    printf("Initialize input for forward transform\n");
    init_r(x, N1, N2, H1, H2);

    printf("Create FFTW plan for 2D float-precision Real to Complex forward_plan out-of-place FFT\n");
#pragma omp target data map(tofrom:x[0:N1*N2],y[0:(N2/2+1)*N1]) device(devNum)
    {
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch use_device_ptr(x,y) device(devNum)
#endif
    forward_plan = fftwf_plan_dft_r2c_2d(N1, N2, x, y, FFTW_ESTIMATE);
    if (forward_plan == 0) printf("Call to fftwf_plan_dft_r2c_2d has failed");

    printf("Compute forward_plan FFT\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum) nowait
#else
#pragma omp target variant dispatch device(devNum) nowait
#endif
    fftwf_execute(forward_plan);
#pragma omp taskwait

#pragma omp target update from(y[0:(N2/2+1)*N1])
    printf("Verify the result of forward_plan FFT\n");
    statusf = verify_c(y, N1, N2, H1, H2);

    printf("Initialize input for backward transform\n");
    init_c(y, N1, N2, H1, H2);
#pragma omp target update to(y[0:(N2/2+1)*N1])

    printf("Create FFTW plan for 2D float-precision Complex to Real backward transform\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch use_device_ptr(y, x) device(devNum)
#endif
    backward_plan = fftwf_plan_dft_c2r_2d(N1, N2, y, x, FFTW_ESTIMATE);
    if (backward_plan == 0) printf("Call to fftwf_plan_dft_c2r_2d has failed");

    printf("Compute backward FFT\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum) nowait
#else
#pragma omp target variant dispatch device(devNum) nowait
#endif
    fftwf_execute(backward_plan);
#pragma omp taskwait
    } // target data map

    printf("Verify the result of backward FFT\n");
    statusb = verify_r(x, N1, N2, H1, H2);

    
    if(statusf != 0 || statusb != 0) goto failed;

 cleanup:

    printf("Destroy FFTW plan\n");
    fftwf_destroy_plan(forward_plan);
    fftwf_destroy_plan(backward_plan);

    printf("Free data array\n");
    fftwf_free(x);
    fftwf_free(y);

    printf("TEST %s\n",0==status ? "PASSED" : "FAILED");
    return status;

 failed:
    printf(" ERROR\n");
    status = 1;
    goto cleanup;
}

/* Compute (K*L)%M accurately */
static float moda(int K, int L, int M)
{
    return (float)(((long long)K * L) % M);
}

/* Initialize array x(N) to produce unit peak at x(H) */
static void init_c(fftwf_complex *x, int N1, int N2, int H1, int H2)
{
    float TWOPI = 6.2831853071795864769f, phase;
    int n1, n2, S1, S2, index;

    /* Generalized strides for row-major addressing of x */
    S2 = 1;
    S1 = N2/2+1;

    for (n1 = 0; n1 < N1; n1++) {
        for (n2 = 0; n2 < N2/2+1; n2++) {
            phase  = moda(n1,H1,N1) / N1;
            phase += moda(n2,H2,N2) / N2;
            index = n1*S1 + n2*S2;
            x[index][0] =  cosf( TWOPI * phase ) / (N1*N2);
            x[index][1] = -sinf( TWOPI * phase ) / (N1*N2);
        }
    }
}

/* Verify that x has unit peak at H */
static int verify_r(float *x, int N1, int N2, int H1, int H2)
{
    float err, errthr, maxerr;
    int n1, n2, S1, S2, index;

    /* Generalized strides for row-major addressing of x */
    S2 = 1;
    S1 = N2;

    errthr = 2.5 * logf( (float)N1*N2 ) / logf(2.0) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    maxerr = 0;
    for (n1 = 0; n1 < N1; n1++) {
        for (n2 = 0; n2 < N2; n2++) {
            float re_exp = 0.0, re_got;

            if ((n1-H1)%N1==0 && (n2-H2)%N2==0) re_exp = 1.0f;

            index = n1*S1 + n2*S2;
            re_got = x[index];
            err  = fabs(re_got - re_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                printf(" x[%i][%i]: ",n1,n2);
                printf(" expected %.17lg, ",re_exp);
                printf(" got %.17lg, ",re_got);
                printf(" err %.3lg\n", err);
                printf(" Verification FAILED\n");
                return 1;
            }
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}

/* Initialize array x(N) to produce unit peaks at x(H) and x(N-H) */
static void init_r(float *x, int N1, int N2, int H1, int H2)
{
    float TWOPI = 6.2831853071795864769f, phase, factor;
    int n1, n2, S1, S2, index;

    /* Generalized strides for row-major addressing of x */
    S2 = 1;
    S1 = N2;

    factor = ((2*(N1-H1)%N1)==0 && (2*(N2-H2)%N2)==0) ? 1.0f : 2.0f;
    for (n1 = 0; n1 < N1; n1++) {
        for (n2 = 0; n2 < N2; n2++) {
            phase  = moda(n1,H1,N1) / N1;
            phase += moda(n2,H2,N2) / N2;
            index = n1*S1 + n2*S2;
            x[index] = factor * cosf( TWOPI * phase ) / (N1*N2);
        }
    }
}

/* Verify that x has unit peak at H */
static int verify_c(fftwf_complex *x, int N1, int N2, int H1, int H2)
{
    float err, errthr, maxerr;
    int n1, n2, S1, S2, index;

    /* Generalized strides for row-major addressing of x */
    S2 = 1;
    S1 = N2/2+1;

    errthr = 2.5 * logf( (float)N1*N2 ) / logf(2.0) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    maxerr = 0;
    for (n1 = 0; n1 < N1; n1++) {
        for (n2 = 0; n2 < N2/2+1; n2++) {
            float re_exp = 0.0f, im_exp = 0.0f, re_got, im_got;

            if ((( n1-H1)%N1==0 && ( n2-H2)%N2==0) || ((-n1-H1)%N1==0 && (-n2-H2)%N2==0)) re_exp = 1.0f;

            index = n1*S1 + n2*S2;
            re_got = x[index][0];
            im_got = x[index][1];
            err  = fabs(re_got - re_exp) + fabs(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                printf(" x[%i][%i]: ",n1,n2);
                printf(" expected (%.17lg,%.17lg), ",re_exp,im_exp);
                printf(" got (%.17lg,%.17lg), ",re_got,im_got);
                printf(" err %.3lg\n", err);
                printf(" Verification FAILED\n");
                return 1;
            }
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}
