Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,36 @@ bool PyObject_CheckLong(PyObject* obj) {
}

int32_t PyObject_ToInt32(PyObject* obj) {
int32_t res = 0;
int64_t res = 0;
if ((PyLong_Check(obj) && !PyBool_Check(obj)) || // NOLINT
PyObject_CheckVarType(obj) || // NOLINT
PyObject_CheckDataType(obj) || // NOLINT
(PyObject_CheckTensor(obj) &&
reinterpret_cast<TensorObject*>(obj)->tensor.numel() == 1)) {
res = static_cast<int32_t>(PyLong_AsLong(obj));
return res;
}
std::string type_name =
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
if (type_name.find("numpy.int") != std::string::npos) {
auto num_obj = PyNumber_Long(obj);
res = static_cast<int32_t>(PyLong_AsLong(num_obj));
Py_DECREF(num_obj);
res = PyLong_AsLongLong(obj);
} else {
PADDLE_THROW(
common::errors::InvalidType("Cannot convert %s to long", type_name));
std::string type_name =
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
if (type_name.find("numpy.int") != std::string::npos) {
auto num_obj = PyNumber_Long(obj);
res = PyLong_AsLongLong(num_obj);
Py_DECREF(num_obj);
} else {
PADDLE_THROW(
common::errors::InvalidType("Cannot convert %s to int32", type_name));
}
}
return res;

if (res > std::numeric_limits<int32_t>::max() ||
res < std::numeric_limits<int32_t>::min()) {
PADDLE_THROW(common::errors::OutOfRange(
"Integer value %ld exceeds int32 range [%d, %d]",
res,
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()));
}

return static_cast<int32_t>(res);
}

uint32_t PyObject_ToUInt32(PyObject* obj) {
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/funcs/index_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) {
int64_t numel = out->numel();
T *out_data = dev_ctx.template Alloc<T>(out);
if (numel <= 0) return;
int vec_size = std::min(4, phi::GetVectorizedSize(out_data));
size_t vec_size = std::min(4, phi::GetVectorizedSize(out_data));
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
int grid = 8;
size_t block = 64;
size_t grid = 8;
auto stream = dev_ctx.x_context()->xpu_stream;
#else
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
size_t grid = config.block_per_grid.x;
size_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
#endif
size_t main_offset =
Expand Down
174 changes: 90 additions & 84 deletions paddle/phi/kernels/funcs/softmax_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,26 @@ class SoftmaxEigen {
const int axis_dim,
const phi::DenseTensor* X,
phi::DenseTensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;
constexpr int64_t kAxisDim = 1;

auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);

const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = logits.dimension(kBatchDim);
const int64_t num_classes = logits.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);

// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
Expand Down Expand Up @@ -106,25 +107,26 @@ class SoftmaxEigen<DeviceContext, phi::float16> {
const int axis_dim,
const phi::DenseTensor* X,
phi::DenseTensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;
constexpr int64_t kAxisDim = 1;

auto logits = EigenMatrix<phi::float16>::From(*X);
auto softmax = EigenMatrix<phi::float16>::From(*Y);

const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = logits.dimension(kBatchDim);
const int64_t num_classes = logits.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);

// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
Expand Down Expand Up @@ -164,25 +166,26 @@ class SoftmaxEigen<DeviceContext, phi::bfloat16> {
const int axis_dim,
const phi::DenseTensor* X,
phi::DenseTensor* Y) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;
constexpr int64_t kAxisDim = 1;

auto logits = EigenMatrix<phi::bfloat16>::From(*X);
auto softmax = EigenMatrix<phi::bfloat16>::From(*Y);

const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = logits.dimension(kBatchDim);
const int64_t num_classes = logits.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 1> along_axis(kAxisDim);
Eigen::DSizes<int64_t, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int64_t, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);

// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
Expand Down Expand Up @@ -236,18 +239,18 @@ class SoftmaxFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
const phi::DenseTensor* X,
phi::DenseTensor* Y) {
const auto& in_dims = X->dims();
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;

const int num_classes = in_dims[kClassDim];
const int batch_size = in_dims[kBatchDim];
const int num_remain = num_classes / axis_dim;
const int64_t num_classes = in_dims[kClassDim];
const int64_t batch_size = in_dims[kBatchDim];
const int64_t num_remain = num_classes / axis_dim;

if (num_remain == 1 &&
phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) {
const T* in_data = X->data<T>();
T* out_data = Y->data<T>();
for (int bs = 0; bs < batch_size; ++bs) {
for (int64_t bs = 0; bs < batch_size; ++bs) {
T max_val = *std::max_element(in_data, in_data + num_classes);
max_val *= static_cast<T>(-1);
vec_add_bias<T, phi::backends::cpu::avx>(
Expand Down Expand Up @@ -283,18 +286,19 @@ class SoftmaxGradEigen {
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);

constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;

const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = softmax.dimension(kBatchDim);
const int64_t num_classes = softmax.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);

auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
Expand All @@ -318,18 +322,19 @@ class SoftmaxGradEigen<DeviceContext, phi::float16> {
auto softmax_grad = EigenMatrix<phi::float16>::From(*y_grad);
auto logits_grad = EigenMatrix<phi::float16>::From(*x_grad);

constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;

const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = softmax.dimension(kBatchDim);
const int64_t num_classes = softmax.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);

auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
Expand All @@ -352,18 +357,19 @@ class SoftmaxGradEigen<DeviceContext, phi::bfloat16> {
auto softmax_grad = EigenMatrix<phi::bfloat16>::From(*y_grad);
auto logits_grad = EigenMatrix<phi::bfloat16>::From(*x_grad);

constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;

const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
const int64_t batch_size = softmax.dimension(kBatchDim);
const int64_t num_classes = softmax.dimension(kClassDim);
const int64_t num_remain = num_classes / axis_dim;

Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int64_t, 1> along_class(kClassDim);
Eigen::DSizes<int64_t, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int64_t, 2> one_by_class(1, num_classes);
Eigen::DSizes<int64_t, 3> batch_axis_remain(
batch_size, axis_dim, num_remain);
Eigen::DSizes<int64_t, 2> one_axis(1, axis_dim);

auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
Expand Down Expand Up @@ -393,18 +399,18 @@ class SoftmaxGradFunctor<DeviceContext, T, enable_if_CPU<DeviceContext>> {
const phi::DenseTensor* y_grad,
phi::DenseTensor* x_grad) {
const auto& out_dims = y->dims();
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int num_classes = out_dims[kClassDim];
const int batch_size = out_dims[kBatchDim];
const int num_remain = num_classes / axis_dim;
constexpr int64_t kBatchDim = 0;
constexpr int64_t kClassDim = 1;
const int64_t num_classes = out_dims[kClassDim];
const int64_t batch_size = out_dims[kBatchDim];
const int64_t num_remain = num_classes / axis_dim;

if (num_remain == 1 &&
phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) {
const T* out_data = y->data<T>();
const T* out_grad = y_grad->data<T>();
T* in_grad = x_grad->data<T>();
for (int bs = 0; bs < batch_size; ++bs) {
for (int64_t bs = 0; bs < batch_size; ++bs) {
T scalar;
vec_mul_reduce<T, phi::backends::cpu::avx>(
num_classes, out_grad, out_data, &scalar);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/stack_and_unstack.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void UnStackRawKernel(const Context& dev_ctx,

// zero sized tensor case
if (x.numel() == 0) {
for (int i = 0; i < split_dim; i++) {
for (int64_t i = 0; i < split_dim; i++) {
dev_ctx.template Alloc<T>((*outs)[i]);
auto x_grad_dim = (*outs)[i]->dims();
(*outs)[i]->Resize(x_grad_dim);
Expand Down
Loading
Loading