// Copyright (c) 2015-16 Tom Deakin, Simon McIntosh-Smith,
// University of Bristol HPC
//
// For full license terms please see the LICENSE file distributed with this
// source code

#include "RAJAStream.hpp"

using RAJA::forall;
using RAJA::RangeSegment;

#ifndef ALIGNMENT
#define ALIGNMENT (2*1024*1024) // 2MB
#endif

template <class T>
RAJAStream<T>::RAJAStream(const unsigned int ARRAY_SIZE, const int device_index)
    : array_size(ARRAY_SIZE)
{
  RangeSegment seg(0, ARRAY_SIZE);
  index_set.push_back(seg);

#ifdef RAJA_TARGET_CPU
  d_a = (T*)aligned_alloc(ALIGNMENT, sizeof(T)*array_size);
  d_b = (T*)aligned_alloc(ALIGNMENT, sizeof(T)*array_size);
  d_c = (T*)aligned_alloc(ALIGNMENT, sizeof(T)*array_size);
#else
  cudaMallocManaged((void**)&d_a, sizeof(T)*ARRAY_SIZE, cudaMemAttachGlobal);
  cudaMallocManaged((void**)&d_b, sizeof(T)*ARRAY_SIZE, cudaMemAttachGlobal);
  cudaMallocManaged((void**)&d_c, sizeof(T)*ARRAY_SIZE, cudaMemAttachGlobal);
  cudaDeviceSynchronize();
#endif
}

template <class T>
RAJAStream<T>::~RAJAStream()
{
#ifdef RAJA_TARGET_CPU
  free(d_a);
  free(d_b);
  free(d_c);
#else
  cudaFree(d_a);
  cudaFree(d_b);
  cudaFree(d_c);
#endif
}

template <class T>
void RAJAStream<T>::init_arrays(T initA, T initB, T initC)
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT b = d_b;
  T* RAJA_RESTRICT c = d_c;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    a[index] = initA;
    b[index] = initB;
    c[index] = initC;
  });
}

template <class T>
void RAJAStream<T>::read_arrays(
        std::vector<T>& a, std::vector<T>& b, std::vector<T>& c)
{
  std::copy(d_a, d_a + array_size, a.data());
  std::copy(d_b, d_b + array_size, b.data());
  std::copy(d_c, d_c + array_size, c.data());
}

template <class T>
float RAJAStream<T>::read()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT c = d_c;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    T local_temp = a[index];
    if (local_temp == 126789.)
      c[index] = local_temp;
  });
  return 0.;
}

template <class T>
float RAJAStream<T>::write()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT c = d_c;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    c[index] = 0.;
  });
  return 0.;
}

template <class T>
float RAJAStream<T>::copy()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT c = d_c;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    c[index] = a[index];
  });
  return 0.;
}

template <class T>
float RAJAStream<T>::mul()
{
  T* RAJA_RESTRICT b = d_b;
  T* RAJA_RESTRICT c = d_c;
  const T scalar = startScalar;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    b[index] = scalar*c[index];
  });
  return 0.;
}

template <class T>
float RAJAStream<T>::add()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT b = d_b;
  T* RAJA_RESTRICT c = d_c;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    c[index] = a[index] + b[index];
  });
  return 0.;
}

template <class T>
float RAJAStream<T>::triad()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT b = d_b;
  T* RAJA_RESTRICT c = d_c;
  const T scalar = startScalar;
  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    a[index] = b[index] + scalar*c[index];
  });
  return 0.;
}

template <class T>
T RAJAStream<T>::dot()
{
  T* RAJA_RESTRICT a = d_a;
  T* RAJA_RESTRICT b = d_b;

  RAJA::ReduceSum<reduce_policy, T> sum(0.0);

  forall<policy>(index_set, [=] RAJA_DEVICE (RAJA::Index_type index)
  {
    sum += a[index] * b[index];
  });

  return T(sum);
}


void listDevices(void)
{
  std::cout << "This is not the device you are looking for.";
}


std::string getDeviceName(const int device)
{
  return "RAJA";
}


std::string getDeviceDriver(const int device)
{
  return "RAJA";
}

template class RAJAStream<float>;
template class RAJAStream<double>;