Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 117 additions & 87 deletions activation/activation/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
// Required for Microsoft math constants and must be defined before including <cmath>
// https://learn.microsoft.com/en-us/cpp/c-runtime-library/math-constants?view=msvc-170
#define _USE_MATH_DEFINES

#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/Exception.h>

// The shim's stream accessor is guarded by USE_CUDA, so declare it here.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index, void** ret_stream);

#include <cuda_runtime.h>

#include <algorithm>
#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

using torch::stable::Tensor;

namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
Expand Down Expand Up @@ -65,54 +80,59 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
void* stream_ptr = nullptr; \
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( \
input.get_device_index(), &stream_ptr)); \
const cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
<<<grid, block, 0, stream>>>( \
static_cast<scalar_t*>(out.data_ptr()), \
static_cast<const scalar_t*>(input.data_ptr()), d); \
});

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(Tensor& out, // [..., d]
Tensor const& input) // [..., 2 * d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}

void mul_and_silu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void mul_and_silu(Tensor& out, // [..., d]
Tensor const& input) // [..., 2 * d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_and_mul(Tensor& out, // [..., d]
Tensor const& input) // [..., 2 * d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
}

void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_tanh_and_mul(Tensor& out, // [..., d]
Tensor const& input) // [..., 2 * d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
}

Expand All @@ -138,26 +158,30 @@ __global__ void act_and_mul_kernel_with_param(

} // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
void* stream_ptr = nullptr; \
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( \
input.get_device_index(), &stream_ptr)); \
const cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
static_cast<scalar_t*>(out.data_ptr()), \
static_cast<const scalar_t*>(input.data_ptr()), d, PARAM); \
});

void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
void fatrelu_and_mul(Tensor& out, // [..., d],
Tensor const& input, // [..., 2 * d]
double threshold) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {
Expand All @@ -178,18 +202,24 @@ __global__ void activation_kernel(
} // namespace vllm

// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const torch::stable::accelerator::DeviceGuard device_guard( \
input.get_device_index()); \
void* stream_ptr = nullptr; \
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( \
input.get_device_index(), &stream_ptr)); \
const cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr); \
VLLM_STABLE_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
static_cast<scalar_t*>(out.data_ptr()), \
static_cast<const scalar_t*>(input.data_ptr()), d); \
});

namespace vllm {

Expand Down Expand Up @@ -217,50 +247,50 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) {

} // namespace vllm

void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_new(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_fast(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_quick(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}

void gelu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
}

void gelu_tanh(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_tanh(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
}

void silu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void silu(Tensor& out, // [..., d]
Tensor const& input) // [..., d]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
STD_TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
LAUNCH_ACTIVATION_KERNEL(vllm::silu_kernel);
}
Loading