From efabc4828d3f1e4d71f4b633910e7f8c897ec248 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Thu, 6 Nov 2025 07:43:49 +0000 Subject: [PATCH 1/3] slice big tensor --- paddle/fluid/pybind/slice_utils.h | 2 +- paddle/phi/kernels/funcs/gather.cu.h | 6 ++++-- paddle/phi/kernels/funcs/index_elementwise.cu.h | 6 +++--- paddle/phi/kernels/funcs/index_impl.cu.h | 7 ++++--- paddle/phi/kernels/funcs/index_put_utils.h | 2 +- paddle/phi/kernels/funcs/select_impl.cu.h | 4 ++-- 6 files changed, 15 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 73af402de7b31e..e6e4eefd5ef579 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -740,7 +740,7 @@ static std::vector PrepareIndices( const paddle::Tensor& bool_2_idx, const paddle::Tensor& bool_index) { std::vector indices; - for (int j = 0; j < bool_2_idx.shape()[1]; ++j) { + for (int64_t j = 0; j < bool_2_idx.shape()[1]; ++j) { paddle::Tensor sliced_tensor = slice_ad_func(bool_2_idx, {1}, {j}, {j + 1}, {1}, {}); paddle::Tensor sliced_tensor_c = sliced_tensor.contiguous(); diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index 99dfdf1f3e1fec..54d8118510ae9a 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -151,9 +151,11 @@ __global__ void GatherGPUKernel(const T* input, int64_t input_index_dim_size, int64_t size) { int64_t block_size = blockDim.x; - int64_t idx = (blockIdx.x * block_size + threadIdx.x) * VecSize; + int64_t idx = + (static_cast(blockIdx.x) * block_size + threadIdx.x) * VecSize; int64_t outer_size = outer_dim_size * out_index_dim_size; - for (; idx < size; idx += gridDim.x * block_size * VecSize) { + for (; idx < size; + idx += static_cast(gridDim.x) * block_size * VecSize) { int64_t inner_dim_index = idx / outer_size; int64_t next_idx = idx % outer_size; int64_t index_dim_index = next_idx / outer_dim_size; diff --git a/paddle/phi/kernels/funcs/index_elementwise.cu.h b/paddle/phi/kernels/funcs/index_elementwise.cu.h index 9efbbef704a5e8..6f2c924d23ed0d 100644 --- a/paddle/phi/kernels/funcs/index_elementwise.cu.h +++ b/paddle/phi/kernels/funcs/index_elementwise.cu.h @@ -38,7 +38,7 @@ __global__ void index_elementwise_with_tensor_kernel(const int64_t N, const func_t f) { const auto tid = threadIdx.x; const auto nv = nt * vt; - auto idx = nv * blockIdx.x + tid; + int64_t idx = static_cast(nv) * blockIdx.x + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { @@ -54,7 +54,7 @@ __global__ void index_elementwise_kernel(const int64_t N, const func_t f) { const auto tid = threadIdx.x; const auto nv = nt * vt; - auto idx = nv * blockIdx.x + tid; + int64_t idx = static_cast(nv) * blockIdx.x + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { @@ -70,7 +70,7 @@ __global__ void index_put_kernel(const int64_t N, const func_t f) { const auto tid = threadIdx.x; const auto nv = nt * vt; - auto idx = nv * blockIdx.x + tid; + int64_t idx = static_cast(nv) * blockIdx.x + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { diff --git a/paddle/phi/kernels/funcs/index_impl.cu.h b/paddle/phi/kernels/funcs/index_impl.cu.h index 1ef3699ab2c86e..bafd91b613fd0b 100644 --- a/paddle/phi/kernels/funcs/index_impl.cu.h +++ b/paddle/phi/kernels/funcs/index_impl.cu.h @@ -31,8 +31,8 @@ __global__ void VectorizedIndexKernel(T *out, size_t numel, size_t main_offset, Functor func) { - size_t data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; - size_t stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + size_t data_offset = static_cast(BLOCK_ID_X) * BLOCK_NUM_X * VecSize; + size_t stride = static_cast(BLOCK_NUM_X) * GRID_NUM_X * VecSize; size_t args[VecSize]; T result[VecSize]; for (; data_offset < main_offset; data_offset += stride) { @@ -69,7 +69,8 @@ void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) { int block = config.thread_per_block.x; auto stream = dev_ctx.stream(); #endif - size_t main_offset = (numel / (vec_size * block)) * vec_size * block; + size_t main_offset = + (numel / (vec_size * static_cast(block))) * vec_size * block; switch (vec_size) { case 4: VectorizedIndexKernel diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 0e14e613468109..3839cf6edb3eda 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -310,7 +310,7 @@ static void CalCompressedDimsWith1AndWithout1( #if defined(__NVCC__) || defined(__HIPCC__) template __global__ void range_cuda_kernel(int64_t N, T* out) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t idx = threadIdx.x + static_cast(blockDim.x) * blockIdx.x; if (idx >= N) { return; diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h index 5e13b78d996eda..b73aa057a1828b 100644 --- a/paddle/phi/kernels/funcs/select_impl.cu.h +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -156,7 +156,7 @@ __global__ void CumsumOneBlock(const InT *in, int64_t numel, int64_t main_offset, Functor func) { - int64_t stride = BLOCK_NUM_X * VecSize; + int64_t stride = static_cast(BLOCK_NUM_X) * VecSize; int64_t offset = 0; OutT pre_cumsum = static_cast(0); for (; offset < main_offset; offset += stride) { @@ -164,7 +164,7 @@ __global__ void CumsumOneBlock(const InT *in, in + offset, out + offset, &pre_cumsum, stride, func); } - int num = numel - offset; + int64_t num = numel - offset; if (num > 0) { CumsumImpl( in + offset, out + offset, &pre_cumsum, num, func); From b0b7939b68d10745e5b7634baa0e7693eacc4f15 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Thu, 6 Nov 2025 12:13:43 +0000 Subject: [PATCH 2/3] enhance the slice kernels robust for big tensor --- .../phi/kernels/funcs/index_elementwise.cu.h | 6 ++++++ .../gpu/index_elementwise_get_grad_kernel.cu | 18 +++++++++++++----- .../gpu/index_elementwise_get_kernel.cu | 11 +++++------ .../gpu/index_elementwise_put_grad_kernel.cu | 10 ++++++++++ .../phi/kernels/gpu/masked_fill_grad_kernel.cu | 6 +++--- paddle/phi/kernels/gpu/masked_fill_kernel.cu | 4 ++-- 6 files changed, 39 insertions(+), 16 deletions(-) diff --git a/paddle/phi/kernels/funcs/index_elementwise.cu.h b/paddle/phi/kernels/funcs/index_elementwise.cu.h index 6f2c924d23ed0d..1142b96a90de05 100644 --- a/paddle/phi/kernels/funcs/index_elementwise.cu.h +++ b/paddle/phi/kernels/funcs/index_elementwise.cu.h @@ -227,6 +227,12 @@ static OffsetCalculator make_offset_calculator( return OffsetCalculator( iter.ndim(), iter.shape().data(), strides.data()); } +constexpr bool IsInUint32Range(int64_t value) { + return value >= 0 && value <= std::numeric_limits::max(); +} +constexpr bool IsInUint32Range(int64_t v1, int64_t v2) { + return IsInUint32Range(v1) && IsInUint32Range(v2); +} } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu index 6c3e077d21a8e6..8f9f61ca9e1cb1 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu @@ -40,7 +40,7 @@ __global__ void IndexEleGetGradAccKernel( offset_calc_t offset_calc) { const int tid = threadIdx.x; const int nv = nt * vt; - int idx = nv * blockIdx.x + tid; + int64_t idx = nv * static_cast(blockIdx.x) + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { @@ -111,10 +111,17 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx, funcs::make_offset_calculator_put<3>(desired_shape, strides_array); const int64_t N = numel; + + PADDLE_ENFORCE_EQ(true, + funcs::IsInUint32Range(N, value.numel()), + common::errors::PreconditionNotMet( + "the numel of input or output should be in [0, " + "std::numeric_limits::max()]")); constexpr int nt = 128; constexpr int vt = 4; const dim3 block(nt); - const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + const dim3 grid((N + static_cast(block.x) * vt - 1) / + (static_cast(block.x) * vt)); auto stream = dev_ctx.stream(); using dtype = funcs::OpaqueType; @@ -171,11 +178,12 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices, using opmath_t = typename phi::dtype::MPTypeTrait::Type; for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { - int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + int64_t idx = static_cast(blockIdx.x) * blockDim.y + threadIdx.y; if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])) { do { - int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + int64_t start_feature = + threadIdx.x + static_cast(blockIdx.y) * blockDim.x * SZ; if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { idx++; @@ -221,7 +229,7 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices, static_cast(weight[ii]); } } - start_feature += gridDim.y * blockDim.x * SZ; + start_feature += static_cast(gridDim.y) * blockDim.x * SZ; } idx++; } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); diff --git a/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu index 2bb2df1c82bf9e..dbdb1cd29949c8 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu @@ -67,12 +67,11 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx, funcs::make_offset_calculator_put<3>(desired_shape, strides_array); const int64_t N = output->numel(); - PADDLE_ENFORCE_GE( - N, 0, common::errors::InvalidArgument("Output numel must >= 0")); - PADDLE_ENFORCE_LE( - N, - std::numeric_limits::max(), - common::errors::InvalidArgument("Output numel must <= INT32_MAX")); + PADDLE_ENFORCE_EQ(true, + funcs::IsInUint32Range(N, input.numel()), + common::errors::PreconditionNotMet( + "the numel of input or output should be in [0, " + "std::numeric_limits::max()]")); constexpr int nt = 128; constexpr int vt = 4; const dim3 block(nt); diff --git a/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu index d9867709b55379..04a303faaf29ca 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu @@ -131,6 +131,11 @@ void GPUIndexElementwisePutGradKernel( auto index_ptrs = funcs::GetIndexDataPtrs(index); const char* out_ptr = reinterpret_cast(out_grad.data()); char* value_ptr = reinterpret_cast(value_grad->data()); + PADDLE_ENFORCE_EQ(true, + funcs::IsInUint32Range(value_grad->numel()), + common::errors::PreconditionNotMet( + "the numel of input or output should be in [0, " + "std::numeric_limits::max()]")); funcs::index_elementwise_with_tensor_kernel <<>>(N, [=] __device__(int idx) { const auto offsets = offset_calc.get(idx); @@ -153,6 +158,11 @@ void GPUIndexElementwisePutGradKernel( } else { auto index_ptrs = funcs::GetIndexDataPtrs(index); char* out_ptr = reinterpret_cast(x_grad->data()); + PADDLE_ENFORCE_EQ(true, + funcs::IsInUint32Range(value_grad->numel()), + common::errors::PreconditionNotMet( + "the numel of input or output should be in [0, " + "std::numeric_limits::max()]")); char* value_ptr = reinterpret_cast(value_grad->data()); funcs::index_elementwise_with_tensor_kernel <<>>(N, [=] __device__(int idx) { diff --git a/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu index 2034b339a0b775..4a9d8c79deb917 100644 --- a/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu @@ -42,7 +42,7 @@ __global__ void GPUMaskedFillXGradKernel(const T* out_grad, const int64_t input_len, const int64_t batch_size, T* x_grad) { - int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x); + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= (input_len / VecSize)) { return; @@ -73,7 +73,7 @@ __global__ void GPUMaskedFillValueGradKernel(const T* out_grad, const int64_t input_len, const int64_t batch_size, T* value_grad) { - int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x); + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= (input_len / VecSize)) { return; @@ -243,7 +243,7 @@ void GPUMaskedFillGrad(const phi::GPUContext& dev_ctx, int64_t input_len = out_grad.numel(); int64_t mask_len = mask.numel(); - int batch_size = input_len / mask_len; + int64_t batch_size = input_len / mask_len; int vec_size = 8; vec_size = std::min(phi::GetVectorizedSize(out_grad_data), vec_size); diff --git a/paddle/phi/kernels/gpu/masked_fill_kernel.cu b/paddle/phi/kernels/gpu/masked_fill_kernel.cu index c8573826b787ca..a094e4e4c6f03f 100644 --- a/paddle/phi/kernels/gpu/masked_fill_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_fill_kernel.cu @@ -69,7 +69,7 @@ __global__ void GPUMaskedFillKernel(const T* input, const int64_t input_len, const int64_t batch_size, T* output) { - int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x); + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= (input_len / VecSize)) { return; @@ -161,7 +161,7 @@ void GPUMaskedFill(const phi::GPUContext& dev_ctx, const T* value_data = value.data(); int64_t input_len = input.numel(); int64_t mask_len = mask.numel(); - int batch_size = input_len / mask_len; + int64_t batch_size = input_len / mask_len; int vec_size = 8; vec_size = std::min(phi::GetVectorizedSize(input_data), vec_size); From 732f27d3261b06042a4fbfe68ee0e8fb139e3999 Mon Sep 17 00:00:00 2001 From: DanielSun11 <1395924413@qq.com> Date: Tue, 11 Nov 2025 12:33:50 +0000 Subject: [PATCH 3/3] fix offset_calculator --- .../phi/kernels/funcs/index_elementwise.cu.h | 18 +- .../gpu/index_elementwise_get_grad_kernel.cu | 48 +++--- .../gpu/index_elementwise_get_kernel.cu | 36 ++-- .../gpu/index_elementwise_put_grad_kernel.cu | 163 ++++++++++-------- .../gpu/index_elementwise_put_kernel.cu | 85 +++++---- paddle/phi/kernels/stride/indexing_kernel.cu | 20 ++- 6 files changed, 222 insertions(+), 148 deletions(-) diff --git a/paddle/phi/kernels/funcs/index_elementwise.cu.h b/paddle/phi/kernels/funcs/index_elementwise.cu.h index 1142b96a90de05..af95b94588d6d2 100644 --- a/paddle/phi/kernels/funcs/index_elementwise.cu.h +++ b/paddle/phi/kernels/funcs/index_elementwise.cu.h @@ -192,15 +192,15 @@ struct OffsetCalculator { stride_t strides_[MAX_DIMS][std::max(NARGS, 1)]; }; -template -static OffsetCalculator make_offset_calculator_put( +template +static OffsetCalculator make_offset_calculator_put( std::vector desired_shape, std::array strides_array) { - return OffsetCalculator( + return OffsetCalculator( desired_shape.size(), desired_shape.data(), strides_array.data()); } -template -static OffsetCalculator make_offset_calculator( +template +static OffsetCalculator make_offset_calculator( int ndim, const int64_t* shape, const std::vector>& strides) { @@ -209,12 +209,12 @@ static OffsetCalculator make_offset_calculator( strides_array[i] = strides[i].data(); } - return OffsetCalculator( + return OffsetCalculator( ndim, shape, strides_array.data()); } -template -static OffsetCalculator make_offset_calculator( +template +static OffsetCalculator make_offset_calculator( const phi::DenseTensorIteratorBase& iter) { PADDLE_ENFORCE_LE(N, iter.ntensors(), @@ -224,7 +224,7 @@ static OffsetCalculator make_offset_calculator( for (int i = 0; i < N; i++) { strides[i] = iter.operands_[i].stride_bytes.data(); } - return OffsetCalculator( + return OffsetCalculator( iter.ndim(), iter.shape().data(), strides.data()); } constexpr bool IsInUint32Range(int64_t value) { diff --git a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu index 8f9f61ca9e1cb1..45c16873a249f0 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu @@ -63,7 +63,7 @@ __global__ void IndexEleGetGradAccKernel( } } -template +template void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx, const DenseTensor& input, const DenseTensor& value, @@ -107,16 +107,11 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx, &strides_array, &numel, strides_vec); - auto offset_calc = - funcs::make_offset_calculator_put<3>(desired_shape, strides_array); + auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>( + desired_shape, strides_array); const int64_t N = numel; - PADDLE_ENFORCE_EQ(true, - funcs::IsInUint32Range(N, value.numel()), - common::errors::PreconditionNotMet( - "the numel of input or output should be in [0, " - "std::numeric_limits::max()]")); constexpr int nt = 128; constexpr int vt = 4; const dim3 block(nt); @@ -425,18 +420,31 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx, return; #endif } - - GPUIndexElementwiseGetGrad(dev_ctx, - x, - out_grad, - index, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - accumulate, - x_grad); + if (funcs::IsInUint32Range(x_grad->numel(), out_grad.numel())) { + GPUIndexElementwiseGetGrad(dev_ctx, + x, + out_grad, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + accumulate, + x_grad); + } else { + GPUIndexElementwiseGetGrad(dev_ctx, + x, + out_grad, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + accumulate, + x_grad); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu index dbdb1cd29949c8..b741d51afe95d9 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu @@ -20,7 +20,7 @@ #include "paddle/phi/kernels/funcs/stride_utils.h" namespace phi { -template +template void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx, const DenseTensor& input, const std::vector& index, @@ -63,8 +63,8 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx, &strides_array, &numel, strides_vec); - auto offset_calc = - funcs::make_offset_calculator_put<3>(desired_shape, strides_array); + auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>( + desired_shape, strides_array); const int64_t N = output->numel(); PADDLE_ENFORCE_EQ(true, @@ -135,15 +135,27 @@ void IndexElementwiseGetKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); if (out->numel() == 0) return; - GPUIndexElementwiseGetKernel(dev_ctx, - x, - index, - input_dims, - input_strides, - index_dims, - index_stride, - slice_offset, - out); + if (funcs::IsInUint32Range(out->numel())) { + GPUIndexElementwiseGetKernel(dev_ctx, + x, + index, + input_dims, + input_strides, + index_dims, + index_stride, + slice_offset, + out); + } else { + GPUIndexElementwiseGetKernel(dev_ctx, + x, + index, + input_dims, + input_strides, + index_dims, + index_stride, + slice_offset, + out); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu index 04a303faaf29ca..5b7b94e781431e 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu @@ -26,7 +26,7 @@ namespace phi { -template +template void GPUIndexElementwisePutGradKernel( const phi::GPUContext& dev_ctx, const DenseTensor& out_grad, @@ -78,14 +78,10 @@ void GPUIndexElementwisePutGradKernel( &strides_array, &numel, strides_vec); - auto offset_calc = - funcs::make_offset_calculator_put<3>(desired_shape, strides_array); + auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>( + desired_shape, strides_array); const int64_t N = numel; - PADDLE_ENFORCE_EQ(true, - (N >= 0 && N <= std::numeric_limits::max()), - common::errors::PreconditionNotMet( - "the value of N should be in [0, " - "std::numeric_limits::max()]")); + constexpr int nt = 128; constexpr int vt = 4; const dim3 block(nt); @@ -189,7 +185,7 @@ void GPUIndexElementwisePutGradKernel( } } -template +template void LaunchIndexElementwisePutWithTensorGradCudaKernel( const Context& dev_ctx, const std::vector& indices, @@ -204,16 +200,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel( if (x_grad && !value_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); - GPUIndexElementwisePutGradKernel(dev_ctx, - out_grad, - indices, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - x_grad, - value_grad); + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + value_grad); } else if (value_grad) { if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); @@ -223,16 +219,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel( tmp_value_grad.Resize(common::make_ddim(input_dims)); dev_ctx.template Alloc(&tmp_value_grad); - GPUIndexElementwisePutGradKernel(dev_ctx, - out_grad, - indices, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - x_grad, - &tmp_value_grad); + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + &tmp_value_grad); std::vector v_dims(tmp_value_grad.dims().size()); std::iota(v_dims.begin(), v_dims.end(), 0); @@ -245,31 +241,31 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel( value_grad); } else if (value_grad->dims() == common::make_ddim(input_dims)) { dev_ctx.template Alloc(value_grad); - GPUIndexElementwisePutGradKernel(dev_ctx, - out_grad, - indices, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - x_grad, - value_grad); + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + value_grad); } else { DenseTensor tmp_value_grad(value_grad->dtype()); tmp_value_grad.Resize(common::make_ddim(input_dims)); dev_ctx.template Alloc(&tmp_value_grad); - GPUIndexElementwisePutGradKernel(dev_ctx, - out_grad, - indices, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - x_grad, - &tmp_value_grad); + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + &tmp_value_grad); std::vector after_dims = common::vectorize(tmp_value_grad.dims()); @@ -307,17 +303,29 @@ void LaunchIndexElementwisePutGradCudaKernel( DenseTensor* x_grad) { if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); - - GPUIndexElementwisePutGradKernel(dev_ctx, - out_grad, - indices, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - x_grad, - nullptr); + if (funcs::IsInUint32Range(x_grad->numel())) { + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + nullptr); + } else { + GPUIndexElementwisePutGradKernel(dev_ctx, + out_grad, + indices, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + x_grad, + nullptr); + } } } @@ -399,17 +407,30 @@ void IndexElementwisePutWithTensorGradKernel( } return; } - - LaunchIndexElementwisePutWithTensorGradCudaKernel(dev_ctx, - indices, - out_grad, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - value_grad, - x_grad); + if (x_grad && funcs::IsInUint32Range(x_grad->numel())) { + LaunchIndexElementwisePutWithTensorGradCudaKernel(dev_ctx, + indices, + out_grad, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + value_grad, + x_grad); + } else { + LaunchIndexElementwisePutWithTensorGradCudaKernel( + dev_ctx, + indices, + out_grad, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + value_grad, + x_grad); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu b/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu index 1f195a06276267..844b8dd29da8f0 100644 --- a/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu @@ -21,7 +21,7 @@ namespace phi { -template +template void GPUIndexElementwisePutKernel(const phi::GPUContext& dev_ctx, const DenseTensor& input, const Scalar& value, @@ -65,8 +65,8 @@ void GPUIndexElementwisePutKernel(const phi::GPUContext& dev_ctx, &numel, strides_vec); - auto offset_calc = - funcs::make_offset_calculator_put<3>(desired_shape, strides_array); + auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>( + desired_shape, strides_array); const int64_t N = numel; PADDLE_ENFORCE_EQ(true, @@ -113,7 +113,7 @@ void GPUIndexElementwisePutKernel(const phi::GPUContext& dev_ctx, } } -template +template void GPUIndexElementwisePutWithTensorKernel( const phi::GPUContext& dev_ctx, const DenseTensor& input, @@ -157,15 +157,11 @@ void GPUIndexElementwisePutWithTensorKernel( &numel, strides_vec); - auto offset_calc = - funcs::make_offset_calculator_put<3>(desired_shape, strides_array); + auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>( + desired_shape, strides_array); const int64_t N = numel; - PADDLE_ENFORCE_EQ(true, - (N >= 0 && N <= std::numeric_limits::max()), - common::errors::PreconditionNotMet( - "the value of N should be in [0, " - "std::numeric_limits::max()]")); + constexpr int nt = 128; constexpr int vt = 4; const dim3 block(nt); @@ -221,16 +217,29 @@ void IndexElementwisePutKernel(const Context& dev_ctx, phi::DataType::INT64)); dev_ctx.template Alloc(out); if (out->numel() == 0) return; - GPUIndexElementwisePutKernel(dev_ctx, - x, - value, - index, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - out); + if (funcs::IsInUint32Range(out->numel())) { + GPUIndexElementwisePutKernel(dev_ctx, + x, + value, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + out); + } else { + GPUIndexElementwisePutKernel(dev_ctx, + x, + value, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + out); + } } template @@ -256,16 +265,30 @@ void IndexElementwisePutWithTensorKernel( dev_ctx.template Alloc(out); if (out->numel() == 0) return; - GPUIndexElementwisePutWithTensorKernel(dev_ctx, - x, - value, - index, - input_dims, - input_strides, - index_dims, - index_strides, - slice_offset, - out); + if (funcs::IsInUint32Range(out->numel())) { + GPUIndexElementwisePutWithTensorKernel(dev_ctx, + x, + value, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + out); + + } else { + GPUIndexElementwisePutWithTensorKernel(dev_ctx, + x, + value, + index, + input_dims, + input_strides, + index_dims, + index_strides, + slice_offset, + out); + } } } // namespace phi diff --git a/paddle/phi/kernels/stride/indexing_kernel.cu b/paddle/phi/kernels/stride/indexing_kernel.cu index ec44b2c531f953..6c0e4fc61e468b 100644 --- a/paddle/phi/kernels/stride/indexing_kernel.cu +++ b/paddle/phi/kernels/stride/indexing_kernel.cu @@ -77,7 +77,7 @@ inline bool CheckIsDimsMatchBool(const DDim& first, const DDim& second) { return false; } -template +template void LaunchIndexPutKernel_V2(const Context& dev_ctx, const DenseTensor& x, const std::vector& indices, @@ -248,8 +248,13 @@ void IndexPutKernel_V2(const Context& dev_ctx, "Kernel using DenseTensorIterator " "be called, something wrong has happened!")); } - LaunchIndexPutKernel_V2( - dev_ctx, x_, indices, value_, accumulate, out); + if (out && !funcs::IsInUint32Range(out->numel(), value_.numel())) { + LaunchIndexPutKernel_V2( + dev_ctx, x_, indices, value_, accumulate, out); + } else { + LaunchIndexPutKernel_V2( + dev_ctx, x_, indices, value_, accumulate, out); + } } template @@ -327,8 +332,13 @@ void IndexPutGradKernel_V2(const Context& dev_ctx, phi::IntArray(common::vectorize(value.dims())), 0, &value_zero); - LaunchIndexPutKernel_V2( - dev_ctx, out_grad, indices, value_zero, false, x_grad); + if (funcs::IsInUint32Range(x_grad->numel(), value.numel())) { + LaunchIndexPutKernel_V2( + dev_ctx, out_grad, indices, value_zero, false, x_grad); + } else { + LaunchIndexPutKernel_V2( + dev_ctx, out_grad, indices, value_zero, false, x_grad); + } } } }