Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions paddle/phi/kernels/funcs/index_elementwise.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ struct OffsetCalculator {
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
};

template <int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator_put(
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator_put(
std::vector<int64_t> desired_shape, std::array<int64_t*, N> strides_array) {
return OffsetCalculator<N, uint32_t, signed_strides>(
return OffsetCalculator<N, OffsetT, signed_strides>(
desired_shape.size(), desired_shape.data(), strides_array.data());
}

template <int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
int ndim,
const int64_t* shape,
const std::vector<std::vector<int64_t>>& strides) {
Expand All @@ -209,12 +209,12 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
strides_array[i] = strides[i].data();
}

return OffsetCalculator<N, uint32_t, signed_strides>(
return OffsetCalculator<N, OffsetT, signed_strides>(
ndim, shape, strides_array.data());
}

template <int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
template <int N, bool signed_strides = false, typename OffsetT = uint32_t>
static OffsetCalculator<N, OffsetT, signed_strides> make_offset_calculator(
const phi::DenseTensorIteratorBase& iter) {
PADDLE_ENFORCE_LE(N,
iter.ntensors(),
Expand All @@ -224,7 +224,7 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
for (int i = 0; i < N; i++) {
strides[i] = iter.operands_[i].stride_bytes.data();
}
return OffsetCalculator<N, uint32_t, signed_strides>(
return OffsetCalculator<N, OffsetT, signed_strides>(
iter.ndim(), iter.shape().data(), strides.data());
}
constexpr bool IsInUint32Range(int64_t value) {
Expand Down
49 changes: 28 additions & 21 deletions paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ __global__ void IndexEleGetGradAccKernel(
}
}

template <typename T, typename IndexT>
template <typename T, typename IndexT, typename OffsetT = uint32_t>
void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
const DenseTensor& input,
const DenseTensor& value,
Expand Down Expand Up @@ -107,16 +107,10 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
&strides_array,
&numel,
strides_vec);
auto offset_calc =
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
desired_shape, strides_array);

const int64_t N = numel;

PADDLE_ENFORCE_EQ(true,
funcs::IsInUint32Range(N, value.numel()),
common::errors::PreconditionNotMet(
"the numel of input or output should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down Expand Up @@ -425,18 +419,31 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
return;
#endif
}

GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
x,
out_grad,
index,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
accumulate,
x_grad);
if (funcs::IsInUint32Range(x_grad->numel(), out_grad.numel())) {
GPUIndexElementwiseGetGrad<T, int64_t>(dev_ctx,
x,
out_grad,
index,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
accumulate,
x_grad);
} else {
GPUIndexElementwiseGetGrad<T, int64_t, uint64_t>(dev_ctx,
x,
out_grad,
index,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
accumulate,
x_grad);
}
}

} // namespace phi
Expand Down
36 changes: 24 additions & 12 deletions paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/phi/kernels/funcs/stride_utils.h"

namespace phi {
template <typename T, typename IndexT = int>
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
const DenseTensor& input,
const std::vector<const DenseTensor*>& index,
Expand Down Expand Up @@ -63,8 +63,8 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
&strides_array,
&numel,
strides_vec);
auto offset_calc =
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
desired_shape, strides_array);

const int64_t N = output->numel();
PADDLE_ENFORCE_EQ(true,
Expand Down Expand Up @@ -135,15 +135,27 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) return;

GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
x,
index,
input_dims,
input_strides,
index_dims,
index_stride,
slice_offset,
out);
if (funcs::IsInUint32Range(out->numel())) {
GPUIndexElementwiseGetKernel<T, int64_t>(dev_ctx,
x,
index,
input_dims,
input_strides,
index_dims,
index_stride,
slice_offset,
out);
} else {
GPUIndexElementwiseGetKernel<T, int64_t, uint64_t>(dev_ctx,
x,
index,
input_dims,
input_strides,
index_dims,
index_stride,
slice_offset,
out);
}
}

} // namespace phi
Expand Down
163 changes: 92 additions & 71 deletions paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace phi {

template <typename T, typename IndexT = int>
template <typename T, typename IndexT = int, typename OffsetT = uint32_t>
void GPUIndexElementwisePutGradKernel(
const phi::GPUContext& dev_ctx,
const DenseTensor& out_grad,
Expand Down Expand Up @@ -78,14 +78,10 @@ void GPUIndexElementwisePutGradKernel(
&strides_array,
&numel,
strides_vec);
auto offset_calc =
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);
auto offset_calc = funcs::make_offset_calculator_put<3, false, OffsetT>(
desired_shape, strides_array);
const int64_t N = numel;
PADDLE_ENFORCE_EQ(true,
(N >= 0 && N <= std::numeric_limits<int32_t>::max()),
common::errors::PreconditionNotMet(
"the value of N should be in [0, "
"std::numeric_limits<int32_t>::max()]"));

constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down Expand Up @@ -189,7 +185,7 @@ void GPUIndexElementwisePutGradKernel(
}
}

template <typename T, typename Context>
template <typename T, typename Context, typename OffsetT = uint32_t>
void LaunchIndexElementwisePutWithTensorGradCudaKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& indices,
Expand All @@ -204,16 +200,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
if (x_grad && !value_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);

GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
value_grad);
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
value_grad);
} else if (value_grad) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
Expand All @@ -223,16 +219,16 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
tmp_value_grad.Resize(common::make_ddim(input_dims));
dev_ctx.template Alloc<T>(&tmp_value_grad);

GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
&tmp_value_grad);
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
&tmp_value_grad);

std::vector<int> v_dims(tmp_value_grad.dims().size());
std::iota(v_dims.begin(), v_dims.end(), 0);
Expand All @@ -245,31 +241,31 @@ void LaunchIndexElementwisePutWithTensorGradCudaKernel(
value_grad);
} else if (value_grad->dims() == common::make_ddim(input_dims)) {
dev_ctx.template Alloc<T>(value_grad);
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
value_grad);
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
value_grad);
} else {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(common::make_ddim(input_dims));
dev_ctx.template Alloc<T>(&tmp_value_grad);

GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
&tmp_value_grad);
GPUIndexElementwisePutGradKernel<T, int64_t, OffsetT>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
&tmp_value_grad);

std::vector<int64_t> after_dims =
common::vectorize(tmp_value_grad.dims());
Expand Down Expand Up @@ -307,17 +303,29 @@ void LaunchIndexElementwisePutGradCudaKernel(
DenseTensor* x_grad) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);

GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
nullptr);
if (funcs::IsInUint32Range(x_grad->numel())) {
GPUIndexElementwisePutGradKernel<T, int64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
nullptr);
} else {
GPUIndexElementwisePutGradKernel<T, int64_t, uint64_t>(dev_ctx,
out_grad,
indices,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
x_grad,
nullptr);
}
}
}

Expand Down Expand Up @@ -399,17 +407,30 @@ void IndexElementwisePutWithTensorGradKernel(
}
return;
}

LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
indices,
out_grad,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
value_grad,
x_grad);
if (x_grad && funcs::IsInUint32Range(x_grad->numel())) {
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context>(dev_ctx,
indices,
out_grad,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
value_grad,
x_grad);
} else {
LaunchIndexElementwisePutWithTensorGradCudaKernel<T, Context, uint64_t>(
dev_ctx,
indices,
out_grad,
input_dims,
input_strides,
index_dims,
index_strides,
slice_offset,
value_grad,
x_grad);
}
}

} // namespace phi
Expand Down
Loading
Loading