3535 CHECK_TH_CUDA (x); \
3636 CHECK_CONTIGUOUS (x)
3737
38- #define FINAL_MASK 0xffffffff
38+ #ifdef USE_ROCM
39+ #define FINAL_MASK 0xffffffffffffffffULL
40+ #else
41+ #define FINAL_MASK 0xffffffff
42+ #endif
3943
40- // TODO: suport for AMD ROCM platform
41- #ifndef USE_ROCM
4244namespace tensorrt_llm ::common {
4345template <typename T, int num>
4446struct packed_as ;
@@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
6062
6163template <typename T>
6264__inline__ __device__ T warpReduceSum (T val) {
63- #pragma unroll
65+ #pragma unroll
6466 for (int mask = 16 ; mask > 0 ; mask >>= 1 )
6567 val += __shfl_xor_sync (FINAL_MASK, val, mask, 32 );
6668 return val;
@@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
9799 int64_t const * position_ids, // Position IDs for RoPE
98100 int const num_tokens // Number of tokens
99101) {
100- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
102+ #if ( !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
101103 if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
102104 std::is_same_v<scalar_t_cache, c10::BFloat16>) {
103105 return ;
104106 } else {
105- #endif
107+ #endif
106108
107109 using Converter = vllm::_typeConvert<scalar_t_in>;
108110 static_assert (Converter::exists,
@@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
179181 {
180182 vec_T vec = *reinterpret_cast <vec_T const *>(&qkv[offsetThread]);
181183 constexpr int num_packed_elems = elemSizeBytes / sizeof (T2_in);
182- #pragma unroll
184+ #pragma unroll
183185 for (int i = 0 ; i < num_packed_elems; i++) {
184186 // Interpret the generic vector chunk as the specific packed type
185187 T2_in packed_val = *(reinterpret_cast <T2_in*>(&vec) + i);
@@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
200202 float rms_rcp = rsqrtf (sumOfSquares / static_cast <float >(head_dim) + eps);
201203
202204 // Normalize elements
203- #pragma unroll
205+ #pragma unroll
204206 for (int i = 0 ; i < numElemsPerThread; i++) {
205207 int dim = laneId * numElemsPerThread + i;
206208 float weight = isQ ? Converter::convert (q_weight[dim])
@@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
222224
223225 if constexpr (interleave) {
224226 // Perform interleaving. Use pre-computed cos/sin values.
225- #pragma unroll
227+ #pragma unroll
226228 for (int i = 0 ; i < numElemsPerThread / 2 ; ++i) {
227229 int const idx0 = 2 * i;
228230 int const idx1 = 2 * i + 1 ;
@@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
245247 __syncwarp ();
246248 // Get the data from the other half of the warp. Use pre-computed cos/sin
247249 // values.
248- #pragma unroll
250+ #pragma unroll
249251 for (int i = 0 ; i < numElemsPerThread; i++) {
250- elements2[i] = __shfl_xor_sync (0xffffffff , elements[i], 16 );
252+ elements2[i] = __shfl_xor_sync (FINAL_MASK , elements[i], 16 );
251253 if (laneId < 16 ) {
252254 elements2[i] = -elements2[i];
253255 }
@@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
269271 {
270272 vec_T vec;
271273 constexpr int num_packed_elems = elemSizeBytes / sizeof (T2_in);
272- #pragma unroll
274+ #pragma unroll
273275 for (int i = 0 ; i < num_packed_elems; i++) {
274276 // Convert from float2 back to the specific packed type
275277 T2_in packed_val = Converter::convert (
@@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
280282 *reinterpret_cast <vec_T*>(&qkv[offsetThread]) = vec;
281283 }
282284
283- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
285+ #if ( !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
284286 }
285- #endif
287+ #endif
286288}
287289
288- // Borrowed from
289- // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
290- #define DISPATCH_INTERLEAVE (interleave, INTERLEAVE, ...) \
291- if (interleave) { \
292- const bool INTERLEAVE = true ; \
293- __VA_ARGS__ \
294- } else { \
295- const bool INTERLEAVE = false ; \
296- __VA_ARGS__ \
297- }
290+ // Borrowed from
291+ // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
292+ #define DISPATCH_INTERLEAVE (interleave, INTERLEAVE, ...) \
293+ if (interleave) { \
294+ const bool INTERLEAVE = true ; \
295+ __VA_ARGS__ \
296+ } else { \
297+ const bool INTERLEAVE = false ; \
298+ __VA_ARGS__ \
299+ }
298300
299301template <typename scalar_t_in, typename scalar_t_cache>
300302void launchFusedQKNormRope (void * qkv, int const num_tokens,
@@ -413,6 +415,4 @@ void fused_qk_norm_rope(
413415 stream);
414416 });
415417 });
416- }
417-
418- #endif // not USE_ROCM
418+ }
0 commit comments