Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2273,6 +2273,70 @@ void FusedRMSNormGradInferMeta(const MetaTensor& x,
}
}

PADDLE_API void FastLayerNormGradInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& mean,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad,
MetaTensor* bias_grad) {
PADDLE_ENFORCE_EQ(
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
x.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument(
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
x.dtype()));
PADDLE_ENFORCE_EQ(
scale.dtype() == DataType::FLOAT32 ||
scale.dtype() == DataType::FLOAT16 ||
scale.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
"FLOAT16 or BFLOAT16, but got [%s]",
scale.dtype()));
if (x_grad && x) {
x_grad->share_meta(x);
}
if (scale_grad && scale) {
scale_grad->share_meta(scale);
}
if (bias_grad) {
bias_grad->share_meta(scale);
}
}

PADDLE_API void FastRMSNormGradInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad) {
PADDLE_ENFORCE_EQ(
x.dtype() == DataType::FLOAT32 || x.dtype() == DataType::FLOAT16 ||
x.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument(
"The dtype of x must be FLOAT32, FLOAT16 or BFLOAT16, but got [%s]",
x.dtype()));
PADDLE_ENFORCE_EQ(
scale.dtype() == DataType::FLOAT32 ||
scale.dtype() == DataType::FLOAT16 ||
scale.dtype() == DataType::BFLOAT16,
true,
common::errors::InvalidArgument("The dtype of scale must be FLOAT32, "
"FLOAT16 or BFLOAT16, but got [%s]",
scale.dtype()));
if (x_grad && x) {
x_grad->share_meta(x);
}
if (scale_grad && scale) {
scale_grad->share_meta(scale);
}
}
void IndexElementwiseGetGradInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& index,
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,4 +842,22 @@ PADDLE_API void IndexElementwiseGetGradInferMeta(
const bool accumulate,
const bool is_combined,
MetaTensor* x_grad);

PADDLE_API void FastLayerNormGradInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& mean,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad,
MetaTensor* bias_grad);

PADDLE_API void FastRMSNormGradInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& invvar,
const MetaTensor& y_grad,
float epsilon,
MetaTensor* x_grad,
MetaTensor* scale_grad);
} // namespace phi
50 changes: 50 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,56 @@ void ExpandAsInferMeta(const MetaTensor& x,
#undef MAX_RANK_SUPPORTED
}

void FastRMSNormInfermeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
MetaTensor* y,
MetaTensor* invvar) {
auto x_dim = x.dims();
auto x_ndim = x_dim.size();

auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);

int64_t right = matrix_dim[1];
if (scale) {
PADDLE_ENFORCE_EQ(scale.dims().size(),
1,
common::errors::InvalidArgument(
"The dimensions of Input(Scale) must be 1, but "
"received dimensions of "
"Input(Scale) is [%d]",
scale.dims().size()));
}

PADDLE_ENFORCE_EQ(
scale.dims()[0],
right,
common::errors::InvalidArgument(
"The first dimension value of Input(Scale) must equal to be the "
"second dimension value of the flattened 2D matrix of Input(X), "
"But received the first dimension value of Input(Scale) is "
"[%d], the second dimension value of the flattened 2D matrix of "
" Input(Scale) is [%d].",
scale.dims()[0],
right));

PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
common::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));

phi::DataType x_dtype = x.dtype();
phi::DataType scale_dtype = scale.dtype();
y->set_dims(x_dim);
y->set_dtype(scale_dtype);

auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
invvar->set_dims({row_shape});
invvar->set_dtype(paddle::DataType::FLOAT32);
}

void FakeDequantizeMaxAbsInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float max_range,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ PADDLE_API void ExpandAsInferMeta(const MetaTensor& x,
const std::vector<int64_t>& target_shape,
MetaTensor* out);

PADDLE_API void FastRMSNormInfermeta(const MetaTensor& x,
const MetaTensor& scale,
float epsilon,
MetaTensor* y,
MetaTensor* invvar);

PADDLE_API void FakeDequantizeMaxAbsInferMeta(const MetaTensor& x,
const MetaTensor& scale,
float max_range,
Expand Down
73 changes: 73 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,79 @@ void DpsgdInferMeta(const MetaTensor& param,
param_out->set_dims(param_dims);
}

void FastLayerNormInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
float epsilon,
MetaTensor* y,
MetaTensor* mean,
MetaTensor* invvar) {
auto x_dim = x.dims();
auto x_ndim = x_dim.size();

auto matrix_dim = common::flatten_to_2d(x_dim, x_ndim - 1);

int64_t right = matrix_dim[1];
if (scale) {
PADDLE_ENFORCE_EQ(scale.dims().size(),
1,
common::errors::InvalidArgument(
"The dimensions of Input(Scale) must be 1, but "
"received dimensions of "
"Input(Scale) is [%d]",
scale.dims().size()));
}

PADDLE_ENFORCE_EQ(
scale.dims()[0],
right,
common::errors::InvalidArgument(
"The first dimension value of Input(Scale) must equal to be the "
"second dimension value of the flattened 2D matrix of Input(X), "
"But received the first dimension value of Input(Scale) is "
"[%d], the second dimension value of the flattened 2D matrix of "
" Input(Scale) is [%d].",
scale.dims()[0],
right));
if (bias) {
PADDLE_ENFORCE_EQ(bias.dims().size(),
1,
common::errors::InvalidArgument(
"The dimensions of Input(Bias) must be 1, but "
"received dimensions of "
"Input(Bias) is [%d]",
bias.dims().size()));
}
PADDLE_ENFORCE_EQ(
bias.dims()[0],
right,
common::errors::InvalidArgument(
"The first dimension value of Input(Bias) must equal to be the "
"second dimension value of the flattened 2D matrix of Input(X), "
"But received the first dimension value of Input(Bias) is "
"[%d], the second dimension value of the flattened 2D matrix of "
" Input(Bias) is [%d].",
bias.dims()[0],
right));

PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
true,
common::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));

phi::DataType x_dtype = x.dtype();
phi::DataType scale_dtype = scale.dtype();
y->set_dims(x_dim);
y->set_dtype(scale_dtype);

auto row_shape = slice_ddim(x_dim, 0, x_dim.size() - 1);
mean->set_dims({row_shape});
mean->set_dtype(paddle::DataType::FLOAT32);
invvar->set_dims({row_shape});
invvar->set_dtype(paddle::DataType::FLOAT32);
}
void FakeQuantizeRangeAbsMaxInferMeta(const MetaTensor& x,
const MetaTensor& in_scale,
const MetaTensor& iter,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ PADDLE_API void FlashAttnV3InferMeta(const MetaTensor& q,
MetaTensor* out,
MetaTensor* softmax_lse);

PADDLE_API void FastLayerNormInfermeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
float epsilon,
MetaTensor* y,
MetaTensor* mean,
MetaTensor* invvar);

PADDLE_API void FlashAttnV3VarlenInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))
"legacy/gpu/int_bincount.cu"
"legacy/gpu/fp8_gemm_blockwise_kernel.cu"
"legacy/gpu/fp8_quant_blockwise_kernel.cu"
"legacy/gpu/fast_layernorm_kernel.cu"
"legacy/gpu/fast_layernorm_grad_kernel.cu"
"legacy/gpu/fast_rmsnorm_kernel.cu"
"legacy/gpu/fast_rmsnorm_grad_kernel.cu"
"legacy/gpu/ln.cu"
"legacy/gpu/ln_bwd_semi_cuda_kernel.cu"
"legacy/gpu/ln_fwd_cuda_kernel.cu"
"fusion/gpu/fused_act_dequant_kernel.cu"
"fusion/gpu/fused_stack_transpose_quant_kernel.cu"
"fusion/gpu/fused_transpose_split_quant_kernel.cu"
Expand Down
99 changes: 99 additions & 0 deletions paddle/phi/kernels/legacy/gpu/fast_layernorm_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */

/*This code is copied from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */

#include "ln.h" // NOLINT
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void LnBwdKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &mean,
const DenseTensor &invvar,
const DenseTensor &y_grad,
float epsilon,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
auto input_type = x.type();
auto weight_type = scale.type();
auto output_type = weight_type;
auto compute_type = paddle::DataType::FLOAT32;

PD_CHECK(y_grad.dtype() == output_type);

auto sizes = x.dims();
PD_CHECK(sizes.size() >= 2);
PD_CHECK(y_grad.dims() == sizes);

int64_t rows = 1;
for (size_t i = 0; i + 1 < sizes.size(); ++i) {
rows *= sizes[i];
}
auto cols = sizes[sizes.size() - 1];

auto hidden_size = scale.numel();

PD_CHECK(mean.numel() == rows);

PD_CHECK(mean.dims() == invvar.dims());

PD_CHECK(scale.numel() == cols);

dev_ctx.template Alloc<T>(x_grad);
dev_ctx.template Alloc<T>(scale_grad);
dev_ctx.template Alloc<T>(bias_grad);

auto place = x.place();

LaunchNormBwd<T, Context>(
dev_ctx,
dev_ctx.stream(),
place,
/* x_ptr */ x.data(),
/* scale_ptr */ scale.data(),
/* mean_ptr */ mean.data(),
/* invvar_ptr */ invvar.data(),
/* y_grad_ptr */ y_grad.data(),
/* x_grad_ptr */ x_grad ? x_grad->data() : nullptr,
/* scale_grad_ptr */ scale_grad ? scale_grad->data() : nullptr,
/* bias_grad_ptr */ bias_grad ? bias_grad->data() : nullptr,
weight_type,
input_type,
output_type,
compute_type,
hidden_size,
rows,
cols,
epsilon);
}
} // namespace phi

PD_REGISTER_KERNEL(fast_ln_grad,
GPU,
ALL_LAYOUT,
phi::LnBwdKernel,
float,
double,
phi::float16,
phi::bfloat16) {}
Loading
Loading