Skip to content

Commit 8b12875

Browse files
authored
[large tensor] Use int64_t for CUDA indexing to avoid overflow (#76303) (#76348)
1 parent 5fe294d commit 8b12875

File tree

67 files changed

+448
-151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+448
-151
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
212212
using VecType = phi::kps::details::VectorType<Type, VecSize>;
213213
VecType vec_temp;
214214

215-
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
215+
int64_t thread_offset =
216+
static_cast<int64_t>(threadIdx.x) +
217+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
216218
const VecType *__restrict__ vec_input =
217219
reinterpret_cast<const VecType *__restrict__>(ins[Index]);
218220
vec_temp = vec_input[thread_offset];

paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ __global__ void KeFastCollectiveGruGate(T *gate_value,
128128
T c0 = 0.0f;
129129
T b0[Tiled_size];
130130

131-
int COL = blockIdx.x * blockDim.x + threadIdx.x;
131+
int64_t COL =
132+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
133+
static_cast<int64_t>(threadIdx.x);
132134
int Tiled_mask = ((1 << Tiled_size) - 1);
133135
// Tiled matrix multiply using register shift, faster than sm.
134136
if (prev_output_value) {
@@ -185,7 +187,9 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight,
185187
int frame_size,
186188
ActivationType act_node,
187189
bool origin_mode) {
188-
int COL = blockIdx.x * blockDim.x + threadIdx.x;
190+
int64_t COL =
191+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
192+
static_cast<int64_t>(threadIdx.x);
189193

190194
T a0 = 0.0f;
191195
T b0[Tiled_size];

paddle/phi/kernels/funcs/fake_quantize_functor.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ struct QuantizeDataType<phi::float16> {
2929

3030
template <typename T>
3131
__global__ void FindAbsMaxKernel(const T *in, const int64_t n, T *out) {
32-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
32+
int64_t bid =
33+
static_cast<int64_t>(threadIdx.x) +
34+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
3335
int tid = threadIdx.x;
3436

3537
extern __shared__ char *shared_max_data_tmp[];
@@ -70,7 +72,9 @@ __global__ void ClipAndQuantKernel(const T *in,
7072
const int round_type,
7173
const int64_t n,
7274
T *out) {
73-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
75+
int64_t bid =
76+
static_cast<int64_t>(threadIdx.x) +
77+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
7478
int tid = threadIdx.x;
7579

7680
using ComputeDataType = typename QuantizeDataType<T>::type;
@@ -155,7 +159,9 @@ __global__ void ClipAndQuantDequantKernel(const T *in,
155159
const int round_type,
156160
const int64_t n,
157161
T *out) {
158-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
162+
int64_t bid =
163+
static_cast<int64_t>(threadIdx.x) +
164+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
159165
int tid = threadIdx.x;
160166

161167
using ComputeDataType = typename QuantizeDataType<T>::type;

paddle/phi/kernels/funcs/fc_functor.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ struct FcTypeTraits<float16> {
6363

6464
template <typename T, bool DoRelu>
6565
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
66-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
66+
int64_t tid =
67+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
68+
static_cast<int64_t>(threadIdx.x);
6769
if (tid < num) {
6870
int bias_idx = tid % K;
6971
const T bias_ptr = bias[bias_idx];

paddle/phi/kernels/funcs/math_function.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ DEFINE_GPU_TRANS(6);
202202

203203
template <typename T>
204204
__global__ void FillConstantKernel(const int N, T* a, const T val) {
205-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
205+
for (int64_t i =
206+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
207+
static_cast<int64_t>(threadIdx.x);
208+
i < N;
206209
i += blockDim.x * gridDim.x) {
207210
a[i] = val;
208211
}

paddle/phi/kernels/funcs/norm_utils.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,9 @@ __global__ void DoubleGradComputeDXWithGlobal(const T *dy,
370370
const int sample_size,
371371
const int64_t num,
372372
T *dx) {
373-
int gid = blockIdx.x * blockDim.x + threadIdx.x;
373+
int64_t gid =
374+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
375+
static_cast<int64_t>(threadIdx.x);
374376
int stride = blockDim.x * gridDim.x;
375377
if (ddscale != nullptr) {
376378
for (int64_t i = gid; i < num; i += stride) {
@@ -397,7 +399,9 @@ __global__ void DoubleGradComputeDDYWithGlobal(const T *ddx,
397399
const int sample_size,
398400
const int64_t num,
399401
T *ddy) {
400-
int gid = blockIdx.x * blockDim.x + threadIdx.x;
402+
int64_t gid =
403+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
404+
static_cast<int64_t>(threadIdx.x);
401405
int stride = blockDim.x * gridDim.x;
402406

403407
if (ddx != nullptr) {

paddle/phi/kernels/funcs/quant_dequant.h

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ __global__ void QuantKernel(const T* input,
9494
const int round_type,
9595
const float max_bound,
9696
const float min_bound) {
97-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
98-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
97+
int64_t n_id =
98+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
99+
static_cast<int64_t>(threadIdx.x))
100+
<< 2;
101+
int64_t m_id =
102+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
103+
static_cast<int64_t>(threadIdx.y);
99104

100105
bool check = ((m_id < m) && (n_id < n));
101106
if (check) {
@@ -121,8 +126,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
121126
const int round_type,
122127
const float max_bound,
123128
const float min_bound) {
124-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
125-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
129+
int64_t n_id =
130+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
131+
static_cast<int64_t>(threadIdx.x))
132+
<< 2;
133+
int64_t m_id =
134+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
135+
static_cast<int64_t>(threadIdx.y);
126136

127137
bool check = ((m_id < m) && (n_id < n));
128138
if (check) {
@@ -148,8 +158,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
148158
const int round_type,
149159
const float max_bound,
150160
const float min_bound) {
151-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3;
152-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
161+
int64_t n_id =
162+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
163+
static_cast<int64_t>(threadIdx.x)) *
164+
3;
165+
int64_t m_id =
166+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
167+
static_cast<int64_t>(threadIdx.y);
153168

154169
bool check = ((m_id < m) && (n_id < n));
155170
if (check) {
@@ -173,8 +188,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
173188
const int round_type,
174189
const float max_bound,
175190
const float min_bound) {
176-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
177-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
191+
int64_t n_id =
192+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
193+
static_cast<int64_t>(threadIdx.x)) *
194+
2;
195+
int64_t m_id =
196+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
197+
static_cast<int64_t>(threadIdx.y);
178198

179199
bool check = ((m_id < m) && (n_id < n));
180200
if (check) {
@@ -196,8 +216,12 @@ __global__ void QuantKernelWithVecSize(const T* input,
196216
const int round_type,
197217
const float max_bound,
198218
const float min_bound) {
199-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x);
200-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
219+
int64_t n_id =
220+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
221+
static_cast<int64_t>(threadIdx.x));
222+
int64_t m_id =
223+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
224+
static_cast<int64_t>(threadIdx.y);
201225

202226
bool check = ((m_id < m) && (n_id < n));
203227
if (check) {
@@ -323,7 +347,10 @@ __global__ void DequantKernel(T* output,
323347
const float* dequant_out_scale_data) {
324348
int numel = m * n;
325349
int stride = blockDim.x * gridDim.x * VecSize;
326-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
350+
int64_t idx =
351+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
352+
static_cast<int64_t>(threadIdx.x)) *
353+
VecSize;
327354
int col_id = idx % n;
328355

329356
phi::AlignedVector<int32_t, VecSize> in_vec;
@@ -369,7 +396,10 @@ __global__ void DequantKernelWithScaleOfInputAndWeight(
369396
float quant_max_bound) {
370397
int numel = m * n;
371398
int stride = blockDim.x * gridDim.x * VecSize;
372-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
399+
int64_t idx =
400+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
401+
static_cast<int64_t>(threadIdx.x)) *
402+
VecSize;
373403
int col_id = idx % n;
374404

375405
phi::AlignedVector<int32_t, VecSize> in_vec;

paddle/phi/kernels/funcs/scatter.cu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ inline DenseTensor restride_dim(const phi::DenseTensor& src,
402402
template <int nt, int vt, typename func_t>
403403
__global__ void scatter_gather_elementwise_kernel(int N, func_t f) {
404404
constexpr int nv = nt * vt;
405-
int idx = nv * blockIdx.x + threadIdx.x;
405+
int64_t idx =
406+
nv * static_cast<int64_t>(blockIdx.x) + static_cast<int64_t>(threadIdx.x);
406407

407408
#pragma unroll
408409
for (int i = 0; i < vt; ++i) {

paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ __global__ void FlattenIndicesKernel(const IntT* indices,
2626
const int64_t non_zero_num,
2727
const int64_t sparse_dim,
2828
IntT* out) {
29-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
29+
int64_t tid =
30+
static_cast<int64_t>(threadIdx.x) +
31+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
3032
phi::funcs::sparse::FlattenIndices<IntT>(indices,
3133
sparse_offsets,
3234
non_zero_num,
@@ -42,7 +44,9 @@ __global__ void IndexToCoordinateKernel(const IntT* index,
4244
const int64_t non_zero_num,
4345
const int64_t sparse_dim,
4446
IntT* indices) {
45-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
47+
int64_t tid =
48+
static_cast<int64_t>(threadIdx.x) +
49+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
4650
IndexToCoordinate(index,
4751
dims,
4852
non_zero_num,

paddle/phi/kernels/funcs/sparse/scatter.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ __global__ void ScatterKernel(const T* input,
4141
const int rulebook_len,
4242
const int channels,
4343
T* out) {
44-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
44+
int64_t tid =
45+
static_cast<int64_t>(threadIdx.x) +
46+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
4547
const int vec_channels = channels / VecSize;
4648
using LoadT = phi::AlignedVector<T, VecSize>;
4749
using StoreT = phi::AlignedVector<T, VecSize>;
@@ -82,7 +84,9 @@ __global__ void ScatterKernelV2(const T* input,
8284
const int channels,
8385
const int buffer_counts,
8486
T* out) {
85-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
87+
int64_t tid =
88+
static_cast<int64_t>(threadIdx.x) +
89+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
8690
const int vec_channels = channels / VecSize;
8791
using LoadT = phi::AlignedVector<T, VecSize>;
8892
using StoreT = phi::AlignedVector<T, VecSize>;

0 commit comments

Comments
 (0)