Skip to content

Commit 77f1588

Browse files
DanielSun11wanghuancoder
authored andcommitted
Fix slice bigtensor (PaddlePaddle#76364)
* slice big tensor * enhance the slice kernels robust for big tensor * fix offset_calculator --------- Co-authored-by: Wang Huan <[email protected]>
1 parent 923ef85 commit 77f1588

File tree

6 files changed

+222
-149
lines changed

6 files changed

+222
-149
lines changed

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ struct OffsetCalculator {
192192
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
193193
};
194194

195-
template <int N, bool signed_strides = false>
196-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator_put(
195+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
196+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator_put(
197197
std::vector<int64_t> desired_shape, std::array<int64_t*, N> strides_array) {
198-
return OffsetCalculator<N, uint32_t, signed_strides>(
198+
return OffsetCalculator<N, OffsetT, signed_strides>(
199199
desired_shape.size(), desired_shape.data(), strides_array.data());
200200
}
201201

202-
template <int N, bool signed_strides = false>
203-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
202+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
203+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
204204
int ndim,
205205
const int64_t* shape,
206206
const std::vector<std::vector<int64_t>>& strides) {
@@ -209,12 +209,12 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
209209
strides_array[i] = strides[i].data();
210210
}
211211

212-
return OffsetCalculator<N, uint32_t, signed_strides>(
212+
return OffsetCalculator<N, OffsetT, signed_strides>(
213213
ndim, shape, strides_array.data());
214214
}
215215

216-
template <int N, bool signed_strides = false>
217-
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
216+
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
217+
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
218218
const phi::DenseTensorIteratorBase& iter) {
219219
PADDLE_ENFORCE_LE(N,
220220
iter.ntensors(),
@@ -224,7 +224,7 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
224224
for (int i = 0; i < N; i++) {
225225
strides[i] = iter.operands_[i].stride_bytes.data();
226226
}
227-
return OffsetCalculator<N, uint32_t, signed_strides>(
227+
return OffsetCalculator<N, OffsetT, signed_strides>(
228228
iter.ndim(), iter.shape().data(), strides.data());
229229
}
230230
constexpr bool IsInUint32Range(int64_t value) {

paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ __global__ void IndexEleGetGradAccKernel(
6363
}
6464
}
6565

66-
template <typename T, typename IndexT>
66+
template <typename T, typename IndexT, typename OffsetT = uint32_t>
6767
void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
6868
const DenseTensor& input,
6969
const DenseTensor& value,
@@ -107,16 +107,10 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
107107
&strides_array,
108108
&numel,
109109
strides_vec);
110-
auto offset_calc =
111-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
110+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
111+
desired_shape, strides_array);
112112

113113
const int64_t N = numel;
114-
115-
PADDLE_ENFORCE_EQ(true,
116-
funcs::IsInUint32Range(N, value.numel()),
117-
common::errors::PreconditionNotMet(
118-
"the numel of input or output should be in [0, "
119-
"std::numeric_limits<int32_t>::max()]"));
120114
constexpr int nt = 128;
121115
constexpr int vt = 4;
122116
const dim3 block(nt);
@@ -425,18 +419,31 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
425419
return;
426420
#endif
427421
}
428-
429-
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
430-
x,
431-
out_grad,
432-
index,
433-
input_dims,
434-
input_strides,
435-
index_dims,
436-
index_strides,
437-
slice_offset,
438-
accumulate,
439-
x_grad);
422+
if (funcs::IsInUint32Range(x_grad->numel(), out_grad.numel())) {
423+
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
424+
x,
425+
out_grad,
426+
index,
427+
input_dims,
428+
input_strides,
429+
index_dims,
430+
index_strides,
431+
slice_offset,
432+
accumulate,
433+
x_grad);
434+
} else {
435+
GPUIndexElementwiseGetGrad<T, int64_t, uint64_t>(dev_ctx,
436+
x,
437+
out_grad,
438+
index,
439+
input_dims,
440+
input_strides,
441+
index_dims,
442+
index_strides,
443+
slice_offset,
444+
accumulate,
445+
x_grad);
446+
}
440447
}
441448

442449
} // namespace phi

paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "paddle/phi/kernels/funcs/stride_utils.h"
2121

2222
namespace phi {
23-
template <typename T, typename IndexT = int>
23+
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
2424
void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
2525
const DenseTensor& input,
2626
const std::vector<const DenseTensor*>& index,
@@ -63,8 +63,8 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
6363
&strides_array,
6464
&numel,
6565
strides_vec);
66-
auto offset_calc =
67-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
66+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
67+
desired_shape, strides_array);
6868

6969
const int64_t N = output->numel();
7070
PADDLE_ENFORCE_EQ(true,
@@ -135,15 +135,27 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
135135
dev_ctx.template Alloc<T>(out);
136136
if (out->numel() == 0) return;
137137

138-
GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
139-
x,
140-
index,
141-
input_dims,
142-
input_strides,
143-
index_dims,
144-
index_stride,
145-
slice_offset,
146-
out);
138+
if (funcs::IsInUint32Range(out->numel())) {
139+
GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
140+
x,
141+
index,
142+
input_dims,
143+
input_strides,
144+
index_dims,
145+
index_stride,
146+
slice_offset,
147+
out);
148+
} else {
149+
GPUIndexElementwiseGetKernel<T, int64_t, uint64_t>(dev_ctx,
150+
x,
151+
index,
152+
input_dims,
153+
input_strides,
154+
index_dims,
155+
index_stride,
156+
slice_offset,
157+
out);
158+
}
147159
}
148160

149161
} // namespace phi

paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu

Lines changed: 92 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
namespace phi {
2828

29-
template <typename T, typename IndexT = int>
29+
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
3030
void GPUIndexElementwisePutGradKernel(
3131
const phi::GPUContext& dev_ctx,
3232
const DenseTensor& out_grad,
@@ -78,14 +78,10 @@ void GPUIndexElementwisePutGradKernel(
7878
&strides_array,
7979
&numel,
8080
strides_vec);
81-
auto offset_calc =
82-
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
81+
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
82+
desired_shape, strides_array);
8383
const int64_t N = numel;
84-
PADDLE_ENFORCE_EQ(true,
85-
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
86-
common::errors::PreconditionNotMet(
87-
"the value of N should be in [0, "
88-
"std::numeric_limits<int32_t>::max()]"));
84+
8985
constexpr int nt = 128;
9086
constexpr int vt = 4;
9187
const dim3 block(nt);
@@ -189,7 +185,7 @@ void GPUIndexElementwisePutGradKernel(
189185
}
190186
}
191187

192-
template <typename T, typename Context>
188+
template <typename T, typename Context, typename OffsetT = uint32_t>
193189
void LaunchIndexElementwisePutWithTensorGradCudaKernel(
194190
const Context& dev_ctx,
195191
const std::vector<const DenseTensor*>& indices,
@@ -204,16 +200,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
204200
if (x_grad && !value_grad) {
205201
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
206202

207-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
208-
out_grad,
209-
indices,
210-
input_dims,
211-
input_strides,
212-
index_dims,
213-
index_strides,
214-
slice_offset,
215-
x_grad,
216-
value_grad);
203+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
204+
out_grad,
205+
indices,
206+
input_dims,
207+
input_strides,
208+
index_dims,
209+
index_strides,
210+
slice_offset,
211+
x_grad,
212+
value_grad);
217213
} else if (value_grad) {
218214
if (x_grad) {
219215
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
@@ -223,16 +219,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
223219
tmp_value_grad.Resize(common::make_ddim(input_dims));
224220
dev_ctx.template Alloc<T>(&tmp_value_grad);
225221

226-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
227-
out_grad,
228-
indices,
229-
input_dims,
230-
input_strides,
231-
index_dims,
232-
index_strides,
233-
slice_offset,
234-
x_grad,
235-
&tmp_value_grad);
222+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
223+
out_grad,
224+
indices,
225+
input_dims,
226+
input_strides,
227+
index_dims,
228+
index_strides,
229+
slice_offset,
230+
x_grad,
231+
&tmp_value_grad);
236232

237233
std::vector<int> v_dims(tmp_value_grad.dims().size());
238234
std::iota(v_dims.begin(), v_dims.end(), 0);
@@ -245,31 +241,31 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
245241
value_grad);
246242
} else if (value_grad->dims() == common::make_ddim(input_dims)) {
247243
dev_ctx.template Alloc<T>(value_grad);
248-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
249-
out_grad,
250-
indices,
251-
input_dims,
252-
input_strides,
253-
index_dims,
254-
index_strides,
255-
slice_offset,
256-
x_grad,
257-
value_grad);
244+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
245+
out_grad,
246+
indices,
247+
input_dims,
248+
input_strides,
249+
index_dims,
250+
index_strides,
251+
slice_offset,
252+
x_grad,
253+
value_grad);
258254
} else {
259255
DenseTensor tmp_value_grad(value_grad->dtype());
260256
tmp_value_grad.Resize(common::make_ddim(input_dims));
261257
dev_ctx.template Alloc<T>(&tmp_value_grad);
262258

263-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
264-
out_grad,
265-
indices,
266-
input_dims,
267-
input_strides,
268-
index_dims,
269-
index_strides,
270-
slice_offset,
271-
x_grad,
272-
&tmp_value_grad);
259+
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
260+
out_grad,
261+
indices,
262+
input_dims,
263+
input_strides,
264+
index_dims,
265+
index_strides,
266+
slice_offset,
267+
x_grad,
268+
&tmp_value_grad);
273269

274270
std::vector<int64_t> after_dims =
275271
common::vectorize(tmp_value_grad.dims());
@@ -307,17 +303,29 @@ void LaunchIndexElementwisePutGradCudaKernel(
307303
DenseTensor* x_grad) {
308304
if (x_grad) {
309305
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
310-
311-
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
312-
out_grad,
313-
indices,
314-
input_dims,
315-
input_strides,
316-
index_dims,
317-
index_strides,
318-
slice_offset,
319-
x_grad,
320-
nullptr);
306+
if (funcs::IsInUint32Range(x_grad->numel())) {
307+
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
308+
out_grad,
309+
indices,
310+
input_dims,
311+
input_strides,
312+
index_dims,
313+
index_strides,
314+
slice_offset,
315+
x_grad,
316+
nullptr);
317+
} else {
318+
GPUIndexElementwisePutGradKernel<T, int64_t, uint64_t>(dev_ctx,
319+
out_grad,
320+
indices,
321+
input_dims,
322+
input_strides,
323+
index_dims,
324+
index_strides,
325+
slice_offset,
326+
x_grad,
327+
nullptr);
328+
}
321329
}
322330
}
323331

@@ -399,17 +407,30 @@ void IndexElementwisePutWithTensorGradKernel(
399407
}
400408
return;
401409
}
402-
403-
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
404-
indices,
405-
out_grad,
406-
input_dims,
407-
input_strides,
408-
index_dims,
409-
index_strides,
410-
slice_offset,
411-
value_grad,
412-
x_grad);
410+
if (x_grad && funcs::IsInUint32Range(x_grad->numel())) {
411+
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
412+
indices,
413+
out_grad,
414+
input_dims,
415+
input_strides,
416+
index_dims,
417+
index_strides,
418+
slice_offset,
419+
value_grad,
420+
x_grad);
421+
} else {
422+
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context, uint64_t>(
423+
dev_ctx,
424+
indices,
425+
out_grad,
426+
input_dims,
427+
input_strides,
428+
index_dims,
429+
index_strides,
430+
slice_offset,
431+
value_grad,
432+
x_grad);
433+
}
413434
}
414435

415436
} // namespace phi

0 commit comments

Comments
 (0)