Skip to content

Commit edb59a9

Browse files
authored
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <[email protected]>
1 parent c5f10cc commit edb59a9

File tree

6 files changed

+37
-38
lines changed

6 files changed

+37
-38
lines changed

csrc/fused_qknorm_rope_kernel.cu

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
CHECK_TH_CUDA(x); \
3636
CHECK_CONTIGUOUS(x)
3737

38-
#define FINAL_MASK 0xffffffff
38+
#ifdef USE_ROCM
39+
#define FINAL_MASK 0xffffffffffffffffULL
40+
#else
41+
#define FINAL_MASK 0xffffffff
42+
#endif
3943

40-
// TODO: suport for AMD ROCM platform
41-
#ifndef USE_ROCM
4244
namespace tensorrt_llm::common {
4345
template <typename T, int num>
4446
struct packed_as;
@@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
6062

6163
template <typename T>
6264
__inline__ __device__ T warpReduceSum(T val) {
63-
#pragma unroll
65+
#pragma unroll
6466
for (int mask = 16; mask > 0; mask >>= 1)
6567
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
6668
return val;
@@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
9799
int64_t const* position_ids, // Position IDs for RoPE
98100
int const num_tokens // Number of tokens
99101
) {
100-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
102+
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
101103
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
102104
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
103105
return;
104106
} else {
105-
#endif
107+
#endif
106108

107109
using Converter = vllm::_typeConvert<scalar_t_in>;
108110
static_assert(Converter::exists,
@@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
179181
{
180182
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
181183
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
182-
#pragma unroll
184+
#pragma unroll
183185
for (int i = 0; i < num_packed_elems; i++) {
184186
// Interpret the generic vector chunk as the specific packed type
185187
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
@@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
200202
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
201203

202204
// Normalize elements
203-
#pragma unroll
205+
#pragma unroll
204206
for (int i = 0; i < numElemsPerThread; i++) {
205207
int dim = laneId * numElemsPerThread + i;
206208
float weight = isQ ? Converter::convert(q_weight[dim])
@@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
222224

223225
if constexpr (interleave) {
224226
// Perform interleaving. Use pre-computed cos/sin values.
225-
#pragma unroll
227+
#pragma unroll
226228
for (int i = 0; i < numElemsPerThread / 2; ++i) {
227229
int const idx0 = 2 * i;
228230
int const idx1 = 2 * i + 1;
@@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
245247
__syncwarp();
246248
// Get the data from the other half of the warp. Use pre-computed cos/sin
247249
// values.
248-
#pragma unroll
250+
#pragma unroll
249251
for (int i = 0; i < numElemsPerThread; i++) {
250-
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
252+
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
251253
if (laneId < 16) {
252254
elements2[i] = -elements2[i];
253255
}
@@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
269271
{
270272
vec_T vec;
271273
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
272-
#pragma unroll
274+
#pragma unroll
273275
for (int i = 0; i < num_packed_elems; i++) {
274276
// Convert from float2 back to the specific packed type
275277
T2_in packed_val = Converter::convert(
@@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
280282
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
281283
}
282284

283-
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
285+
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
284286
}
285-
#endif
287+
#endif
286288
}
287289

288-
// Borrowed from
289-
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
290-
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
291-
if (interleave) { \
292-
const bool INTERLEAVE = true; \
293-
__VA_ARGS__ \
294-
} else { \
295-
const bool INTERLEAVE = false; \
296-
__VA_ARGS__ \
297-
}
290+
// Borrowed from
291+
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
292+
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
293+
if (interleave) { \
294+
const bool INTERLEAVE = true; \
295+
__VA_ARGS__ \
296+
} else { \
297+
const bool INTERLEAVE = false; \
298+
__VA_ARGS__ \
299+
}
298300

299301
template <typename scalar_t_in, typename scalar_t_cache>
300302
void launchFusedQKNormRope(void* qkv, int const num_tokens,
@@ -413,6 +415,4 @@ void fused_qk_norm_rope(
413415
stream);
414416
});
415417
});
416-
}
417-
418-
#endif // not USE_ROCM
418+
}

csrc/torch_bindings.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
175175
"float epsilon) -> ()");
176176
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
177177

178-
#ifndef USE_ROCM
179178
// Function for fused QK Norm and RoPE
180179
ops.def(
181180
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
182181
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
183182
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
184183
"bool is_neox, Tensor position_ids) -> ()");
185184
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
186-
#endif
187185

188186
// Apply repetition penalties to logits in-place
189187
ops.def(

csrc/type_convert.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
6767
}
6868
};
6969

70-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
70+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
7171
// CUDA_ARCH < 800 does not have BF16 support
72-
// TODO: Add in ROCm support once public headers handle bf16 maturely
72+
// ROCm 7.0+ supports bfloat16
7373
template <>
7474
struct _typeConvert<c10::BFloat16> {
7575
static constexpr bool exists = true;
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
8989
return __float22bfloat162_rn(x);
9090
}
9191
};
92-
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
92+
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
93+
// defined(USE_ROCM)
9394
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
9495
// 12000))
9596

tests/compile/test_qk_norm_rope_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
113113
@pytest.mark.parametrize("enable_rope_custom_op", [True])
114114
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
115115
@pytest.mark.skipif(
116-
not current_platform.is_cuda(),
117-
reason="Only test on cuda platform",
116+
not current_platform.is_cuda_alike(),
117+
reason="Only test on cuda and rocm platform",
118118
)
119119
def test_qk_norm_rope_fusion(
120120
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype

tests/kernels/core/test_fused_qk_norm_rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
4444

4545

4646
@pytest.mark.skipif(
47-
not current_platform.is_cuda(),
48-
reason="fused_qk_norm_rope custom op requires cuda platform",
47+
not current_platform.is_cuda_alike(),
48+
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
4949
)
5050
@pytest.mark.parametrize("device", CUDA_DEVICES)
5151
@pytest.mark.parametrize("dtype", DTYPES)

vllm/config/compilation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ def __post_init__(self) -> None:
184184
"Fusion enabled but reshape elimination disabled. "
185185
"Allreduce + rms norm + quant (fp8) fusion might not work"
186186
)
187-
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
187+
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
188188
logger.warning_once(
189189
"QK Norm + RoPE fusion enabled but the current platform is not "
190-
"CUDA. The fusion will be disabled."
190+
"CUDA or ROCm. The fusion will be disabled."
191191
)
192192
self.enable_qk_norm_rope_fusion = False
193193

0 commit comments

Comments
 (0)