#include "cuda_utils.h"
#include <cusparse.h>
#include "cusparse_utils.h"
#include <cublas_v2.h>
#include "cublas_utils.h"
#include <stdio.h>

typedef float FloatType;

// computes sparse matrix-vector-multiplication //
// operation: y=alpha*A*x+beta*y                //
void csrMatrixVectorMultiplication(cusparseHandle_t handle,
                                   int rowCount,int columnCount,
                                   FloatType alpha,
                                   cusparseMatDescr_t matrixDescription,
                                   FloatType* matrixAValues,
                                   int* matrixARowPointers,
                                   int* matrixAColumnIndices,
                                   FloatType* vectorX,
                                   FloatType beta,
                                   FloatType* vectorY) {
  // TODO: replace CPU code with call to cusparse //
  for (int i=0;i<rowCount;i++) {
    int rowEndIndex=matrixARowPointers[i+1];
    FloatType rowSum=0;
    for (int j=matrixARowPointers[i];j<rowEndIndex;j++) {
      rowSum+=matrixAValues[j]*vectorX[matrixAColumnIndices[j]];
    }
    vectorY[i]=alpha*rowSum+beta*vectorY[i];
  }
}

// compute y=alpha*x+y //
void axpy(cublasHandle_t handle,
          int vectorSize,FloatType alpha,FloatType* vectorX,
          FloatType* vectorY) {
  // TODO: replace CPU code with call to cublas //
  for(int i=0;i<vectorSize;i++) {
    vectorY[i]+=alpha*vectorX[i];
  }
}

// compute x=alpha*x //
void scaleVector(cublasHandle_t handle,
                 int vectorSize,FloatType alpha,FloatType* vectorX) {
  // TODO: replace CPU code with call to cublas //
  for(int i=0;i<vectorSize;i++) {
    vectorX[i]*=alpha;
  }
}

// scalar product (single precision) //
FloatType scalarProduct(cublasHandle_t handle,
                        int vectorSize,FloatType* vectorX,FloatType* vectorY) {
  // TODO: replace CPU code with call to cublas //
  double sum=0;
  for(int i=0;i<vectorSize;i++) {
    sum+=vectorX[i]*vectorY[i];
  }
  return sum;
}

// scalar product (double precision) //
FloatType vector2NormSquare(cublasHandle_t handle,
                            int vectorSize,FloatType* vectorX) {
  // TODO: replace CPU code with call to cublas //
  double sum=0;
  for(int i=0;i<vectorSize;i++) {
    sum+=vectorX[i]*vectorX[i];
  }
  return sum;
}

void conjugateGradientSolve(int vectorSize,
                            FloatType* matrixAValues,
                            int* matrixARowPointers,
                            int* matrixAColumnIndices,
                            FloatType* vectorX,
                            FloatType* vectorB,
                            FloatType maxResidual) {

  cublasHandle_t cublasHandle;
  cublasVerify(cublasCreate(&cublasHandle));
  cusparseHandle_t cusparseHandle;
  cusparseVerify(cusparseCreate(&cusparseHandle));

  int matrixEntryCount = matrixARowPointers[vectorSize];

  FloatType* matrixAValues_gpu;
  int*       matrixARowPointers_gpu;
  int*       matrixAColumnIndices_gpu;
  FloatType* vectorX_gpu;
  FloatType* vectorB_gpu;
  FloatType* vectorP_gpu;
  FloatType* vectorV_gpu;
  FloatType* vectorR_gpu;

  cudaVerify(cudaMalloc((void**)&matrixAValues_gpu,
                        matrixEntryCount*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&matrixARowPointers_gpu,
                        (vectorSize+1)*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&matrixAColumnIndices_gpu,
                        matrixEntryCount*sizeof(FloatType)));

  cusparseMatDescr_t matrixDescription;
  cusparseVerify(cusparseCreateMatDescr(&matrixDescription));

  cudaVerify(cudaMalloc((void**)&vectorX_gpu,vectorSize*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&vectorB_gpu,vectorSize*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&vectorP_gpu,vectorSize*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&vectorV_gpu,vectorSize*sizeof(FloatType)));
  cudaVerify(cudaMalloc((void**)&vectorR_gpu,vectorSize*sizeof(FloatType)));

  cudaVerify(cudaMemcpy(matrixAValues_gpu,matrixAValues,
                        matrixEntryCount*sizeof(FloatType),
                        cudaMemcpyHostToDevice));
  cudaVerify(cudaMemcpy(matrixARowPointers_gpu,matrixARowPointers,
                        (vectorSize+1)*sizeof(FloatType),
                        cudaMemcpyHostToDevice));
  cudaVerify(cudaMemcpy(matrixAColumnIndices_gpu,matrixAColumnIndices,
                        matrixEntryCount*sizeof(FloatType),
                        cudaMemcpyHostToDevice));

  cudaVerify(cudaMemcpy(vectorX_gpu,vectorX,
                        vectorSize*sizeof(FloatType),
                        cudaMemcpyHostToDevice));
  cudaVerify(cudaMemcpy(vectorB_gpu,vectorB,
                        vectorSize*sizeof(FloatType),
                        cudaMemcpyHostToDevice));

	csrMatrixVectorMultiplication(cusparseHandle,
                                vectorSize,vectorSize,1,
                                matrixDescription,
                                matrixAValues_gpu,
                                matrixARowPointers_gpu,
                                matrixAColumnIndices_gpu,
                                vectorX_gpu,0,vectorV_gpu);

  cudaVerify(cudaMemcpy(vectorR_gpu,vectorB_gpu,vectorSize*sizeof(FloatType),
                        cudaMemcpyDeviceToDevice));

  axpy(cublasHandle,vectorSize,-1,vectorV_gpu,vectorR_gpu);
  cudaVerify(cudaMemcpy(vectorP_gpu,vectorR_gpu,vectorSize*sizeof(FloatType),
                        cudaMemcpyDeviceToDevice));

	FloatType alphaOld = vector2NormSquare(cublasHandle,vectorSize,vectorR_gpu);

  int iteration=0;

  Timer timer;
  initTimer(&timer);

	while(alphaOld>maxResidual) {

    iteration++;
		fprintf(stderr,"iteration=%i residual=%e\n",iteration,
            sqrt(vector2NormSquare(cublasHandle,vectorSize,vectorR_gpu)));

    csrMatrixVectorMultiplication(cusparseHandle,
                                  vectorSize,vectorSize,1,
                                  matrixDescription,
                                  matrixAValues_gpu,
                                  matrixARowPointers_gpu,
                                  matrixAColumnIndices_gpu,
                                  vectorP_gpu,0,vectorV_gpu);

		FloatType lambda 
      = alphaOld/(scalarProduct(cublasHandle,
                                vectorSize,vectorV_gpu,vectorP_gpu));
    axpy(cublasHandle,vectorSize,lambda,vectorP_gpu,vectorX_gpu);
    axpy(cublasHandle,vectorSize,-lambda,vectorV_gpu,vectorR_gpu);
		FloatType alpha = vector2NormSquare(cublasHandle,vectorSize,vectorR_gpu);
    scaleVector(cublasHandle,vectorSize,alpha/alphaOld,vectorP_gpu);
    axpy(cublasHandle,vectorSize,1.,vectorR_gpu,vectorP_gpu);
		alphaOld=alpha;
	}

  double duration = getTimer(&timer);
  fprintf(stderr,"duration=%e s, "
          "duration per field point per iteration = %e s\n",
          duration,duration/iteration/vectorSize);

  cudaVerify(cudaMemcpy(vectorX,vectorX_gpu,
                        vectorSize*sizeof(FloatType),
                        cudaMemcpyDeviceToHost));

  cudaVerify(cudaFree(matrixAValues_gpu));
  cudaVerify(cudaFree(matrixARowPointers_gpu));
  cudaVerify(cudaFree(matrixAColumnIndices_gpu));
  cudaVerify(cudaFree(vectorX_gpu));
  cudaVerify(cudaFree(vectorB_gpu));
  cudaVerify(cudaFree(vectorV_gpu));
  cudaVerify(cudaFree(vectorR_gpu));

  cusparseVerify(cusparseDestroyMatDescr(matrixDescription));
  cusparseVerify(cusparseDestroy(cusparseHandle));
  cublasVerify(cublasDestroy(cublasHandle));
}

void init2DHeatEquationStiffnessMatrix(int sizeX,int sizeY,
                                       FloatType* matrixValues,
                                       int* matrixRowPointers,
                                       int* matrixColumnIndices) {
  int matrixIndex=0;
  for(int i=0;i<sizeY;i++) {
    for(int j=0;j<sizeX;j++) {
      int vectorIndex=i*sizeX+j;
      matrixRowPointers[vectorIndex]=matrixIndex;
      // add matrix entry for neighbor on top
      if (i!=0) {
        matrixValues[matrixIndex]=-1.;
        matrixColumnIndices[matrixIndex]=vectorIndex-sizeX;
        matrixIndex++;
      } 
      // add matrix entry for neighbor on left
      if (j!=0) {
        matrixValues[matrixIndex]=-1.;
        matrixColumnIndices[matrixIndex]=vectorIndex-1;
        matrixIndex++;
      } 
      // add matrix entry for center of stencil
      matrixValues[matrixIndex]=4.;
      matrixColumnIndices[matrixIndex]=vectorIndex;
      matrixIndex++;
      // add matrix entry for neighbor on right
      if (j!=sizeX-1) {
        matrixValues[matrixIndex]=-1.;
        matrixColumnIndices[matrixIndex]=vectorIndex+1;
        matrixIndex++;
      }
      // add matrix entry for neighbor on bottom
      if (i!=sizeY-1) {
        matrixValues[matrixIndex]=-1.;
        matrixColumnIndices[matrixIndex]=vectorIndex+sizeX;
        matrixIndex++;
      }
    }
  }
  // mark end of last row by setting row pointer for index after last row //
  // thus matrixRowPointers[] must be of size vectorSize+1                //
  matrixRowPointers[sizeX*sizeY]=matrixIndex;
}

void verifySolution(int vectorSize,
                    FloatType* matrixAValues,
                    int* matrixARowPointers,
                    int* matrixAColumnIndices,
                    FloatType* vectorX,
                    FloatType* vectorB,
                    FloatType maxResidual) {
  double residual=0;
  for (int i=0;i<vectorSize;i++) {
    int rowEndIndex=matrixARowPointers[i+1];
    FloatType rowSum=-vectorB[i];
    for (int j=matrixARowPointers[i];j<rowEndIndex;j++) {
      rowSum+=matrixAValues[j]*vectorX[matrixAColumnIndices[j]];
    }
    residual+=rowSum*rowSum;
  }
  if (residual>maxResidual || isnan(residual)) {
    fprintf(stderr,"computation result is wrong.\n"
            "Maximum allowed residual: %e, actual residual: %e\n",
            maxResidual,residual);
  } else {
    fprintf(stderr,"computation result is correct.\n"
            "residual: %e\n",residual);
  }
}

int main() {

  int fieldSizeX=800;
  int fieldSizeY=800;

  int vectorSize=fieldSizeX*fieldSizeY;
  // laplace operator stencil uses 4 neighbors + self                   //
  // = 5 entries per matrix row.                                        //
  // actually the matrix is somewhat smaller because of the boundaries, //
  // but for simplicity the allocations size will be field cells * 5.   //
  int matrixEntryCount=vectorSize*5;

  FloatType* matrixAValues;
  int*       matrixARowPointers;
  int*       matrixAColumnIndices;
  FloatType* vectorX;
  FloatType* vectorB;

  // Linear system to be solved: A*x=b //
  // allocate memory for A,x,b         //
  matrixAValues        = (FloatType*)malloc(matrixEntryCount*sizeof(FloatType));
  matrixARowPointers   = (int*)      malloc((vectorSize+1)*sizeof(int));
  matrixAColumnIndices = (int*)      malloc(matrixEntryCount*sizeof(int));
  vectorX              = (FloatType*)malloc(vectorSize*sizeof(FloatType));
  vectorB              = (FloatType*)malloc(vectorSize*sizeof(FloatType));

  init2DHeatEquationStiffnessMatrix(fieldSizeX,fieldSizeY,matrixAValues,
                                    matrixARowPointers,matrixAColumnIndices);


  for(int i=0;i<vectorSize;i++) {
    vectorX[i]=0;
    vectorB[i]=0;
  }
  // define point source in center of field //
  vectorB[fieldSizeY/2*fieldSizeX+fieldSizeX/2]=1.;
  
  conjugateGradientSolve(vectorSize,
                         matrixAValues,
                         matrixARowPointers,
                         matrixAColumnIndices,
                         vectorX,
                         vectorB,
                         1e-6);

  verifySolution(vectorSize,
                 matrixAValues,
                 matrixARowPointers,
                 matrixAColumnIndices,
                 vectorX,
                 vectorB,
                 1e-5);
  /*    
  for(int i=0;i<fieldSizeY;i++) {
    for(int j=0;j<fieldSizeX;j++) {
      printf("%i %i %e\n",i,j,vectorX[i*fieldSizeY+j]);
    }
    printf("\n");
  }
  */
  return 0;
}
