// Copyright 2025, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto.  Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.

#include <nvhpc/openacc_scan.hpp>

// _Index is a signed integral type suitable for indexing into any of the
// sequences that are passed to that algorithm.  It is always the
// difference_type for one of the iterator types.

namespace std { namespace __stdpar { namespace __openacc {

//========== exclusive_scan ==========

// (exclusive_scan is implemented in <nvhpc/openacc_scan.hpp>)

//========== inclusive_scan ==========

template <class _FIt1, class _FIt2, class _BF>
_FIt2 inclusive_scan(_FIt1 __first, _FIt1 __last, _FIt2 __d_first, _BF __f) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  using _ValT = typename std::iterator_traits<_FIt1>::value_type;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __d_first;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _ValT* __partial_scan = new _ValT[__num_chunks];
  // Inclusive scan of each chunk in parallel.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    __d_first[__chunk_start] = __first[__chunk_start];
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __d_first[__j] = __f(__d_first[__j - 1], __first[__j]);
    }
    __partial_scan[__i] = __d_first[__chunk_end - 1];
  }
  // Inclusive scan of the partial scan array.  The last entry isn't used.
  for (_Index __i = 1; __i < __num_chunks - 1; ++__i) {
    __partial_scan[__i] =
        __f(__partial_scan[__i - 1], std::move(__partial_scan[__i]));
  }
  // Apply the partial scan results to each element, processing chunks in
  // parallel.  The first chunk doesn't need any adjustment.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _Index __this_size = __chunk_end - __chunk_start;
    #pragma acc_stdpar loop vector
    for (_Index __j = 0; __j < __this_size; ++__j) {
      _Index __idx = __chunk_start + __j;
      __d_first[__idx] =
          __f(__partial_scan[__i - 1], std::move(__d_first[__idx]));
    }
  }
  delete[] __partial_scan;
  return __d_first + __input_size;
}

template <class _FIt1, class _FIt2, class _BF, class _T>
_FIt2 inclusive_scan(_FIt1 __first, _FIt1 __last, _FIt2 __d_first, _BF __f,
                     _T __init) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  using _ValT = _T;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __d_first;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _ValT* __partial_scan = new _ValT[__num_chunks];
  // Inclusive scan of each chunk in parallel
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    if (__i == 0) {
      __d_first[0] = __f(std::move(__init), __first[0]);
    } else {
      __d_first[__chunk_start] = __first[__chunk_start];
    }
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __d_first[__j] = __f(__d_first[__j - 1], __first[__j]);
    }
    __partial_scan[__i] = __d_first[__chunk_end - 1];
  }
  // Inclusive scan of the partial scan array.  The last entry isn't used.
  for (_Index __i = 1; __i < __num_chunks - 1; ++__i) {
    __partial_scan[__i] =
        __f(__partial_scan[__i - 1], std::move(__partial_scan[__i]));
  }
  // Apply the partial scan results to each element, processing chunks in
  // parallel.  The first chunk doesn't need any adjustment.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _Index __this_size = __chunk_end - __chunk_start;
    #pragma acc_stdpar loop vector
    for (_Index __j = 0; __j < __this_size; ++__j) {
      _Index __idx = __chunk_start + __j;
      __d_first[__idx] =
          __f(__partial_scan[__i - 1], std::move(__d_first[__idx]));
    }
  }
  delete[] __partial_scan;
  return __d_first + __input_size;
}

//========== reduce ==========

template <class _FIt, class _T, class _BF>
typename std::enable_if<std::is_scalar<_T>::value &&
                            __op_is_instance_of<_BF, std::plus, _T>::value,
                        _T>::type
reduce(_FIt __first, _FIt __last, _T __init, _BF __f) {
  using _Index = typename std::iterator_traits<_FIt>::difference_type;
  _Index __input_size = std::distance(__first, __last);
  _T __result = std::move(__init);
  if constexpr (!std::is_pointer<_FIt>::value) {
    #pragma acc_stdpar parallel loop reduction(+ : __result)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __first[__i];
    }
  } else {
    static_assert(std::is_pointer<_FIt>::value,
                  "internal error: unhandled OpenACC parallel loop variant");
    #pragma acc_stdpar parallel loop reduction(+ : __result) deviceptr(__first)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __first[__i];
    }
  }
  return __result;
}

template <class _FIt, class _T, class _BF>
typename std::enable_if<!std::is_scalar<_T>::value ||
                            !__op_is_instance_of<_BF, std::plus, _T>::value,
                        _T>::type
reduce(_FIt __first, _FIt __last, _T __init, _BF __f) {
  using _Index = typename std::iterator_traits<_FIt>::difference_type;
  using _ValT = _T;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return std::move(__init);
  }
  constexpr _Index __chunk_size = 20;
  _Index __num_chunks = __detail::__div_round_up(__input_size, __chunk_size);
  _ValT* __partial_sums = new _ValT[__num_chunks];
  #pragma acc_stdpar parallel loop vector_length(1024)
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start = __i * __chunk_size;
    _Index __chunk_end =
        __detail::min(__chunk_start + __chunk_size, __input_size);
    _ValT __sum = __first[__chunk_start];
    #pragma acc_stdpar loop seq
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __sum = __f(std::move(__sum), __first[__j]);
    }
    __partial_sums[__i] = std::move(__sum);
  }
  constexpr _Index __smallest_num_chunks = 100;
  _Index __small_chunk_size = 4;
  _Index __prev_num_chunks = __num_chunks;
  _Index __curr_num_chunks =
      __detail::__div_round_up(__num_chunks, __small_chunk_size);
  _Index __prev_step_size = 1;
  _Index __curr_step_size = __small_chunk_size;
  while (__curr_num_chunks > __smallest_num_chunks) {
    _Index __v = __detail::min(__curr_num_chunks, 1024);
    #pragma acc_stdpar parallel loop vector_length(__v)
    for (_Index __i = 0; __i < __curr_num_chunks; ++__i) {
      _Index __idx = __i * __curr_step_size;
      _Index __step = __i * __small_chunk_size;
      _Index __curr_size =
          __detail::min(__small_chunk_size, __prev_num_chunks - __step);
      _ValT __sum = std::move(__partial_sums[__idx]);
      #pragma acc_stdpar loop seq
      for (_Index __j = 1; __j < __curr_size; ++__j) {
        _Index __jdx = __j * __prev_step_size;
        __sum = __f(std::move(__sum), std::move(__partial_sums[__idx + __jdx]));
      }
      __partial_sums[__idx] = std::move(__sum);
    }
    __prev_num_chunks = __curr_num_chunks;
    __curr_num_chunks =
        __detail::__div_round_up(__curr_num_chunks, __small_chunk_size);
    __prev_step_size = __curr_step_size;
    __curr_step_size *= __small_chunk_size;
  }
  _ValT __result = std::move(__init);
  for (_Index __i = 0; __i < __num_chunks; __i += __prev_step_size) {
    __result = __f(std::move(__result), std::move(__partial_sums[__i]));
  }
  delete[] __partial_sums;
  return __result;
}

//========== transform_exclusive_scan ==========

template <class _FIt1, class _FIt2, class _T, class _BF, class _UF>
_FIt2 transform_exclusive_scan(_FIt1 __first, _FIt1 __last, _FIt2 __d_first,
                               _T __init, _BF __fsum, _UF __ft) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  using _ValT = _T;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __d_first;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _ValT* __partial_scan = new _ValT[__num_chunks];
  // To protect against first == d_first, the last element from the previous
  // chunk needs to be stored in the temporary storage.
  __partial_scan[0] = std::move(__init);
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    __partial_scan[__i] = __ft(
        __first[__detail::__chunk_start(__i, __chunk_size, __leftover) - 1]);
  }
  // Exclusive scan of each chunk in parallel.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _ValT __acc = std::move(__partial_scan[__i]);
    for (_Index __j = __chunk_start; __j < __chunk_end - 1; ++__j) {
      _ValT __next = __fsum(__acc, __ft(__first[__j]));
      __d_first[__j] = std::move(__acc);
      __acc = std::move(__next);
    }
    __d_first[__chunk_end - 1] = __acc;
    __partial_scan[__i] = std::move(__acc);
  }
  // Inclusive scan of the partial scan array.  The last entry isn't used.
  for (_Index __i = 1; __i < __num_chunks - 1; ++__i) {
    __partial_scan[__i] =
        __fsum(__partial_scan[__i - 1], std::move(__partial_scan[__i]));
  }
  // Apply the partial scan results to each element, processing chunks in
  // parallel.  The first chunk doesn't need any adjustment.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _Index __this_size = __chunk_end - __chunk_start;
    #pragma acc_stdpar loop vector
    for (_Index __j = 0; __j < __this_size; ++__j) {
      _Index __idx = __chunk_start + __j;
      __d_first[__idx] =
          __fsum(__partial_scan[__i - 1], std::move(__d_first[__idx]));
    }
  }
  delete[] __partial_scan;
  return __d_first + __input_size;
}

//========== transform_inclusive_scan ==========

template <class _FIt1, class _FIt2, class _BF, class _UF>
_FIt2 transform_inclusive_scan(_FIt1 __first, _FIt1 __last, _FIt2 __d_first,
                               _BF __fsum, _UF __ft) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  using _ValT = typename std::decay<decltype(__ft(*__first))>::type;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __d_first;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _ValT* __partial_scan = new _ValT[__num_chunks];
  // Inclusive scan of each chunk in parallel.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    __d_first[__chunk_start] = __ft(__first[__chunk_start]);
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __d_first[__j] = __fsum(__d_first[__j - 1], __ft(__first[__j]));
    }
    __partial_scan[__i] = __d_first[__chunk_end - 1];
  }
  // Inclusive scan of the partial scan array.  The last entry isn't used.
  for (_Index __i = 1; __i < __num_chunks - 1; ++__i) {
    __partial_scan[__i] =
        __fsum(__partial_scan[__i - 1], std::move(__partial_scan[__i]));
  }
  // Apply the partial scan results to each element, processing chunks in
  // parallel.  The first chunk doesn't need any adjustment.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _Index __this_size = __chunk_end - __chunk_start;
    #pragma acc_stdpar loop vector
    for (_Index __j = 0; __j < __this_size; ++__j) {
      _Index __idx = __chunk_start + __j;
      __d_first[__idx] =
          __fsum(__partial_scan[__i - 1], std::move(__d_first[__idx]));
    }
  }
  delete[] __partial_scan;
  return __d_first + __input_size;
}

template <class _FIt1, class _FIt2, class _BF, class _UF, class _T>
_FIt2 transform_inclusive_scan(_FIt1 __first, _FIt1 __last, _FIt2 __d_first,
                               _BF __fsum, _UF __ft, _T __init) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  using _ValT = _T;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __d_first;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _ValT* __partial_scan = new _ValT[__num_chunks];
  // Inclusive scan of each chunk in parallel
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    if (__i == 0) {
      __d_first[0] = __fsum(std::move(__init), __ft(__first[0]));
    } else {
      __d_first[__chunk_start] = __ft(__first[__chunk_start]);
    }
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __d_first[__j] = __fsum(__d_first[__j - 1], __ft(__first[__j]));
    }
    __partial_scan[__i] = __d_first[__chunk_end - 1];
  }
  // Inclusive scan of the partial scan array.  The last entry isn't used.
  for (_Index __i = 1; __i < __num_chunks - 1; ++__i) {
    __partial_scan[__i] =
        __fsum(__partial_scan[__i - 1], std::move(__partial_scan[__i]));
  }
  // Apply the partial scan results to each element, processing chunks in
  // parallel.  The first chunk doesn't need any adjustment.
  #pragma acc_stdpar parallel loop
  for (_Index __i = 1; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _Index __this_size = __chunk_end - __chunk_start;
    #pragma acc_stdpar loop vector
    for (_Index __j = 0; __j < __this_size; ++__j) {
      _Index __idx = __chunk_start + __j;
      __d_first[__idx] =
          __fsum(__partial_scan[__i - 1], std::move(__d_first[__idx]));
    }
  }
  delete[] __partial_scan;
  return __d_first + __input_size;
}

//========== transform_reduce ==========

template <class _FIt1, class _FIt2, class _T, class _BF1, class _BF2>
typename std::enable_if<std::is_scalar<_T>::value &&
                            __op_is_instance_of<_BF1, std::plus, _T>::value,
                        _T>::type
transform_reduce(_FIt1 __first1, _FIt1 __last1, _FIt2 __first2, _T __init,
                 _BF1 __fsum, _BF2 __ft) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  _Index __input_size = std::distance(__first1, __last1);
  _T __result = std::move(__init);
  if constexpr (!std::is_pointer<_FIt1>::value &&
                !std::is_pointer<_FIt2>::value) {
    #pragma acc_stdpar parallel loop reduction(+ : __result)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first1[__i], __first2[__i]);
    }
  } else if constexpr (std::is_pointer<_FIt1>::value &&
                       !std::is_pointer<_FIt2>::value) {
    #pragma acc_stdpar parallel loop reduction(+ : __result) deviceptr(__first1)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first1[__i], __first2[__i]);
    }
  } else if constexpr (!std::is_pointer<_FIt1>::value &&
                       std::is_pointer<_FIt2>::value) {
    #pragma acc_stdpar parallel loop reduction(+ : __result) deviceptr(__first2)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first1[__i], __first2[__i]);
    }
  } else {
    static_assert(std::is_pointer<_FIt1>::value &&
                      std::is_pointer<_FIt2>::value,
                  "internal error: unhandled OpenACC parallel loop variant");
    #pragma acc_stdpar parallel loop reduction(+ : __result) \
            deviceptr(__first1, __first2)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first1[__i], __first2[__i]);
    }
  }
  return __result;
}

template <class _FIt, class _T, class _BF, class _UF>
typename std::enable_if<std::is_scalar<_T>::value &&
                            __op_is_instance_of<_BF, std::plus, _T>::value,
                        _T>::type
transform_reduce(_FIt __first, _FIt __last, _T __init, _BF __fsum, _UF __ft) {
  using _Index = typename std::iterator_traits<_FIt>::difference_type;
  _Index __input_size = std::distance(__first, __last);
  _T __result = std::move(__init);
  if constexpr (!std::is_pointer<_FIt>::value) {
    #pragma acc_stdpar parallel loop reduction(+ : __result)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first[__i]);
    }
  } else {
    static_assert(std::is_pointer<_FIt>::value,
                  "internal error: unhandled OpenACC parallel loop variant");
    #pragma acc_stdpar parallel loop reduction(+ : __result) deviceptr(__first)
    for (_Index __i = 0; __i < __input_size; ++__i) {
      __result = std::move(__result) + __ft(__first[__i]);
    }
  }
  return __result;
}

template <class _FIt1, class _FIt2, class _T, class _BF1, class _BF2>
typename std::enable_if<!std::is_scalar<_T>::value ||
                            !__op_is_instance_of<_BF1, std::plus, _T>::value,
                        _T>::type
transform_reduce(_FIt1 __first1, _FIt1 __last1, _FIt2 __first2, _T __init,
                 _BF1 __fsum, _BF2 __ft) {
  using _Index = typename std::iterator_traits<_FIt1>::difference_type;
  _Index __input_size = std::distance(__first1, __last1);
  if (__input_size == 0) {
    return __init;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _T* __partial_sums = new _T[__num_chunks];
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _T __sum = __ft(__first1[__chunk_start], __first2[__chunk_start]);
    #pragma acc_stdpar loop seq
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __sum = __fsum(std::move(__sum), __ft(__first1[__j], __first2[__j]));
    }
    __partial_sums[__i] = std::move(__sum);
  }
  _T __result = std::move(__init);
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    __result = __fsum(std::move(__result), std::move(__partial_sums[__i]));
  }
  delete[] __partial_sums;
  return __result;
}

template <class _FIt, class _T, class _BF, class _UF>
typename std::enable_if<!std::is_scalar<_T>::value ||
                            !__op_is_instance_of<_BF, std::plus, _T>::value,
                        _T>::type
transform_reduce(_FIt __first, _FIt __last, _T __init, _BF __fsum, _UF __ft) {
  using _Index = typename std::iterator_traits<_FIt>::difference_type;
  _Index __input_size = std::distance(__first, __last);
  if (__input_size == 0) {
    return __init;
  }
  _Index __num_chunks = __detail::__iterations_for_reduce_or_scan(__input_size);
  _Index __chunk_size = __input_size / __num_chunks;
  _Index __leftover = __input_size % __num_chunks;
  _T* __partial_sums = new _T[__num_chunks];
  #pragma acc_stdpar parallel loop
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    _Index __chunk_start =
        __detail::__chunk_start(__i, __chunk_size, __leftover);
    _Index __chunk_end = __detail::__chunk_end(__i, __chunk_size, __leftover);
    _T __sum = __ft(__first[__chunk_start]);
    #pragma acc_stdpar loop seq
    for (_Index __j = __chunk_start + 1; __j < __chunk_end; ++__j) {
      __sum = __fsum(std::move(__sum), __ft(__first[__j]));
    }
    __partial_sums[__i] = std::move(__sum);
  }
  _T __result = std::move(__init);
  for (_Index __i = 0; __i < __num_chunks; ++__i) {
    __result = __fsum(std::move(__result), std::move(__partial_sums[__i]));
  }
  delete[] __partial_sums;
  return __result;
}

}}} // namespace std::__stdpar::__openacc
