#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "cuda_utils.h"

// size 'T' (width and height of one matrix) //
const int matrix_size=16;

__global__ void MatrixProd(float* a,float* b,float* c) {

  __shared__ float shared_a[matrix_size][matrix_size];
  __shared__ float shared_b[matrix_size][matrix_size];

  int matrix_offset=blockIdx.x*matrix_size*matrix_size;

  // read matrix in shared memory //
  shared_a[threadIdx.y][threadIdx.x]
    =a[matrix_offset+threadIdx.y*matrix_size+threadIdx.x];
  shared_b[threadIdx.x][threadIdx.y]
    =b[matrix_offset+threadIdx.y*matrix_size+threadIdx.x];
  __syncthreads();
  // compute product //
  float sum=0;
  for(int l=0;l<matrix_size;l++) {
    sum+=shared_a[threadIdx.y][l]*shared_b[threadIdx.x][l];
  }
  c[matrix_offset+threadIdx.y*matrix_size+threadIdx.x]=sum;
}


// CPU code for reference computation //
void MatrixProdReference(float* a,float* b,float* c,int matrix_count) {
  for(int i=0;i<matrix_count;i++) {
    size_t matrix_offset=i*matrix_size*matrix_size;
    for(int j=0;j<matrix_size;j++) {
      for(int k=0;k<matrix_size;k++) {
        float sum=0;
        for(int l=0;l<matrix_size;l++) {
          sum+= a[matrix_offset+j*matrix_size+l]
               *b[matrix_offset+l*matrix_size+k];
        }
        c[matrix_offset+j*matrix_size+k]=sum;
      }
    }
  }
}

int main() {

  srand48(time(NULL));

  // define sizes //
  const int matrix_count = 60000;

  const size_t data_size = matrix_count*matrix_size*matrix_size*sizeof(float);

  // allocate memory on cpu //
  float* matrices_a_cpu;
  float* matrices_b_cpu;
  float* matrices_c_cpu;
  cudaVerify(cudaMallocHost((void**)&matrices_a_cpu,data_size));
  cudaVerify(cudaMallocHost((void**)&matrices_b_cpu,data_size));
  cudaVerify(cudaMallocHost((void**)&matrices_c_cpu,data_size));
  
  // initialize matrices //
  for(int i=0;i<matrix_count;i++) {
    for(int j=0;j<matrix_size*matrix_size;j++) {
      matrices_a_cpu[i*matrix_size*matrix_size+j]=drand48();
      matrices_b_cpu[i*matrix_size*matrix_size+j]=drand48();
      matrices_c_cpu[i*matrix_size*matrix_size+j]=-1;
    }
  }

  // allocate memory on gpu //
  float* matrices_a_gpu;
  float* matrices_b_gpu;
  float* matrices_c_gpu;
  cudaVerify(cudaMalloc((void**)&matrices_a_gpu,data_size));
  cudaVerify(cudaMalloc((void**)&matrices_b_gpu,data_size));
  cudaVerify(cudaMalloc((void**)&matrices_c_gpu,data_size));

  // initialize output matrix C on gpu //
  cudaVerify(cudaMemcpy(matrices_c_gpu,matrices_c_cpu,data_size,
                        cudaMemcpyHostToDevice));

  // copy input matrices A, B from cpu to gpu //

  Timer timer_transfer_in;
  initTimer(&timer_transfer_in);

  cudaVerify(cudaMemcpy(matrices_a_gpu,matrices_a_cpu,data_size,
                        cudaMemcpyHostToDevice));
  cudaVerify(cudaMemcpy(matrices_b_gpu,matrices_b_cpu,data_size,
                        cudaMemcpyHostToDevice));

  dim3 threads(matrix_size,matrix_size);
  dim3 grid(matrix_count);

  Timer timer_kernel;
  initTimer(&timer_kernel);

  // kernel call //
  cudaVerifyKernel((MatrixProd<<<grid,threads>>>(matrices_a_gpu,matrices_b_gpu,
                                                 matrices_c_gpu)));
  cudaVerify(cudaThreadSynchronize());

  Timer timer_transfer_out;
  initTimer(&timer_transfer_out);

  // copy output vector from gpu to cpu //
  cudaVerify(cudaMemcpy(matrices_c_cpu,matrices_c_gpu,data_size,
                        cudaMemcpyDeviceToHost));

  Timer timer_end;
  initTimer(&timer_end);

  // calculate reference values
  float* matrices_c_reference = (float*)malloc(data_size);

  Timer timer_cpu_begin;
  initTimer(&timer_cpu_begin);

  MatrixProdReference(matrices_a_cpu,matrices_b_cpu,matrices_c_reference,
                      matrix_count);

  Timer timer_cpu_end;
  initTimer(&timer_cpu_end);

  // verify results
  bool success=true;
  for(int i=0;i<matrix_count;i++) {
    for(int j=0;j<matrix_size*matrix_size;j++) {
      if (fabs( matrices_c_cpu[i*matrix_size*matrix_size+j]
               -matrices_c_reference[i*matrix_size*matrix_size+j])
          > 1e-5 * 
          ( fabs(matrices_c_cpu[i*matrix_size*matrix_size+j])
           +fabs(matrices_c_reference[i*matrix_size*matrix_size+j]))) {
        fprintf(stderr,"error: matrix=%i position=%i,%i reference=%e "
                "cuda result=%e\n",
                i,j%matrix_size,j/matrix_size,
                matrices_c_reference[i*matrix_size*matrix_size+j],
                matrices_c_cpu[i*matrix_size*matrix_size+j]);
        success=false;
        break;
      }
    }
    if (!success)
      break;
  }

  free(matrices_c_reference);

  if (success) {
    printf("computation result is correct.\n");
  }

  double duration = getTimerDifference(&timer_transfer_in,&timer_kernel);
  fprintf(stderr,"duration (data transfer in) = %e s, bandwith = %e bytes/s\n",
          duration,data_size*2./duration);
  duration = getTimerDifference(&timer_transfer_out,&timer_end);
  fprintf(stderr,"duration (data transfer out) = %e s, bandwith = %e bytes/s\n",
          duration,data_size/duration);
  duration = getTimerDifference(&timer_kernel,&timer_transfer_out);
  fprintf(stderr,"duration (computation) = %e s, performance = %e FLOPS\n",
          duration,
          2.*matrix_size*matrix_size*matrix_size*matrix_count/duration);
  duration = getTimerDifference(&timer_transfer_in,&timer_end);
  fprintf(stderr,"duration (overall) = %e s, performance "
          "(including transfer time) = %e FLOPS\n",
          duration,
          2.*matrix_size*matrix_size*matrix_size*matrix_count/duration);
  duration = getTimerDifference(&timer_cpu_begin,&timer_cpu_end);
  fprintf(stderr,"duration (computation on CPU) = %e s, "
          "performance = %e FLOPS\n",
          duration,
          2.*matrix_size*matrix_size*matrix_size*matrix_count/duration);

  // free gpu memory //
  cudaVerify(cudaFree(matrices_a_gpu));
  cudaVerify(cudaFree(matrices_b_gpu));
  cudaVerify(cudaFree(matrices_c_gpu));

  // free cpu memory //
  cudaVerify(cudaFreeHost(matrices_a_cpu));
  cudaVerify(cudaFreeHost(matrices_b_cpu));
  cudaVerify(cudaFreeHost(matrices_c_cpu));

  // exit //
  return 0;
}
