18#ifndef REALM_CUDA_REDOP_H
19#define REALM_CUDA_REDOP_H
33 namespace ReductionKernels {
35 template <
typename LHS,
typename RHS,
typename F>
36 __device__
void iter_cuda_kernel(uintptr_t lhs_base, uintptr_t lhs_stride,
37 uintptr_t rhs_base, uintptr_t rhs_stride,
38 size_t count, F func,
void *context =
nullptr)
40 const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
41 for(
size_t idx = tid; idx < count; idx += blockDim.x * gridDim.x) {
42 (*func)(*
reinterpret_cast<LHS *
>(lhs_base + idx * lhs_stride),
43 *
reinterpret_cast<const RHS *
>(rhs_base + idx * rhs_stride), context);
47 template <
typename REDOP,
bool EXCL>
48 __device__
void redop_apply_wrapper(
typename REDOP::LHS &lhs,
49 const typename REDOP::RHS &rhs,
void *context)
51 REDOP &redop = *
reinterpret_cast<REDOP *
>(context);
52 redop.template apply_cuda<EXCL>(lhs, rhs);
54 template <
typename REDOP,
bool EXCL>
55 __device__
void redop_fold_wrapper(
typename REDOP::RHS &rhs1,
56 const typename REDOP::RHS &rhs2,
void *context)
58 REDOP &redop = *
reinterpret_cast<REDOP *
>(context);
59 redop.template fold_cuda<EXCL>(rhs1, rhs2);
62 template <
typename REDOP,
bool EXCL>
63 __global__
void apply_cuda_kernel(uintptr_t lhs_base, uintptr_t lhs_stride,
64 uintptr_t rhs_base, uintptr_t rhs_stride,
65 size_t count, REDOP redop)
67 iter_cuda_kernel<typename REDOP::LHS, typename REDOP::RHS>(
68 lhs_base, lhs_stride, rhs_base, rhs_stride, count,
69 redop_apply_wrapper<REDOP, EXCL>, (
void *)&redop);
72 template <
typename REDOP,
bool EXCL>
73 __global__
void fold_cuda_kernel(uintptr_t rhs1_base, uintptr_t rhs1_stride,
74 uintptr_t rhs2_base, uintptr_t rhs2_stride,
75 size_t count, REDOP redop)
77 iter_cuda_kernel<typename REDOP::RHS, typename REDOP::RHS>(
78 rhs1_base, rhs1_stride, rhs2_base, rhs2_stride, count,
79 redop_fold_wrapper<REDOP, EXCL>, (
void *)&redop);
86 template <
typename REDOP,
typename T >
87 void add_cuda_redop_kernels(T *redop)
91 redop->cuda_apply_excl_fn =
92 reinterpret_cast<void *
>(&ReductionKernels::apply_cuda_kernel<REDOP, true>);
93 redop->cuda_apply_nonexcl_fn =
94 reinterpret_cast<void *
>(&ReductionKernels::apply_cuda_kernel<REDOP, false>);
95 redop->cuda_fold_excl_fn =
96 reinterpret_cast<void *
>(&ReductionKernels::fold_cuda_kernel<REDOP, true>);
97 redop->cuda_fold_nonexcl_fn =
98 reinterpret_cast<void *
>(&ReductionKernels::fold_cuda_kernel<REDOP, false>);
103 typedef cudaError_t (*PFN_cudaLaunchKernel)(
const void *func, dim3 gridDim,
104 dim3 blockDim,
void **args,
106 PFN_cudaLaunchKernel launch_fn =
107 static_cast<PFN_cudaLaunchKernel
>(cudaLaunchKernel);
108 redop->cudaLaunchKernel_fn =
reinterpret_cast<void *
>(launch_fn);
109#if CUDART_VERSION >= 11000
110 typedef cudaError_t (*PFN_cudaGetFuncBySymbol)(cudaFunction_t * functionPtr,
111 const void *symbolPtr);
112 PFN_cudaGetFuncBySymbol symbol_fn =
113 static_cast<PFN_cudaGetFuncBySymbol
>(cudaGetFuncBySymbol);
114 redop->cudaGetFuncBySymbol_fn =
reinterpret_cast<void *
>(symbol_fn);
#define cudaStream_t
Definition hip_cuda.h:27
#define cudaError_t
Definition hip_cuda.h:25
Definition activemsg.h:38