Skip to content

Commit 8cfbbc1

Browse files
authored
fix offset_calculator (#76372)
1 parent 995050e commit 8cfbbc1

File tree

6 files changed

+213
-140
lines changed

6 files changed

+213
-140
lines changed

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ struct OffsetCalculator {
192192
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
193193
};
194194

195-
template <int N, bool signed_strides = false>
196-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator_put(
195+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
196+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator_put(
197197
std::vector<int64_t> desired_shape, std::array<int64_t*, N> strides_array) {
198-
return OffsetCalculator<N, uint32_t, signed_strides>(
198+
return OffsetCalculator<N, OffsetT, signed_strides>(
199199
desired_shape.size(), desired_shape.data(), strides_array.data());
200200
}
201201

202-
template <int N, bool signed_strides = false>
203-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
202+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
203+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
204204
int ndim,
205205
const int64_t* shape,
206206
const std::vector<std::vector<int64_t>>& strides) {
@@ -209,12 +209,12 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
209209
strides_array[i] = strides[i].data();
210210
}
211211

212-
return OffsetCalculator<N, uint32_t, signed_strides>(
212+
return OffsetCalculator<N, OffsetT, signed_strides>(
213213
ndim, shape, strides_array.data());
214214
}
215215

216-
template <int N, bool signed_strides = false>
217-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
216+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
217+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
218218
const phi::DenseTensorIteratorBase& iter) {
219219
PADDLE_ENFORCE_LE(N,
220220
iter.ntensors(),
@@ -224,7 +224,7 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
224224
for (int i = 0; i < N; i++) {
225225
strides[i] = iter.operands_[i].stride_bytes.data();
226226
}
227-
return OffsetCalculator<N, uint32_t, signed_strides>(
227+
return OffsetCalculator<N, OffsetT, signed_strides>(
228228
iter.ndim(), iter.shape().data(), strides.data());
229229
}
230230
constexpr bool IsInUint32Range(int64_t value) {

paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ __global__ void IndexEleGetGradAccKernel(
6464
}
6565
}
6666

67-
template <typename T, typename IndexT>
67+
template <typename T, typename IndexT, typename OffsetT = uint32_t>
6868
void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
6969
const DenseTensor& input,
7070
const DenseTensor& value,
@@ -108,16 +108,11 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
108108
&strides_array,
109109
&numel,
110110
strides_vec);
111-
auto offset_calc =
112-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
111+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
112+
desired_shape, strides_array);
113113

114114
const int64_t N = numel;
115115

116-
PADDLE_ENFORCE_EQ(true,
117-
funcs::IsInUint32Range(N, value.numel()),
118-
common::errors::PreconditionNotMet(
119-
"the numel of input or output should be in [0, "
120-
"std::numeric_limits<int32_t>::max()]"));
121116
constexpr int nt = 128;
122117
constexpr int vt = 4;
123118
const dim3 block(nt);
@@ -426,18 +421,31 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
426421
return;
427422
#endif
428423
}
429-
430-
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
431-
x,
432-
out_grad,
433-
index,
434-
input_dims,
435-
input_strides,
436-
index_dims,
437-
index_strides,
438-
slice_offset,
439-
accumulate,
440-
x_grad);
424+
if (funcs::IsInUint32Range(x_grad->numel(), out_grad.numel())) {
425+
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
426+
x,
427+
out_grad,
428+
index,
429+
input_dims,
430+
input_strides,
431+
index_dims,
432+
index_strides,
433+
slice_offset,
434+
accumulate,
435+
x_grad);
436+
} else {
437+
GPUIndexElementwiseGetGrad<T, int64_t, uint64_t>(dev_ctx,
438+
x,
439+
out_grad,
440+
index,
441+
input_dims,
442+
input_strides,
443+
index_dims,
444+
index_strides,
445+
slice_offset,
446+
accumulate,
447+
x_grad);
448+
}
441449
}
442450

443451
} // namespace phi

paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "paddle/phi/kernels/funcs/stride_utils.h"
2222

2323
namespace phi {
24-
template <typename T, typename IndexT = int>
24+
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
2525
void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
2626
const DenseTensor& input,
2727
const std::vector<const DenseTensor*> index,
@@ -64,8 +64,8 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
6464
&strides_array,
6565
&numel,
6666
strides_vec);
67-
auto offset_calc =
68-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
67+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
68+
desired_shape, strides_array);
6969

7070
const int64_t N = output->numel();
7171
PADDLE_ENFORCE_EQ(true,
@@ -136,15 +136,27 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
136136
dev_ctx.template Alloc<T>(out);
137137
if (out->numel() == 0) return;
138138

139-
GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
140-
x,
141-
index,
142-
input_dims,
143-
input_strides,
144-
index_dims,
145-
index_stride,
146-
slice_offset,
147-
out);
139+
if (funcs::IsInUint32Range(out->numel())) {
140+
GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
141+
x,
142+
index,
143+
input_dims,
144+
input_strides,
145+
index_dims,
146+
index_stride,
147+
slice_offset,
148+
out);
149+
} else {
150+
GPUIndexElementwiseGetKernel<T, int64_t, uint64_t>(dev_ctx,
151+
x,
152+
index,
153+
input_dims,
154+
input_strides,
155+
index_dims,
156+
index_stride,
157+
slice_offset,
158+
out);
159+
}
148160
}
149161

150162
} // namespace phi

paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
namespace phi {
2929

30-
template <typename T, typename IndexT = int>
30+
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
3131
void GPUIndexElementwisePutGradKernel(
3232
const phi::GPUContext& dev_ctx,
3333
const DenseTensor& out_grad,
@@ -79,11 +79,9 @@ void GPUIndexElementwisePutGradKernel(
7979
&strides_array,
8080
&numel,
8181
strides_vec);
82-
auto offset_calc =
83-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
82+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
83+
desired_shape, strides_array);
8484
const int64_t N = numel;
85-
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
86-
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
8785
constexpr int nt = 128;
8886
constexpr int vt = 4;
8987
const dim3 block(nt);
@@ -187,7 +185,7 @@ void GPUIndexElementwisePutGradKernel(
187185
}
188186
}
189187

190-
template <typename T, typename Context>
188+
template <typename T, typename Context, typename OffsetT = uint32_t>
191189
void LaunchIndexElementwisePutWithTensorGradCudaKernel(
192190
const Context& dev_ctx,
193191
const std::vector<const DenseTensor*>& indices,
@@ -202,16 +200,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
202200
if (x_grad && !value_grad) {
203201
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
204202

205-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
206-
out_grad,
207-
indices,
208-
input_dims,
209-
input_strides,
210-
index_dims,
211-
index_strides,
212-
slice_offset,
213-
x_grad,
214-
value_grad);
203+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
204+
out_grad,
205+
indices,
206+
input_dims,
207+
input_strides,
208+
index_dims,
209+
index_strides,
210+
slice_offset,
211+
x_grad,
212+
value_grad);
215213
} else if (value_grad) {
216214
if (x_grad) {
217215
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
@@ -221,16 +219,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
221219
tmp_value_grad.Resize(common::make_ddim(input_dims));
222220
dev_ctx.template Alloc<T>(&tmp_value_grad);
223221

224-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
225-
out_grad,
226-
indices,
227-
input_dims,
228-
input_strides,
229-
index_dims,
230-
index_strides,
231-
slice_offset,
232-
x_grad,
233-
&tmp_value_grad);
222+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
223+
out_grad,
224+
indices,
225+
input_dims,
226+
input_strides,
227+
index_dims,
228+
index_strides,
229+
slice_offset,
230+
x_grad,
231+
&tmp_value_grad);
234232

235233
std::vector<int> v_dims(tmp_value_grad.dims().size());
236234
std::iota(v_dims.begin(), v_dims.end(), 0);
@@ -243,31 +241,31 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
243241
value_grad);
244242
} else if (value_grad->dims() == common::make_ddim(input_dims)) {
245243
dev_ctx.template Alloc<T>(value_grad);
246-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
247-
out_grad,
248-
indices,
249-
input_dims,
250-
input_strides,
251-
index_dims,
252-
index_strides,
253-
slice_offset,
254-
x_grad,
255-
value_grad);
244+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
245+
out_grad,
246+
indices,
247+
input_dims,
248+
input_strides,
249+
index_dims,
250+
index_strides,
251+
slice_offset,
252+
x_grad,
253+
value_grad);
256254
} else {
257255
DenseTensor tmp_value_grad(value_grad->dtype());
258256
tmp_value_grad.Resize(common::make_ddim(input_dims));
259257
dev_ctx.template Alloc<T>(&tmp_value_grad);
260258

261-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
262-
out_grad,
263-
indices,
264-
input_dims,
265-
input_strides,
266-
index_dims,
267-
index_strides,
268-
slice_offset,
269-
x_grad,
270-
&tmp_value_grad);
259+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
260+
out_grad,
261+
indices,
262+
input_dims,
263+
input_strides,
264+
index_dims,
265+
index_strides,
266+
slice_offset,
267+
x_grad,
268+
&tmp_value_grad);
271269

272270
std::vector<int64_t> after_dims =
273271
common::vectorize(tmp_value_grad.dims());
@@ -305,17 +303,29 @@ void LaunchIndexElementwisePutGradCudaKernel(
305303
DenseTensor* x_grad) {
306304
if (x_grad) {
307305
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
308-
309-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
310-
out_grad,
311-
indices,
312-
input_dims,
313-
input_strides,
314-
index_dims,
315-
index_strides,
316-
slice_offset,
317-
x_grad,
318-
nullptr);
306+
if (funcs::IsInUint32Range(x_grad->numel())) {
307+
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
308+
out_grad,
309+
indices,
310+
input_dims,
311+
input_strides,
312+
index_dims,
313+
index_strides,
314+
slice_offset,
315+
x_grad,
316+
nullptr);
317+
} else {
318+
GPUIndexElementwisePutGradKernel<T, int64_t, uint64_t>(dev_ctx,
319+
out_grad,
320+
indices,
321+
input_dims,
322+
input_strides,
323+
index_dims,
324+
index_strides,
325+
slice_offset,
326+
x_grad,
327+
nullptr);
328+
}
319329
}
320330
}
321331

@@ -397,17 +407,30 @@ void IndexElementwisePutWithTensorGradKernel(
397407
}
398408
return;
399409
}
400-
401-
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
402-
indices,
403-
out_grad,
404-
input_dims,
405-
input_strides,
406-
index_dims,
407-
index_strides,
408-
slice_offset,
409-
value_grad,
410-
x_grad);
410+
if (x_grad && funcs::IsInUint32Range(x_grad->numel())) {
411+
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
412+
indices,
413+
out_grad,
414+
input_dims,
415+
input_strides,
416+
index_dims,
417+
index_strides,
418+
slice_offset,
419+
value_grad,
420+
x_grad);
421+
} else {
422+
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context, uint64_t>(
423+
dev_ctx,
424+
indices,
425+
out_grad,
426+
input_dims,
427+
input_strides,
428+
index_dims,
429+
index_strides,
430+
slice_offset,
431+
value_grad,
432+
x_grad);
433+
}
411434
}
412435

413436
} // namespace phi

0 commit comments

Comments
 (0)