@@ -94,8 +94,13 @@ __global__ void QuantKernel(const T* input,
9494 const int round_type,
9595 const float max_bound,
9696 const float min_bound) {
97- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) << 2 ;
98- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
97+ int64_t n_id =
98+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
99+ static_cast <int64_t >(threadIdx.x ))
100+ << 2 ;
101+ int64_t m_id =
102+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
103+ static_cast <int64_t >(threadIdx.y );
99104
100105 bool check = ((m_id < m) && (n_id < n));
101106 if (check) {
@@ -121,8 +126,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
121126 const int round_type,
122127 const float max_bound,
123128 const float min_bound) {
124- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) << 2 ;
125- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
129+ int64_t n_id =
130+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
131+ static_cast <int64_t >(threadIdx.x ))
132+ << 2 ;
133+ int64_t m_id =
134+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
135+ static_cast <int64_t >(threadIdx.y );
126136
127137 bool check = ((m_id < m) && (n_id < n));
128138 if (check) {
@@ -148,8 +158,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
148158 const int round_type,
149159 const float max_bound,
150160 const float min_bound) {
151- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) * 3 ;
152- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
161+ int64_t n_id =
162+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
163+ static_cast <int64_t >(threadIdx.x )) *
164+ 3 ;
165+ int64_t m_id =
166+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
167+ static_cast <int64_t >(threadIdx.y );
153168
154169 bool check = ((m_id < m) && (n_id < n));
155170 if (check) {
@@ -173,8 +188,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
173188 const int round_type,
174189 const float max_bound,
175190 const float min_bound) {
176- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) * 2 ;
177- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
191+ int64_t n_id =
192+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
193+ static_cast <int64_t >(threadIdx.x )) *
194+ 2 ;
195+ int64_t m_id =
196+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
197+ static_cast <int64_t >(threadIdx.y );
178198
179199 bool check = ((m_id < m) && (n_id < n));
180200 if (check) {
@@ -196,8 +216,12 @@ __global__ void QuantKernelWithVecSize(const T* input,
196216 const int round_type,
197217 const float max_bound,
198218 const float min_bound) {
199- int n_id = (blockIdx.x * blockDim.x + threadIdx.x );
200- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
219+ int64_t n_id =
220+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
221+ static_cast <int64_t >(threadIdx.x ));
222+ int64_t m_id =
223+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
224+ static_cast <int64_t >(threadIdx.y );
201225
202226 bool check = ((m_id < m) && (n_id < n));
203227 if (check) {
@@ -323,7 +347,10 @@ __global__ void DequantKernel(T* output,
323347 const float * dequant_out_scale_data) {
324348 int numel = m * n;
325349 int stride = blockDim.x * gridDim.x * VecSize;
326- int idx = (blockIdx.x * blockDim.x + threadIdx.x ) * VecSize;
350+ int64_t idx =
351+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
352+ static_cast <int64_t >(threadIdx.x )) *
353+ VecSize;
327354 int col_id = idx % n;
328355
329356 phi::AlignedVector<int32_t , VecSize> in_vec;
@@ -369,7 +396,10 @@ __global__ void DequantKernelWithScaleOfInputAndWeight(
369396 float quant_max_bound) {
370397 int numel = m * n;
371398 int stride = blockDim.x * gridDim.x * VecSize;
372- int idx = (blockIdx.x * blockDim.x + threadIdx.x ) * VecSize;
399+ int64_t idx =
400+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
401+ static_cast <int64_t >(threadIdx.x )) *
402+ VecSize;
373403 int col_id = idx % n;
374404
375405 phi::AlignedVector<int32_t , VecSize> in_vec;
0 commit comments