Skip to content

Commit fbef36b

Browse files
cangtianhuangLittleHeroZZZX
authored andcommitted
Add MixedPrecision AddGrad (PaddlePaddle#76178)
* add MixedPrecisionAddGrad * revert * refine * refine * refine * refine * fix kernel * add IndexT, add test * refine
1 parent 51c70b8 commit fbef36b

File tree

6 files changed

+316
-17
lines changed

6 files changed

+316
-17
lines changed

paddle/phi/common/type_promotion.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,10 @@ inline bool NeedTypePromotion(
205205
// floating-point numbers and between complex and real numbers.
206206
if (x_dtype != y_dtype) {
207207
// TODO(Xi Zhao): we got special case for add now, should remove it in future.
208-
#ifdef PADDLE_WITH_CUDA
209-
if ((op_name == "add" || op_name == "add_") &&
210-
x_dtype == DataType::FLOAT32 &&
211-
(y_dtype == phi::DataType::BFLOAT16 ||
212-
y_dtype == phi::DataType::FLOAT16)) {
213-
return false;
214-
}
215-
#elif defined(PADDLE_WITH_XPU)
208+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU)
216209
if ((op_name == "add" || op_name == "add_") &&
217210
x_dtype == DataType::FLOAT32 &&
218-
(y_dtype == phi::DataType::BFLOAT16 ||
219-
y_dtype == phi::DataType::FLOAT16)) {
211+
(y_dtype == DataType::FLOAT16 || y_dtype == DataType::BFLOAT16)) {
220212
return false;
221213
}
222214
#endif

paddle/phi/kernels/gpu/elementwise_grad.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
2424
#include "paddle/phi/kernels/funcs/reduce_function.h"
2525
#include "paddle/phi/kernels/reduce_sum_kernel.h"
26+
2627
namespace phi {
2728

2829
template <typename T>
@@ -112,6 +113,144 @@ void GetGradXOrYOut(const GPUContext &dev_ctx,
112113
******************************
113114
*/
114115

116+
template <typename T>
117+
struct alignas(sizeof(T) * 4) Pack4 {
118+
T val[4];
119+
};
120+
121+
template <typename T_dy, typename IndexT = int>
122+
static __global__ void MixedPrecisionElemwiseAddGradCUDAKernel(
123+
const float *__restrict__ dout,
124+
IndexT size,
125+
float *__restrict__ dx,
126+
T_dy *__restrict__ dy) {
127+
IndexT tid = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
128+
IndexT stride = static_cast<IndexT>(gridDim.x) * blockDim.x;
129+
130+
constexpr int vec_size = 4;
131+
IndexT loop = size / vec_size;
132+
IndexT remainder = size % vec_size;
133+
134+
const float4 *__restrict__ dout_vec = reinterpret_cast<const float4 *>(dout);
135+
float4 *__restrict__ dx_vec = reinterpret_cast<float4 *>(dx);
136+
Pack4<T_dy> *__restrict__ dy_vec = reinterpret_cast<Pack4<T_dy> *>(dy);
137+
138+
for (IndexT i = tid; i < loop; i += stride) {
139+
float4 val = __ldg(dout_vec + i);
140+
dx_vec[i] = val;
141+
142+
Pack4<T_dy> dy_pack;
143+
dy_pack.val[0] = static_cast<T_dy>(val.x);
144+
dy_pack.val[1] = static_cast<T_dy>(val.y);
145+
dy_pack.val[2] = static_cast<T_dy>(val.z);
146+
dy_pack.val[3] = static_cast<T_dy>(val.w);
147+
dy_vec[i] = dy_pack;
148+
}
149+
150+
if (remainder != 0) {
151+
IndexT tail_start = loop * vec_size;
152+
for (IndexT i = tail_start + tid; i < size; i += stride) {
153+
float val = __ldg(dout + i);
154+
dx[i] = val;
155+
dy[i] = static_cast<T_dy>(val);
156+
}
157+
}
158+
}
159+
160+
template <typename T_dy>
161+
void ElementwiseMixedPrecisionAddGrad(const GPUContext &dev_ctx,
162+
const DenseTensor &dout,
163+
DenseTensor *dx,
164+
DenseTensor *dy) {
165+
using T_dout = float;
166+
using T_dx = float;
167+
168+
auto *dx_data = dev_ctx.template Alloc<T_dx>(dx);
169+
T_dy *dy_data = dev_ctx.template Alloc<T_dy>(dy);
170+
auto *dout_data = dout.data<T_dout>();
171+
172+
if (dx_data == dout_data) {
173+
VLOG(7) << "Special case when dx_data is the same as dout_data, "
174+
"need cast dout to dy.";
175+
phi::CastKernel<T_dout>(dev_ctx, dout, dy->dtype(), dy);
176+
return;
177+
}
178+
179+
auto size = dout.numel();
180+
if (size == 0) return;
181+
182+
constexpr int vec_size = 4;
183+
const int64_t main_size = (size / vec_size) * vec_size;
184+
const int block_size = PREDEFINED_BLOCK_SIZE;
185+
const int grid_size =
186+
std::min(static_cast<int>((main_size + block_size - 1) / block_size),
187+
(dev_ctx.GetMaxPhysicalThreadCount() / block_size));
188+
189+
dim3 grid_dim(grid_size, 1, 1);
190+
dim3 block_dim(block_size, 1, 1);
191+
192+
if (size < std::numeric_limits<int>::max()) {
193+
MixedPrecisionElemwiseAddGradCUDAKernel<T_dy, int>
194+
<<<grid_dim, block_dim, 0, dev_ctx.stream()>>>(
195+
dout_data, static_cast<int>(size), dx_data, dy_data);
196+
} else {
197+
MixedPrecisionElemwiseAddGradCUDAKernel<T_dy, int64_t>
198+
<<<grid_dim, block_dim, 0, dev_ctx.stream()>>>(
199+
dout_data, static_cast<int64_t>(size), dx_data, dy_data);
200+
}
201+
}
202+
203+
template <typename T_dy>
204+
void DefaultMixedPrecisionAddGrad(const GPUContext &dev_ctx,
205+
const DenseTensor &x,
206+
const DenseTensor &y,
207+
const DenseTensor &dout,
208+
DenseTensor *dx,
209+
DenseTensor *dy,
210+
int axis = -1) {
211+
using T_dout = float;
212+
using T_dx = float;
213+
214+
auto *dout_data = dout.data<T_dout>();
215+
216+
// dx
217+
if (dx != nullptr) {
218+
auto *dx_data = dev_ctx.template Alloc<T_dx>(dx);
219+
if (dx->dims() == dout.dims()) {
220+
if (dx_data != dout_data) {
221+
phi::Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
222+
}
223+
} else {
224+
if (dx->IsSharedBufferWith(dout)) {
225+
dx->clear();
226+
dx->Resize(x.dims());
227+
dev_ctx.template Alloc<T_dx>(dx);
228+
}
229+
std::vector<int> reduce_dims =
230+
funcs::GetReduceDim(x.dims(), dout.dims(), axis);
231+
phi::SumKernel<T_dout, GPUContext>(
232+
dev_ctx, dout, reduce_dims, dout.dtype(), false, dx);
233+
}
234+
}
235+
236+
// dy
237+
if (dy != nullptr) {
238+
auto *dy_data = dev_ctx.template Alloc<T_dy>(dy);
239+
if (dy->dims() == dout.dims()) {
240+
phi::CastKernel<T_dout>(dev_ctx, dout, dy->dtype(), dy);
241+
} else {
242+
DenseTensor dy_fp32;
243+
dy_fp32.Resize(dout.dims());
244+
dev_ctx.template Alloc<float>(&dy_fp32);
245+
std::vector<int> reduce_dims =
246+
funcs::GetReduceDim(y.dims(), dout.dims(), axis);
247+
phi::SumKernel<float, GPUContext>(
248+
dev_ctx, dout, reduce_dims, dout.dtype(), false, &dy_fp32);
249+
phi::CastKernel<float>(dev_ctx, dy_fp32, dy->dtype(), dy);
250+
}
251+
}
252+
}
253+
115254
template <typename T, typename IndexT = int>
116255
static __global__ void SimpleElemwiseAddGradCUDAKernel(
117256
const T *__restrict__ dout, IndexT size, int vec_size, T *dx, T *dy) {

paddle/phi/kernels/gpu/elementwise_grad_kernel.cu

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,54 @@ void DivideGradKernel(const Context& dev_ctx,
115115
}
116116
}
117117

118+
template <typename T>
119+
void MixedPrecisionAddGradFunc(const GPUContext& dev_ctx,
120+
const DenseTensor& x,
121+
const DenseTensor& y,
122+
const DenseTensor& out,
123+
const DenseTensor& dout,
124+
DenseTensor* dx,
125+
DenseTensor* dy,
126+
int axis = -1) {
127+
const auto& x_dtype = x.dtype();
128+
const auto& y_dtype = y.dtype();
129+
bool no_broadcast =
130+
(dx && dy && dx->dims() == dy->dims() && dx->dims() == dout.dims());
131+
if (no_broadcast) {
132+
// Dispatch to non-broadcast (elementwise) kernels
133+
if (x_dtype == phi::DataType::FLOAT32 &&
134+
y_dtype == phi::DataType::FLOAT16) {
135+
ElementwiseMixedPrecisionAddGrad<phi::float16>(dev_ctx, dout, dx, dy);
136+
} else if (x_dtype == phi::DataType::FLOAT32 &&
137+
y_dtype == phi::DataType::BFLOAT16) {
138+
ElementwiseMixedPrecisionAddGrad<phi::bfloat16>(dev_ctx, dout, dx, dy);
139+
} else {
140+
PADDLE_THROW(common::errors::Unimplemented(
141+
"Unsupported mixed precision combination for AddGrad non-broadcast "
142+
"path: x_dtype=%s, y_dtype=%s",
143+
phi::DataTypeToString(x_dtype),
144+
phi::DataTypeToString(y_dtype)));
145+
}
146+
} else {
147+
// Dispatch to broadcast-aware kernels
148+
if (x_dtype == phi::DataType::FLOAT32 &&
149+
y_dtype == phi::DataType::FLOAT16) {
150+
DefaultMixedPrecisionAddGrad<phi::float16>(
151+
dev_ctx, x, y, dout, dx, dy, axis);
152+
} else if (x_dtype == phi::DataType::FLOAT32 &&
153+
y_dtype == phi::DataType::BFLOAT16) {
154+
DefaultMixedPrecisionAddGrad<phi::bfloat16>(
155+
dev_ctx, x, y, dout, dx, dy, axis);
156+
} else {
157+
PADDLE_THROW(common::errors::Unimplemented(
158+
"Unsupported mixed precision combination for AddGrad broadcast path: "
159+
"x_dtype=%s, y_dtype=%s",
160+
phi::DataTypeToString(x_dtype),
161+
phi::DataTypeToString(y_dtype)));
162+
}
163+
}
164+
}
165+
118166
template <typename T>
119167
void AddGradFunc(const GPUContext& dev_ctx,
120168
const DenseTensor& x,
@@ -139,6 +187,14 @@ void AddGradKernel(const Context& dev_ctx,
139187
int axis,
140188
DenseTensor* dx,
141189
DenseTensor* dy) {
190+
#ifdef PADDLE_WITH_CUDA
191+
if (x.dtype() == DataType::FLOAT32 &&
192+
(y.dtype() == DataType::FLOAT16 || y.dtype() == DataType::BFLOAT16)) {
193+
phi::MixedPrecisionAddGradImpl<float>(
194+
dev_ctx, x, y, dout, axis, dx, dy, MixedPrecisionAddGradFunc<float>);
195+
return;
196+
}
197+
#endif
142198
phi::AddGradImpl<T>(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc<T>);
143199
}
144200

paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/phi/common/amp_type_traits.h"
2020
#include "paddle/phi/core/dense_tensor.h"
2121
#include "paddle/phi/core/tensor_utils.h"
22+
#include "paddle/phi/kernels/cast_kernel.h"
2223
#include "paddle/phi/kernels/expand_kernel.h"
2324
#include "paddle/phi/kernels/full_kernel.h"
2425
#include "paddle/phi/kernels/funcs/broadcast_function.h"
@@ -28,6 +29,31 @@ limitations under the License. */
2829

2930
namespace phi {
3031

32+
template <typename T, typename Context, typename GradFunc>
33+
void MixedPrecisionAddGradImpl(const Context& dev_ctx,
34+
const DenseTensor& x,
35+
const DenseTensor& y,
36+
const DenseTensor& out_grad,
37+
int axis,
38+
DenseTensor* x_grad,
39+
DenseTensor* y_grad,
40+
GradFunc grad_func) {
41+
phi::funcs::ElementwiseGradPreProcess(out_grad, x_grad);
42+
phi::funcs::ElementwiseGradPreProcess(out_grad, y_grad);
43+
auto* out = &out_grad;
44+
if (x_grad != nullptr && y_grad == nullptr &&
45+
x_grad->dims() == out_grad.dims()) {
46+
VLOG(4) << "Mixed precision: only x_grad needed, no reduce";
47+
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
48+
} else if (x_grad == nullptr && y_grad != nullptr &&
49+
y_grad->dims() == out_grad.dims()) {
50+
VLOG(4) << "Mixed precision: only y_grad needed, no reduce";
51+
phi::CastKernel<T>(dev_ctx, out_grad, y.dtype(), y_grad);
52+
} else {
53+
grad_func(dev_ctx, x, y, *out, out_grad, x_grad, y_grad, axis);
54+
}
55+
}
56+
3157
template <typename T, typename Context, typename GradFunc>
3258
void AddGradImpl(const Context& dev_ctx,
3359
const DenseTensor& x,
@@ -38,6 +64,7 @@ void AddGradImpl(const Context& dev_ctx,
3864
DenseTensor* y_grad,
3965
GradFunc grad_func) {
4066
phi::funcs::ElementwiseGradPreProcess(out_grad, x_grad);
67+
phi::funcs::ElementwiseGradPreProcess(out_grad, y_grad);
4168
auto* out = &out_grad;
4269
// Special case when y_grad is not needed and x_grad doesn't reduce
4370
if (x_grad != nullptr && y_grad == nullptr &&

paddle/phi/kernels/kps/elementwise_kernel.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,13 @@ void AddKernel(const Context& dev_ctx,
100100
return;
101101
}
102102
#ifdef PADDLE_WITH_CUDA
103-
if (x.dtype() == phi::DataType::FLOAT32 &&
104-
(y.dtype() == phi::DataType::BFLOAT16 ||
105-
y.dtype() == phi::DataType::FLOAT16)) {
103+
if (x.dtype() == DataType::FLOAT32 &&
104+
(y.dtype() == DataType::FLOAT16 || y.dtype() == DataType::BFLOAT16)) {
106105
MultiPrecisionAddKernelImpl<float, Context>(dev_ctx, x, y, out);
107-
} else {
108-
#endif
109-
phi::AddRawKernel<T, Context>(dev_ctx, x, y, -1, out);
110-
#ifdef PADDLE_WITH_CUDA
106+
return;
111107
}
112108
#endif
109+
phi::AddRawKernel<T, Context>(dev_ctx, x, y, -1, out);
113110
}
114111

115112
template <typename T, typename Context>

0 commit comments

Comments
 (0)