@@ -41,7 +41,7 @@ __global__ void IndexEleGetGradAccKernel(
4141 offset_calc_t offset_calc) {
4242 const int tid = threadIdx .x ;
4343 const int nv = nt * vt;
44- int idx = nv * blockIdx .x + tid;
44+ int64_t idx = nv * static_cast < int64_t >( blockIdx .x ) + tid;
4545#pragma unroll
4646 for (int i = 0 ; i < vt; i++) {
4747 if (idx < N) {
@@ -112,10 +112,17 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
112112 funcs::make_offset_calculator_put<3 >(desired_shape, strides_array);
113113
114114 const int64_t N = numel;
115+
116+ PADDLE_ENFORCE_EQ (true ,
117+ funcs::IsInUint32Range (N, value.numel ()),
118+ common::errors::PreconditionNotMet (
119+ " the numel of input or output should be in [0, "
120+ " std::numeric_limits<int32_t>::max()]" ));
115121 constexpr int nt = 128 ;
116122 constexpr int vt = 4 ;
117123 const dim3 block (nt);
118- const dim3 grid ((N + block.x * vt - 1 ) / (block.x * vt));
124+ const dim3 grid ((N + static_cast <int64_t >(block.x ) * vt - 1 ) /
125+ (static_cast <int64_t >(block.x ) * vt));
119126 auto stream = dev_ctx.stream ();
120127
121128 using dtype = funcs::OpaqueType<sizeof (T)>;
@@ -172,11 +179,12 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices,
172179 using opmath_t = typename phi::dtype::MPTypeTrait<scalar_t >::Type;
173180
174181 for (int64_t z = blockIdx .z ; z < outer_dim; z += gridDim .z ) {
175- int64_t idx = blockIdx .x * blockDim .y + threadIdx .y ;
182+ int64_t idx = static_cast < int64_t >( blockIdx .x ) * blockDim .y + threadIdx .y ;
176183 if (idx < numel &&
177184 (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1 ])) {
178185 do {
179- int64_t start_feature = threadIdx .x + blockIdx .y * blockDim .x * SZ;
186+ int64_t start_feature =
187+ threadIdx .x + static_cast <int64_t >(blockIdx .y ) * blockDim .x * SZ;
180188 if (!accumulate && (idx < numel - 1 ) &&
181189 sorted_indices[idx] == sorted_indices[idx + 1 ]) {
182190 idx++;
@@ -222,7 +230,7 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices,
222230 static_cast <scalar_t >(weight[ii]);
223231 }
224232 }
225- start_feature += gridDim .y * blockDim .x * SZ;
233+ start_feature += static_cast < int64_t >( gridDim .y ) * blockDim .x * SZ;
226234 }
227235 idx++;
228236 } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1 ]);
0 commit comments