diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index ec6cf2ec4661ce..f92a0c7291fdab 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -98,26 +98,36 @@ bool PyObject_CheckLong(PyObject* obj) { } int32_t PyObject_ToInt32(PyObject* obj) { - int32_t res = 0; + int64_t res = 0; if ((PyLong_Check(obj) && !PyBool_Check(obj)) || // NOLINT PyObject_CheckVarType(obj) || // NOLINT PyObject_CheckDataType(obj) || // NOLINT (PyObject_CheckTensor(obj) && reinterpret_cast(obj)->tensor.numel() == 1)) { - res = static_cast(PyLong_AsLong(obj)); - return res; - } - std::string type_name = - std::string(reinterpret_cast(obj->ob_type)->tp_name); - if (type_name.find("numpy.int") != std::string::npos) { - auto num_obj = PyNumber_Long(obj); - res = static_cast(PyLong_AsLong(num_obj)); - Py_DECREF(num_obj); + res = PyLong_AsLongLong(obj); } else { - PADDLE_THROW( - common::errors::InvalidType("Cannot convert %s to long", type_name)); + std::string type_name = + std::string(reinterpret_cast(obj->ob_type)->tp_name); + if (type_name.find("numpy.int") != std::string::npos) { + auto num_obj = PyNumber_Long(obj); + res = PyLong_AsLongLong(num_obj); + Py_DECREF(num_obj); + } else { + PADDLE_THROW( + common::errors::InvalidType("Cannot convert %s to int32", type_name)); + } } - return res; + + if (res > std::numeric_limits::max() || + res < std::numeric_limits::min()) { + PADDLE_THROW(common::errors::OutOfRange( + "Integer value %ld exceeds int32 range [%d, %d]", + res, + std::numeric_limits::min(), + std::numeric_limits::max())); + } + + return static_cast(res); } uint32_t PyObject_ToUInt32(PyObject* obj) { diff --git a/paddle/phi/kernels/funcs/index_impl.cu.h b/paddle/phi/kernels/funcs/index_impl.cu.h index bafd91b613fd0b..fce2e690c8e358 100644 --- a/paddle/phi/kernels/funcs/index_impl.cu.h +++ b/paddle/phi/kernels/funcs/index_impl.cu.h @@ -57,16 +57,16 @@ void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) { int64_t numel = out->numel(); T *out_data = dev_ctx.template Alloc(out); if (numel <= 0) return; - int vec_size = std::min(4, phi::GetVectorizedSize(out_data)); + size_t vec_size = std::min(4, phi::GetVectorizedSize(out_data)); #ifdef PADDLE_WITH_XPU_KP - int block = 64; - int grid = 8; + size_t block = 64; + size_t grid = 8; auto stream = dev_ctx.x_context()->xpu_stream; #else auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); - int grid = config.block_per_grid.x; - int block = config.thread_per_block.x; + size_t grid = config.block_per_grid.x; + size_t block = config.thread_per_block.x; auto stream = dev_ctx.stream(); #endif size_t main_offset = diff --git a/paddle/phi/kernels/funcs/softmax_impl.h b/paddle/phi/kernels/funcs/softmax_impl.h index 361936305cc820..9f12293c0f6643 100644 --- a/paddle/phi/kernels/funcs/softmax_impl.h +++ b/paddle/phi/kernels/funcs/softmax_impl.h @@ -45,25 +45,26 @@ class SoftmaxEigen { const int axis_dim, const phi::DenseTensor* X, phi::DenseTensor* Y) { - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - constexpr int kAxisDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; + constexpr int64_t kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = logits.dimension(kBatchDim); + const int64_t num_classes = logits.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_axis(kAxisDim); - Eigen::DSizes batch_classes(batch_size, num_classes); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); - Eigen::DSizes one_axis_one(1, axis_dim, 1); - Eigen::DSizes one_axis(1, axis_dim); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); // For numerical stability, logits should be shifted by maximum number along // axis, calculate shifted_logits into softmax tensor for memory reuse. @@ -106,25 +107,26 @@ class SoftmaxEigen { const int axis_dim, const phi::DenseTensor* X, phi::DenseTensor* Y) { - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - constexpr int kAxisDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; + constexpr int64_t kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = logits.dimension(kBatchDim); + const int64_t num_classes = logits.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_axis(kAxisDim); - Eigen::DSizes batch_classes(batch_size, num_classes); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); - Eigen::DSizes one_axis_one(1, axis_dim, 1); - Eigen::DSizes one_axis(1, axis_dim); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); // For numerical stability, logits should be shifted by maximum number along // axis, calculate shifted_logits into softmax tensor for memory reuse. @@ -164,25 +166,26 @@ class SoftmaxEigen { const int axis_dim, const phi::DenseTensor* X, phi::DenseTensor* Y) { - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - constexpr int kAxisDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; + constexpr int64_t kAxisDim = 1; auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = logits.dimension(kBatchDim); + const int64_t num_classes = logits.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_axis(kAxisDim); - Eigen::DSizes batch_classes(batch_size, num_classes); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); - Eigen::DSizes one_axis_one(1, axis_dim, 1); - Eigen::DSizes one_axis(1, axis_dim); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); // For numerical stability, logits should be shifted by maximum number along // axis, calculate shifted_logits into softmax tensor for memory reuse. @@ -236,18 +239,18 @@ class SoftmaxFunctor> { const phi::DenseTensor* X, phi::DenseTensor* Y) { const auto& in_dims = X->dims(); - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; - const int num_classes = in_dims[kClassDim]; - const int batch_size = in_dims[kBatchDim]; - const int num_remain = num_classes / axis_dim; + const int64_t num_classes = in_dims[kClassDim]; + const int64_t batch_size = in_dims[kBatchDim]; + const int64_t num_remain = num_classes / axis_dim; if (num_remain == 1 && phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { const T* in_data = X->data(); T* out_data = Y->data(); - for (int bs = 0; bs < batch_size; ++bs) { + for (int64_t bs = 0; bs < batch_size; ++bs) { T max_val = *std::max_element(in_data, in_data + num_classes); max_val *= static_cast(-1); vec_add_bias( @@ -283,18 +286,19 @@ class SoftmaxGradEigen { auto softmax_grad = EigenMatrix::From(*y_grad); auto logits_grad = EigenMatrix::From(*x_grad); - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = softmax.dimension(kBatchDim); + const int64_t num_classes = softmax.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); auto dot = (softmax * softmax_grad) .reshape(batch_axis_remain) @@ -318,18 +322,19 @@ class SoftmaxGradEigen { auto softmax_grad = EigenMatrix::From(*y_grad); auto logits_grad = EigenMatrix::From(*x_grad); - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = softmax.dimension(kBatchDim); + const int64_t num_classes = softmax.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); auto dot = (softmax * softmax_grad) .reshape(batch_axis_remain) @@ -352,18 +357,19 @@ class SoftmaxGradEigen { auto softmax_grad = EigenMatrix::From(*y_grad); auto logits_grad = EigenMatrix::From(*x_grad); - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - const int num_remain = num_classes / axis_dim; + const int64_t batch_size = softmax.dimension(kBatchDim); + const int64_t num_classes = softmax.dimension(kClassDim); + const int64_t num_remain = num_classes / axis_dim; - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain( + batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); auto dot = (softmax * softmax_grad) .reshape(batch_axis_remain) @@ -393,18 +399,18 @@ class SoftmaxGradFunctor> { const phi::DenseTensor* y_grad, phi::DenseTensor* x_grad) { const auto& out_dims = y->dims(); - constexpr int kBatchDim = 0; - constexpr int kClassDim = 1; - const int num_classes = out_dims[kClassDim]; - const int batch_size = out_dims[kBatchDim]; - const int num_remain = num_classes / axis_dim; + constexpr int64_t kBatchDim = 0; + constexpr int64_t kClassDim = 1; + const int64_t num_classes = out_dims[kClassDim]; + const int64_t batch_size = out_dims[kBatchDim]; + const int64_t num_remain = num_classes / axis_dim; if (num_remain == 1 && phi::backends::cpu::MayIUse(phi::backends::cpu::avx)) { const T* out_data = y->data(); const T* out_grad = y_grad->data(); T* in_grad = x_grad->data(); - for (int bs = 0; bs < batch_size; ++bs) { + for (int64_t bs = 0; bs < batch_size; ++bs) { T scalar; vec_mul_reduce( num_classes, out_grad, out_data, &scalar); diff --git a/paddle/phi/kernels/funcs/stack_and_unstack.h b/paddle/phi/kernels/funcs/stack_and_unstack.h index 25c10e947efd20..f8fab92c1007d1 100644 --- a/paddle/phi/kernels/funcs/stack_and_unstack.h +++ b/paddle/phi/kernels/funcs/stack_and_unstack.h @@ -265,7 +265,7 @@ void UnStackRawKernel(const Context& dev_ctx, // zero sized tensor case if (x.numel() == 0) { - for (int i = 0; i < split_dim; i++) { + for (int64_t i = 0; i < split_dim; i++) { dev_ctx.template Alloc((*outs)[i]); auto x_grad_dim = (*outs)[i]->dims(); (*outs)[i]->Resize(x_grad_dim); diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 7706299a92d92c..39805ca85f3df6 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -387,7 +387,7 @@ __device__ __forceinline__ void ThreadVecWriteVec(T* out, out += blockDim.x; } - const int last = dim_size % (VecSize * blockDim.x); + const IndexType last = dim_size % (VecSize * blockDim.x); T in_v[VecSize]; VecT* in_value = reinterpret_cast(&in_v); @@ -420,7 +420,7 @@ __device__ __forceinline__ void ThreadVecWrite(T* out, T* input, IndexType dim_size, Reduction functor) { - const int last = dim_size % (VecSize * blockDim.x); + const IndexType last = dim_size % (VecSize * blockDim.x); for (IndexType offset = threadIdx.x; offset < dim_size - last; offset += blockDim.x * VecSize) { @@ -446,7 +446,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, MaxWithOne::kValue; using VecT = phi::AlignedVector; - int bid = blockIdx.x; + uint64_t bid = blockIdx.x; T* batch_input = const_cast(src) + (uint64_t)bid * dim_size; T* batch_output = softmax + (uint64_t)bid * dim_size; @@ -1213,7 +1213,7 @@ template void LaunchKeMatrixSoftmaxForwardKernel(const GPUContext& dev_ctx, T* out, const T* input, - int N, + int64_t N, IndexType dim_size) { using AccT = typename phi::dtype::MPTypeTrait::Type; constexpr int kVecSize = @@ -1286,7 +1286,10 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, IndexType D = tensor_dims[2]; if (D == 1) { - if (!UseCudnnSoftmax(dev_ctx, dim, true)) { + if (!UseCudnnSoftmax(dev_ctx, dim, true) || + N > std::numeric_limits::max() || + dim > std::numeric_limits::max() || + D > std::numeric_limits::max()) { int dim_log2 = static_cast(Log2Ceil(dim)); IndexType dim_ceil = 1 << dim_log2; int warp_size = (dim_ceil < 32) ? dim_ceil : 32; @@ -1381,7 +1384,9 @@ void SoftmaxBackwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, if (D == 1) { if (!UseCudnnSoftmax(dev_ctx, dim, true) || - dim > std::numeric_limits::max()) { + N > std::numeric_limits::max() || + dim > std::numeric_limits::max() || + D > std::numeric_limits::max()) { int dim_log2 = Log2Ceil(dim); IndexType dim_ceil = 1 << dim_log2; int warp_size = (dim_ceil < 32) ? dim_ceil : 32; diff --git a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h index e01165d57b135d..04996ff5e51eb1 100644 --- a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h @@ -41,10 +41,10 @@ void SoftmaxGradKernel(const Context& dev_ctx, } const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); - int axis_dim = x_grad->dims()[calc_axis]; + int64_t axis_dim = x_grad->dims()[calc_axis]; - const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims()); - const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims()); + const int64_t n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims()); + const int64_t d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims()); DenseTensor dX_2d, Out_2d, dOut_2d; dX_2d.ShareDataWith(*x_grad).Resize({n, d}); Out_2d.ShareDataWith(out).Resize({n, d}); diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h index 25674a5f2fb36e..65639b323ffe9f 100644 --- a/paddle/phi/kernels/impl/softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -28,7 +28,7 @@ void SoftmaxKernel(const Context& dev_ctx, DenseTensor* out) { const int rank = x.dims().size(); const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); - int axis_dim = x.dims()[calc_axis]; + int64_t axis_dim = x.dims()[calc_axis]; // allocate memory on device. dev_ctx.template Alloc(out); @@ -42,8 +42,8 @@ void SoftmaxKernel(const Context& dev_ctx, return; } - const int n = phi::funcs::SizeToAxis(calc_axis, x.dims()); - const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims()); + const int64_t n = phi::funcs::SizeToAxis(calc_axis, x.dims()); + const int64_t d = phi::funcs::SizeFromAxis(calc_axis, x.dims()); DenseTensor X_2d, Out_2d; X_2d.ShareDataWith(x).Resize({n, d}); Out_2d.ShareDataWith(*out).Resize({n, d});