diff --git a/.gitignore b/.gitignore index 6049c2cdb..63408699f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist .idea .vscode tmp/ +requirements-musa.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 573ff399c..e7e043a1f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,4 +10,4 @@ repos: rev: 6.1.0 hooks: - id: flake8 - args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] + args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231, F541'] diff --git a/docs/CN/source/getting_started/installation.rst b/docs/CN/source/getting_started/installation.rst index fb998b756..5fa0e304d 100755 --- a/docs/CN/source/getting_started/installation.rst +++ b/docs/CN/source/getting_started/installation.rst @@ -27,7 +27,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 前请确保你的docker设置中已经分配了足够的共享内存,否则可能导致 $ # 服务无法正常启动。 $ # 1.如果是纯文本服务,建议分配2GB以上的共享内存, 如果你的内存充足,建议分配16GB以上的共享内存. - $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. + $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. $ # 如果你没有足够的共享内存,可以尝试在启动服务的时候调低 --running_max_req_size 参数,这会降低 $ # 服务的并发请求数量,但可以减少共享内存的占用。如果是多模态服务,也可以通过降低 --cache_capacity $ # 参数来减少共享内存的占用。 @@ -38,7 +38,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 你也可以使用源码手动构建镜像并运行,建议手动构建镜像,因为更新比较频繁: .. code-block:: console - + $ # 进入代码仓库的根目录 $ cd /lightllm $ # 手动构建镜像, docker 目录下有不同功能场景的镜像构建文件,按需构建。 @@ -52,7 +52,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 或者你也可以直接使用脚本一键启动镜像并且运行: .. code-block:: console - + $ # 查看脚本参数 $ python tools/quick_launch_docker.py --help @@ -80,6 +80,10 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 安装lightllm的依赖 (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # 安装lightllm的依赖 (摩尔线程 GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # 安装lightllm $ python setup.py install @@ -97,6 +101,6 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + 具体原因可以参考:`issue `_ 和 `fix PR `_ diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index ce7a79ab9..5976fcb32 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -183,22 +183,6 @@ PD 分离模式参数 设置为 True 时,--nccl_host 必须等于 config_server_host,--nccl_port 对于 config_server 必须是唯一的, 不要为不同的推理节点使用相同的 nccl_port,这将是严重错误 -attention类型选择参数 ---------------------- - -.. option:: --mode - - 模型推理模式,可以指定多个值: - - * ``triton_int8kv``: 使用 int8 存储 kv cache,可增加 token 容量,使用 triton kernel - * ``ppl_int8kv``: 使用 int8 存储 kv cache,使用 ppl 快速 kernel - * ``ppl_fp16``: 使用 ppl 快速 fp16 解码注意力 kernel - * ``triton_flashdecoding``: 用于长上下文的 flashdecoding 模式,当前支持 llama llama2 qwen - * ``triton_gqa_attention``: 使用 GQA 的模型的快速 kernel - * ``triton_gqa_flashdecoding``: 使用 GQA 的模型的快速 flashdecoding kernel - * ``triton_fp8kv``: 使用 float8 存储 kv cache,目前仅用于 deepseek2 - - 需要阅读源代码以确认所有模型支持的具体模式 调度参数 -------- @@ -327,17 +311,9 @@ attention类型选择参数 推理后端将为解码使用微批次重叠模式 -.. option:: --enable_flashinfer_prefill - - 推理后端将为预填充使用 flashinfer 的注意力 kernel - -.. option:: --enable_flashinfer_decode - - 推理后端将为解码使用 flashinfer 的注意力 kernel - -.. option:: --enable_fa3 +.. option:: --llm_kv_type - 推理后端将为预填充和解码使用 fa3 注意力 kernel + 推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv" .. option:: --disable_cudagraph diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 071d9405a..5d57b137c 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,14 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `LOADWORKER=18`: 模型加载线程数,提高加载速度 - `--tp 8`: 张量并行度,使用8个GPU -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 - `--port 8088`: 服务端口 1.2 单机 DP + EP 模式 (Data Parallel + Expert Parallel) @@ -55,13 +57,15 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `MOE_MODE=EP`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -85,7 +89,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +106,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +135,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +153,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +203,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -219,7 +228,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -287,7 +297,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -306,7 +317,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # 如果需要启用微批次重叠,可以取消注释以下行 diff --git a/docs/CN/source/tutorial/multi_level_cache_deployment.rst b/docs/CN/source/tutorial/multi_level_cache_deployment.rst index 0446b0780..223b92dca 100644 --- a/docs/CN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/CN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ LightLLM 的多级缓存系统采用分层设计: --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ LightLLM 的多级缓存系统采用分层设计: - ``--model_dir``: 模型文件路径,支持本地路径或 HuggingFace 模型名称 - ``--tp 8``: 张量并行度,使用 8 个 GPU 进行模型推理 - ``--graph_max_batch_size 500``: CUDA Graph 最大批次大小,影响吞吐量和显存占用 -- ``--enable_fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 +- ``--llm_prefill_att_backend fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 - ``--mem_fraction 0.88``: GPU 显存使用比例,建议设置为 0.88及以下 CPU 缓存参数 @@ -130,7 +131,8 @@ CPU 缓存参数 --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/CN/source/tutorial/reasoning_parser.rst b/docs/CN/source/tutorial/reasoning_parser.rst index 547eb05d1..a9a0d09fe 100644 --- a/docs/CN/source/tutorial/reasoning_parser.rst +++ b/docs/CN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/docs/EN/source/getting_started/installation.rst b/docs/EN/source/getting_started/installation.rst index 75fa71476..6439c48de 100755 --- a/docs/EN/source/getting_started/installation.rst +++ b/docs/EN/source/getting_started/installation.rst @@ -24,16 +24,16 @@ The easiest way to install Lightllm is using the official image. You can directl $ docker pull ghcr.io/modeltc/lightllm:main $ $ # Run,The current LightLLM service relies heavily on shared memory. - $ # Before starting, please make sure that you have allocated enough shared memory + $ # Before starting, please make sure that you have allocated enough shared memory $ # in your Docker settings; otherwise, the service may fail to start properly. $ # - $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. + $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. $ # If your system has sufficient RAM, allocating 16GB or more is recommended. - $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. + $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. $ # You can adjust this value according to your specific requirements. $ # - $ # If you do not have enough shared memory available, you can try lowering - $ # the --running_max_req_size parameter when starting the service. + $ # If you do not have enough shared memory available, you can try lowering + $ # the --running_max_req_size parameter when starting the service. $ # This will reduce the number of concurrent requests, but also decrease shared memory usage. $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ @@ -42,13 +42,13 @@ The easiest way to install Lightllm is using the official image. You can directl You can also manually build the image from source and run it: .. code-block:: console - + $ # move into lightllm root dir $ cd /lightllm $ # Manually build the image $ docker build -t -f ./docker/Dockerfile . $ - $ # Run, + $ # Run, $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ $ /bin/bash @@ -56,7 +56,7 @@ You can also manually build the image from source and run it: Or you can directly use the script to launch the image and run it with one click: .. code-block:: console - + $ # View script parameters $ python tools/quick_launch_docker.py --help @@ -84,6 +84,10 @@ You can also install Lightllm from source: $ # Install Lightllm dependencies (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # Install Lightllm dependencies (Moore Threads GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # Install Lightllm $ python setup.py install @@ -101,5 +105,5 @@ You can also install Lightllm from source: .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + For specific reasons, please refer to: `issue `_ and `fix PR `_ \ No newline at end of file diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 1644bbab5..0767ae7e3 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -183,23 +183,6 @@ Different Parallel Mode Setting Parameters When set to True, --nccl_host must equal config_server_host, --nccl_port must be unique for config_server, do not use the same nccl_port for different inference nodes, this will be a serious error -Attention Type Selection Parameters ------------------------------------- - -.. option:: --mode - - Model inference mode, can specify multiple values: - - * ``triton_int8kv``: Use int8 to store kv cache, can increase token capacity, uses triton kernel - * ``ppl_int8kv``: Use int8 to store kv cache, uses ppl fast kernel - * ``ppl_fp16``: Use ppl fast fp16 decode attention kernel - * ``triton_flashdecoding``: Flashdecoding mode for long context, currently supports llama llama2 qwen - * ``triton_gqa_attention``: Fast kernel for models using GQA - * ``triton_gqa_flashdecoding``: Fast flashdecoding kernel for models using GQA - * ``triton_fp8kv``: Use float8 to store kv cache, currently only used for deepseek2 - - Need to read source code to confirm specific modes supported by all models - Scheduling Parameters --------------------- @@ -325,18 +308,6 @@ Performance Optimization Parameters .. option:: --enable_decode_microbatch_overlap The inference backend will use microbatch overlap mode for decoding - -.. option:: --enable_flashinfer_prefill - - The inference backend will use flashinfer's attention kernel for prefill - -.. option:: --enable_flashinfer_decode - - The inference backend will use flashinfer's attention kernel for decoding - -.. option:: --enable_fa3 - - The inference backend will use fa3 attention kernel for prefill and decoding .. option:: --disable_cudagraph diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 6098411be..accdbc462 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,13 @@ Suitable for deploying DeepSeek-R1 model on a single H200 node. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `LOADWORKER=18`: Model loading thread count, improves loading speed - `--tp 8`: Tensor parallelism, using 8 GPUs -- `--enable_fa3`: Enable Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: Enable Flash Attention 3.0 - `--port 8088`: Service port 1.2 Single node DP + EP Mode (Data Parallel + Expert Parallel) @@ -55,13 +56,13 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `MOE_MODE=EP`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp -- `--enable_fa3`: Enable Flash Attention 3.0 **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -85,7 +86,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +103,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +132,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +150,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +200,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip @@ -216,7 +222,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -284,7 +291,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -303,7 +311,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/docs/EN/source/tutorial/multi_level_cache_deployment.rst b/docs/EN/source/tutorial/multi_level_cache_deployment.rst index bb8d943b8..6c99c351f 100644 --- a/docs/EN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/EN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ Suitable for most scenarios, significantly increasing cache capacity while maint --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ Basic Parameters - ``--model_dir``: Model file path, supports local path or HuggingFace model name - ``--tp 8``: Tensor parallelism degree, using 8 GPUs for model inference - ``--graph_max_batch_size 500``: CUDA Graph maximum batch size, affects throughput and memory usage -- ``--enable_fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance +- ``--llm_prefill_att_backend fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance - ``--mem_fraction 0.88``: GPU memory usage ratio, recommended to set to 0.88 or below CPU Cache Parameters @@ -130,7 +131,8 @@ Suitable for ultra-long text or extremely high-concurrency scenarios, providing --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/EN/source/tutorial/reasoning_parser.rst b/docs/EN/source/tutorial/reasoning_parser.rst index e76e093d6..56e61e6cd 100644 --- a/docs/EN/source/tutorial/reasoning_parser.rst +++ b/docs/EN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/generate_requirements_musa.sh b/generate_requirements_musa.sh new file mode 100755 index 000000000..f5bfb8ff8 --- /dev/null +++ b/generate_requirements_musa.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Script to generate requirements-musa.txt from requirements.txt +# MUSA is not compatible with CUDA packages, so they need to be removed +# Torch-related packages are pre-installed in the MUSA docker container + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +INPUT_FILE="${SCRIPT_DIR}/requirements.txt" +OUTPUT_FILE="${SCRIPT_DIR}/requirements-musa.txt" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: requirements.txt not found at $INPUT_FILE" + exit 1 +fi + +echo "Generating requirements-musa.txt from requirements.txt..." + +# Define patterns to remove (CUDA-specific packages) +# These packages are not compatible with MUSA +CUDA_PACKAGES=( + "^cupy" # cupy-cuda12x and similar + "^cuda_bindings" # CUDA bindings + "^nixl" # NIXL (NVIDIA Inter-node eXchange Library) + "^flashinfer" # flashinfer-python (CUDA-specific attention kernel) + "^sgl-kernel" # SGL kernel (CUDA-specific) +) + +# Define torch-related packages (pre-installed in MUSA container, remove version pins) +TORCH_PACKAGES=( + "^torch==" + "^torchvision==" +) + +# Create the output file with a header comment +cat > "$OUTPUT_FILE" << 'EOF' +# Requirements for MUSA (Moore Threads GPU) +# Auto-generated from requirements.txt by generate_requirements_musa.sh +# CUDA-specific packages have been removed +# Torch-related packages have version pins removed (pre-installed in MUSA container) + +EOF + +# Process the requirements file +while IFS= read -r line || [ -n "$line" ]; do + # Skip empty lines and comments (but keep them in output) + if [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]]; then + echo "$line" >> "$OUTPUT_FILE" + continue + fi + + # Extract package name (before ==, >=, <=, ~=, etc.) + pkg_name=$(echo "$line" | sed -E 's/^([a-zA-Z0-9_-]+).*/\1/') + + # Check if this is a CUDA package to skip + skip=false + for pattern in "${CUDA_PACKAGES[@]}"; do + if [[ "$pkg_name" =~ $pattern ]]; then + echo " Removing CUDA package: $line" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Check if this is a torch-related package (remove version pin) + for pattern in "${TORCH_PACKAGES[@]}"; do + if [[ "$line" =~ $pattern ]]; then + # Remove version pin, keep just the package name + pkg_only=$(echo "$line" | sed -E 's/==.*//') + echo " Unpinning version for: $pkg_only (pre-installed in MUSA container)" + echo "$pkg_only" >> "$OUTPUT_FILE" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Keep the package as-is + echo "$line" >> "$OUTPUT_FILE" + +done < "$INPUT_FILE" + +# Add MUSA-specific packages at the end +cat >> "$OUTPUT_FILE" << 'EOF' + +# MUSA-specific packages +torch_musa +torchada +EOF + +echo "" +echo "Successfully generated: $OUTPUT_FILE" +echo "" +echo "Summary of changes:" +echo " - Removed CUDA-specific packages: cupy-cuda12x, cuda_bindings, nixl, flashinfer-python, sgl-kernel" +echo " - Unpinned torch-related packages: torch, torchvision (pre-installed in MUSA container)" +echo " - Added MUSA-specific packages: torch_musa, torchada" + diff --git a/lightllm/__init__.py b/lightllm/__init__.py index e69de29bb..e9ba6f304 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -0,0 +1,4 @@ +from lightllm.utils.device_utils import is_musa + +if is_musa(): + import torchada # noqa: F401 diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..67dd3852c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..555386ebd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..6f92439c1 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..67dd3852c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..555386ebd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..6f92439c1 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..7f69e86a8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7d8dc868c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..3e543b2ea --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..7f69e86a8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7d8dc868c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..3e543b2ea --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..a3b0edde6 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..22d1ce6f6 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..328fcec83 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..a3b0edde6 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..22d1ce6f6 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..328fcec83 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..4c4ae8624 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..61884a937 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..037bfd291 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..4c4ae8624 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..61884a937 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..037bfd291 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..e2028e2d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7e99dc1be --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5d6b46dda --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..e2028e2d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7e99dc1be --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5d6b46dda --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..7795b47e7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..ff4d6efd4 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..1d8ca6967 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..7795b47e7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..ff4d6efd4 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..1d8ca6967 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..2f1cd5dfd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..a369088bf --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..cb4e6a0d3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..2f1cd5dfd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..a369088bf --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..cb4e6a0d3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..60827b791 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7b42cad46 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9bb49d70b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..60827b791 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..7b42cad46 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9bb49d70b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..bd3d1c418 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..f1b3539f5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..e12c05b96 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..bd3d1c418 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..f1b3539f5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..e12c05b96 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..c83dca52d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..59a3e1051 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5b7c4eaa9 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..c83dca52d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..59a3e1051 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 000000000..5b7c4eaa9 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..abd760af0 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 3}, "16": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "64": {"8": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 4}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 9}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..a560ce9e1 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "64": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 7}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "16": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 7}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/allocator_utils.py b/lightllm/common/allocator_utils.py new file mode 100644 index 000000000..803ed0a71 --- /dev/null +++ b/lightllm/common/allocator_utils.py @@ -0,0 +1,98 @@ +from typing import List, Union + +import torch + +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class TokenAllocator: + def __init__(self, size, shared_can_use_token_num_name: str): + self.size = size + + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + + self.can_use_mem_size = self.size + + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name) + + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) + self.mem_state[start:end] = free_index_tensor + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py new file mode 100644 index 000000000..80df54549 --- /dev/null +++ b/lightllm/common/basemodel/attention/__init__.py @@ -0,0 +1,18 @@ +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend + +from .create_utils import ( + get_prefill_att_backend_class, + get_decode_att_backend_class, + get_mla_prefill_att_backend_class, + get_mla_decode_att_backend_class, +) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py new file mode 100644 index 000000000..859d97ca8 --- /dev/null +++ b/lightllm/common/basemodel/attention/base_att.py @@ -0,0 +1,117 @@ +import torch +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Tuple, Union, Dict + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class BaseAttBackend: + """ + 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self, model: "TpPartBaseModel"): + self.model = model + + def create_att_prefill_state(self) -> "BasePrefillAttState": + raise NotImplementedError("not impl") + + def create_att_decode_state(self) -> "BaseDecodeAttState": + raise NotImplementedError("not impl") + + def _find_layer_index( + self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] + ) -> int: + kv_buffer = att_state.infer_state.mem_manager.kv_buffer + layer_count = len(kv_buffer) + find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} + key = min(k.data_ptr(), v.data_ptr()) + assert key in find_dict + return find_dict[key] + + +@dataclass +class AttControl: + """ + prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 + """ + + use_alibi: bool = False + tp_alibi: torch.Tensor = None + use_sliding_window: bool = False + sliding_window: Tuple[int, int] = (-1, -1) + use_att_sink: bool = False + sink_weight: torch.Tensor = None + # mla 专用传参项 + mla_prefill: bool = False + mla_prefill_dict: Dict = None + mla_decode: bool = False + mla_decode_dict: Dict = None + + +@dataclass +class BasePrefillAttState(ABC): + + backend: BaseAttBackend = None + infer_state: "InferStateInfo" = None + + @abstractmethod + def init_state(self): + pass + + @abstractmethod + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + raise NotImplementedError("not impl") + + +@dataclass +class BaseDecodeAttState(ABC): + backend: BaseAttBackend = None + infer_state: "InferStateInfo" = None + + @abstractmethod + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "BaseDecodeAttState"): + for attr_name, attr_value in vars(new_state).items(): + if isinstance(attr_value, torch.Tensor): + attr_ = getattr(self, attr_name, None) + if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): + attr_.copy_(attr_value, non_blocking=True) + + @abstractmethod + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + pass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py new file mode 100644 index 000000000..39e32ac63 --- /dev/null +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -0,0 +1,80 @@ +from lightllm.utils.envs_utils import get_env_start_args +from .base_att import BaseAttBackend +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend + +data_type_to_backend = { + "None": { + "triton": TritonAttBackend, + "fa3": Fa3AttBackend, + "flashinfer": FlashInferAttBackend, + }, + "int4kv": { + "triton": Int4kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flashinfer": Fp8FlashInferAttBackend, + }, + "int8kv": { + "triton": Int8kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flashinfer": Fp8FlashInferAttBackend, + }, +} + +mla_data_type_to_backend = { + "None": { + "triton": MlaTritonAttBackend, + "fa3": MlaFa3AttBackend, + "flashinfer": MlaFlashInferAttBackend, + }, +} + + +def get_prefill_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_decode_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_mla_prefill_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") + + +def get_mla_decode_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") diff --git a/lightllm/models/bloom/triton_kernel/__init__.py b/lightllm/common/basemodel/attention/fa3/__init__.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/__init__.py rename to lightllm/common/basemodel/attention/fa3/__init__.py diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py new file mode 100644 index 000000000..952bb39d9 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -0,0 +1,243 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class Fa3AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "Fa3PrefillAttState": + return Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": + return Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + self.page_table = torch.empty( + (self.infer_state.batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + self.page_table.copy_( + self.infer_state.req_manager.req_to_token_indexs[ + self.infer_state.b_req_idx, : self.infer_state.max_kv_seq_len + ] + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fa3AttBackend = self.backend # for typing + + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + sinks=sink_weight, + ) + return o + + +@dataclasses.dataclass +class Fa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: Fa3AttBackend = self.backend + args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._normal_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + sinks=sink_weight, + ) + return o diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py new file mode 100644 index 000000000..3feed1ef4 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -0,0 +1,221 @@ +import dataclasses +import torch +from ..base_att import AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops +from typing import Union +from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState + +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant +else: + scaled_fp8_quant = None + + +class Fp8Fa3AttBackend(Fa3AttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": + return Fp8Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8Fa3DecodeAttState": + return Fp8Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8Fa3PrefillAttState(Fa3PrefillAttState): + # 临时共享变量 + mid_token_batch_ids: torch.Tensor = None + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + super().init_state() + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + self.mid_token_batch_ids = torch.repeat_interleave( + torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len + ) + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fp8Fa3AttBackend = self.backend # for typing + + q, q_scale = q_per_head_fp8_quant( + q, + self.infer_state.b_seq_len, + self.cu_seqlens_q, + self.mid_token_batch_ids, + ) + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + causal=True, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale, + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], + return_softmax_lse=False, + ) + return o + + +@dataclasses.dataclass +class Fp8Fa3DecodeAttState(Fa3DecodeAttState): + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + super().init_state() + self.backend: Fp8Fa3AttBackend = self.backend + + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + device = self.infer_state.input_ids.device + batch_size = att_batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) + + q_head_num = q.shape[1] + q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True) + o = flash_attn_with_kvcache( + q=q.view(-1, q_head_num, k_head_dim), + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + causal=False, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], + return_softmax_lse=False, + ) + return o diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py new file mode 100644 index 000000000..9a10457b1 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -0,0 +1,229 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING, Tuple +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.utils.sgl_utils import flash_attn_varlen_func + + +class MlaFa3AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "MlaFa3PrefillAttState": + return MlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaFa3DecodeAttState": + return MlaFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFa3AttBackend = self.backend # for typing + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + + o_tensor = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=softmax_scale, + causal=True, + return_softmax_lse=False, + ) + return o_tensor + + +@dataclasses.dataclass +class MlaFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: MlaFa3AttBackend = self.backend + args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "MlaFa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + q_nope, q_rope = q + kv = k + qk_rope_head_dim = 64 + kv_lora_rank = kv.shape[-1] - qk_rope_head_dim + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) + k_descale, v_descale = None, None + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o_tensor diff --git a/lightllm/models/chatglm2/__init__.py b/lightllm/common/basemodel/attention/flashinfer/__init__.py similarity index 100% rename from lightllm/models/chatglm2/__init__.py rename to lightllm/common/basemodel/attention/flashinfer/__init__.py diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py new file mode 100644 index 000000000..4c6ec0efc --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -0,0 +1,229 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import repack_kv_index + + +class FlashInferAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state) -> "FlashInferPrefillAttState": + return FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "FlashInferDecodeAttState": + return FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class FlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: FlashInferAttBackend = self.backend + + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32, device=device) + kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + kv_starts[:-1], + self.infer_state.max_kv_seq_len, + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: FlashInferAttBackend = self.backend # for typing + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q, + (k.unsqueeze(1), v.unsqueeze(1)), + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class FlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: FlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + self.kv_last_page_len_buffer = torch.full((self.infer_state.batch_size,), 1, dtype=torch.int32, device=device) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + self.infer_state.batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_kv_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + assert self.decode_wrapper is None + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + 1, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._normal_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + (k.unsqueeze(1), v.unsqueeze(1)), + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/flashinfer/fp8.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py new file mode 100644 index 000000000..115d6985a --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/fp8.py @@ -0,0 +1,121 @@ +import dataclasses +import torch +from ..base_att import AttControl +from .fp import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState + + +class Fp8FlashInferAttBackend(FlashInferAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.kv_data_type = torch.float8_e4m3fn + + def create_att_prefill_state(self, infer_state) -> "Fp8FlashInferPrefillAttState": + return Fp8FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState": + return Fp8FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) + offline_scales = self.offline_scales + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None + self.prefill_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def copy_for_decode_cuda_graph(self, new_state): + return super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + offline_scales = self.offline_scales + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) + + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None + self.decode_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py new file mode 100644 index 000000000..6e52203b4 --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -0,0 +1,233 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import repack_kv_index +from typing import Tuple + + +class MlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + return + + def create_att_prefill_state(self, infer_state) -> "MlaFlashInferPrefillAttState": + return MlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaFlashInferDecodeAttState": + return MlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: MlaFlashInferAttBackend = self.backend + + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.qk_nope_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFlashInferAttBackend = self.backend # for typing + k_nope, k_rope = k + o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class MlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: MlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_kv_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + assert self.decode_wrapper is None + + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + 1, + False, # causal + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + 1, + False, # causal + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + qk_rope_head_dim = 64 + q_nope, q_rope = q + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) + assert att_control.mla_decode + + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim], + k[:, :, -qk_rope_head_dim:], + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/models/chatglm2/layer_infer/__init__.py b/lightllm/common/basemodel/attention/triton/__init__.py similarity index 100% rename from lightllm/models/chatglm2/layer_infer/__init__.py rename to lightllm/common/basemodel/attention/triton/__init__.py diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py new file mode 100644 index 000000000..d29f15ec3 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -0,0 +1,275 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional + + +class TritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": + return TritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": + return TritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class TritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_sliding_window is False and att_control.use_att_sink is False + if att_control.use_alibi: + assert att_control.tp_alibi is not None + return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + else: + return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + + def _alibi_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + out = alloc_func(q.shape, q.dtype) + + from ...triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd + + context_attention_fwd( + q, + k, + v, + out, + self.infer_state.b_req_idx, + att_control.tp_alibi, + self.infer_state.b_q_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + self.infer_state.req_manager.req_to_token_indexs, + ) + return out + + def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + + out = alloc_func(q.shape, q.dtype) + context_attention_fwd( + q, + k, + v, + out, + self.infer_state.b_req_idx, + self.infer_state.b_q_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + self.infer_state.req_manager.req_to_token_indexs, + ) + return out + + +@dataclasses.dataclass +class TritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_sliding_window is False and att_control.use_att_sink is False + if att_control.use_alibi: + assert att_control.tp_alibi is not None + return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + else: + q_head_num = q.shape[1] + k_head_num = k.shape[1] + if q_head_num == k_head_num: + return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + elif q_head_num > k_head_num: + return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + else: + raise NotImplementedError("error") + + def _alibi_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + from ...triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd + + out = alloc_func(q.shape, q.dtype) + token_attention_fwd( + q, + k, + v, + out, + att_control.tp_alibi, + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_kv_start_loc, + self.infer_state.b_seq_len, + self.infer_state.max_kv_seq_len, + self.infer_state.total_token_num, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + from ...triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + out = alloc_func(q.shape, q.dtype) + + token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_gqa_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, + ) + + out = alloc_func(q.shape, q.dtype) + + gqa_token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, + alloc_tensor_func=alloc_func, + ) + + return out + + def _normal_decode_gqa_flash_decoding_att_vsm( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_vsm import ( + gqa_token_decode_attention_flash_decoding_vsm, + ) + + out = alloc_func(q.shape, q.dtype) + + gqa_token_decode_attention_flash_decoding_vsm( + q=q, + k=k, + v=v, + infer_state=self.infer_state, + out=out, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_gqa_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ...triton_kernel.att.decode_att.gqa.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + + out = alloc_func(q.shape, q.dtype) + + gqa_decode_attention_fwd( + q=q, + k=k, + v=v, + out=out, + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_seq_len=self.infer_state.b_seq_len, + ) + return out + + def _normal_decode_stage3_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + total_token_num = self.infer_state.total_token_num + batch_size = self.infer_state.batch_size + q_head_num = q.shape[1] + head_dim = q.shape[2] + + calcu_shape1 = (batch_size, q_head_num, head_dim) + att_m_tensor = alloc_func((q_head_num, total_token_num), torch.float32) + + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_nopad_att1 import token_att_fwd + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + B_req_idx=self.infer_state.b_req_idx, + B_Start_Loc=self.infer_state.b_kv_start_loc, + B_Seqlen=self.infer_state.b_seq_len, + max_len_in_batch=self.infer_state.max_kv_seq_len, + ) + + o_tensor = alloc_func(q.shape, q.dtype) + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o_tensor.view(calcu_shape1), + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_start_loc=self.infer_state.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py new file mode 100644 index 000000000..25199dc47 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -0,0 +1,170 @@ +import dataclasses +import torch +from lightllm.utils.envs_utils import get_env_start_args +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple + + +class Int4kvTritonAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int4kvTritonPrefillAttState": + return Int4kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int4kvTritonDecodeAttState": + return Int4kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int4kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + + def init_state(self): + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + self.backend: Int4kvTritonAttBackend = self.backend # for typing + + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + head_dim = k.shape[2] * 2 # 2个4bit存储为一个int8, 所以维度需要翻倍,才是解量化后的精度 + k_dequant = alloc_func((total_token_num, k.shape[1], head_dim), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], head_dim), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ...triton_kernel.kv_copy.ppl_int4kv_copy_kv import dequantize_int4kv + + dequantize_int4kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + + context_attention_fwd_contiguous_kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_q_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) + return o_tensor + + +@dataclasses.dataclass +class Int4kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int4kvTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k, k_scale = k + v, v_scale = v + + return self.ppl_int4kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + + def ppl_int4kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ...triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py new file mode 100644 index 000000000..975d7b629 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -0,0 +1,196 @@ +import dataclasses +import torch +from lightllm.utils.envs_utils import get_env_start_args +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple +from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel + + +class Int8kvTritonAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int8kvTritonPrefillAttState": + return Int8kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int8kvTritonDecodeAttState": + return Int8kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int8kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + + def init_state(self): + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + self.backend: Int8kvTritonAttBackend = self.backend # for typing + + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ...triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv + + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + + context_attention_fwd_contiguous_kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_q_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) + return o_tensor + + +@dataclasses.dataclass +class Int8kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int8kvTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k, k_scale = k + v, v_scale = v + if enable_diverse_mode_gqa_decode_fast_kernel(): + return self.diverse_decode_att(q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, alloc_func=alloc_func) + else: + return self.ppl_mha_int8kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + + def diverse_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + + from ...triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) + + def ppl_mha_int8kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py new file mode 100644 index 000000000..8288193ad --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -0,0 +1,125 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Tuple + + +class MlaTritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "MlaTritonPrefillAttState": + return MlaTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaTritonDecodeAttState": + return MlaTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaTritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + from ...triton_kernel.mla_att.prefill_att import context_attention_fwd_with_v + + qk_rope_head_dim = 64 + q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + k_nope, k_rope = k + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o_tensor, + self.infer_state.b_q_start_loc, + self.infer_state.b1_cu_kv_seq_len, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + softmax_scale, + ) + return o_tensor + + +@dataclasses.dataclass +class MlaTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "MlaTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_sliding_window is False + and att_control.use_att_sink is False + and att_control.use_alibi is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + from ...triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding + + qk_rope_head_dim = 64 + q_nope, q_rope = q + kv = k + + out = gqa_token_decode_attention_flash_decoding( + q_nope=q_nope, + q_rope=q_rope, + kv_nope=kv[:, :, :-qk_rope_head_dim], + kv_rope=kv[:, :, -qk_rope_head_dim:], + infer_state=self.infer_state, + softmax_scale=softmax_scale, + alloc_tensor_func=alloc_func, + ) + return out diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 011f998fc..e50fdd4fc 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,6 +32,8 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class +from .attention import BaseAttBackend logger = init_logger(__name__) @@ -51,6 +53,16 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo + @classmethod + def get_radix_cache_class(cls): + """Return the appropriate RadixCache class for this model type. + + Override in subclasses that need specialized cache (e.g., hybrid models). + """ + from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache + + return RadixCache + def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] @@ -58,7 +70,6 @@ def __init__(self, kvargs): self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") - self.mode = kvargs.get("mode", []) self.weight_dict = kvargs.get("weight_dict", None) self.finetune_config = kvargs.get("finetune_config", None) self.max_req_num = kvargs.get("max_req_num", 1000) @@ -116,9 +127,18 @@ def __init__(self, kvargs): self._init_infer_layer() self._init_some_value() self._init_custom() - self._init_inferstate_cls() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + + self._init_att_backend() + self._init_att_backend1() + + logger.info(f"use prefill att backend: {self.prefill_att_backend.__class__.__name__}") + logger.info(f"use decode att backend: {self.decode_att_backend.__class__.__name__}") + if self.prefill_att_backend1 is not None: + logger.info(f"use prefill att backend1: {self.prefill_att_backend1.__class__.__name__}") + logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}") + self._autotune_warmup() self._init_padded_req() self._init_cudagraph() @@ -144,9 +164,6 @@ def _init_config(self): self.config["vocab_size"] = self.finetune_config.vocab_size return - def _init_inferstate_cls(self): - pass - @final def _verify_must(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -162,15 +179,12 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) @@ -220,10 +234,10 @@ def _init_req_manager(self): return def _init_infer_layer(self, start_layer_index=0): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] return @@ -238,6 +252,17 @@ def _init_some_value(self): self.vocab_size = self.config["vocab_size"] return + def _init_att_backend(self): + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0)(model=self) + return + + def _init_att_backend1(self): + # self.prefill_att_backend1 是给后续有模型支持不同层用不同的att模块时,保留的扩展。 + self.prefill_att_backend1: BaseAttBackend = None + self.decode_att_backend1: BaseAttBackend = None + return + def _init_cudagraph(self): self.graph = ( None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) @@ -281,13 +306,13 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.return_all_prompt_logics = self.return_all_prompt_logics infer_state.batch_size = model_input.batch_size infer_state.total_token_num = model_input.total_token_num - infer_state.max_len_in_batch = model_input.max_len_in_batch infer_state.max_q_seq_len = model_input.max_q_seq_len infer_state.max_kv_seq_len = model_input.max_kv_seq_len infer_state.max_cache_len = model_input.max_cache_len infer_state.prefix_total_token_num = model_input.prefix_total_token_num assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] infer_state.b_req_idx = model_input.b_req_idx + infer_state.b_mtp_index = model_input.b_mtp_index infer_state.b_seq_len = model_input.b_seq_len if model_input.is_prefill: if model_input.b_ready_cache_len is not None: @@ -311,6 +336,19 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) # 特殊模型,特殊模式的特定变量初始化操作。 infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens + if infer_state.is_prefill: + infer_state.prefill_att_state = self.prefill_att_backend.create_att_prefill_state(infer_state=infer_state) + if self.prefill_att_backend1 is not None: + infer_state.prefill_att_state1 = self.prefill_att_backend1.create_att_prefill_state( + infer_state=infer_state + ) + else: + infer_state.decode_att_state = self.decode_att_backend.create_att_decode_state(infer_state=infer_state) + if self.decode_att_backend1 is not None: + infer_state.decode_att_state1 = self.decode_att_backend1.create_att_decode_state( + infer_state=infer_state + ) + return infer_state def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): @@ -323,6 +361,7 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID @@ -366,7 +405,6 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle new_model_input = copy.copy(model_input) new_model_input.batch_size = model_input.batch_size + 1 new_model_input.total_token_num += padded_token_num - new_model_input.max_len_in_batch = max(padded_token_num, model_input.max_len_in_batch) new_model_input.max_q_seq_len = max(padded_token_num, model_input.max_q_seq_len) new_model_input.max_kv_seq_len = max(padded_token_num, model_input.max_kv_seq_len) new_model_input.max_cache_len = max(0, model_input.max_cache_len) @@ -464,6 +502,7 @@ def _prefill( prefill_mem_indexes_ready_event.record() infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._context_forward(infer_state) if is_padded_model_input: model_output = self._create_unpad_prefill_model_output( @@ -484,7 +523,7 @@ def _decode( model_input.b_mtp_index, ) - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): + if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_kv_seq_len): find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) @@ -495,6 +534,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True @@ -514,6 +554,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._token_forward(infer_state) return model_output @@ -622,6 +663,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state0.max_q_seq_len, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -634,6 +676,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state1.max_q_seq_len, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -672,7 +715,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode assert model_input1.mem_indexes.is_cuda origin_batch_size = model_input0.batch_size - max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) + max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) @@ -688,6 +731,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -696,6 +741,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -724,6 +770,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -732,6 +780,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) return model_output0, model_output1 @@ -818,7 +867,6 @@ def _check_max_len_infer(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=self.batch_max_tokens, max_q_seq_len=self.batch_max_tokens, max_kv_seq_len=self.batch_max_tokens, max_cache_len=0, @@ -895,7 +943,6 @@ def _autotune_warmup(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=input_len, max_q_seq_len=input_len, max_kv_seq_len=input_len, max_cache_len=0, @@ -958,7 +1005,6 @@ def _init_padded_req(self): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=prefill_input_len, max_q_seq_len=prefill_input_len, max_kv_seq_len=prefill_input_len, max_cache_len=0, @@ -993,6 +1039,7 @@ def _gen_special_model_input(self, token_num: int): "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) + or "Qwen3NextMTPModel" in str(self.__class__) ) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 138f08427..758c0b519 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -11,10 +11,7 @@ class ModelInput: # 通用变量 batch_size: int total_token_num: int - max_len_in_batch: int - # 在 decode 阶段, 常规模式下, max_q_seq_len 必定是 1, - # 在 mtp 模式下,max_q_seq_len 统计的是一个请求考虑了 mtp 步数的 - # 最大长度,实际值是 max([(1 + req.mtp_step) for req in reqs]) + # 在 decode 阶段, max_q_seq_len 必定是 1, max_q_seq_len: int max_kv_seq_len: int max_cache_len: int = None diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 9eeab7270..516d8d686 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -3,6 +3,7 @@ import copy import bisect from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup @@ -67,7 +68,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size # warmup # 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上 @@ -77,10 +78,16 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): # 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性, # 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象 # 中的 tensor。 + for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state)) torch.cuda.synchronize() + for param_name in set(vars(infer_state).keys()): + if param_name not in pure_para_set: + delattr(infer_state, param_name) with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): @@ -100,15 +107,25 @@ def _capture_decode_overlap( graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size - infer_state1.max_len_in_batch = self.graph_max_len_in_batch + infer_state1.max_kv_seq_len = self.graph_max_len_in_batch infer_state1.total_token_num = self.graph_max_len_in_batch * batch_size # warmup for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) + pure_para_set1 = set(vars(infer_state1).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() + for para_name in set(vars(infer_state).keys()): + if para_name not in pure_para_set: + delattr(infer_state, para_name) + for para_name in set(vars(infer_state1).keys()): + if para_name not in pure_para_set1: + delattr(infer_state1, para_name) + with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): @@ -180,7 +197,12 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + # Get available memory info + avail_mem, total_mem = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -196,8 +218,7 @@ def warmup(self, model): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, mem_indexes=mem_indexes, @@ -236,7 +257,14 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + # Get available memory info + avail_mem, total_mem = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph @@ -256,8 +284,7 @@ def warmup_overlap(self, model): is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, b_mtp_index=b_mtp_index, diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 8e7174bb3..e12edd7bf 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -11,6 +11,7 @@ from .batch_objs import ModelInput from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_dp_rank +from .attention import BasePrefillAttState, BaseDecodeAttState class InferStateInfo: @@ -19,10 +20,19 @@ class InferStateInfo: """ def __init__(self): + # prefill 和 decode 使用的 att 状态对象 + self.prefill_att_state: BasePrefillAttState = None + self.decode_att_state: BaseDecodeAttState = None + + # 保留的扩展, 支持线性att与标准att混合使用时使用 + self.prefill_att_state1: BasePrefillAttState = None + self.decode_att_state1: BaseDecodeAttState = None + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None self.b_req_idx: torch.Tensor = None + self.b_mtp_index: torch.Tensor = None # MTP index for each batch item (0: main, 1-mtp_step: candidates) self.b_start_loc: torch.Tensor = None self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used. @@ -30,10 +40,6 @@ def __init__(self): self.b_mark_shared_group: torch.Tensor = None # only for diverse mode used in decode phase. self.b_seq_len: torch.Tensor = None - # max_len_in_batch prefill 和 decode 阶段含义不同 - # prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值 - # decode 阶段指的是每个req的总长 最大值 - self.max_len_in_batch: int = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -68,6 +74,11 @@ def __init__(self): self.max_q_seq_len: int = None self.max_kv_seq_len: int = None + # prefill 用 + self.b_q_start_loc: torch.Tensor = None + # decode 用 + self.b_kv_start_loc: torch.Tensor = None + # 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在 # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 @@ -89,8 +100,10 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None - def init_some_extra_state(self, model): + # 专门用于管理混合注意力模型的buffer + self.buffer_indexes: torch.Tensor = None + def init_some_extra_state(self, model, input_ids: torch.Tensor = None): if self.is_prefill: ( self.b_q_seq_len, @@ -103,7 +116,7 @@ def init_some_extra_state(self, model): b_ready_cache_len=self.b_ready_cache_len, b_seq_len=self.b_seq_len, ) - self.b_start_loc = self.b1_cu_q_seq_len[0:-1] + self.b_q_start_loc = self.b1_cu_q_seq_len[0:-1] else: ( self.b_q_seq_len, @@ -112,16 +125,31 @@ def init_some_extra_state(self, model): self.b1_cu_kv_seq_len, self.position_ids, ) = gen_decode_params(self.b_seq_len) - # TODO: check the correctness - self.max_kv_seq_len = self.max_len_in_batch + self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1] + # max_kv_seq_len is already set in _create_inferstate from model_input.max_kv_seq_len + self.max_q_seq_len = self.b_q_seq_len.max().item() if self.b_q_seq_len.numel() > 0 else 1 self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] + def init_att_state(self): + if self.is_prefill: + self.prefill_att_state.init_state() + if self.prefill_att_state1 is not None: + self.prefill_att_state1.init_state() + else: + self.decode_att_state.init_state() + if self.decode_att_state1 is not None: + self.decode_att_state1.init_state() + def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): if isinstance(attr_value, torch.Tensor): attr_ = getattr(self, attr_name, None) - if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): + if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr() and attr_.shape == attr_value.shape: attr_.copy_(attr_value, non_blocking=True) + + self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state) + if self.decode_att_state1 is not None: + self.decode_att_state1.copy_for_decode_cuda_graph(new_infer_state.decode_att_state1) return def prefill_dp_balance(self, input_ids: torch.Tensor): diff --git a/lightllm/common/basemodel/layer_infer/post_layer_infer.py b/lightllm/common/basemodel/layer_infer/post_layer_infer.py index d254eb510..c7bae26ea 100644 --- a/lightllm/common/basemodel/layer_infer/post_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/post_layer_infer.py @@ -4,8 +4,7 @@ class PostLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py index 3626346f2..e83fe8949 100644 --- a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py @@ -4,8 +4,7 @@ class PreLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py index fa7e96a69..1b7813fca 100644 --- a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py @@ -6,8 +6,8 @@ class PostLayerInferTpl(PostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index e7a084079..04f8cda16 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -5,8 +5,8 @@ class PreLayerInferTpl(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 return diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py deleted file mode 100755 index 27f71a17e..000000000 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ /dev/null @@ -1,136 +0,0 @@ -from functools import partial -from typing import Tuple - -import torch -import torch.distributed as dist - -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time - -from ...infer_struct import InferStateInfo -from ..transformer_layer_infer import TransformerLayerInfer -from lightllm.distributed.communication_op import all_reduce - - -class TransformerLayerCohereInferTpl(TransformerLayerInferTpl): - """ """ - - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - - self.use_qk_norm_ = self.network_config_.get("use_qk_norm", False) - return - - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _bind_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - self._att_norm = partial(TransformerLayerCohereInferTpl._q_norm, self) - self._q_norm = partial(TransformerLayerCohereInferTpl._k_norm, self) - self._k_norm = partial(TransformerLayerCohereInferTpl._att_norm, self) - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - raise Exception("need to impl") - - def _bind_rotary_emb_fwd(self): - raise Exception("need to impl") - - def _get_qkv( - self, input, infer_state: InferStateInfo, layer_weight - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view( - -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ - ) - - if self.use_qk_norm_: - q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = self._q_norm(q, infer_state, layer_weight) - cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) - self._rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _cohere_residual(self, input_embdings, infer_state: InferStateInfo): - # emb_addr = input_embdings.data_ptr() - # attn_out_addr = infer_state._attn_out.data_ptr() - # ffn_addr = infer_state._ffn_out.data_ptr() - # assert emb_addr != attn_out_addr - # assert emb_addr != ffn_addr - # assert attn_out_addr != ffn_addr - input_embdings.add_( - infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_) - ) - - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._context_attention(input1, infer_state, layer_weight=layer_weight) - self._context_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings - - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._token_attention(input1, infer_state, layer_weight=layer_weight) - self._token_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 436ca77d8..646f99864 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -3,8 +3,6 @@ import torch.distributed as dist from ..transformer_layer_infer import TransformerLayerInfer from ...infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.distributed import all_reduce from typing import Tuple from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -13,8 +11,8 @@ class TransformerLayerInferTpl(TransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) # need to set by subclass self.eps_ = 1e-5 self.tp_q_head_num_ = -1 @@ -39,11 +37,11 @@ def _tpsp_get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tup def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight): mem_manager = infer_state.mem_manager - self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager) - return - - def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + mem_manager.copy_kv_to_mem_manager( + layer_index=self.layer_num_, + mem_index=infer_state.mem_index, + kv=cache_kv, + ) return def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: @@ -64,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -89,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -131,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None diff --git a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py index 7350531bb..53daffcdd 100644 --- a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py @@ -4,9 +4,8 @@ class TransformerLayerInfer(BaseLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): + def __init__(self, layer_num, network_config): super().__init__() self.layer_num_ = layer_num self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 0fa02780c..32994ee4e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -7,7 +7,11 @@ ROWBMMWeight, ) from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight + +# NormWeight is an alias for NoTpNormWeight for backward compatibility +NormWeight = NoTpNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 000000000..e4da4a23a --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,93 @@ +import torch +from typing import Dict, Optional, Tuple +from .base_weight import BaseWeightTpl + + +class ParameterWeight(BaseWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + weight_shape: Optional[Tuple[int, ...]] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.weight: Optional[torch.Tensor] = None + self.bias: Optional[torch.Tensor] = None + # Create weights if shapes are provided + if weight_shape is not None: + self._create_weight() + + def _create_weight(self): + """Create weight and bias tensors with pre-allocated memory.""" + if self.weight_shape is not None: + self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_) + if self.bias_name is not None and self.bias_shape is not None: + self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + t_weight = weights[self.weight_name] + if self.weight is None: + # If weight was not pre-created, create it now based on loaded shape + self.weight = torch.empty(*t_weight.shape, dtype=self.data_type_, device=self.device_id_) + self.weight.copy_(t_weight.to(self.data_type_)) + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name] + if self.bias is None: + # If bias was not pre-created, create it now based on loaded shape + self.bias = torch.empty(*t_bias.shape, dtype=self.data_type_, device=self.device_id_) + self.bias.copy_(t_bias.to(self.data_type_)) + + def verify_load(self): + load_ok = True + # Verify weight. The weight must be not None. + load_ok = load_ok and self.weight is not None + # Verify bias. If bias_name is set, it must be not None. + if self.bias_name is not None: + load_ok = load_ok and self.bias is not None + return load_ok + + +class TpParameterWeight(ParameterWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_shape: Optional[Tuple[int, ...]] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + self.split_n_embed = split_n_embed + # Calculate TP-split shapes if full shapes are provided + tp_weight_shape = None + tp_bias_shape = None + if weight_shape is not None: + tp_weight_shape = (split_n_embed,) + weight_shape[1:] + if bias_shape is not None: + tp_bias_shape = (split_n_embed,) + bias_shape[1:] + super().__init__(weight_name, data_type, bias_name, tp_weight_shape, tp_bias_shape) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + t_weight = weights[self.weight_name][start:end] + if self.weight is None: + # If weight was not pre-created, create it now based on loaded shape + self.weight = torch.empty(*t_weight.shape, dtype=self.data_type_, device=self.device_id_) + self.weight.copy_(t_weight.to(self.data_type_)) + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name][start:end] + if self.bias is None: + # If bias was not pre-created, create it now based on loaded shape + self.bias = torch.empty(*t_bias.shape, dtype=self.data_type_, device=self.device_id_) + self.bias.copy_(t_bias.to(self.data_type_)) diff --git a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py index 19eb67017..8c81fd5bc 100644 --- a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py @@ -3,11 +3,10 @@ class PreAndPostLayerWeight(BaseLayerWeight): - def __init__(self, data_type, network_config, mode): + def __init__(self, data_type, network_config): super().__init__() self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.init_static_params() return diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc76237..4bc58c76f 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -9,12 +9,11 @@ class TransformerLayerWeight(BaseLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__() self.layer_num_ = layer_num self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.quant_cfg = quant_cfg self._parse_config() self._init_weight_names() diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3d77a3ae4..a8b261641 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -168,7 +168,6 @@ def warmup(self, model): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, @@ -229,7 +228,6 @@ def warmup_overlap(self, model): micro_batch = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, diff --git a/lightllm/models/chatglm2/layer_weights/__init__.py b/lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py similarity index 100% rename from lightllm/models/chatglm2/layer_weights/__init__.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py diff --git a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py similarity index 77% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py index 25af80fab..adf97735f 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py @@ -6,11 +6,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -19,16 +23,22 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return @@ -44,10 +54,14 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) @@ -59,7 +73,7 @@ def test1(): import torch B, N_CTX, H, D = 4, 1025, 12, 128 - + del D dtype = torch.float16 Logics = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) @@ -85,6 +99,7 @@ def test2(): import torch B, N_CTX, H, D = 3, 1025, 12, 128 + del D dtype = torch.float16 @@ -107,7 +122,7 @@ def test2(): start = 0 for i in range(B): end = start + b_seq_len[i] - torch_o = Logics[:, start: end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) + torch_o = Logics[:, start:end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) start = end torch_out.append(torch_o) torch_out = torch.cat(torch_out, dim=-1) diff --git a/lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py diff --git a/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py new file mode 100644 index 000000000..b6444449b --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py @@ -0,0 +1,80 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def alloc_buffer_for_req_kernel( + req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for + buffer_indexes_ptr, # [num_reqs * num_buffers_per_req] - buffer indices to assign (from CPU) + req_to_buffer_index_ptr, # [max_request_num + 1, num_buffers_per_req] - tensor mapping req_idx to buffer_idx + num_reqs, # number of requests to process + stride_buffer, # stride for req_to_buffer_index second dimension + NUM_BUFFERS_PER_REQ: tl.constexpr, # number of buffers per request (mtp_step + 1) + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Mask for valid indices + mask = offsets < num_reqs + + # Load request indices + req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0) + + # For each request, allocate NUM_BUFFERS_PER_REQ buffers + for buf_idx in tl.static_range(NUM_BUFFERS_PER_REQ): + # Load buffer index for this position + buffer_offset = offsets * NUM_BUFFERS_PER_REQ + buf_idx + buffer_indices = tl.load(buffer_indexes_ptr + buffer_offset, mask=mask, other=0) + + # Update req_to_buffer_index[req_indices, buf_idx] = buffer_indices + output_offset = req_indices * stride_buffer + buf_idx + tl.store(req_to_buffer_index_ptr + output_offset, buffer_indices, mask=mask) + + +def alloc_buffer_for_req_triton( + req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA + buffer_indexes: torch.Tensor, # [num_reqs * (mtp_step + 1)] int32 tensor (can be CPU or CUDA) + req_to_buffer_index: torch.Tensor, # [max_request_num + 1, mtp_step + 1] int32 tensor on CUDA + mtp_step: int = 0, # number of additional buffers per request (default 0 for non-MTP mode) +): + num_reqs = req_index.shape[0] + num_buffers_per_req = mtp_step + 1 + + # Ensure inputs are on CUDA + if not req_index.is_cuda: + req_index = req_index.cuda() + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() + + # Ensure correct dtypes + if req_index.dtype not in [torch.int32, torch.int64]: + req_index = req_index.to(torch.int32) + if buffer_indexes.dtype != torch.int32: + buffer_indexes = buffer_indexes.to(torch.int32) + + # Validate buffer_indexes size + expected_size = num_reqs * num_buffers_per_req + assert buffer_indexes.shape[0] == expected_size, ( + f"Expected {expected_size} buffer indices for {num_reqs} requests " + f"with mtp_step={mtp_step}, but got {buffer_indexes.shape[0]}" + ) + + # Get stride for the second dimension of req_to_buffer_index + stride_buffer = req_to_buffer_index.stride(0) + + # Launch kernel + BLOCK_SIZE = 256 + grid = (triton.cdiv(num_reqs, BLOCK_SIZE),) + + alloc_buffer_for_req_kernel[grid]( + req_index, + buffer_indexes, + req_to_buffer_index, + num_reqs, + stride_buffer, + NUM_BUFFERS_PER_REQ=num_buffers_per_req, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/lightllm/models/chatglm2/triton_kernel/__init__.py b/lightllm/common/basemodel/triton_kernel/att/__init__.py similarity index 100% rename from lightllm/models/chatglm2/triton_kernel/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/__init__.py diff --git a/lightllm/models/cohere/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py similarity index 100% rename from lightllm/models/cohere/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py diff --git a/lightllm/models/cohere/layer_infer/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py similarity index 100% rename from lightllm/models/cohere/layer_infer/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py diff --git a/lightllm/models/cohere/layer_weights/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py similarity index 100% rename from lightllm/models/cohere/layer_weights/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py similarity index 65% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index 67be7c968..26ec3ebd7 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -2,11 +2,12 @@ def gqa_token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty + q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty ): BLOCK_SEQ = 128 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) from .gqa_flash_decoding_stage1 import flash_decode_stage1 @@ -15,10 +16,10 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -28,7 +29,7 @@ def gqa_token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py similarity index 96% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 320c2cf79..2814ff44b 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -123,8 +123,18 @@ def _fwd_kernel_flash_decode_stage1( @torch.no_grad() def flash_decode_stage1( - q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq + q, + k: torch.Tensor, + v: torch.Tensor, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, ): + assert k.stride() == v.stride() BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py similarity index 65% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 81227f967..101e99dde 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -6,14 +6,22 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): + BLOCK_DMODEL: tl.constexpr, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -32,33 +40,43 @@ def _fwd_kernel_flash_decode_stage2( tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) - + old_scale = tl.exp(max_logic - new_max_logic) acc *= old_scale exp_logic = tl.exp(tlogic - new_max_logic) acc += exp_logic * tv sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) - + _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), + B_Seqlen, + mid_out, + mid_out_logexpsum, + out, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py similarity index 99% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 850d4185c..6a9bb79c7 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -421,7 +421,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = infer_state.max_len_in_batch + avg_seq_len_in_batch = infer_state.max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py similarity index 100% rename from lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py diff --git a/lightllm/models/cohere/triton_kernels/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py similarity index 100% rename from lightllm/models/cohere/triton_kernels/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py new file mode 100644 index 000000000..212825a96 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -0,0 +1,200 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def int4_to_float(k_int8, k_scale, offs_d): + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 + k_int4 = tl.where( + offs_d[None, :] % 2 == 0, + k_low, + k_high, + ) + k = k_int4.to(k_scale.dtype) * k_scale + return k + + +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + K_scale, + V, + V_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + quant_group_size, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + k_loc = k_loc.to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] // 2 + off_k_scale = off_k // (quant_group_size // 2) + k_int8 = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + k_scale = tl.load(K_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + k = int4_to_float(k_int8, k_scale, offs_d) + + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where((offs_n_new < cur_batch_end_index), att_value, float("-inf")) + v_int8 = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + v_scale = tl.load(V_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + v = int4_to_float(v_int8, v_scale, offs_d) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +@torch.no_grad() +def int4kv_flash_decode_stage1( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] * 2 + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, head_num = B_req_idx.shape[0], q.shape[1] + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + quant_group_size = Lk // k_scale.shape[-1] + assert triton.next_power_of_2(quant_group_size) == quant_group_size + assert k.stride() == v.stride() + # TODO 优化为gqa使用tensor core的实现,速度更快。 + _fwd_kernel_flash_decode_stage1[grid]( + q, + k, + k_scale, + v, + v_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + quant_group_size, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py new file mode 100644 index 000000000..a5a054b93 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -0,0 +1,50 @@ +import torch + + +def token_decode_attention_flash_decoding( + q, + infer_state, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, +): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num = q.shape[1] + head_dim = q.shape[2] + calcu_shape1 = (batch_size, q_head_num, head_dim) + + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" + ) + + from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 + + int4kv_flash_decode_stage1( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + ) + + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + return o_tensor diff --git a/lightllm/models/mistral/triton_kernel/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py similarity index 100% rename from lightllm/models/mistral/triton_kernel/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py similarity index 67% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index 84054bf86..ad6a8b5b3 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -2,16 +2,15 @@ import torch from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.infer_struct import InferStateInfo -from .ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 -from .ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 +from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size def token_decode_attention_flash_decoding( q, infer_state: InferStateInfo, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -28,18 +27,21 @@ def token_decode_attention_flash_decoding( stream1 = shared_streams_dict["stream1"] stream2 = shared_streams_dict["stream2"] + q_head_num = q.shape[1] + head_dim = q.shape[2] + BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" ) current_stream = torch.cuda.current_stream() @@ -56,7 +58,7 @@ def token_decode_attention_flash_decoding( B_req_idx=infer_state.b_req_idx, b_shared_seq_len=infer_state.b_shared_seq_len, b_mark_shared_group=infer_state.b_mark_shared_group, - max_len_in_batch=infer_state.max_len_in_batch, + max_len_in_batch=infer_state.max_kv_seq_len, mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, @@ -64,21 +66,20 @@ def token_decode_attention_flash_decoding( ) stream2.wait_stream(current_stream) with torch.cuda.stream(stream2): - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.b_shared_seq_len, - infer_state.max_len_in_batch, + flash_decode_stage2( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + b_shared_seq_len=infer_state.b_shared_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, ) current_stream.wait_stream(stream1) diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py similarity index 89% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 8b3423ce9..295ae66ab 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -9,7 +9,7 @@ class GQADiverseDecodeStage1KernelConfig(KernelConfigs): - kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v1" + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v2" @classmethod @lru_cache(maxsize=200) @@ -113,6 +113,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( BLOCK_N: tl.constexpr, BLOCK_BATCH: tl.constexpr, KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, ): cur_batch = tl.program_id(0) shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch) @@ -128,6 +129,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( cur_q_head_range = tl.where(cur_q_head_range < q_head_end_index, cur_q_head_range, cur_kv_head * gqa_group_size) offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) cur_batch_seq_len = tl.load(b_shared_seq_len + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_start_index = seq_start_block * BLOCK_SEQ @@ -162,25 +164,37 @@ def _fwd_kernel_flash_decode_diverse_stage1( mask=n_mask, other=0, ).to(tl.int64) - off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] - off_k_scale = off_k // KV_QUANT_GROUP_SIZE + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) att_value = tl.dot(q, k.to(q.dtype)) att_value *= sm_scale att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] v = tl.load( - V + off_k.T, + V + off_v, mask=n_mask[:, None], other=0, ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) v_scale = tl.load( - V_scale + off_k_scale.T, - mask=n_mask[:, None], + V_scale + off_k_scale, + mask=n_mask[None, :], other=0.0, ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) cur_max_logic = tl.max(att_value, axis=1) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -269,12 +283,16 @@ def flash_decode_stage1( gqa_group_size = q.shape[1] // k.shape[1] assert triton.next_power_of_2(Lk) == Lk KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] - assert KV_QUANT_GROUP_SIZE == 8 + assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) if BLOCK_HEAD * BLOCK_BATCH < 16: BLOCK_BATCH = 16 // BLOCK_HEAD + assert k.stride() == v.stride() + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, stride_qbs=q.stride(0), @@ -313,6 +331,7 @@ def flash_decode_stage1( BLOCK_N=BLOCK_N, BLOCK_BATCH=BLOCK_BATCH, KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 000000000..f5c0b9c39 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,306 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict +from lightllm.common.triton_utils.autotuner import autotune, Autotuner + + +class GQADiverseDecodeStage2KernelConfig(KernelConfigs): + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage2:v1" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + ) -> dict: + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + batch_size_config: dict = finded_config[ + min( + finded_config.keys(), + key=lambda x: abs(int(x) - avg_seq_len_in_batch), + ) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + config_json: Dict[int, Dict[int, Dict]], + ): + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def _fwd_kernel_flash_decode_diverse_stage2( + Q, + stride_qbs, + stride_qh, + stride_qd, + K, + K_scale, + stride_kbs, + stride_kh, + stride_kd, + V, + V_scale, + stride_vbs, + stride_vh, + stride_vd, + sm_scale, + Req_to_tokens, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + B_req_idx, + B_Seqlen, + b_shared_seq_len, + Mid_O, # [batch, head, seq_block_num, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, gqa_group_size) + + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) + cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_shared_len + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + store_seq_block = seq_start_block + tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + + block_n_size = tl.cdiv( + tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), + BLOCK_N, + ) + + if block_n_size == 0: + return + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + q = tl.load(Q + off_q) + + sum_exp = tl.zeros([gqa_group_size], dtype=tl.float32) + max_logic = tl.zeros([gqa_group_size], dtype=tl.float32) - float("inf") + acc = tl.zeros([gqa_group_size, BLOCK_HEADDIM], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + # q (4, 128) k (128, BLOCK_N) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + store_seq_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + store_seq_block + tl.store( + Mid_O + off_mid_o, + (acc / sum_exp[:, None]), + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + (max_logic + tl.log(sum_exp)), + ) + return + + +@torch.no_grad() +def flash_decode_stage2( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + Req_to_tokens: torch.Tensor, + B_req_idx: torch.Tensor, + B_Seqlen: torch.Tensor, + b_shared_seq_len: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, + run_config: Optional[dict] = None, +): + if not run_config: + run_config = GQADiverseDecodeStage2KernelConfig.try_to_get_best_config( + batch_size=int(q.shape[0]), + avg_seq_len_in_batch=max_len_in_batch, + gqa_group_size=int(q.shape[1] // k.shape[1]), + q_head_dim=int(q.shape[2]), + block_seq=block_seq, + out_dtype=q.dtype, + ) + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + BLOCK_SEQ = block_seq + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + assert triton.next_power_of_2(Lk) == Lk + KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] + assert KV_QUANT_GROUP_SIZE == 8 + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + + assert k.stride() == v.stride() + + _fwd_kernel_flash_decode_diverse_stage2[grid]( + Q=q, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + K=k, + K_scale=k_scale, + stride_kbs=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + V=v, + V_scale=v_scale, + stride_vbs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + sm_scale=sm_scale, + Req_to_tokens=Req_to_tokens, + stride_req_to_tokens_b=Req_to_tokens.stride(0), + stride_req_to_tokens_s=Req_to_tokens.stride(1), + B_req_idx=B_req_idx, + B_Seqlen=B_Seqlen, + b_shared_seq_len=b_shared_seq_len, + Mid_O=mid_out, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] + stride_mid_o_eb=mid_out_logsumexp.stride(0), + stride_mid_o_eh=mid_out_logsumexp.stride(1), + stride_mid_o_es=mid_out_logsumexp.stride(2), + gqa_group_size=gqa_group_size, + BLOCK_SEQ=block_seq, + BLOCK_HEADDIM=Lk, + BLOCK_N=BLOCK_N, + KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, + num_warps=num_warps, + num_stages=num_stages, + ) + return diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py similarity index 71% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py index 88e39b82f..f51d61166 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py @@ -5,8 +5,6 @@ def token_decode_attention_flash_decoding( q, infer_state, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -15,19 +13,20 @@ def token_decode_attention_flash_decoding( alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 + q_head_num, head_dim = q.shape[1], q.shape[2] batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" ) light_ops.group8_int8kv_flashdecoding_stage1( @@ -43,7 +42,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py similarity index 63% rename from lightllm/models/phi3/triton_kernel/flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py index e47e30886..6c50fc392 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py @@ -1,12 +1,11 @@ import torch -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) from .flash_decoding_stage1 import flash_decode_stage1 @@ -15,10 +14,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -28,7 +27,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py similarity index 88% rename from lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py index f6d8b5abe..f41a5c8fd 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py @@ -33,7 +33,6 @@ def _fwd_kernel_flash_decode_stage1( stride_mid_o_eh, stride_mid_o_es, gqa_group_size, - head_dim, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -62,7 +61,7 @@ def _fwd_kernel_flash_decode_stage1( offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - q = tl.load(Q + off_q, mask=offs_d < head_dim, other=0.0) + q = tl.load(Q + off_q) sum_exp = 0.0 max_logic = -float("inf") @@ -75,16 +74,13 @@ def _fwd_kernel_flash_decode_stage1( mask=offs_n_new < cur_batch_end_index, other=0, ) + k_loc = k_loc.to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - k = tl.load( - K + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) - v = tl.load( - V + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) + v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) cur_max_logic = tl.max(att_value, axis=0) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -101,7 +97,7 @@ def _fwd_kernel_flash_decode_stage1( for _ in range(0, need_store, 1): off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp, mask=offs_d < head_dim) + tl.store(Mid_O + off_mid_o, acc / sum_exp) tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -116,13 +112,12 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) + assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + assert k.stride() == v.stride() _fwd_kernel_flash_decode_stage1[grid]( q, k, @@ -152,9 +147,8 @@ def flash_decode_stage1( mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), gqa_group_size, - head_dim, BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, num_warps=1, num_stages=2, diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py similarity index 81% rename from lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py index a06ee5454..101e99dde 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py @@ -8,7 +8,7 @@ def _fwd_kernel_flash_decode_stage2( B_Seqlen, Mid_O, # [batch, head, seq_block_num, head_dim] Mid_O_LogExpSum, # [batch, head, seq_block_num] - Out, # [batch, head, head_dim] + O, # [batch, head, head_dim] stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -19,7 +19,6 @@ def _fwd_kernel_flash_decode_stage2( stride_obs, stride_oh, stride_od, - head_dim, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -38,7 +37,7 @@ def _fwd_kernel_flash_decode_stage2( offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0) + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) @@ -49,23 +48,22 @@ def _fwd_kernel_flash_decode_stage2( sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim) + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] - head_dim = Lk + assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] - BLOCK_DMODEL = triton.next_power_of_2(head_dim) grid = (batch, head_num) _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, mid_out, mid_out_logexpsum, - Out, + out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), @@ -73,12 +71,11 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_dim, + out.stride(0), + out.stride(1), + out.stride(2), BLOCK_SEQ=block_seq, - BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py similarity index 63% rename from lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py index 9a8261132..9de2b8205 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py @@ -14,8 +14,6 @@ def _fwd_kernel_token_att1( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, Att_Out, stride_req_to_tokens_b, stride_req_to_tokens_s, @@ -28,7 +26,6 @@ def _fwd_kernel_token_att1( att_stride_h, att_stride_bs, kv_group_num, - sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -38,38 +35,32 @@ def _fwd_kernel_token_att1( cur_kv_head = cur_head // kv_group_num - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) - # use new start index of k value - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) + cur_batch_start_index = 0 cur_batch_end_index = cur_batch_seq_len - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - # use new value to decide block mask block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark - offs_n_new = cur_batch_start_index + offs_n # the latest window of token + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0, - ) - off_k = ( - k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - ) # [32, D], find token index + ).to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] - att_value = att_value.to(tl.float32) + att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) @@ -77,19 +68,17 @@ def _fwd_kernel_token_att1( @torch.no_grad() -def token_att_fwd( - q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): +def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK)) + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) kv_group_num = q.shape[1] // k.shape[1] if kv_group_num == 1: @@ -105,8 +94,6 @@ def token_att_fwd( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, att_out, Req_to_tokens.stride(0), Req_to_tokens.stride(1), @@ -119,7 +106,6 @@ def token_att_fwd( att_out.stride(0), att_out.stride(1), kv_group_num=kv_group_num, - sliding_window=sliding_window, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py similarity index 59% rename from lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py index acf4923f8..96a5b26dd 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py @@ -13,8 +13,6 @@ def _fwd_kernel_token_att2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, stride_req_to_tokens_b, stride_req_to_tokens_s, stride_ph, @@ -26,7 +24,6 @@ def _fwd_kernel_token_att2( stride_oh, stride_od, kv_group_num, - sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -35,36 +32,30 @@ def _fwd_kernel_token_att2( cur_kv_head = cur_head // kv_group_num - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - # cur_batch_end_index = cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index + cur_batch_start_index = 0 + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # att length - v_loc_off = ( - cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - ) # the latest window of value [64] - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs # [64] - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] + v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_att_seq_len, other=0.0) # [64] + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) v_loc = tl.load( Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n + cur_batch_start_index) < cur_batch_seq_len, + mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0, - ) # [64] + ).to(tl.int64) v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, - other=0.0, - ) # [1, D] + [64, 1] = [64, D] - acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D] + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + acc += tl.sum(p_value[:, None] * v_value, 0) acc = acc.to(Out.dtype.element_ty) off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od @@ -74,9 +65,7 @@ def _fwd_kernel_token_att2( @torch.no_grad() -def token_att_fwd2( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): +def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): BLOCK = 128 # BLOCK = 64 # for triton 2.0.0dev batch, head = B_req_idx.shape[0], prob.shape[0] @@ -94,8 +83,6 @@ def token_att_fwd2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, Req_to_tokens.stride(0), Req_to_tokens.stride(1), prob.stride(0), @@ -107,7 +94,6 @@ def token_att_fwd2( out.stride(1), out.stride(2), kv_group_num=kv_group_num, - siliding_window=sliding_window, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py similarity index 71% rename from lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py index 5e6040ac5..0bb6410e1 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py @@ -5,11 +5,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -18,18 +22,25 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return + @torch.no_grad() def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): BLOCK_SIZE = triton.next_power_of_2(max_input_len) @@ -42,20 +53,26 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) return + def test1(): import torch B, N_CTX, H, D = 4, 1025, 12, 128 + del D dtype = torch.float16 diff --git a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py similarity index 100% rename from lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py similarity index 51% rename from lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py index 8fda08460..b0a9b6245 100644 --- a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py @@ -1,27 +1,27 @@ import torch +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + q_head_num = q.shape[1] + head_dim = q.shape[2] + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) - from lightllm_ppl_fp16_flashdecoding_kernel import fp16_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) - fp16_flashdecoding_stage1( + light_ops.fp16_flashdecoding_stage1( BLOCK_SEQ, mid_o, mid_o_logexpsum, @@ -32,7 +32,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py similarity index 81% rename from lightllm/models/llama/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index e36c51b39..5ba6d0beb 100644 --- a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -336,25 +336,24 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma @triton.jit -def _fwd_kernel_int8kv( +def _fwd_kernel_contiguous_kv( Q, K, V, sm_scale, Out, B_Start_Loc, + B_kv_start_loc, B_Seqlen, b_prompt_cache_len, stride_qbs, stride_qh, stride_qd, - stride_kb, - stride_kh, stride_ks, + stride_kh, stride_kd, - stride_vb, - stride_vh, stride_vs, + stride_vh, stride_vd, stride_obs, stride_oh, @@ -374,6 +373,7 @@ def _fwd_kernel_int8kv( prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + kv_start_loc = tl.load(B_kv_start_loc + cur_batch) block_start_loc = BLOCK_M * start_m @@ -393,6 +393,9 @@ def _fwd_kernel_int8kv( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + stride_ks = tl.cast(stride_ks, tl.int64) + stride_vs = tl.cast(stride_vs, tl.int64) + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) # causal mask @@ -405,8 +408,7 @@ def _fwd_kernel_int8kv( # other=0, # ) off_k = ( - cur_batch * stride_kb - + (start_n + offs_n[None, :]) * stride_ks + (kv_start_loc + start_n + offs_n[None, :]) * stride_ks + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd ) @@ -432,8 +434,7 @@ def _fwd_kernel_int8kv( # other=0.0, # ) off_v = ( - cur_batch * stride_vb - + (start_n + offs_n[:, None]) * stride_vs + (kv_start_loc + start_n + offs_n[:, None]) * stride_vs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd ) @@ -455,7 +456,9 @@ def _fwd_kernel_int8kv( @torch.no_grad() -def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len): +def context_attention_fwd_contiguous_kv( + q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, max_q_input_len, b_prompt_cache_len +): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -468,34 +471,33 @@ def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_inp batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + grid = lambda meta: (triton.cdiv(max_q_input_len, meta["BLOCK_M"]), batch * head, 1) BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 - _fwd_kernel_int8kv[grid]( - q, - k, - v, - sm_scale, - o, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), + _fwd_kernel_contiguous_kv[grid]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out=o, + B_Start_Loc=b_start_loc, + B_kv_start_loc=b_kv_start_loc, + B_Seqlen=b_seq_len, + b_prompt_cache_len=b_prompt_cache_len, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + stride_ks=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + stride_vs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + stride_obs=o.stride(0), + stride_oh=o.stride(1), + stride_od=o.stride(2), kv_group_num=kv_group_num, H=head, BLOCK_DMODEL=Lk, @@ -596,86 +598,5 @@ def test(): assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) -def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_cache_len): - - batch = b_start_loc.shape[0] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - for i in range(batch): - start_loc = b_start_loc[i] - seq_len = b_seq_len[i] - prompt_cache_len = b_prompt_cache_len[i] - cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] - cur_q = cur_q.clone().to(torch.float32) - cur_k = k[i, :seq_len, :] - cur_k = cur_k.clone().to(torch.float32) - - cur_v = v[i, :seq_len, :] - cur_v = cur_v.clone().to(torch.float32) - - cur_q = cur_q.transpose(0, 1) - cur_k = cur_k.transpose(0, 1) - cur_v = cur_v.transpose(0, 1) - dk = cur_q.shape[-1] - - p = torch.matmul(cur_q, cur_k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) - - q_index = torch.arange(cur_q.shape[1]).unsqueeze(-1).to(p.device) - k_index = torch.arange(cur_k.shape[1]).unsqueeze(0).to(p.device) - mask = (q_index + prompt_cache_len >= k_index).int() - mask = mask.unsqueeze(0).expand(cur_q.shape[0], -1, -1) - - p = p.masked_fill(mask == 0, float("-inf")) - - s = F.softmax(p, dim=-1) - - o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) - - -def test2(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 16, 16, 2048, 128 - dtype = torch.float16 - prompt_cache_len = 0 - q = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - kv = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - k = kv[:, :H] - v = kv[:, H:] - # v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - torch_o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.3, std=0.2 - ) - max_input_len = N_CTX - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(Z, dtype=torch.int32, device="cuda") - - for i in range(Z): - b_seq_len[i] = N_CTX - if i != 0: - b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len - b_prompt_cache_len[i] = prompt_cache_len - torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_seq_len, b_prompt_cache_len) - - import time - - torch.cuda.synchronize() - a = time.time() - for i in range(1000): - context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len) - torch.cuda.synchronize() - b = time.time() - # print(o.shape, torch_out.shape) - print((b - a)) - - print("max ", torch.max(torch.abs(torch_o - o))) - print("mean ", torch.mean(torch.abs(torch_o - o))) - assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) - - if __name__ == "__main__": test() - test2() diff --git a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py index bd53de386..060a92bf7 100644 --- a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py @@ -16,12 +16,13 @@ def _fwd_kernel_destindex_copy_kv( stride_o_h, stride_o_d, head_num, + head_dim, BLOCK_DMODEL: tl.constexpr, BLOCK_HEAD: tl.constexpr, ): cur_index = tl.program_id(0) offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = (tl.arange(0, BLOCK_DMODEL)) % head_dim dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) @@ -54,133 +55,10 @@ def destindex_copy_kv(K, DestLoc, Out): Out.stride(1), Out.stride(2), head_num, - BLOCK_DMODEL=head_dim, + head_dim, + BLOCK_DMODEL=triton.next_power_of_2(head_dim), BLOCK_HEAD=BLOCK_HEAD, num_warps=num_warps, num_stages=1, ) return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=offs_h[:, None] < head_num, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num) - tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - BLOCK_DMODEL=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index 9804e4668..c8a6a850b 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -9,8 +9,6 @@ def gen_decode_params(b_seq_len: torch.Tensor): b_kv_seq_len = b_seq_len position_ids = b_seq_len - 1 - mtp_step = get_env_start_args().mtp_step - mtp_size = mtp_step + 1 - b_q_seq_len = torch.ones(b_seq_len.shape[0] // mtp_size, dtype=torch.int32, device=b_seq_len.device) * mtp_size - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size]) + b_q_seq_len = torch.ones(b_seq_len.shape[0], dtype=torch.int32, device=b_seq_len.device) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index e73b34299..8f9172b55 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -43,6 +43,7 @@ def _gen_cumsum_pad0_kernel( def gen_cumsum_pad0_tensor(b_q_seq_len: torch.Tensor, b_kv_seq_len: torch.Tensor): assert len(b_q_seq_len.shape) == 1 assert b_q_seq_len.shape == b_kv_seq_len.shape + assert b_q_seq_len.is_contiguous() b1_cu_q_seq_len = torch.empty((b_q_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") b1_cu_kv_seq_len = torch.empty((b_kv_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py b/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py similarity index 92% rename from lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py rename to lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py index 39deb1b6f..41a25877a 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py @@ -36,11 +36,11 @@ def _fwd_kernel_destindex_copy_kv( dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] - kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] + kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope - o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] - o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] + o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope kv_nope = tl.load(kv_nope_ptrs) kv_rope = tl.load(kv_rope_ptrs) @@ -60,6 +60,9 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): assert KV_nope.shape[2] == O_nope.shape[2] assert KV_rope.shape[1] == O_rope.shape[1] assert KV_rope.shape[2] == O_rope.shape[2] + assert triton.next_power_of_2(kv_nope_head_dim) == kv_nope_head_dim + assert triton.next_power_of_2(kv_rope_head_dim) == kv_rope_head_dim + grid = (seq_len,) num_warps = 1 diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py new file mode 100644 index 000000000..53d1256ec --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -0,0 +1,374 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_int4_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_count, + token_num, + HEAD_NUM: tl.constexpr, + BLOCK_GROUP_COUNT: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_index = tl.program_id(0) + + for cur_index in range(start_index, token_num, step=tl.num_programs(axis=0)): + offs_g = tl.arange(0, BLOCK_GROUP_COUNT) % group_count + offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + for cur_head in tl.static_range(HEAD_NUM, step=1): + src_data_0 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2, + ) + src_data_1 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2 + + 1, + ) + + abs_data_0 = tl.abs(src_data_0) + abs_data_1 = tl.abs(src_data_1) + + data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to( + Out_scale.dtype.element_ty + ) + q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) + q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) + q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) + q_src_data_0 += 7 + q_src_data_0 = q_src_data_0.to(tl.uint8, bitcast=True) + + q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) + q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) + q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) + q_src_data_1 += 7 + q_src_data_1 = q_src_data_1.to(tl.uint8, bitcast=True) + + low_4 = q_src_data_0 & 0xF + high_4 = (q_src_data_1 & 0xF) << 4 + + out_data = (low_4 | high_4).to(Out.dtype.element_ty, bitcast=True) + + o_ptrs = ( + Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + ) + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, out_data) + tl.store(os_ptrs, data_scale) + return + + +@torch.no_grad() +def destindex_copy_int4kv( + KV: torch.Tensor, + DestLoc: torch.Tensor, + KV_buffer: torch.Tensor, + KV_scale_buffer: torch.Tensor, + quant_group_size: int, +): + head_num = KV.shape[1] + head_dim = KV.shape[2] + + assert head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + + group_count = head_dim // quant_group_size + group_dim = quant_group_size + + assert triton.next_power_of_2(group_dim) == group_dim + + KV = KV.view((KV.shape[0], head_num, group_count, group_dim)) + KV_buffer = KV_buffer.view( + KV_buffer.shape[0], KV_buffer.shape[1], group_count, group_dim // 2 + ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 + KV_scale_buffer = KV_scale_buffer.view(KV_scale_buffer.shape[0], KV_scale_buffer.shape[1], group_count) + if len(DestLoc) < 1024: + grid = (len(DestLoc),) + else: + grid = (1024,) + + _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( + K=KV, + Dest_loc=DestLoc, + Out=KV_buffer, + Out_scale=KV_scale_buffer, + stride_k_bs=KV.stride(0), + stride_k_h=KV.stride(1), + stride_k_g=KV.stride(2), + stride_k_d=KV.stride(3), + stride_o_bs=KV_buffer.stride(0), + stride_o_h=KV_buffer.stride(1), + stride_o_g=KV_buffer.stride(2), + stride_o_d=KV_buffer.stride(3), + stride_os_bs=KV_scale_buffer.stride(0), + stride_os_h=KV_scale_buffer.stride(1), + stride_os_g=KV_scale_buffer.stride(2), + group_count=group_count, + token_num=len(DestLoc), + HEAD_NUM=head_num, + BLOCK_GROUP_COUNT=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=group_dim, + num_warps=4, + num_stages=1, + ) + return + + +@triton.jit +def int4_to_float(k_int8, offs_d): + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 + + k_int4 = tl.where( + offs_d[None, None, :] % 2 == 0, + k_low, + k_high, + ) + return k_int4 + + +@triton.jit +def _fwd_dequantize_int4kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] // 2 + ) + k_int4 = int4_to_float(k_int8, offs_d) + + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int4.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] // 2 + ) + v_int4 = int4_to_float(v_int8, offs_d) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int4.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int4kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] * 2 + v_head_num = v.shape[1] + v_head_dim = v.shape[2] * 2 + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + assert triton.next_power_of_2(group_dim) == group_dim + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim // 2)) # int4kv 以 int8 存储的 + v = v.view((v.shape[0], v.shape[1], group_count, group_dim // 2)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int4kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=group_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py new file mode 100644 index 000000000..e5ee5cb8b --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py @@ -0,0 +1,330 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_size, + BLOCK_GROUP_NUM: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_g = tl.arange(0, BLOCK_GROUP_NUM) + offs_d = tl.arange(0, BLOCK_GROUP_DIM) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + src_data = tl.load( + K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], + mask=offs_g[:, None] < group_size, + other=0.0, + ) + abs_data = tl.abs(src_data) + data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) + q_src_data = (src_data / data_scale[:, None]).to(tl.int8) + + o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) + tl.store(os_ptrs, data_scale, mask=offs_g < group_size) + return + + +@torch.no_grad() +def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale, quant_group_dim): + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + assert triton.next_power_of_2(quant_group_dim) == quant_group_dim, "error quant group dim" + + assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" + grid = (seq_len, head_num) + num_warps = 1 + + group_size = head_dim // quant_group_dim + group_dim = quant_group_dim + + K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) + Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) + + _fwd_kernel_destindex_copy_quantize_kv[grid]( + K, + DestLoc, + Out, + Out_scale, + K.stride(0), + K.stride(1), + K.stride(2), + K.stride(3), + Out.stride(0), + Out.stride(1), + Out.stride(2), + Out.stride(3), + Out_scale.stride(0), + Out_scale.stride(1), + Out_scale.stride(2), + group_size, + BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), + BLOCK_GROUP_DIM=group_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_dequantize_int8kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) % group_dim + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] + ) + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int8.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] + ) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int8.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int8kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + v_head_num = v.shape[1] + v_head_dim = v.shape[2] + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim)) + v = v.view((v.shape[0], v.shape[1], group_count, group_dim)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int8kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + num_warps=num_warps, + num_stages=1, + ) + return + + +def test2(): + import time + + B, N_CTX, H, D = 1, 3, 12, 128 + src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() + value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) + scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() + + for _ in range(10): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t1 = time.time() + for _ in range(1000): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t2 = time.time() + + print("Time cost ", t2 - t1) + value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) + scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) + print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + cos = torch.nn.CosineSimilarity(0) + print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py new file mode 100644 index 000000000..b4a91f786 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -0,0 +1,961 @@ +""" +Optimized Mamba Buffer Copy Kernels with Autotune Support + +This module provides auto-tuned Triton kernels for efficient buffer copying operations +in Mamba-style models, including support for MTP (Multi-Token Prediction) buffer broadcasting. +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _copy_buffer_p2p_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + BLOCK_D: tl.constexpr, +): + """ + Optimized kernel for 1D buffer copy. + + Grid: (num_pairs, layer_num, num_blocks_d) + Each program copies one block of dimension d for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + + # Create mask for valid indices + mask = d_offsets < d_size + + # Calculate source and destination pointers for this layer and pair + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + src_ptr = base_src + d_offsets * stride_d + dst_ptr = base_dst + d_offsets * stride_d + + # Load and store + data = tl.load(src_ptr, mask=mask, other=0.0) + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Kernel to copy 2D buffer from source indices to destination indices. + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2) + Each program copies one 2D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1 and d2 block indices + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source and destination indices + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + # Create mask for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full offsets + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_1d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d, + d_size, + num_dst_per_src, + BLOCK_D: tl.constexpr, +): + """ + Broadcast kernel for 1D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_d_idx = tl.program_id(2) + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d_start = block_d_idx * BLOCK_D + d_offsets = d_start + tl.arange(0, BLOCK_D) + mask = d_offsets < d_size + + # Calculate source pointer + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + src_ptr = base_src + d_offsets * stride_d + + # Load data once + data = tl.load(src_ptr, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + dst_ptr = base_dst + d_offsets * stride_d + + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_2d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, +): + """ + Broadcast kernel for 2D buffer copy (one source to multiple destinations). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2) + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx + block_d1_idx = block_idx // num_blocks_d2 + block_d2_idx = block_idx % num_blocks_d2 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + mask = d1_mask[:, None] & d2_mask[None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + offsets = d1_offsets[:, None] * stride_d1 + d2_offsets[None, :] * stride_d2 + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_p2p_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + pair_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Optimized kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_pairs, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program copies one 3D block for one (pair, layer) combination. + """ + pair_idx = tl.program_id(0) + pair_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source and destination indices for this pair + src_idx = tl.load(src_indexes_ptr + pair_idx).to(tl.int64) + dst_idx = tl.load(dst_indexes_ptr + pair_idx).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks for valid indices + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + # 3D mask: [BLOCK_D1, BLOCK_D2, BLOCK_D3] + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate base pointers + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + + # Calculate full 3D offsets + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + # Load and store + data = tl.load(base_src + offsets, mask=mask, other=0.0) + tl.store(base_dst + offsets, data, mask=mask) + + +@triton.jit +def _copy_buffer_broadcast_3d_kernel( + src_buffer_ptr, + dst_buffer_ptr, + src_indexes_ptr, + dst_indexes_ptr, + copy_idx_offset, + layer_idx_offset, + stride_layer, + stride_index, + stride_d1, + stride_d2, + stride_d3, + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1: tl.constexpr, + BLOCK_D2: tl.constexpr, + BLOCK_D3: tl.constexpr, +): + """ + Broadcast kernel for 3D data buffer copy (5D tensor: layer, buffer, d1, d2, d3). + + Grid: (num_src, layer_num, num_blocks_d1 * num_blocks_d2 * num_blocks_d3) + Each program loads once from source and broadcasts to all destinations. + """ + src_idx_in_batch = tl.program_id(0) + copy_idx_offset + layer_idx = tl.program_id(1) + layer_idx_offset + block_idx = tl.program_id(2) + + # Decompose block_idx into d1, d2, d3 block indices + block_d1_idx = block_idx // (num_blocks_d2 * num_blocks_d3) + temp = block_idx % (num_blocks_d2 * num_blocks_d3) + block_d2_idx = temp // num_blocks_d3 + block_d3_idx = temp % num_blocks_d3 + + # Load source index + src_idx = tl.load(src_indexes_ptr + src_idx_in_batch).to(tl.int64) + + # Calculate offsets for this block + d1_start = block_d1_idx * BLOCK_D1 + d2_start = block_d2_idx * BLOCK_D2 + d3_start = block_d3_idx * BLOCK_D3 + + d1_offsets = d1_start + tl.arange(0, BLOCK_D1) + d2_offsets = d2_start + tl.arange(0, BLOCK_D2) + d3_offsets = d3_start + tl.arange(0, BLOCK_D3) + + # Create masks + d1_mask = d1_offsets < d1_size + d2_mask = d2_offsets < d2_size + d3_mask = d3_offsets < d3_size + + mask = d1_mask[:, None, None] & d2_mask[None, :, None] & d3_mask[None, None, :] + + # Calculate source pointer and load data once + base_src = src_buffer_ptr + layer_idx * stride_layer + src_idx * stride_index + + offsets = ( + d1_offsets[:, None, None] * stride_d1 + + d2_offsets[None, :, None] * stride_d2 + + d3_offsets[None, None, :] * stride_d3 + ) + + data = tl.load(base_src + offsets, mask=mask, other=0.0) + + # Broadcast to all destinations for this source + for dst_offset in range(num_dst_per_src): + dst_idx_in_batch = src_idx_in_batch * num_dst_per_src + dst_offset + dst_idx = tl.load(dst_indexes_ptr + dst_idx_in_batch).to(tl.int64) + + base_dst = dst_buffer_ptr + layer_idx * stride_layer + dst_idx * stride_index + tl.store(base_dst + offsets, data, mask=mask) + + +# ==================== Config Generation Functions ==================== + + +def _get_buffer_copy_1d_configs(): + """Generate candidate configurations for 1D buffer copy.""" + configs = [] + for block_d in [32, 64, 128, 256, 512, 1024]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D": block_d, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_2d_configs(): + """Generate candidate configurations for 2D buffer copy.""" + configs = [] + for block_d1 in [16, 32, 64, 128]: + for block_d2 in [16, 32, 64, 128, 256]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_buffer_copy_3d_configs(): + """Generate candidate configurations for 3D buffer copy (5D tensor).""" + configs = [] + for block_d1 in [8, 16, 32]: + for block_d2 in [8, 16, 32, 64]: + for block_d3 in [8, 16, 32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [2, 3]: + # Skip configs that are too large for shared memory + if block_d1 * block_d2 * block_d3 > 32768: + continue + configs.append( + { + "BLOCK_D1": block_d1, + "BLOCK_D2": block_d2, + "BLOCK_D3": block_d3, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +# ==================== Static and Run Key Functions ==================== + + +def _get_buffer_copy_static_key(src_buffer: torch.Tensor): + """Static key based on buffer shape and dtype.""" + shape = src_buffer.shape + return { + "ndim": len(shape), + "layer_num": shape[0], + "d_sizes": str(shape[2:]), # Dimension sizes + "dtype": str(src_buffer.dtype), + } + + +def _get_buffer_copy_run_key(src_indexes: torch.Tensor): + """Run key based on number of copy pairs.""" + return src_indexes.shape[0] + + +# ==================== Auto-tuned Buffer Copy Functions ==================== + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_p2p_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer copy.""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + + if run_config is None: + # Default config if autotune is disabled + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_1d:v1", + configs_gen_func=_get_buffer_copy_1d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_1d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 1D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D = triton.next_power_of_2(min(d_size, 256)) + num_warps = 4 if BLOCK_D > 256 else 2 + num_stages = 2 + else: + BLOCK_D = run_config["BLOCK_D"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_d) + + _copy_buffer_broadcast_1d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + d_size, + num_dst_per_src, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_2d:v1", + configs_gen_func=_get_buffer_copy_2d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_2d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 2D buffer broadcast (one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 64)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 128)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_total = num_blocks_d1 * num_blocks_d2 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_2d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + d1_size, + d2_size, + num_blocks_d2, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_copy_p2p_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_p2p_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer copy (5D tensor).""" + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for pair_chunk_start in range(0, num_pairs, MAX_GRID_SIZE): + pair_chunk_end = min(pair_chunk_start + MAX_GRID_SIZE, num_pairs) + pair_chunk_size = pair_chunk_end - pair_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (pair_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_p2p_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + pair_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@autotune( + kernel_name="mamba_buffer_broadcast_3d:v1", + configs_gen_func=_get_buffer_copy_3d_configs, + static_key_func=_get_buffer_copy_static_key, + run_key_func=_get_buffer_copy_run_key, + mutates_args=["dst_buffer"], +) +def _copy_buffer_broadcast_3d_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + """Auto-tuned 3D data buffer broadcast (5D tensor, one src to multiple dst).""" + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d1_size = src_buffer.shape[2] + d2_size = src_buffer.shape[3] + d3_size = src_buffer.shape[4] + num_dst_per_src = dst_indexes.shape[0] // num_src + + if run_config is None: + BLOCK_D1 = triton.next_power_of_2(min(d1_size, 16)) + BLOCK_D2 = triton.next_power_of_2(min(d2_size, 32)) + BLOCK_D3 = triton.next_power_of_2(min(d3_size, 64)) + num_warps = 4 if BLOCK_D1 * BLOCK_D2 * BLOCK_D3 > 4096 else 8 + num_stages = 2 + else: + BLOCK_D1 = run_config["BLOCK_D1"] + BLOCK_D2 = run_config["BLOCK_D2"] + BLOCK_D3 = run_config["BLOCK_D3"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_d1 = triton.cdiv(d1_size, BLOCK_D1) + num_blocks_d2 = triton.cdiv(d2_size, BLOCK_D2) + num_blocks_d3 = triton.cdiv(d3_size, BLOCK_D3) + num_blocks_total = num_blocks_d1 * num_blocks_d2 * num_blocks_d3 + + MAX_GRID_SIZE = 65535 + + for src_chunk_start in range(0, num_src, MAX_GRID_SIZE): + src_chunk_end = min(src_chunk_start + MAX_GRID_SIZE, num_src) + src_chunk_size = src_chunk_end - src_chunk_start + + for layer_chunk_start in range(0, layer_num, MAX_GRID_SIZE): + layer_chunk_end = min(layer_chunk_start + MAX_GRID_SIZE, layer_num) + layer_chunk_size = layer_chunk_end - layer_chunk_start + + grid = (src_chunk_size, layer_chunk_size, num_blocks_total) + + _copy_buffer_broadcast_3d_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_chunk_start, + layer_chunk_start, + src_buffer.stride(0), + src_buffer.stride(1), + src_buffer.stride(2), + src_buffer.stride(3), + src_buffer.stride(4), + d1_size, + d2_size, + d3_size, + num_blocks_d2, + num_blocks_d3, + num_dst_per_src, + BLOCK_D1=BLOCK_D1, + BLOCK_D2=BLOCK_D2, + BLOCK_D3=BLOCK_D3, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ==================== Unified Interface ==================== + + +def copy_buffer_p2p( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Copy buffers from source indices to destination indices with auto-tuning. + + Supports 3D (conv states), 4D (standard buffers), and 5D (SSM states) buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_pairs] + dst_indexes: Destination buffer indices [num_pairs] + """ + assert src_buffer.shape == dst_buffer.shape + assert src_indexes.shape == dst_indexes.shape + assert len(src_indexes.shape) == 1 + + if len(src_buffer.shape) == 3: + # 1D case: (layer_num, buffer_size, d) + _copy_buffer_p2p_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 4: + # 2D case: (layer_num, buffer_size, d1, d2) + _copy_buffer_p2p_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_p2p_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") + + +def copy_buffer_broadcast( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + """ + Broadcast buffers from source indices to multiple destination indices (MTP use case). + + Each source buffer is copied to multiple destination buffers. + + Args: + src_buffer: Source buffer tensor [layer_num, buffer_size, ...] + dst_buffer: Destination buffer tensor [layer_num, buffer_size, ...] + src_indexes: Source buffer indices [num_src] + dst_indexes: Destination buffer indices [num_src, num_dst_per_src] (2D tensor) + """ + assert src_buffer.shape == dst_buffer.shape + assert len(src_indexes.shape) == 1 + assert len(dst_indexes.shape) == 2, f"dst_indexes must be 2D, got shape {dst_indexes.shape}" + + num_src = src_indexes.shape[0] + + assert num_src == dst_indexes.shape[0], f"Mismatch: src_indexes {num_src} vs dst_indexes {dst_indexes.shape[0]}" + + # Flatten dst_indexes for kernel + dst_indexes_flat = dst_indexes.reshape(-1).contiguous() + + if len(src_buffer.shape) == 3: + # 1D case + _copy_buffer_broadcast_1d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 4: + # 2D case + _copy_buffer_broadcast_2d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + elif len(src_buffer.shape) == 5: + # 5D case: (layer_num, buffer_size, d1, d2, d3) - Use Triton kernel for zero extra memory + _copy_buffer_broadcast_3d_autotuned(src_buffer, dst_buffer, src_indexes, dst_indexes_flat) + + else: + raise ValueError(f"Unsupported buffer shape: {src_buffer.shape}") diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py new file mode 100644 index 000000000..fb0609a40 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py @@ -0,0 +1 @@ +from .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py similarity index 92% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 256dfce5a..28839b5f5 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -12,28 +12,21 @@ def gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv_nope, - kv_rope, - infer_state, - q_head_num, - kv_lora_rank, - q_rope_dim, - qk_nope_head_dim, - softmax_scale, - out=None, - alloc_tensor_func=torch.empty, - **run_config + q_nope, q_rope, kv_nope, kv_rope, infer_state, softmax_scale, out=None, alloc_tensor_func=torch.empty, **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + + q_head_num, kv_lora_rank = q_nope.shape[1], q_nope.shape[2] + q_rope_dim = q_rope.shape[2] + assert q_rope_dim == 64 + calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py new file mode 100644 index 000000000..5725bed2e --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py @@ -0,0 +1 @@ +from .context_flashattention_nopad_with_v import context_attention_fwd_with_v diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py rename to lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/repack_kv_index.py rename to lightllm/common/basemodel/triton_kernel/repack_kv_index.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 66caf5d78..7d516e672 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -1,20 +1,16 @@ from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager -from .int8kv_mem_manager import INT8KVMemoryManager from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager from .export_calibration_mem_manager import ExportCalibrationMemoryManager from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager -from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager __all__ = [ "MemoryManager", "ReadOnlyStaticsMemoryManager", - "INT8KVMemoryManager", "CalibrationFP8KVMemoryManager", "ExportCalibrationMemoryManager", "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", - "Deepseek2FP8KVMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py deleted file mode 100644 index 00699f4b1..000000000 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from .deepseek2_mem_manager import Deepseek2MemoryManager - - -class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - # scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8 - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 771173460..3d93e1b07 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -3,13 +3,14 @@ import torch.distributed as dist from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager -from typing import List, Union +from typing import List, Union, Any from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io + logger = init_logger(__name__) @@ -17,6 +18,29 @@ class Deepseek2MemoryManager(MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + self.kv_buffer[layer_index][:, :, :kv_lora_rank], + self.kv_buffer[layer_index][:, :, kv_lora_rank:], + ) + return + + def get_att_input_params(self, layer_index: int) -> Any: + kv = self.kv_buffer[layer_index] + return kv + def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py index b2749176e..ffdc9b2c9 100755 --- a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py @@ -1,6 +1,28 @@ +import torch +from typing import Tuple, Any from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 + + scales = self.scales + destindex_copy_kv_fp8( + kv, + mem_index, + scales[layer_index] if scales is not None else None, + self.kv_buffer[layer_index].view(torch.float8_e4m3fn), + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v diff --git a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py deleted file mode 100755 index 5725cdb7b..000000000 --- a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch - -from .mem_manager import MemoryManager - - -class INT8KVMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9): - self.kv_dtype = torch.int8 - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=mem_fraction) - - def get_cell_size(self): - return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( - self.kv_dtype - ) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") - self.scale_buffer = torch.empty((layer_num, size + 1, 2 * head_num, 1), dtype=dtype, device="cuda") - - def _free_buffers(self): - self.kv_buffer = None - self.scale_buffer = None - - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"]) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009..8d6fb48c2 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,7 +3,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import List, Union +from typing import List, Union, Tuple, Any from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -18,13 +18,17 @@ from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm +from lightllm.common.allocator_utils import TokenAllocator from multiprocessing.reduction import ForkingPickler from filelock import FileLock + logger = init_logger(__name__) +KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_kv_cache_token_can_use_num" + -class MemoryManager: +class MemoryManager(TokenAllocator): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num @@ -35,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False # profile the max total token num if the size is None self.profile_size(mem_fraction) - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._mem_state_return = torch.arange( - 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self._return_start = 0 - self.mark_start = 0 - self.mark_end = self.size - - self.can_use_mem_size = self.size + super().__init__(self.size, f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name - - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) - - self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -63,7 +48,20 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) - self.HOLD_TOKEN_MEMINDEX = self.size + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv + + destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -326,59 +324,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -389,24 +341,13 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + # 调用父类的resize_mem + super().resize_mem(new_size) + self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - def copy_kv_from_other_dp_ranks( self, mem_managers: List["MemoryManager"], @@ -498,12 +439,12 @@ def __init__(self) -> None: self.dp_world_size = self.global_world_size // args.dp # 兼容多机 dp size=1 纯 tp 模式的情况 self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 - self.shared_tp_infos = [ - SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{KVCACHE_TOKEN_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: - return self.shared_tp_infos[0].get_value() - return self.shared_tp_infos[dp_rank_in_node].get_value() + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 259c5a56f..1ff58b89a 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -1,12 +1,10 @@ from . import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -18,8 +16,6 @@ @lru_cache(maxsize=None) def select_mem_manager_class(): - mode = get_env_start_args().mode - # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() @@ -27,38 +23,25 @@ def select_mem_manager_class(): if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager - if "triton_fp8kv" in mode: - mem_class = Deepseek2FP8KVMemoryManager - - logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") + logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class # case normal - logger.info(f"mode setting params: {mode}") - if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode or "ppl_int8kv_flashdecoding_diverse" in mode: + logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") + if get_env_start_args().llm_kv_type == "int8kv": memory_manager_class = PPLINT8KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "ppl_int4kv_flashdecoding" in mode: + elif get_env_start_args().llm_kv_type == "int4kv": memory_manager_class = PPLINT4KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "triton_int8kv" in mode: - memory_manager_class = INT8KVMemoryManager - logger.info("Model kv cache using mode triton int8kv") - elif "triton_fp8kv" in mode: - raise Exception("currently only for deepseek") - elif "offline_calibration_fp8kv" in mode: - memory_manager_class = CalibrationFP8KVMemoryManager - logger.info("Model kv cache using mode offline calibration fp8kv") - elif "export_fp8kv_calibration" in mode: + elif get_env_start_args().llm_kv_type == "fp8kv": memory_manager_class = ExportCalibrationMemoryManager - logger.info("Using mode export fp8kv calibration") - else: + elif get_env_start_args().llm_kv_type == "None": memory_manager_class = MemoryManager - logger.info("Model kv cache using mode normal") + + logger.info(f"Model kv cache using mem_manager class: {memory_manager_class}") return memory_manager_class @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, INT8KVMemoryManager] + return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager] diff --git a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py index 5cc0b12d0..56a79a3b5 100755 --- a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py @@ -31,8 +31,10 @@ def __init__( self.scales_list = None self.abs_max = None + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if is_export_mode: - scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2] + scales_shape = [layer_num, 2 * head_num] if enable_fa3 else [layer_num, 2] self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda") elif get_env_start_args().kv_quant_calibration_config_path is not None: logger.info( @@ -43,7 +45,7 @@ def __init__( self.scales_list = cfg["scales"] self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) - if not get_env_start_args().enable_fa3: + if not enable_fa3: self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1) elif cfg["num_head"] > self.total_head_num: factor = cfg["num_head"] // self.total_head_num @@ -51,7 +53,7 @@ def __init__( elif cfg["num_head"] < self.total_head_num: factor = self.total_head_num // cfg["num_head"] self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() - if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: + if enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: half_head = self.total_head_num // 2 start_head = dist.get_rank() * head_num end_head = start_head + head_num @@ -65,6 +67,8 @@ def __init__( logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales") def _load_and_check_config(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if os.path.exists(get_env_start_args().kv_quant_calibration_config_path): with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f: cfg = json.load(f) @@ -86,7 +90,7 @@ def _load_and_check_config(self): raise ValueError( f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}" ) - if get_env_start_args().enable_fa3: + if enable_fa3: if cfg["quant_type"] != "per_head": raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend") else: @@ -100,6 +104,7 @@ def _load_and_check_config(self): ) def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend inference_counts = get_kv_quant_calibration_inference_count() warmup_counts = get_kv_quant_calibration_warmup_count() if not get_model_init_status() or self.count >= warmup_counts + inference_counts: @@ -109,7 +114,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): logger.info("kv cache calibration mode will collect kv cache data for quantization calibration") if self.abs_max is not None and self.count >= warmup_counts: - if get_env_start_args().enable_fa3: + if enable_fa3: kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) else: k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32) @@ -119,7 +124,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1: final_abs_max = self.abs_max if dist.is_initialized() and dist.get_world_size() > 1: - if get_env_start_args().enable_fa3: + if enable_fa3: k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) k_max = k_max.contiguous() v_max = v_max.contiguous() @@ -144,11 +149,13 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): self.count += 1 def _export_calibration_data(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + model_arch = get_model_architectures(get_env_start_args().model_dir) cfg = { "version": "1.0", "architectures": model_arch, - "quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor", + "quant_type": "per_head" if enable_fa3 else "per_tensor", "qmin": self.qmin, "qmax": self.qmax, "num_layers": self.layer_num, diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index f3218594d..559980dc1 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -9,6 +9,28 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv + + destindex_copy_int4kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_size=self.group_quant_size, + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 2a5aad7c8..951d72e2c 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -9,6 +9,28 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv + + destindex_copy_quantize_kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_dim=self.group_quant_size, + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py new file mode 100644 index 000000000..348b14192 --- /dev/null +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -0,0 +1,188 @@ +from typing import List, Tuple, Union + +import torch +import numpy as np + +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.common.allocator_utils import TokenAllocator +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt + +logger = init_logger(__name__) + +MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" + + +class LayerCache: + def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): + self.size = size + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + + self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) + + +class MambaCacheManager(TokenAllocator): + def __init__( + self, + size: int, + layer_num: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + ): + super().__init__(size, f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) + self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) + self.HOLD_BUFFER_INDEX = size + + logger.warning( + f"Linear attention state cache size: {size}\n" + f"Conv state use : " + f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + f"Ssm state use : " + f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + ) + + def get_mamba_cache(self, layer_idx: int): + conv_state = self.conv_state_cache.buffer[layer_idx] + ssm_state = self.ssm_state_cache.buffer[layer_idx] + return conv_state, ssm_state + + def copy_buffer_p2p(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Copy buffers from source indices to destination indices using optimized Triton kernel. + + Args: + src_buffer_indexes: Source buffer indices (1D tensor) + dst_buffer_indexes: Destination buffer indices (1D tensor) + """ + assert src_buffer_indexes.dim() == 1 + assert dst_buffer_indexes.dim() == 1 + assert src_buffer_indexes.shape[0] == dst_buffer_indexes.shape[0] + + # Validate indices are within valid range [0, size] (size+1 is the buffer dim) + max_valid_idx = self.size # HOLD_BUFFER_INDEX = size is valid + src_max = src_buffer_indexes.max().item() if src_buffer_indexes.numel() > 0 else -1 + src_min = src_buffer_indexes.min().item() if src_buffer_indexes.numel() > 0 else -1 + dst_max = dst_buffer_indexes.max().item() if dst_buffer_indexes.numel() > 0 else -1 + dst_min = dst_buffer_indexes.min().item() if dst_buffer_indexes.numel() > 0 else -1 + + if src_min < 0 or src_max > max_valid_idx or dst_min < 0 or dst_max > max_valid_idx: + logger.error( + f"Invalid buffer indices: src=[{src_min}, {src_max}], dst=[{dst_min}, {dst_max}], " + f"valid range=[0, {max_valid_idx}], conv shape={self.conv_state_cache.buffer.shape}, " + f"ssm shape={self.ssm_state_cache.buffer.shape}" + ) + raise ValueError("Invalid buffer indices for copy_buffer_p2p") + + # Use PyTorch advanced indexing for buffer copy (safer than Triton for complex shapes) + # The buffer shape is [layer_num, buffer_size, *shape] + # We need to copy all layers for the given buffer indices + src_idx = src_buffer_indexes.long() + dst_idx = dst_buffer_indexes.long() + + # Copy conv_state: [layer_num, buffer_size, d1, d2] + self.conv_state_cache.buffer[:, dst_idx, ...] = self.conv_state_cache.buffer[:, src_idx, ...] + + # Copy ssm_state: [layer_num, buffer_size, d1, d2, d3] + self.ssm_state_cache.buffer[:, dst_idx, ...] = self.ssm_state_cache.buffer[:, src_idx, ...] + return + + def copy_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for broadcast copy + # src_buffer_index: [num_src] + # dst_buffer_indexes: [num_src, num_dst_per_src] + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations + # For each (src, dst_group), copy buffer[src] to buffer[dst1], buffer[dst2], ... + num_src, num_dst_per_src = dst_idx.shape + for i in range(num_src): + src = src_idx[i : i + 1] # Keep as 1D tensor with 1 element + dsts = dst_idx[i, :] # 1D tensor with num_dst_per_src elements + # Copy conv_state + self.conv_state_cache.buffer[:, dsts, ...] = self.conv_state_cache.buffer[:, src, ...] + # Copy ssm_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def copy_ssm_buffer_broadcast(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Broadcast ONLY SSM states (not conv states) from source indices to destination indices. + + This is used for MTP mode where each buffer maintains its own independent conv state, + but SSM states need to be synchronized. + """ + assert src_buffer_index.dim() == 1 + assert dst_buffer_indexes.dim() == 2 + assert src_buffer_index.shape[0] == dst_buffer_indexes.shape[0] + + # Use PyTorch advanced indexing for SSM-only broadcast copy + src_idx = src_buffer_index.long() + dst_idx = dst_buffer_indexes.long() + + # Broadcast each source to all its destinations (SSM only) + num_src = dst_idx.shape[0] + for i in range(num_src): + src = src_idx[i : i + 1] + dsts = dst_idx[i, :] + # Only copy ssm_state, NOT conv_state + self.ssm_state_cache.buffer[:, dsts, ...] = self.ssm_state_cache.buffer[:, src, ...] + return + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """ + Free the allocated cache buffers and clear them. + + Args: + free_index: Buffer indices to free (tensor or list of ints) + """ + # Convert to tensor if needed for indexing + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") + else: + free_index_tensor = free_index.to(device="cuda", dtype=torch.long) + + # Clear the buffers for the freed indices + # Shape: [layer_num, buffer_index, *shape] + self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 + self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 + + # Call parent's free method to update allocator state + super().free(free_index) + return + + +class ReadOnlyStaticsMambaCacheManager: + """ + 读取一些统计信息 + """ + + def __init__(self) -> None: + args = get_env_start_args() + self.global_world_size = args.tp + self.node_world_size = args.tp // args.nnodes + self.dp_world_size = self.global_world_size // args.dp + # 兼容多机 dp size=1 纯 tp 模式的情况 + self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] + + def get_unrefed_token_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 40c8aa993..171ac27fa 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,8 +1,11 @@ import torch import collections +from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional +from typing_extensions import override + from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args @@ -92,6 +95,18 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def alloc_buffer_for_req(self, req_index: torch.Tensor): + """Allocate buffers for requests. No-op for standard models without linear attention.""" + pass + + def free_buffer(self, free_buffer_indexes): + """Free buffer memory. No-op for standard models without linear attention.""" + pass + + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + """Copy buffer state between requests. No-op for standard models without linear attention.""" + pass + class ReqSamplingParamsManager: """ @@ -236,3 +251,38 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): p_token_counts_tensor.cuda(non_blocking=True), p_cumsum_seq_len_tensor.cuda(non_blocking=True), ) + + +class ReqManagerForMamba(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager): + from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.mtp_step = get_env_start_args().mtp_step + self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager + self.req_to_buffer_index = torch.zeros( + (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" + ) + self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + + @override + def free_buffer(self, free_buffer_indexes: List[int]): + self.buffer_mem_manager.free(free_buffer_indexes) + return + + @override + def alloc_buffer_for_req(self, req_index: torch.Tensor): + num_reqs = req_index.shape[0] + num_buffers_per_req = self.mtp_step + 1 + buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) + alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index, self.mtp_step) + + @override + def copy_buffer_from_another_buffer(self, src_buffer_index: torch.Tensor, tgt_req_index: torch.Tensor): + # 获取目标请求的所有 MTP buffer (从 buffer[0] 到 buffer[mtp_step]) + mtp_range = torch.arange(0, self.mtp_step + 1, dtype=torch.int32, device="cuda") + all_mtp_buffers = self.req_to_buffer_index[tgt_req_index[:, None], mtp_range[None, :]] + + # 将 shared buffer 广播到所有 MTP step + self.buffer_mem_manager.copy_buffer_broadcast(src_buffer_index, all_mtp_buffers) + return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..4b002622a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..cc5c68eb7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..7421097fa --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..d831f32c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 000000000..354a6f93a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,41 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 4 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 4 + }, + "164096": { + "num_warps": 1 + }, + "2048": { + "num_warps": 2 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 8 + }, + "8": { + "num_warps": 8 + }, + "8448": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9fbae2414 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..d00af04ca --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,54 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "164096": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 32, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "8448": { + "BLK_HEADS": 32, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..84c47d348 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,54 @@ +{ + "1024": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "128": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "1312768": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "32768": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 512, + "num_warps": 4 + }, + "64": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "67584": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "8": { + "BLOCK_N": 512, + "num_warps": 8 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..3fd0050d7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..3863d48e8 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 256, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..fde50e757 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "84480": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..612f2b51e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..5923f3164 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,54 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "8448": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..4d6191579 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 16 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..0b3aa1e36 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,152 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "10": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "164096": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8448": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "84480": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 000000000..9f44ee6c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 000000000..4fa2f949f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 32, + "num_stages": 4, + "num_warps": 4 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 3, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..cc5c68eb7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..fc9fc9a4a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..7421097fa --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..d831f32c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 000000000..d033b820d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,47 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 2 + }, + "131072": { + "num_warps": 1 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "16640": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 2 + }, + "4096": { + "num_warps": 1 + }, + "64": { + "num_warps": 4 + }, + "8": { + "num_warps": 8 + }, + "8448": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9fbae2414 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..2956daba2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,62 @@ +{ + "1": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "131072": { + "BLK_HEADS": 32, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 32, + "num_warps": 4 + }, + "16640": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "8448": { + "BLK_HEADS": 16, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..62ee466d2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,62 @@ +{ + "1024": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "1048576": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "133120": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 1024, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "67584": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "8": { + "BLOCK_N": 1024, + "num_warps": 4 + }, + "800": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..198a196df --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..356896ebe --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 256, + "num_stages": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..7a9c59625 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,137 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "166400": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "327680": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "84480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..07951d5fd --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,137 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..346c9f439 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,62 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16640": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32768": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..5ad840128 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,92 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 32, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "16640": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32768": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..9cac56fe9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_M": 8, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "10": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "166400": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "327680": { + "BLOCK_M": 128, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8448": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "84480": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003..9cf5864a9 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -1,4 +1,3 @@ -from lightllm.models.cohere.model import CohereTpPartModel from lightllm.models.mixtral.model import MixtralTpPartModel from lightllm.models.bloom.model import BloomTpPartModel from lightllm.models.llama.model import LlamaTpPartModel @@ -8,7 +7,8 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel -from lightllm.models.chatglm2.model import ChatGlm2TpPartModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index 7938869f5..f4fff116c 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -10,9 +10,9 @@ class BloomPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): repair_config(config=network_config, same_names=["layer_norm_epsilon", "rms_norm_eps"]) - super().__init__(network_config, mode) + super().__init__(network_config) return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index baf1d3084..dfe396ab5 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class BloomPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["layer_norm_epsilon"] return diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index d82a23d03..808788f71 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -2,16 +2,15 @@ from typing import Tuple from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight -from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd -from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd from lightllm.common.basemodel import InferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl class BloomTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["layer_norm_epsilon"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = self.tp_q_head_num_ @@ -21,6 +20,40 @@ def __init__(self, layer_num, network_config, mode): self.embed_dim_ = network_config["n_embed"] return + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: BloomTransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor + + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), + alloc_func=self.alloc_tensor, + ) + return o_tensor.view(q.shape) + def _att_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: @@ -42,47 +75,6 @@ def _get_qkv( cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) return q, cache_kv - def _context_attention_kernel( - self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _token_attention_kernel( - self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - token_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.total_token_num, - alloc_tensor_func=self.alloc_tensor, - ) - return o_tensor - def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_)) return o_tensor diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index afc8c9308..83f767453 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.pre_norm_weight_ = NoTpNormWeight( weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 7b27ce6f2..599893655 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -48,8 +48,8 @@ def get_slopes_power_of_2(n): class BloomTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 7e44ec2eb..925620bf9 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -5,6 +5,7 @@ from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend @ModelRegistry("bloom") @@ -35,3 +36,8 @@ def _init_config(self): def _reset_num_key_value_heads(self): self.config["num_key_value_heads"] = self.config["num_attention_heads"] return + + def _init_att_backend(self): + self.prefill_att_backend = TritonAttBackend(self) + self.decode_att_backend = TritonAttBackend(self) + return diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 07ffc4bea..000000000 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight - - -class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): - """ """ - - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) - return - - def swiglu(self, x): - x = torch.chunk(x, 2, dim=-1) - return torch.nn.functional.silu(x[0]) * x[1] - - def _ffn( - self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - input = None - ffn1_out = self.swiglu(up_gate_out) - up_gate_out = None - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 0139eb883..000000000 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,20 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class ChatGLM2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - - self.wte_weight_ = EmbeddingWeight( - weight_name="transformer.embedding.word_embeddings.weight", data_type=self.data_type_ - ) - self.lm_head_weight_ = LMHeadWeight( - weight_name="transformer.output_layer.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="transformer.encoder.final_layernorm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py deleted file mode 100755 index d4dd1b7a2..000000000 --- a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,72 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight - - -class ChatGLM2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__( - layer_num, - data_type, - network_config, - mode, - quant_cfg, - ) - return - - def _preprocess_weight(self, weights): - n_kv_embed = self.head_dim * self.n_kv_head - qkv_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" - if qkv_weight_name in weights: - qkv_weight_ = weights[qkv_weight_name] - weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :] - weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :] - weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :] - del weights[qkv_weight_name] - - qkv_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" - if qkv_bias_name in weights: - qkv_bias_ = weights[qkv_bias_name] - weights[self._q_bias_name] = qkv_bias_[: self.n_embed] - weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed] - weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed] - del weights[qkv_bias_name] - - gate_up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight" - if gate_up_weight_name in weights: - gate_up_weight_ = weights[gate_up_weight_name] - weights[self._gate_weight_name] = gate_up_weight_[: self.n_inter, :] - weights[self._up_weight_name] = gate_up_weight_[self.n_inter : 2 * self.n_inter, :] - del weights[gate_up_weight_name] - - def _parse_config(self): - self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] - self.n_inter = self.network_config_["ffn_hidden_size"] - self.n_kv_head = self.network_config_["multi_query_group_num"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) - - def load_hf_weights(self, weights): - self._preprocess_weight(weights) - super().load_hf_weights(weights) - return - - def _init_weight_names(self): - self._q_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.weight" - self._q_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.bias" - self._k_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.weight" - self._k_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.bias" - self._v_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.weight" - self._v_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.bias" - self._o_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" - self._o_bias_name = None - - self._gate_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.gate_proj.weight" - self._gate_bias_name = None - self._up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.up_proj.weight" - self._up_bias_name = None - self._down_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" - self._down_bias_name = None - - self._att_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.input_layernorm.weight" - self._att_norm_bias_name = None - self._ffn_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" - self._ffn_norm_bias_name = None diff --git a/lightllm/models/chatglm2/model.py b/lightllm/models/chatglm2/model.py deleted file mode 100644 index e6aa39527..000000000 --- a/lightllm/models/chatglm2/model.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import json -import torch - -from lightllm.models.registry import ModelRegistry -from lightllm.models.chatglm2.layer_infer.transformer_layer_infer import ChatGLM2TransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight -from lightllm.models.chatglm2.layer_weights.pre_and_post_layer_weight import ChatGLM2PreAndPostLayerWeight -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.build_utils import repair_config -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("chatglm") -class ChatGlm2TpPartModel(LlamaTpPartModel): - # Please use the fast tokenizer from: - # [THUDM/chatglm3-6b PR #12](https://huggingface.co/THUDM/chatglm3-6b/discussions/12). - - # weight class - pre_and_post_weight_class = ChatGLM2PreAndPostLayerWeight - transformer_weight_class = ChatGLM2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = ChatGLM2TransformerLayerInfer - - def __init__(self, kvargs): - super().__init__(kvargs) - - def _init_config(self): - super()._init_config() - # rename key - # repair_config() - repair_config(self.config, same_names=["num_hidden_layers", "n_layer", "num_layers"]) - repair_config(self.config, same_names=["vocab_size", "padded_vocab_size"]) - repair_config(self.config, same_names=["rms_norm_eps", "layernorm_epsilon"]) - repair_config(self.config, same_names=["seq_length", "max_sequence_length"]) - return - - def _reset_num_key_value_heads(self): - self.config["num_key_value_heads"] = self.config["multi_query_group_num"] - return - - def _verify_params(self): - assert self.load_way == "HF", "ChatGLM only support HF format for now" - assert self.tp_world_size_ in [1, 2], "ChatGLM can only run in tp=1 or tp=2 for now" - - def _init_to_get_rotary(self, base=10000): - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_seq_len = self.config.get("max_position_embeddings", 2048) * rope_scaling_factor - - base = float(base) * self.config.get("rope_ratio", 1.0) - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - except: - pass - n_elem = self.head_dim_ // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py b/lightllm/models/chatglm2/triton_kernel/rotary_emb.py deleted file mode 100755 index ad1d1c2cf..000000000 --- a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/cohere/infer_struct.py b/lightllm/models/cohere/infer_struct.py deleted file mode 100644 index d9571af92..000000000 --- a/lightllm/models/cohere/infer_struct.py +++ /dev/null @@ -1,8 +0,0 @@ -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - - -class CohereInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self._attn_out = None - self._ffn_out = None diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py deleted file mode 100644 index 67987a8d3..000000000 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import numpy as np -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.common.build_utils import repair_config -from lightllm.distributed.communication_op import all_gather - - -class CoherePostLayerInfer(LlamaPostLayerInfer): - def __init__(self, network_config, mode): - repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) - super().__init__(network_config, mode) - self.eps_ = network_config["layer_norm_eps"] - self.logits_scale = network_config["logit_scale"] - return - - def _norm( - self, input: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ) -> torch.Tensor: - return layernorm_forward( - input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ - ).squeeze(1) - - def token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) - last_input = None - vocab_size = layer_weight.lm_head_weight_.vocab_size - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - gather_data = gather_data * self.logits_scale - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, vocab_size), - dtype=torch.float32, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics - - def tpsp_token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - raise NotImplementedError("not impl") - - def overlap_tpsp_token_forward( - self, - input_embdings: torch.Tensor, - input_embdings1: torch.Tensor, - infer_state: CohereInferStateInfo, - infer_state1: CohereInferStateInfo, - layer_weight: CoherePreAndPostLayerWeight, - ): - raise NotImplementedError("not impl") diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py deleted file mode 100644 index 0cdd281a3..000000000 --- a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from functools import partial - -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm -from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd - - -class CohereTransformerLayerInfer(TransformerLayerCohereInferTpl): - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - self.eps_ = self.network_config_["layer_norm_eps"] - self.use_qk_norm_ = network_config.get("use_qk_norm", False) - self._bind_func() - - def _bind_func(self): - self._bind_rotary_emb_fwd() - self._bind_norm() - self._bind_attn() - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - return rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv, - position_cos, - position_sin, - ) - - def _bind_rotary_emb_fwd(self): - self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self) - - def _att_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward( - input.unsqueeze(1), layer_weight.att_norm_weight_.weight.unsqueeze(0), self.eps_ - ).squeeze(1) - - def _q_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.q_norm_weight_.weight, self.eps_) - - def _k_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.k_norm_weight_.weight, self.eps_) - - def _bind_norm(self): - self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) - self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) - self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) - - def _bind_attn(self): - # no need to re-impl - LlamaTransformerLayerInfer._bind_attention(self) - - def _get_o( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - # o_tensor = layer_weight.mm_op.apply(input, layer_weight.o_weight_) - o_tensor = layer_weight.o_proj.mm(input) - return o_tensor - - def _ffn( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - input = None - up_gate_out = None - # ffn2_out = layer_weight.mm_op.apply(ffn1_out, layer_weight.down_proj) - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index f2e5f8547..000000000 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class CoherePreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - tie_weight = self.network_config_.get("tie_word_embeddings", True) - - self.wte_weight_ = EmbeddingWeight( - weight_name="model.embed_tokens.weight", - data_type=self.data_type_, - ) - if tie_weight: - self.lm_head_weight_ = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="model.lm_head.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="model.norm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py deleted file mode 100644 index 9c446b49e..000000000 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight, TpHeadNormWeight - - -class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) - return - - def _parse_config(self): - super()._parse_config() - self.use_qk_norm = self.network_config_.get("use_qk_norm", False) - - def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) - - if self.use_qk_norm: - self.q_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ - ) - self.k_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ - ) - - return diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py deleted file mode 100644 index 5b317c133..000000000 --- a/lightllm/models/cohere/model.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import torch -from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.models.registry import ModelRegistry -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer -from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("cohere") -class CohereTpPartModel(LlamaTpPartModel): - pre_and_post_weight_class = CoherePreAndPostLayerWeight - transformer_weight_class = CohereTransformerLayerWeight - - pre_layer_infer_class = LlamaPreLayerInfer - transformer_layer_infer_class = CohereTransformerLayerInfer - post_layer_infer_class = CoherePostLayerInfer - - infer_state_class = CohereInferStateInfo - - def _init_to_get_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - - base = self.config.get("rope_theta", float(default_base)) - - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_position_embeddings = self.config.get( - "max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384 - ) - max_seq_len = max_position_embeddings * rope_scaling_factor - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula - except: - pass - - inv_freq = 1.0 / ( - base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) - ) - t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - freqs = torch.repeat_interleave(freqs, 2, dim=-1) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py deleted file mode 100644 index c1d5ff4cd..000000000 --- a/lightllm/models/cohere/triton_kernels/layernorm.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import triton -import triton.language as tl - -# LayerNorm adapted from triton tutorial, used for Cohere q, k norm -# X [N, head_num, head_dim] -# W [head_num, head_dim] -@triton.jit -def _layer_norm_fwd_kernel( - X, # pointer to the input - W, # pointer to the weights - Y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, -): - Seq = tl.program_id(0) - H = tl.program_id(1) - - X += Seq * stride_x_N + H * stride_x_hn - Y += Seq * stride_y_N + H * stride_y_hn - W += H * stride_w_hn - - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N - - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.0) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w - - tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask) - - -def layernorm_forward( - X, # pointer to the input - W, # pointer to the weights - eps, # epsilon to avoid division by zero -): - assert len(X.shape) == 3 - assert len(W.shape) == 2 - assert X.shape[-1] == W.shape[-1] - assert X.shape[-2] == W.shape[-2] - - y = torch.empty_like(X) - - stride_x_N = X.stride(0) - stride_x_hn = X.stride(1) - stride_x_hd = X.stride(2) - - stride_y_N = y.stride(0) - stride_y_hn = y.stride(1) - stride_y_hd = y.stride(2) - - stride_w_hn = W.stride(0) - stride_w_hd = W.stride(1) - - N = X.shape[-1] - BLOCK_SIZE = 128 - - grid = (X.shape[0], X.shape[1]) - _layer_norm_fwd_kernel[grid]( - X, - W, - y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, - eps, - BLOCK_SIZE, - ) - - return y - - -def torch_layernorm(x, weight, eps): - inp_dtype = x.dtype - x = x.to(torch.float32) - mean = x.mean(-1, keepdim=True) - variance = (x - mean).pow(2).mean(-1, keepdim=True) - x = (x - mean) * torch.rsqrt(variance + eps) - x = weight.to(torch.float32) * x - return x.to(inp_dtype) - - -def test_layernorm(eps=1e-5): - # create data - dtype = torch.float16 - x_shape = (5, 1, 128) - w_shape = (x_shape[-2], x_shape[-1]) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_ref = torch_layernorm(x, weight, eps).to(dtype) - y_out = layernorm_forward(x, weight, eps) - - # compare - print("type:", y_out.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_out - y_ref))) - assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0) - return diff --git a/lightllm/models/cohere/triton_kernels/rotary_emb.py b/lightllm/models/cohere/triton_kernels/rotary_emb.py deleted file mode 100644 index ac338e71e..000000000 --- a/lightllm/models/cohere/triton_kernels/rotary_emb.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos0 - q1 * sin0 - out1 = q0 * sin1 + q1 * cos1 - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos0 - k1 * sin0 - out_k1 = k0 * sin1 + k1 * cos1 - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -def torch_cohere_rotary_emb(x, cos, sin): - dtype = x.dtype - seq_len, h, dim = x.shape - x = x.float() - x1 = x[:, :, ::2] - x2 = x[:, :, 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - cos = cos.view((seq_len, 1, dim)) - sin = sin.view((seq_len, 1, dim)) - o = (x * cos) + (rot_x * sin) - return o.to(dtype=dtype) - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = int(q.shape[2] * partial_rotary_factor) - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (SEQ_LEN, H, D) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - y = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - cos_shape = (SEQ_LEN, D) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = torch_cohere_rotary_emb(x, cos, sin) - rotary_emb_fwd(x, y, cos, sin) - y_ref = x - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py deleted file mode 100644 index 72ba8a43b..000000000 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - args_mtp_step = get_env_start_args().mtp_step - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - self.has_prefix_kv = self.max_cache_len > 0 - if self.has_prefix_kv: - self.cu_seqlens_prefix_k = torch.nn.functional.pad( - torch.cumsum(self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0) - ) - self.prefix_k_max_len = self.max_cache_len - self.prefix_total_token_num = self.prefix_total_token_num - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - max_seq_len_k = self.max_kv_seq_len - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].view(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty((att_batch_size, self.max_len_in_batch), dtype=torch.int32).to( - self.input_ids.device - ) - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py deleted file mode 100644 index db6386f79..000000000 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(self.input_ids.device) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - self.flashinfer_extra_state.workspace_buffer, - use_cuda_graph=True, - qo_indptr=self.q_indptr, - kv_indices=self.kv_indices, - kv_indptr=self.kv_starts, - kv_len_arr=self.b_seq_len, - ) - self.decode_wrapper.plan( - self.q_indptr, - self.kv_starts, - self.kv_indices, - self.b_seq_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.kv_lora_rank, - self.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - self.flashinfer_extra_state.softmax_scale, - self.flashinfer_extra_state.q_data_type, - self.flashinfer_extra_state.kv_data_type, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_kv_start_loc.int() - if self.prefill_wrapper is None: - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, "NHD" - ) - self.prefill_wrapper.plan( - qo_indptr=q_starts, - kv_indptr=kv_starts, - num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, - num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, - head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim - + self.flashinfer_extra_state.qk_rope_head_dim, - head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, - q_data_type=self.flashinfer_extra_state.q_data_type, - causal=True, - sm_scale=self.flashinfer_extra_state.softmax_scale, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.q_indptr, - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.b_seq_len, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.kv_lora_rank, - new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - new_infer_state.flashinfer_extra_state.softmax_scale, - new_infer_state.flashinfer_extra_state.q_data_type, - new_infer_state.flashinfer_extra_state.kv_data_type, - ) - return diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 0c2ef3048..4dd79305c 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -1,21 +1,6 @@ -import os -import torch -import numpy as np -import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo class Deepseek2InferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() - self.kv_starts = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - if not self.is_prefill: - self.kv_starts = self.b1_cu_kv_seq_len - - if self.is_prefill: - self.b1_kv_start_loc = self.b1_cu_kv_seq_len - self.max_value_in_b_seq_len = self.max_kv_seq_len - return diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ff20bc6ee..8695f2de8 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,40 +1,25 @@ import os import torch -import torch.functional as F import torch.distributed as dist -import numpy as np import triton -from typing import Tuple from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv -from lightllm.models.deepseek2.triton_kernel.repeat_rope import repeat_rope -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger -from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 logger = init_logger(__name__) class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self.qk_nope_head_dim = network_config["qk_nope_head_dim"] @@ -66,7 +51,7 @@ def __init__(self, layer_num, network_config, mode=[]): mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] return @@ -89,58 +74,81 @@ def _bind_ffn(self): self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) self._tpsp_ffn = self._tpsp_ffn_tp - def _bind_attention(self): - if "triton_fp8kv" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self - ) - else: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) - elif get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self - ) - if self.enable_cc_method: - if "triton_fp8kv" in self.mode: - if get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self - ) - else: - if get_env_start_args().enable_fa3: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self - ) - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self - ) - else: - if "triton_fp8kv" in self.mode: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self - ) + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv( + infer_state=infer_state, + layer_weight=layer_weight, + ) + + o_tensor = infer_state.prefill_att_state.prefill_att( + q=q, + k=(k_nope, k_rope), + v=v, + att_control=AttControl(mla_prefill=True, mla_prefill_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return o_tensor + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + + out = infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=kv, + v=None, + att_control=AttControl(mla_decode=True, mla_decode_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return out + + def _decompress_kv( + self, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + ): + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + + total_token_num = infer_state.total_token_num + sampled_compressed_kv_nope = self.alloc_tensor( + [total_token_num, 1, layer_weight.kv_lora_rank], dtype=compressed_kv.dtype + ) + sampled_k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=compressed_kv.dtype) + sample_kv( + all_compressed_kv=compressed_kv, + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + sampled_k_rope=sampled_k_rope, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + b_seq_len=infer_state.b_seq_len, + b_kv_start_loc=infer_state.b1_cu_kv_seq_len[:-1], + max_kv_seq_len=infer_state.max_kv_seq_len, + ) + # CC + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view( + total_token_num, layer_weight.kv_lora_rank + ).contiguous() + sampled_kv_nope = self.alloc_tensor( + [total_token_num, self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], + dtype=sampled_compressed_kv_nope.dtype, + ) + layer_weight.cc_kv_b_proj_.mm(sampled_compressed_kv_nope, out=sampled_kv_nope.view(total_token_num, -1)) + sampled_k_nope, sampled_v = torch.split(sampled_kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + return sampled_k_nope, sampled_k_rope, sampled_v def _get_qkv( self, @@ -297,423 +305,6 @@ def _tpsp_get_o( return o_tensor - def _decompress_kv( - self, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - is_fp8, - total_token_num, - b_seq_len, - max_seq_len, - b_kv_start_loc, - skip_sample=False, - ): - if not skip_sample: - if is_fp8: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype) - else: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = None - k_scale = None - - compressed_kv = self.alloc_tensor([total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype) - k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype) - sample_kv( - kv, - compressed_kv, - k_rope, - infer_state.b_req_idx, - max_seq_len, - b_seq_len, - infer_state.req_manager.req_to_token_indexs, - b_kv_start_loc, - kv_scale, - k_scale, - ) - if k_scale is not None: - compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1) - k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1) - else: - compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r) - kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1 - ) - - # CC - compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous() - kv_nope = self.alloc_tensor( - [compressed_kv.shape[0], self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], - dtype=compressed_kv.dtype, - ) - layer_weight.cc_kv_b_proj_.mm(compressed_kv, out=kv_nope.reshape(compressed_kv.shape[0], -1)) - k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - return k_nope, k_rope, v - - # Adapted from: - # https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962 - def _context_attention_flashattention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashAttentionStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - skip_sample=True, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - o_tensor, lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_q, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.max_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - return_softmax_lse=True, - ) - if infer_state.has_prefix_kv: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.prefix_total_token_num, - infer_state.b_ready_cache_len, - infer_state.prefix_k_max_len, - infer_state.cu_seqlens_prefix_k, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - prefix_output, prefix_lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_prefix_k, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.prefix_k_max_len, - softmax_scale=self.softmax_scale, - causal=False, - return_softmax_lse=True, - ) - lse = torch.transpose(lse, 0, 1).contiguous() - prefix_lse = torch.transpose(prefix_lse, 0, 1).contiguous() - tmp_output = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) - if out is None - else out - ) - tmp_lse = torch.empty_like(lse) - merge_state_v2(prefix_output, prefix_lse, o_tensor, lse, tmp_output, tmp_lse) - o_tensor = tmp_output - return o_tensor - - def _context_attention_flashinfer_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - - def _context_attention_flashinfer_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - - def _context_attention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - context_attention_fwd_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - - def _token_gqa_decode_attention_flashattention( - self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashinfer( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) - - infer_state.decode_wrapper.run( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - out=o_tensor, - return_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashdecoding( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - out = gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, - ) - return out - - def _token_gqa_decode_attention_flashdecoding_fp8( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - return gqa_token_decode_attention_flash_decoding_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, - ) - - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank], - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :], - ) - return - - def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager): - destindex_copy_kv_fp8( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype), - ) - return - def _moe_ffn( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 611878f9e..c5a2d3352 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -17,9 +17,9 @@ class Deepseek2TransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e4ce7c826..f0739a8a8 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -4,51 +4,16 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id - +from lightllm.common.basemodel.attention import get_mla_decode_att_backend_class, get_mla_prefill_att_backend_class logger = init_logger(__name__) -class DeepSeek2FlashInferStateExtraInfo: - def __init__(self, model): - num_heads = model.config["num_attention_heads"] - self.tp_q_head_num = num_heads // get_dp_world_size() - self.qk_nope_head_dim = model.qk_nope_head_dim - self.qk_rope_head_dim = model.qk_rope_head_dim - self.kv_lora_rank = model.kv_lora_rank - self.q_data_type = model.data_type - self.kv_data_type = model.data_type - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - if model.config["rope_scaling"] is not None: - rope_scaling = model.config["rope_scaling"] - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) - scaling_factor = rope_scaling["factor"] - if mscale_all_dim: - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - @ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class @@ -61,18 +26,13 @@ class Deepseek2TpPartModel(LlamaTpPartModel): infer_state_class = Deepseek2InferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = Deepseek2FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = Deepseek2FlashInferStateInfo - self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self) + def _init_att_backend(self): + self.prefill_att_backend = get_mla_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_mla_decode_att_backend_class(index=0)(model=self) + return def _init_some_value(self): super()._init_some_value() diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py index fd437c388..b9be73e27 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py @@ -5,7 +5,6 @@ import triton.language as tl from typing import List from lightllm.utils.log_utils import init_logger -from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig from lightllm.utils.device_utils import get_device_sm_count logger = init_logger(__name__) @@ -28,16 +27,18 @@ def gqa_token_decode_attention_flash_decoding_fp8( **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size + from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig + run_config = MlaDecodeAttentionKernelConfig.try_to_get_best_config( batch_size=batch_size, avg_seq_len_in_batch=avg_seq_len_in_batch, @@ -191,7 +192,7 @@ def _fwd_kernel_calcu_index_and_block_seq( infer_state = Deepseek2InferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_kv_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX infer_state.req_manager = ReqManager(Z, N_CTX, None) infer_state.req_manager.req_to_token_indexs = req_to_token_indexs diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index af0aaa2f6..53a0a60eb 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -8,111 +8,101 @@ @triton.jit def _sample_kv_kernel( - KV_input, - KV_scale, - KV_nope, - KV_rope, - K_scale, - B_start_loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_input_dim, - stride_scale_dim, - stride_nope_dim, - stride_rope_dim, + all_compressed_kv, + stride_all_s, + stride_all_d, + sampled_compressed_kv_nope, + stride_nope_s, + stride_nope_d, + sampled_k_rope, + stride_rope_s, + stride_rope_d, + b_kv_start_loc, + b_seq_len, + req_to_token_indexs, stride_req_to_tokens_b, - HAS_SCALE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, + b_req_idx, + BLOCK_SEQ: tl.constexpr, + BLOCK_NOPE_DIM: tl.constexpr, + BLOCK_ROPE_DIM: tl.constexpr, ): cur_batch = tl.program_id(0) start_m = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_loc = tl.load(B_start_loc + cur_batch) + cur_batch_seq_len = tl.load(b_seq_len + cur_batch) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_batch_start_loc = tl.load(b_kv_start_loc + cur_batch) - offs_nope_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_nope_d = tl.arange(0, BLOCK_NOPE_DIM) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DIM) + offs_m = (start_m * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)) % cur_batch_seq_len - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + if start_m * BLOCK_SEQ > cur_batch_seq_len: + return kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, - mask=offs_m < block_end_loc, - other=0, + req_to_token_indexs + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, ).to(tl.int64) - off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] - off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] - kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) - kv_rope = tl.load(KV_input + off_kv_rope, mask=offs_m[:, None] < block_end_loc, other=0.0) - off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_dim + offs_nope_d[None, :] - off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_dim + offs_rope_d[None, :] - nope_ptrs = KV_nope + off_nope - rope_ptrs = KV_rope + off_rope - tl.store(nope_ptrs, kv_nope, mask=offs_m[:, None] < block_end_loc) - tl.store(rope_ptrs, kv_rope, mask=offs_m[:, None] < block_end_loc) - if HAS_SCALE: - kv_scale = tl.load(KV_scale + kv_loc * stride_scale_dim, mask=offs_m < block_end_loc) - off_k_scale = cur_batch_start_loc + offs_m - k_scale_ptrs = K_scale + off_k_scale - tl.store(k_scale_ptrs, kv_scale, mask=offs_m < block_end_loc) + off_kv_nope = kv_loc[:, None] * stride_all_s + offs_nope_d[None, :] + off_kv_rope = kv_loc[:, None] * stride_all_s + (offs_rope_d + BLOCK_NOPE_DIM)[None, :] + kv_nope = tl.load(all_compressed_kv + off_kv_nope) + kv_rope = tl.load(all_compressed_kv + off_kv_rope) + off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_s + offs_nope_d[None, :] + off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_s + offs_rope_d[None, :] + nope_ptrs = sampled_compressed_kv_nope + off_nope + rope_ptrs = sampled_k_rope + off_rope + tl.store(nope_ptrs, kv_nope) + tl.store(rope_ptrs, kv_rope) return @torch.no_grad() def sample_kv( - kv_input, - kv_nope, - kv_rope, - b_req_idx, - max_value_in_b_seq_len, - b_seq_len, - req_to_token_indexs, - b_kv_start_loc, - kv_scale=None, - k_scale=None, + all_compressed_kv: torch.Tensor, + sampled_compressed_kv_nope: torch.Tensor, + sampled_k_rope: torch.Tensor, + b_req_idx: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_kv_start_loc: torch.Tensor, + max_kv_seq_len: int, ): - BLOCK = 128 if not is_tesla() else 64 - - nope_dim = kv_nope.shape[-1] - rope_dim = kv_rope.shape[-1] - if nope_dim >= 512: - BLOCK = 64 if not is_tesla() else 32 - else: - BLOCK = 128 if not is_tesla() else 64 - + nope_dim = sampled_compressed_kv_nope.shape[-1] + rope_dim = sampled_k_rope.shape[-1] + assert rope_dim == 64 batch = b_seq_len.shape[0] - max_input_len = max_value_in_b_seq_len + BLOCK = 64 if not is_tesla() else 32 + num_warps = 8 grid = ( batch, - triton.cdiv(max_input_len, BLOCK), + triton.cdiv(max_kv_seq_len, BLOCK), ) - num_warps = 4 if nope_dim <= 64 else 8 + + all_compressed_kv = all_compressed_kv.view(all_compressed_kv.shape[0], all_compressed_kv.shape[2]) + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view(sampled_compressed_kv_nope.shape[0], nope_dim) + sampled_k_rope = sampled_k_rope.view(sampled_k_rope.shape[0], rope_dim) + assert triton.next_power_of_2(nope_dim) == nope_dim + assert triton.next_power_of_2(rope_dim) == rope_dim _sample_kv_kernel[grid]( - kv_input, - kv_scale, - kv_nope, - kv_rope, - k_scale, - b_kv_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - kv_input.stride(0), - kv_scale.stride(0) if kv_scale is not None else 0, - kv_nope.stride(0), - kv_rope.stride(0), - req_to_token_indexs.stride(0), - HAS_SCALE=kv_scale is not None, - BLOCK_M=BLOCK, - BLOCK_DMODEL=nope_dim, - BLOCK_ROPE_DMODEL=rope_dim, + all_compressed_kv=all_compressed_kv, + stride_all_s=all_compressed_kv.stride(0), + stride_all_d=all_compressed_kv.stride(1), + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + stride_nope_s=sampled_compressed_kv_nope.stride(0), + stride_nope_d=sampled_compressed_kv_nope.stride(1), + sampled_k_rope=sampled_k_rope, + stride_rope_s=sampled_k_rope.stride(0), + stride_rope_d=sampled_k_rope.stride(1), + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + b_req_idx=b_req_idx, + BLOCK_SEQ=BLOCK, + BLOCK_NOPE_DIM=nope_dim, + BLOCK_ROPE_DIM=rope_dim, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index 26bfc865e..adb749c40 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] self.hidden_size = network_config["hidden_size"] return diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 4a5bf2e96..1f0815c3d 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.eh_proj.weight", diff --git a/lightllm/models/gemma3/layer_infer/post_layer_infer.py b/lightllm/models/gemma3/layer_infer/post_layer_infer.py index 22dc59505..57b21844e 100644 --- a/lightllm/models/gemma3/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/post_layer_infer.py @@ -4,7 +4,7 @@ class Gemma3PostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-6 return diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index dc8a46ad9..3543786f6 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -5,8 +5,8 @@ class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32) self.boi_token_index: int = 255_999 self.eoi_token_index: int = 256_000 diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index d4bd8c3fa..1f386625b 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -1,12 +1,6 @@ import torch -import torch.functional as F import torch.distributed as dist import torch.nn as nn -import numpy as np -from typing import Tuple -from functools import partial -import triton - from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.distributed import all_reduce from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight @@ -18,8 +12,8 @@ class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] self.tp_v_head_num_ = network_config["num_key_value_heads"] self.eps_ = 1e-6 diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 17e65268c..858937d8c 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="language_model.model.embed_tokens.weight", diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 1e7ceeb42..e7808c412 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -9,10 +9,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index dc4f03b7e..9931c3171 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -1,7 +1,5 @@ import os -import re import json -import numpy as np import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer @@ -14,8 +12,6 @@ from lightllm.models.gemma3.layer_weights.pre_and_post_layer_weight import Gemma3PreAndPostLayerWeight from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.llava.layer_weights.pre_and_post_layer_weight import LlavaPreAndPostLayerWeight from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem from lightllm.server.core.objs import SamplingParams from lightllm.common.build_utils import repair_config diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index ce9737820..468d471d2 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -11,8 +11,8 @@ class Gemma_2bPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) self.normfactor = network_config["hidden_size"] ** 0.5 diff --git a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py index 35ddaef34..2ed325659 100644 --- a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py @@ -16,8 +16,8 @@ class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1 self.tp_v_head_num_ = network_config["num_key_value_heads"] return diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index d5d0438fa..6e052caa6 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 1916bd095..9102ce677 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_qkv(self): diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 93cd7413b..d80eefd16 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -1,22 +1,15 @@ import torch -from torch import nn -from torch.nn import functional as F -import numpy as np -from functools import partial -from typing import Optional - from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, LlamaInferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) class GptOssTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.hidden_size = self.network_config_["hidden_size"] self.alpha = 1.702 self.limit = 7.0 @@ -24,22 +17,17 @@ def __init__(self, layer_num, network_config, mode=[]): self.sliding_window = network_config["sliding_window"] self.head_dim_ = network_config["head_dim"] - def _bind_attention(self): - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._context_attention_kernel = self._context_sliding_attention_flashattention - self._token_attention_kernel = self._token_sliding_attention_flashattention - def _bind_norm(self): self._att_norm = self._att_norm self._ffn_norm = self._ffn_norm return - def _att_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _att_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_) return out - def _ffn_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _ffn_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_) return out @@ -51,9 +39,7 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): hidden_states = hidden_states * torch.rsqrt(variance + eps) return (weight * hidden_states).to(input_dtype) # main diff with Llama - def _ffn( - self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight - ) -> torch.Tensor: + def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -68,78 +54,61 @@ def _ffn( ) return hidden_states.view(num_tokens, hidden_dim) - def _context_sliding_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: LlamaInferStateInfo, + layer_weight: GptOssTransformerLayerWeight, + out=None, ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, - ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor - def _token_sliding_attention_flashattention( - self, q, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index f6a841b1a..c5c14b08e 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -17,10 +17,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_moe(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 34a017b31..dc5f2abdf 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -19,4 +19,9 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) - assert get_env_start_args().enable_fa3, "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_prefill_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_decode_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index a2fc91dc4..6ef81122d 100755 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -1,13 +1,9 @@ -import torch -import math -import numpy as np - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight class InternlmTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/internlm/model.py b/lightllm/models/internlm/model.py index 78ac7117e..50adbb3f9 100644 --- a/lightllm/models/internlm/model.py +++ b/lightllm/models/internlm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index b40330aa3..3ed7004c1 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py index a67555863..a05e977f1 100755 --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Internlm2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index b20b9c495..59caf40d6 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.tok_embeddings.weight", data_type=self.data_type_, diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index 7d76d202a..21a4c2e6b 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -24,8 +24,8 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -35,8 +35,8 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 6d264a426..ccb76d351 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -149,6 +149,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) # audio @@ -174,6 +178,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("audio token error") except ValueError: break + if multimodal_params: + audio_cnt = len(multimodal_params.audios) + if audio_cnt != audio_id: + raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py deleted file mode 100644 index 9f71cbbc5..000000000 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index -from lightllm.common.basemodel.batch_objs import ModelInput -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class FlashAttentionStateInfo(LlamaInferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def _init_flash_attention_state(self, model): - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=self.input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - max_seq_len_k = self.max_kv_seq_len - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].reshape(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty( - (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=self.input_ids.device - ) - - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - - if "offline_calibration_fp8kv" in model.mode: - if self.is_prefill: - device = self.input_ids.device - # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 - self.q_scale = torch.empty( - (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device - ) - self.token_batch_ids = torch.repeat_interleave( - torch.arange(self.batch_size, device=device), self.b_q_seq_len - ) - - offline_scales = self.mem_manager.scales - head_num = self.mem_manager.head_num - # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - self.v_descale = ( - offline_scales[:, head_num:] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - return - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self._init_flash_attention_state(model) - return diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py deleted file mode 100644 index 7f9beac1d..000000000 --- a/lightllm/models/llama/flashinfer_struct.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class LlamaFlashInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device - ) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - self.kv_starts = self.b1_cu_kv_seq_len.int() - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=self.kv_starts, - paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, - ) - self.decode_wrapper.plan( - self.kv_starts, - self.kv_indices, - self.kv_last_page_len_buffer, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device) - kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - qo_indptr_buf=q_starts, - paged_kv_indptr_buf=kv_starts, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len, - ) - self.prefill_wrapper.plan( - q_starts, - kv_starts, - kv_indices, - kv_last_page_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.tp_kv_head_num, - new_infer_state.flashinfer_extra_state.head_dim, - 1, - q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, - kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - return diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 3bba43976..fe6ca392a 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -14,7 +14,6 @@ def init_some_extra_state(self, model): super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len position_ids = self.position_ids self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 7c7b0ea39..8bc10d623 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -13,8 +13,8 @@ class LlamaPostLayerInfer(PostLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] return diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index ddb99e262..f4f150b17 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -10,8 +10,8 @@ class LlamaPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100755 new mode 100644 index b08b2aa1f..2a9a54319 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -1,51 +1,23 @@ import torch import triton -import torch.functional as F import torch.distributed as dist -import numpy as np -from typing import Tuple from functools import partial - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, - context_attention_fwd_ppl_int8kv, -) -from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k -from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd -from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd - from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel import TransformerLayerInferTpl -from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None logger = init_logger(__name__) -from lightllm.utils.sgl_utils import flash_attn_with_kvcache - class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["rms_norm_eps"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = max(network_config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -58,7 +30,6 @@ def __init__(self, layer_num, network_config, mode=[]): def _bind_func(self): self._bind_norm() - self._bind_attention() return def _bind_norm(self): @@ -66,125 +37,34 @@ def _bind_norm(self): self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) return - def _bind_attention(self): - if get_env_start_args().enable_fa3: - if "offline_calibration_fp8kv" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - elif "export_fp8kv_calibration" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) - elif not self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") - return - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) - else: - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) - if "ppl_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int8kv_flashdecoding_diverse" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int8kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int4kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self) - elif "ppl_fp16" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "ppl_fp16_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) - elif "offline_calibration_fp8kv" in self.mode: - assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self - ) - elif "triton_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_attention" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_flashdecoding_vsm" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "export_fp8kv_calibration" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) - elif not self.mode: - if get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode: {self.mode}") + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor - return + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + return o_tensor.view(q.shape) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight @@ -241,163 +121,6 @@ def _tpsp_get_qkv( return q, cache_kv - def _context_attention_flashinfer_kernel_fp8( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - - def _context_attention_flashinfer_kernel( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _context_attention_kernel_ppl_int8kv( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - batch_size = infer_state.b_seq_len.shape[0] - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = infer_state.mem_manager.scale_buffer[self.layer_num_] - max_seq_len = infer_state.max_seq_len - kv_dequant = self.alloc_tensor( - (batch_size, kv.shape[1], max_seq_len, kv.shape[2]), device=q.device, dtype=q.dtype - ) - destindex_copy_dequantize_kv( - kv, - kv_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_seq_len, - infer_state.b_req_idx, - max_seq_len, - kv_dequant, - ) - context_attention_fwd_ppl_int8kv( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.b_ready_cache_len, - ) - return o_tensor - - def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _context_attention_flashattention_fp8( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): - q, q_scale = q_per_head_fp8_quant( - q.view(q.shape[0], self.tp_k_head_num_, -1), - infer_state.b_seq_len, - infer_state.cu_seqlens_q, - infer_state.q_scale, - infer_state.token_batch_ids, - ) - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale, - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -486,453 +209,6 @@ def _tpsp_ffn( # gate_out, up_out = None, None # return ffn2_out - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - mem_manager.update_calibration_data(buffer, self.layer_num_) - return - - def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): - scales = mem_manager.scales - destindex_copy_kv_fp8( - buffer, - mem_index, - scales[self.layer_num_] if scales is not None else None, - mem_manager.kv_buffer[self.layer_num_].view(torch.float8_e4m3fn), - ) - return - - def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv - - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv - - destindex_copy_int4kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - - def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - - def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - token_att_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o_tensor - - def _token_decode_gqa_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - # 对 gqa模型进行推理优化的代码 - from ..triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - gqa_decode_attention_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - ) - return o_tensor - - def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), q.dtype) - token_att_fwd_int8k( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - prob = self.alloc_tensor(att_m_tensor.shape, att_m_tensor.dtype) - token_softmax_fwd( - att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch - ) - att_m_tensor = None - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - token_att_fwd2_int8v( - prob, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - prob = None - return o_tensor - - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_gqa_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - # 对 gqa 模型进行推理优化的代码 - from ..triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return gqa_token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, - # at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - light_ops.group8_int8kv_decode_attention( - o_tensor.view(calcu_shape1), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm_ppl_fp16_kernel import fp16_decode_attention - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, - # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - fp16_decode_attention( - o_tensor.view(calcu_shape1), - 1.0 / (self.head_dim_ ** 0.5), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_ppl_fp16_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_fp16_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv_flashdecoding_diverse( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( - token_decode_attention_flash_decoding, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int4kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_gqa_flashdecoding_vsm( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import ( - gqa_token_decode_attention_flash_decoding_vsm, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_) - return gqa_token_decode_attention_flash_decoding_vsm( - q.view(q_shape), - cache_k, - cache_v, - infer_state, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _token_decode_attention_flashattention_fp8( - self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * self.tp_k_head_num_, -1), use_per_token_if_dynamic=True) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(infer_state.batch_size, self.tp_k_head_num_), - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index ea59d24df..7e9ff4167 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 6b92272ee..197116d99 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -11,10 +11,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 95465a9e6..c104ebccc 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -9,39 +9,15 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id logger = init_logger(__name__) -class LlamaFlashInferStateExtraInfo: - def __init__(self, model): - tp_world_size = get_dp_world_size() - self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size - self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) - head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] - self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - self.q_data_type = model.data_type - self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type - - @ModelRegistry("llama") class LlamaTpPartModel(TpPartBaseModel): # weight class @@ -57,9 +33,6 @@ class LlamaTpPartModel(TpPartBaseModel): infer_state_class = LlamaInferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return @@ -94,13 +67,6 @@ def _init_mem_manager(self): ) return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = LlamaFlashInferStateInfo - self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self) - def _init_custom(self): """ 模型特殊的一些初始化 diff --git a/lightllm/models/llama/triton_kernel/flash_decoding.py b/lightllm/models/llama/triton_kernel/flash_decoding.py deleted file mode 100644 index e47e30886..000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from .flash_decoding_stage1 import flash_decode_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, - ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py deleted file mode 100644 index 86a3af103..000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import triton -import triton.language as tl - -@triton.jit -def _fwd_kernel_flash_decode_stage1( - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - gqa_group_size, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - cur_kv_head = cur_head // gqa_group_size - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - - q = tl.load(Q + off_q) - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) - k_loc = k_loc.to(tl.int64) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) - v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - - cur_max_logic = tl.max(att_value, axis=0) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale - acc += tl.sum(exp_logic[:, None] * v, axis=0) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp) - tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) - return - - -@torch.no_grad() -def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq): - BLOCK_SEQ = block_seq - BLOCK_N = 16 - assert BLOCK_SEQ % BLOCK_N == 0 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) - gqa_group_size = q.shape[1] // k.shape[1] - - _fwd_kernel_flash_decode_stage1[grid]( - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), - gqa_group_size, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK_N, - num_warps=1, - num_stages=2, - ) - return \ No newline at end of file diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py deleted file mode 100644 index 81227f967..000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage2( - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) - new_max_logic = tl.maximum(tlogic, max_logic) - - old_scale = tl.exp(max_logic - new_max_logic) - acc *= old_scale - exp_logic = tl.exp(tlogic - new_max_logic) - acc += exp_logic * tv - sum_exp = sum_exp * old_scale + exp_logic - max_logic = new_max_logic - - tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) - return - - -@torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): - Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} - batch, head_num = mid_out.shape[0], mid_out.shape[1] - grid = (batch, head_num) - - _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), - BLOCK_SEQ=block_seq, - BLOCK_DMODEL=Lk, - num_warps=4, - num_stages=2, - ) - return \ No newline at end of file diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py deleted file mode 100644 index 7ba0f3b31..000000000 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_int4_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data_0 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - src_data_1 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - - abs_data_0 = tl.abs(src_data_0) - abs_data_1 = tl.abs(src_data_1) - - data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty) - q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) - q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) - q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) - - q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) - q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) - q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) - - low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) - high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 - - # tl.device_print(low_4) - # tl.device_print(high_4) - - out_data = low_4 | high_4 - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_int4kv(K, DestLoc, Out, Out_scale): - # seq_len = DestLoc.shape[0] - # head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - # grid = (seq_len, head_num) - # num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view( - Out.shape[0], Out.shape[1], group_size, group_dim // 2 - ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 - - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_copy_int4_kv - - group8_copy_int4_kv(Out, Out_scale, K, DestLoc, 4) - - # _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( - # K, - # DestLoc, - # Out, - # Out_scale, - # K.stride(0), - # K.stride(1), - # K.stride(2), - # K.stride(3), - # Out.stride(0), - # Out.stride(1), - # Out.stride(2), - # Out.stride(3), - # Out_scale.stride(0), - # Out_scale.stride(1), - # Out_scale.stride(2), - # group_size, - # BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - # BLOCK_GROUP_DIM=group_dim, - # num_warps=num_warps, - # num_stages=1, - # ) - return - - -def test2(): - import time - - src = torch.randn((1, 1, 8), dtype=torch.float16).cuda() - src[0, 0, :] = torch.tensor([1, -2, 2, 0, 4, 5, 6, 7]).cuda() - dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda() - value_dest = torch.randn((1, 1, 4), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((1, 1, 1), dtype=torch.float16).cuda() - - destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest) - - print(value_dest) - print(scale_dest) - - -if __name__ == "__main__": - test2() diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py deleted file mode 100644 index 1e324bcc0..000000000 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, - infer_state, - q_head_num, - head_dim, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, - alloc_tensor_func=torch.empty, -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) - - group8_int4kv_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py deleted file mode 100644 index 3d9a490f4..000000000 --- a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], - mask=offs_g[:, None] < group_size, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) - q_src_data = (src_data / data_scale[:, None]).to(tl.int8) - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (seq_len, head_num) - num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - K.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - group_size, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_dequantize_kv( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - stride_kv_b, - stride_kv_h, - stride_kv_g, - stride_kv_d, - stride_o_bh, - stride_o_l, - stride_o_g, - stride_o_d, - stride_s_b, - stride_s_h, - stride_s_g, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - group_size, - head_num: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_group = tl.program_id(0) - start_m = tl.program_id(1) - cur_bh = tl.program_id(2) - cur_batch = cur_bh // head_num - cur_head = cur_bh % head_num - - block_start_loc = BLOCK_SIZE * start_m - cur_batch_req_idx = tl.load(b_req_idx + cur_batch) - cur_seq_len = tl.load(b_seq_len + cur_batch) - - # initialize offsets - offs_kv_loc = block_start_loc + tl.arange(0, BLOCK_SIZE) - - # offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - kv_loc = tl.load( - req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len - ).to(tl.int64) - offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :] - - src_data = tl.load( - mem_kv_buffer + offs_kv, - mask=offs_kv_loc[:, None] < cur_seq_len, - other=0.0, - ).to(Out.dtype.element_ty) - - s_ptrs = mem_kv_scale + kv_loc * stride_s_b + cur_head * stride_s_h + cur_group * stride_s_g - data_scale = tl.load( - s_ptrs, - mask=offs_kv_loc < cur_seq_len, - ) - - out_data = src_data * data_scale[:, None] - o_ptrs = Out + cur_bh * stride_o_bh + offs_kv_loc[:, None] * stride_o_l + cur_group * stride_o_g + offs_d[None, :] - tl.store(o_ptrs, out_data, mask=offs_kv_loc[:, None] < cur_seq_len) - return - - -@torch.no_grad() -def destindex_copy_dequantize_kv( - mem_kv_buffer, mem_kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_len_in_batch, Out -): - batch_size = b_seq_len.shape[0] - head_num = mem_kv_buffer.shape[1] - head_dim = mem_kv_buffer.shape[2] - quant_group_dim = 8 - BLOCK_SIZE = 128 - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (group_size, triton.cdiv(max_len_in_batch, BLOCK_SIZE), batch_size * head_num) - num_warps = 1 - mem_kv_buffer = mem_kv_buffer.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], group_size, group_dim)) - mem_kv_scale = mem_kv_scale.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], -1)) - Out = Out.view(Out.shape[0] * Out.shape[1], -1, group_size, group_dim) - - _fwd_kernel_destindex_copy_dequantize_kv[grid]( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - mem_kv_buffer.stride(0), - mem_kv_buffer.stride(1), - mem_kv_buffer.stride(2), - mem_kv_buffer.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - mem_kv_scale.stride(0), - mem_kv_scale.stride(1), - mem_kv_scale.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - group_size, - head_num=head_num, - BLOCK_SIZE=BLOCK_SIZE, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test2(): - import time - - B, N_CTX, H, D = 1, 3, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) - scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) - print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -def torch_dequant(kv, kv_scale, o, b_req_idx, b_seq_len, req_to_token_indexs): - - batch = b_req_idx.shape[0] - for i in range(batch): - req_idx = b_req_idx[i] - seq_len = b_seq_len[i] - print(seq_len, b_seq_len) - kv_loc = req_to_token_indexs[req_idx, :seq_len] - head_num = kv.shape[1] - cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, 8).to(o.dtype) - cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) - out = cur_kv * cur_scale - o[i, :seq_len, :, :] = out.reshape(out.shape[0], out.shape[1], -1) - - -def test3(): - import time - import numpy as np - - Z, H, N_CTX, D_HEAD = 1, 16, 3, 128 - dtype = torch.bfloat16 - kv = torch.empty((Z * N_CTX + 100, 2 * H, D_HEAD), dtype=torch.int8, device="cuda") - kv_scale = torch.randn((Z * N_CTX + 100, 2 * H, D_HEAD // 8), dtype=dtype, device="cuda") - out = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - torch_out = torch.empty((Z, N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.empty((1000, N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - for i in range(Z): - seq_len = N_CTX - i * 100 - b_seq_len[i] = seq_len - b_req_idx[i] = i - req_to_token_indexs[i][:seq_len] = ( - torch.tensor(np.arange(seq_len), dtype=torch.int32).cuda() + b_seq_len[0:i].sum() - ) - print(b_seq_len) - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch_dequant(kv, kv_scale, torch_out, b_req_idx, b_seq_len, req_to_token_indexs) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch.cuda.synchronize() - t2 = time.time() - print((t2 - t1)) - torch_out = torch_out.transpose(1, 2) - for i in range(Z): - print("max ", torch.max(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - print("mean ", torch.mean(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - assert torch.allclose(torch_out[i][:, : b_seq_len[i]], out[i][:, : b_seq_len[i]], atol=1e-2, rtol=0) - # print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # cos = torch.nn.CosineSimilarity(0) - # print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test3() diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fec..45de83e98 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1( ).to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) + att_value = tl.sum(q[None, :] * k, 1) + att_value = att_value.to(tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py deleted file mode 100644 index 243a8d1f6..000000000 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_token_att2( - Prob, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): - BLOCK = 128 - # BLOCK = 64 # for triton 2.0.0dev - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2[grid]( - prob, - v, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_token_att2_int8v( - Prob, - V, - V_scale, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_vsbs, - stride_vsh, - stride_vsd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - vs_offs = cur_kv_head * stride_vsh - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - vs_value = tl.load( - V_scale + vs_offs + v_loc[:, None] * stride_vsbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): - if max_len_in_batch < 512: - BLOCK = triton.next_power_of_2(max_len_in_batch) - else: - BLOCK = 512 - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2_int8v[grid]( - prob, - v, - v_scale, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - v_scale.stride(0), - v_scale.stride(1), - v_scale.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - out = torch.matmul(P, V) - - return out diff --git a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py index b4c070a1e..3afcfb0a7 100644 --- a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -14,8 +12,8 @@ def rename_weight_keys(weights): class LlavaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 0952468d0..45023bdf8 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) hidden_size = self.network_config_["hidden_size"] dim_model_base = self.network_config_.get("dim_model_base", hidden_size) self.lm_head_scale = hidden_size / dim_model_base diff --git a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py index 2bc507838..c37b524fd 100755 --- a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index 59eef6daa..d115c30ec 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -4,7 +4,7 @@ class MistralTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config.get("head_dim", self.head_dim_) return diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index d32f51ae7..f09525c59 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -1,5 +1,3 @@ -import os -import json import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel import TpPartBaseModel @@ -8,7 +6,6 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @@ -43,10 +40,6 @@ def _init_custom(self): self._init_to_get_rotary() return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo - def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] diff --git a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py b/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index abcaf02b5..000000000 --- a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - sliding_window, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # [SYM] mask outside of windows - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk = tl.where((start_n + offs_n[None, :]) > (offs_m[:, None] - sliding_window), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - -@torch.no_grad() -def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, sliding_window): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - sliding_window=sliding_window, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - - Z, H, N_CTX, D_HEAD = 4, 6, 1024, 128 - dtype = torch.float16 - Z = 3 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - - max_input_len = N_CTX - Z = 4 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = 512 - b_seq_len[1] = 1024 - b_seq_len[2] = 512 - b_seq_len[3] = 1024 - - for i in range(1, Z): - b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att(q[start:end], k[start:end], v[start:end], 1, b_seq_len[i], H, D_HEAD) - start = end - torch_out.append(torch_o) - torch_out = torch.cat(torch_out, dim=0) - context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, 10) - print(o.shape, torch_out.shape) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py deleted file mode 100644 index a60fe970b..000000000 --- a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_init_att_window_info( - b_seq_len, - b_att_seq_len, - batch_size, - sliding_window, - BLOCK_SIZE: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_start = cur_index * BLOCK_SIZE - offsets = cur_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < batch_size - - cur_seq_len = tl.load(b_seq_len + offsets, mask=mask) - b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window) - - tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask) - return - - -@torch.no_grad() -def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window): - # shape constraints - assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0] - - BLOCK_SIZE = 32 - num_warps = 1 - grid = (triton.cdiv(batch_size, BLOCK_SIZE),) - - _fwd_kernel_init_att_window_info[grid]( - b_seq_len, - b_att_seq_len, - batch_size=batch_size, - sliding_window=sliding_window, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py deleted file mode 100644 index bf9928f98..000000000 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel( - Logics, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_token_b, - stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 - kv_group_num, - sliding_window, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index - cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, - mask=(start_n + offs_n) < cur_att_seq_len, - other=other_kv_index, - ) # [64] - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=(start_n + offs_n) < cur_att_seq_len, - other=float("-inf"), - ) # [64] - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_softmax_reducev_fwd( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - sliding_window, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - _fwd_kernel[grid]( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_tokens.stride(0), - req_to_tokens.stride(1), - 0, - kv_group_num, - sliding_window, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py index 5eac249ba..f890fbf66 100644 --- a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -4,6 +4,6 @@ class MistralMTPPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 25bea1aa6..dbe9b61c8 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -7,8 +7,8 @@ class MistralMTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def _mtp_context_forward( diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py index 5724f32af..6d72ae2c3 100644 --- a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -10,8 +10,8 @@ class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index 2fbc89cfd..c9032f6fe 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="mtp.eh_proj.weight", diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 6607dbb70..08f280b06 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MistralMTPTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index a60375688..44e66cff2 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -3,14 +3,13 @@ import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight class MixtralTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.num_local_experts = network_config["num_local_experts"] self.num_experts_per_tok = network_config["num_experts_per_tok"] self.renormalize = True diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08b..39e28d465 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -8,12 +8,11 @@ class MixtralTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__( layer_num, data_type, network_config, - mode, quant_cfg=quant_cfg, ) return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 806c59365..fd3d05e42 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,11 +1,5 @@ -import torch -from functools import partial from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -13,14 +7,8 @@ class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) - return - - def _bind_attention(self): - self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self) - self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): @@ -35,44 +23,3 @@ def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Ph infer_state.position_sin, ) return q, cache_kv - - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) diff --git a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py index 91b273091..db4906c19 100755 --- a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Phi3TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index ee04c3367..000000000 --- a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,433 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - head_dim: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, - other=0, - ).to(tl.int64) - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load( - K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0 - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) - acc = acc * acc_scale[:, None] - # update acc - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load( - V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0 - ) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs -): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_no_prompt_cache( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - head_dim, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (offs_d[:, None] < head_dim), - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (offs_d[None, :] < head_dim), - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_no_prompt_cache[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim, prompt_cache_len): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - xv = xv.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - mask_cache = torch.ones((seqlen, prompt_cache_len)).cuda().unsqueeze(0).unsqueeze(0).cuda() - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = torch.cat([mask_cache, mask], dim=-1) - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 10, 6, 500, 96 - dtype = torch.float16 - Z = 1 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.zeros((10, Z * N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - Z = 1 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prompt_cache_len[0] = 0 - prompt_cache_len = 0 - - b_seq_len[0] = 500 - b_req_idx[0] = 0 - req_to_token_indexs[0][: prompt_cache_len + N_CTX] = torch.tensor( - np.arange(prompt_cache_len + N_CTX), dtype=torch.int32 - ).cuda() - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att( - q[start:end], - k[start : end + prompt_cache_len], - v[start : end + prompt_cache_len], - 1, - b_seq_len[i], - H, - D_HEAD, - prompt_cache_len, - ) - start = end - torch_out.append(torch_o) - - torch_out = torch.cat(torch_out, dim=0) - - context_attention_fwd( - q, - k, - v, - o, - b_req_idx, - b_start_loc, - b_seq_len + prompt_cache_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, - ) - - # context_attention_fwd_no_prompt_cache( - # q, k, v, o, b_start_loc, b_seq_len, max_input_len - # ) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py b/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py deleted file mode 100644 index 4f31895ae..000000000 --- a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_kv( - K, - Dest_loc, - Out, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - - k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - - k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0) - tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def destindex_copy_kv(K, DestLoc, Out): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_kv[grid]( - K, - DestLoc, - Out, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num)) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 7a4b2ca81..333870eb9 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -1,7 +1,4 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -12,8 +9,8 @@ class QwenTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 00f68eee6..bf9282a97 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 9afb964ad..ac1bf91f4 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class QwenTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): qkv_weight_name = f"transformer.h.{self.layer_num_}.attn.c_attn.weight" diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index a8a57c02e..6449430d9 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -2,6 +2,6 @@ class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 6962818c4..9c3e2cb3a 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Qwen2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def _init_weight_names(self): super()._init_weight_names() diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index d2f067c42..106610ff0 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -18,7 +18,7 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - if self.config["sliding_window"] is None: + if self.config.get("sliding_window", None) is None: self.config["sliding_window"] = self.max_total_token_num # rename key [SYM: to be confirmed] return diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index 7cf636622..3c974691d 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) del self.lm_head_weight_ self.score_up_weight_ = ROWMMWeight( weight_names="score.0.weight", diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 838590325..747be932d 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -4,13 +4,10 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.get_mrope_position_ids import get_mrope_position_triton -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.utils.envs_utils import get_env_start_args class Qwen2VLInferStateInfo(LlamaInferStateInfo): - init_flash_attention_state_func = FlashAttentionStateInfo._init_flash_attention_state - def __init__(self): super().__init__() self.position_cos = None @@ -35,10 +32,6 @@ def init_some_extra_state(self, model): self.position_ids = self.position_ids.contiguous() self.position_cos = model._cos_cached[self.position_ids] self.position_sin = model._sin_cached[self.position_ids] - if get_env_start_args().enable_fa3: - self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len - self.init_flash_attention_state_func(model) return def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: @@ -85,6 +78,6 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: position_ids=position_ids, b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, + b_start_loc=self.b_q_start_loc, ) return position_ids diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index 19e17c36e..298a77044 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) mrope_section = network_config["rope_scaling"]["mrope_section"] self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 61dd06773..237c4ad89 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -79,6 +79,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids) return input_ids @@ -95,9 +99,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 1c4f60794..f2cd38ec8 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -184,6 +184,8 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: return self._preprocess_bydevice(image, device="cpu") def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: + if image.mode != "RGB": + image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 20f135e76..5f0c91287 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -13,8 +13,8 @@ class Qwen3TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] return diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 86b9e172a..90b7810ad 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -5,8 +5,8 @@ class Qwen3TransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3/model.py b/lightllm/models/qwen3/model.py index 21e71e0e0..e48b36e0f 100644 --- a/lightllm/models/qwen3/model.py +++ b/lightllm/models/qwen3/model.py @@ -1,5 +1,3 @@ -import torch -from typing import final from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 10a734e5c..c85c423c2 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -19,7 +19,7 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 @@ -28,7 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]): ) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.norm_topk_prob = network_config["norm_topk_prob"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6..486f4d696 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -4,14 +4,14 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 and layer_num not in network_config["mlp_only_layers"] and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 ) - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py index d21917340..4e2b65d74 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -14,8 +14,8 @@ class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 6cc447a59..8ba95c138 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -9,8 +9,8 @@ class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.proj.weight", diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 22d4d1950..095afecd9 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index 96e453ebe..c24166e13 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward( diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 175340a77..d1c51365a 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -1,20 +1,12 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from functools import partial from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused -from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward @@ -22,8 +14,9 @@ class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py index 5d41d8551..8a380853d 100644 --- a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py @@ -12,8 +12,8 @@ def rename_weight_keys(weights): class Qwen3VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen3_vl/model.py b/lightllm/models/qwen3_vl/model.py index 0d8a81f67..74aa33e3c 100644 --- a/lightllm/models/qwen3_vl/model.py +++ b/lightllm/models/qwen3_vl/model.py @@ -37,9 +37,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index b155f8b90..328cc0a62 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -7,13 +7,12 @@ from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index b1f5ee660..52a982f49 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py index f4eef6e69..48ddf5208 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): moe_prefix = f"model.layers.{self.layer_num_}.mlp.experts" diff --git a/lightllm/models/qwen3_vl_moe/model.py b/lightllm/models/qwen3_vl_moe/model.py index b11f22fdb..cc1201de2 100644 --- a/lightllm/models/qwen3_vl_moe/model.py +++ b/lightllm/models/qwen3_vl_moe/model.py @@ -25,9 +25,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py new file mode 100644 index 000000000..a9d22c664 --- /dev/null +++ b/lightllm/models/qwen3next/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + +__all__ = ["Qwen3NextTpPartModel"] diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py new file mode 100644 index 000000000..2883534a9 --- /dev/null +++ b/lightllm/models/qwen3next/infer_struct.py @@ -0,0 +1,62 @@ +import torch +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3NextInferStateInfo(LlamaInferStateInfo): + """ + Inference state for Qwen3Next with: + - gate_value attribute for output gating in full attention layers + - MTP-aware batching for multi-token prediction + - Custom buffer management for hybrid attention (full + linear) + """ + + def __init__(self): + super().__init__() + # For output gating in full attention layers + self.gate_value = None + # MTP-aware attributes + self.b_att_seq_len = None + self.att_batch_size = None + self.real_req_idx = None + self.mtp_buffer_idx_list = None + self.b_buffer_idx = None + + def init_some_extra_state(self, model): + """Initialize Qwen3Next-specific state""" + super().init_some_extra_state(model) + + args_mtp_step = get_env_start_args().mtp_step + mtp_size = args_mtp_step + 1 + + if self.is_prefill: + # Prefill: Standard initialization + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + else: + # Decode: MTP-aware handling + # In MTP mode, each request has (mtp_step + 1) tokens + # att_batch_size is the number of unique requests + self.att_batch_size = self.batch_size // mtp_size + + # Use only the sequence lengths for the last token of each MTP group + if args_mtp_step > 0: + self.b_att_seq_len = self.b_seq_len[args_mtp_step::mtp_size].contiguous() + self.real_req_idx = self.b_req_idx[args_mtp_step::mtp_size] + else: + self.b_att_seq_len = self.b_seq_len + self.real_req_idx = self.b_req_idx + + # Buffer indices for Mamba cache (conv and SSM states) + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.real_req_idx, :].flatten().contiguous() + + # Create per-step buffer indices for MTP + if args_mtp_step > 0: + buffer_idx_list = [] + for step_id in range(mtp_size): + buffer_idx_list.append(self.b_buffer_idx[step_id::mtp_size].tolist()) + self.mtp_buffer_idx_list = torch.tensor( + buffer_idx_list, dtype=torch.int32, device=self.b_buffer_idx.device + ) + + return diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..9dcab4e6f --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py @@ -0,0 +1,12 @@ +import torch + +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..696b4705b --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,755 @@ +import os +import torch +import torch.nn.functional as F +from torch.distributed import ReduceOp +from typing import Tuple +from typing_extensions import override +from functools import partial +from einops import rearrange + +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.utils.log_utils import init_logger +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.gdn_decode_mtp import ( + copy_conv_states, + copy_ssm_states, + copy_states_fused, +) +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward +from lightllm.utils.envs_utils import get_env_start_args + +logger = init_logger(__name__) + +# Module-level constant for MoE mode +MOE_MODE = os.environ.get("MOE_MODE", "TP") + + +def is_moe_layer(layer_num: int, network_config: dict) -> bool: + """Determine if a layer should use MoE based on network configuration.""" + return ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + + +class Qwen3NextFFNMixin: + """ + Mixin providing shared FFN implementations for Qwen3Next layers. + + Both full attention and GDN layers use identical FFN logic (standard FFN, + shared expert + MoE with TP or EP modes). This mixin eliminates duplication. + + Requires the using class to have: + - embed_dim_: int + - num_experts_per_tok: int + - norm_topk_prob: bool + - alloc_tensor(): method + """ + + def _standard_ffn(self, input, infer_state, layer_weight) -> torch.Tensor: + """Standard FFN using shared expert weights (for non-MoE layers).""" + input = input.view(-1, self.embed_dim_) + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + return ffn2_out + + def _ffn_with_shared_expert_tp(self, input, infer_state, layer_weight) -> torch.Tensor: + """FFN with shared expert + MoE (tensor parallelism mode).""" + input = input.view(-1, self.embed_dim_) + + # Shared expert + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out + + # MoE + moe_out = self._moe_ffn(input, infer_state, layer_weight) + + return shared_expert_out + moe_out + + def _ffn_with_shared_expert_ep(self, input, infer_state, layer_weight) -> torch.Tensor: + """FFN with shared expert + MoE (expert parallelism mode).""" + input = input.view(-1, self.embed_dim_) + + # Shared expert + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out + + # MoE (EP mode) + moe_out = self._moe_ffn_edp(input, infer_state, layer_weight) + + return shared_expert_out + moe_out + + def _moe_ffn(self, input, infer_state, layer_weight) -> torch.Tensor: + """MoE FFN with tensor parallelism.""" + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + return hidden_states.view(num_tokens, hidden_dim) + + def _moe_ffn_edp(self, input, infer_state, layer_weight) -> torch.Tensor: + """MoE FFN with expert parallelism.""" + hidden_states = input + token_num, hidden_dim = hidden_states.shape + + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + + return ep_output.view(token_num, hidden_dim) + + +class Qwen3NextFullAttentionTransformerLayerInfer(Qwen3NextFFNMixin, LlamaTransformerLayerInfer): + """ + Full attention layer for Qwen3Next. + Inherits from LlamaTransformerLayerInfer to get standard attention via abstraction. + """ + + def __init__(self, layer_num, network_config): + # Store Qwen3Next specific configs before calling super().__init__ + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = is_moe_layer(layer_num, network_config) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + + super().__init__(layer_num, network_config) + # Override head_dim which may be different in Qwen3Next + self.head_dim_ = network_config.get( + "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] + ) + + def _bind_func(self): + super()._bind_func() + self._bind_ffn() + + def _bind_norm(self): + """Use Gemma-style RMSNorm.""" + self._att_norm = partial(Qwen3NextFullAttentionTransformerLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextFullAttentionTransformerLayerInfer._ffn_norm_impl, self) + + def _att_norm_impl( + self, + input, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) + return out + + def _ffn_norm_impl( + self, + input, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) + return out + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """QKV projection with output gating, Q/K normalization, and partial rotary embedding.""" + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + # Save gate value for output projection + infer_state.gate_value = torch.sigmoid(layer_weight.o_gate_proj.mm(input)) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + # Q normalization + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + # K normalization + cache_kv[:, : self.tp_k_head_num_, :] = gemma_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + # Rotary embedding with partial rotation support + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + def _get_o( + self, + input, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextFullAttentionTransformerLayerWeight, + ) -> torch.Tensor: + """Output projection with gating.""" + input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) + gated_input = input * infer_state.gate_value + infer_state.gate_value = None + o_tensor = layer_weight.o_proj.mm(gated_input) + return o_tensor + + def _bind_ffn(self): + """Bind FFN implementation (MoE or shared expert + MoE).""" + if self.is_moe: + if MOE_MODE == "EP": + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextFFNMixin._standard_ffn, self) + + +class Qwen3NextGatedDeltaNetTransformerLayerInfer(Qwen3NextFFNMixin, TransformerLayerInferTpl): + """ + Linear attention (Gated Delta Networks) layer for Qwen3Next. + Inherits from TransformerLayerInferTpl and overrides attention methods with custom GDN logic. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.network_config_ = network_config + + # MoE configuration + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = is_moe_layer(layer_num, network_config) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + + # Standard layer dimensions + self.eps_ = network_config["rms_norm_eps"] + self.embed_dim_ = network_config["hidden_size"] + + # Linear attention specific dimensions + self.num_v_heads = network_config["linear_num_value_heads"] + self.num_k_heads = network_config["linear_num_key_heads"] + self.head_k_dim = network_config["linear_key_head_dim"] + self.head_v_dim = network_config["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] + self.activation = network_config["hidden_act"] + + # Tensor parallelism dimensions + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + + # Template required dimensions (not used for GDN but required by interface) + self.tp_q_head_num_ = self.tp_num_k_heads + self.tp_k_head_num_ = self.tp_num_k_heads + self.tp_v_head_num_ = self.tp_num_v_heads + self.tp_o_head_num_ = self.tp_num_v_heads + self.head_dim_ = self.head_v_dim + + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + # MTP configuration + self.mtp_step = get_env_start_args().mtp_step + self.mtp_size = self.mtp_step + 1 + + self._bind_func() + + def _bind_func(self): + """Bind layer-specific implementations.""" + self._bind_norm() + self._bind_ffn() + + def _bind_norm(self): + """Use Gemma-style RMSNorm.""" + self._att_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._att_norm_impl, self) + self._ffn_norm = partial(Qwen3NextGatedDeltaNetTransformerLayerInfer._ffn_norm_impl, self) + + def _bind_ffn(self): + """Bind FFN implementation (MoE or standard).""" + if self.is_moe: + if MOE_MODE == "EP": + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_tp, self) + else: + self._ffn = partial(Qwen3NextFFNMixin._standard_ffn, self) + + def _att_norm_impl( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Linear attention normalization.""" + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) + return out + + def _ffn_norm_impl( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """FFN normalization.""" + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) + return out + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Not used by GDN - implemented in gdn_forward.""" + raise NotImplementedError("GDN uses gdn_forward instead of _get_qkv") + + def _tpsp_get_qkv( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Not implemented for GDN.""" + raise NotImplementedError("TPSP mode not implemented for GDN layers") + + def _get_o( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN - output computed in gdn_forward.""" + raise NotImplementedError("GDN uses gdn_forward instead of _get_o") + + def _tpsp_get_o( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not implemented for GDN.""" + raise NotImplementedError("TPSP mode not implemented for GDN layers") + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN.""" + raise NotImplementedError("GDN uses gdn_forward instead of _context_attention_kernel") + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ) -> torch.Tensor: + """Not used by GDN.""" + raise NotImplementedError("GDN uses gdn_forward instead of _token_attention_kernel") + + def context_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override context_forward to use GDN logic instead of standard attention flow.""" + # Attention + GDN processing + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=True) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) + gdn_out = None + + # FFN + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Override token_forward to use GDN logic instead of standard attention flow.""" + # Attention + GDN processing + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + gdn_out = self.gdn_forward(input1, infer_state, layer_weight, is_prefill=False) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(gdn_out.view(-1, self.embed_dim_)) + gdn_out = None + + # FFN + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + # ==================== GDN Helper Methods ==================== + + def _fix_query_key_value_ba_ordering(self, mixed_qkvzba): + """ + Derives `query`, `key`, `value`, `z`, `b`, `a` tensors from `mixed_qkvzba`. + Returns qkv already concatenated to avoid allocation in gdn_forward. + """ + mixed_qkvz, mixed_ba = torch.split(mixed_qkvzba, [self.tp_qkvz_dim, self.tp_ba_dim], dim=-1) + + mixed_qkvz = mixed_qkvz.view( + -1, + self.tp_num_k_heads, + self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads_per_k_head, + ) + mixed_ba = mixed_ba.view(-1, self.tp_num_k_heads, 2 * self.num_v_heads_per_k_head) + + qkvz_split_list = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads_per_k_head * self.head_v_dim), + (self.num_v_heads_per_k_head * self.head_v_dim), + ] + (query, key, value, z) = torch.split(mixed_qkvz, qkvz_split_list, dim=2) + (b, a) = torch.split(mixed_ba, [self.num_v_heads_per_k_head, self.num_v_heads_per_k_head], dim=2) + + # Reshape qkv components + query = query.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + key = key.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + value = value.reshape(-1, self.tp_num_v_heads * self.head_v_dim) + + # Concatenate qkv here instead of in gdn_forward (avoids extra allocation) + mixed_qkv = torch.cat([query, key, value], dim=-1) + + z = z.reshape(-1, self.tp_num_v_heads, self.head_v_dim) + b = b.reshape(-1, self.tp_num_v_heads) + a = a.reshape(-1, self.tp_num_v_heads) + + return mixed_qkv, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + if decode: + batch_size = mixed_qkv.shape[0] + query = query.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + key = key.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + value = value.view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) + else: + query, key = map(lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key)) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query, key, value + + @override + def context_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + return gdn_out + + @override + def token_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + return gdn_out + + def _gdn_prefill_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Prefill kernel for GDN forward pass.""" + # Conv1D processing + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = causal_conv1d_fn( + mixed_qkv, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.mm_param.bias, + query_start_loc=infer_state.b1_cu_q_seq_len, + cache_indices=infer_state.b_buffer_idx, + has_initial_state=infer_state.b_ready_cache_len > 0, + conv_states=conv_states, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + + # Recurrent processing + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + initial_state = ssm_states[infer_state.b_buffer_idx] + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Only convert dtype if necessary to avoid overhead + if last_recurrent_state.dtype != ssm_states.dtype: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(ssm_states.dtype, copy=False) + else: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state + return core_attn_out + + def _gdn_decode_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """Decode kernel for GDN forward pass (single-token, non-MTP mode).""" + # Conv1D processing + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.mm_param.bias, + activation=self.activation, + conv_state_indices=infer_state.b_buffer_idx, + ) + + # Recurrent processing + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + # g and beta have shape (1, batch, num_heads), need to squeeze and unsqueeze to get (batch, 1, num_heads) + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.squeeze(0).unsqueeze(1), + beta=beta.squeeze(0).unsqueeze(1), + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=infer_state.b_buffer_idx, + use_qk_l2norm_in_kernel=True, + ) + return core_attn_out + + def _gdn_decode_mtp_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + ): + """ + Optimized decode kernel for GDN forward pass (MTP mode with multiple steps). + + Key optimizations: + 1. Uses pre-allocated work buffer to avoid per-step .contiguous() allocations + 2. Uses optimized flat Triton kernels for state copying + 3. Direct slice assignment for output instead of .copy_() + + Note: Sequential processing is required because each MTP step depends on + the previous step's final state (both conv and SSM states). + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // self.mtp_size + + g_squeezed = g.squeeze(0) # [total_tokens, num_heads] + beta_squeezed = beta.squeeze(0) + + # Pre-allocate output tensor + core_attn_out = torch.empty( + (total_tokens, 1, self.tp_num_v_heads, self.head_v_dim), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Pre-allocate work buffer for conv1d input (avoids per-step .contiguous()) + qkv_work_buffer = torch.empty( + (batch_size, mixed_qkv.shape[-1]), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + + # Process each MTP step sequentially (required due to state dependencies) + for step_idx in range(self.mtp_size): + cur_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx] + + # ========== Conv1D processing ========== + # Copy strided data to contiguous work buffer + qkv_work_buffer.copy_(mixed_qkv[step_idx :: self.mtp_size]) + + # causal_conv1d_update operates in-place on contiguous input + causal_conv1d_update( + qkv_work_buffer, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.mm_param.bias, + activation=self.activation, + conv_state_indices=cur_buffer_idx, + ) + + # ========== Recurrent processing ========== + query_i, key_i, value_i = self._rearrange_mixed_qkv(qkv_work_buffer, decode=True) + g_i = g_squeezed[step_idx :: self.mtp_size].unsqueeze(1) + beta_i = beta_squeezed[step_idx :: self.mtp_size].unsqueeze(1) + + core_attn_out_i, _ = fused_recurrent_gated_delta_rule( + q=query_i, + k=key_i, + v=value_i, + g=g_i, + beta=beta_i, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=cur_buffer_idx, + use_qk_l2norm_in_kernel=True, + ) + + # Direct slice assignment (no .copy_() needed) + core_attn_out[step_idx :: self.mtp_size] = core_attn_out_i + + # ========== State propagation to next step ========== + if step_idx < self.mtp_step: + next_buffer_idx = infer_state.mtp_buffer_idx_list[step_idx + 1] + if conv_states.is_contiguous() and ssm_states.is_contiguous(): + copy_states_fused(conv_states, ssm_states, cur_buffer_idx, next_buffer_idx) + else: + copy_conv_states(conv_states, cur_buffer_idx, next_buffer_idx) + copy_ssm_states(ssm_states, cur_buffer_idx, next_buffer_idx) + + return core_attn_out + + def gdn_forward( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextGatedDeltaNetTransformerLayerWeight, + is_prefill: bool, + ): + assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) + + # Common preprocessing + input = input.view(-1, self.embed_dim_) + conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) + + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + # mixed_qkv is now returned pre-concatenated (no torch.cat needed) + mixed_qkv, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba) + + # Compute g and beta for all modes + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + + # Dispatch to appropriate kernel + if is_prefill: + core_attn_out = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + elif self.mtp_step == 0: + core_attn_out = self._gdn_decode_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + else: + core_attn_out = self._gdn_decode_mtp_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + + # Common postprocessing + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + layer_weight.linear_norm.bias, + self.eps_, + z, + out=norm_out, + ) + core_attn_out = norm_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + + output = layer_weight.linear_out_proj.mm(core_attn_out) + # Note: all_reduce is handled by context_forward/token_forward callers + return output diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..60ff5f1dc --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,212 @@ +import torch +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + NormWeight, + TpParameterWeight, +) +from typing_extensions import override + + +class Qwen3NextFullAttentionTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + @override + def _init_weight(self): + super()._init_weight() + # Additional architecture + self._init_o_gate_proj_weight() + self._init_gate_shared_expert_weight() + return + + @override + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _init_o_gate_proj_weight(self): + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + self.o_gate_proj = ROWMMWeight( + weight_names=self._o_gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="o_gate_proj", + ) + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + self.shared_expert_gate_up_proj = ROWMMWeight( + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate_up_proj", + ) + self.shared_expert_down_proj = COLMMWeight( + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_down_proj", + ) + self.shared_expert_gate = ROWMMWeight( + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate", + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.tp_q_head_num_ * self.tp_world_size_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj + + +class Qwen3NextGatedDeltaNetTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg): + self.is_moe = ( + network_config["num_experts"] > 0 + and layer_num not in network_config["mlp_only_layers"] + and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + ) + super().__init__(layer_num, data_type, network_config, quant_cfg) + + @override + def _parse_config(self): + self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] + self.linear_num_k_heads = self.network_config_["linear_num_key_heads"] + self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] + self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] + + @override + def _init_weight(self): + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self._init_gdn_weight() + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + self._init_gate_shared_expert_weight() + + def _init_gdn_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + self.linear_conv1d = COLMMWeight( + weight_names=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="conv1d_weight", + ) + + self.linear_in_proj = ROWMMWeight( + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="in_proj_weight", + ) + + self.linear_out_proj = COLMMWeight( + weight_names=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="out_proj_weight", + ) + + split_n_embed = self.linear_num_v_heads // self.tp_world_size_ + self.linear_dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + self.linear_A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + self.linear_norm = NormWeight( + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) + + @override + def load_hf_weights(self, weights): + self._preprocess_weight(weights) + return super().load_hf_weights(weights) + + def _preprocess_weight(self, weights): + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + if linear_conv1d_weight_name in weights: + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ).transpose(0, 1) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + + def _parse_linear_conv1d(self, weight): + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + q_part, k_part, v_part = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q_part.chunk(self.tp_world_size_, dim=0) + k_splits = k_part.chunk(self.tp_world_size_, dim=0) + v_splits = v_part.chunk(self.tp_world_size_, dim=0) + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + return new_weight + + def _init_gate_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + self.shared_expert_gate_up_proj = ROWMMWeight( + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate_up_proj", + ) + self.shared_expert_down_proj = COLMMWeight( + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_down_proj", + ) + self.shared_expert_gate = ROWMMWeight( + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 000000000..5bca8d9d1 --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,76 @@ +import torch +from typing import Tuple +from typing_extensions import override +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + +logger = init_logger(__name__) + + +class Qwen3NextHybridMemManager(MemoryManager): + def __init__( + self, + full_attn_cache_size, + linear_attn_cache_size, + dtype, + num_kv_heads, + head_dim, + layer_num, + mtp_layer_num, + full_attention_interval: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + max_req_num: int, + always_copy=False, + mem_fraction=0.9, + ): + + self.full_attention_interval = full_attention_interval + assert layer_num % full_attention_interval == 0 + self.layer_num = layer_num + self.mtp_layer_num = mtp_layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.mamba_cache_mem_manager = MambaCacheManager( + linear_attn_cache_size, + self.linear_attn_layer_num, + conv_state_dtype, + conv_state_shape, + ssm_state_dtype, + ssm_state_shape, + ) + + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + @override + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., + # None, kv_cache, mtp_kv_cache, mtp_kv_cache] + # Only full attention layers and MTP layers have KV cache. + self.kv_buffer = [None for _ in range(self.layer_num)] + for layer_id in range(self.full_attn_layer_num): + self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" + ) + for _ in range(self.mtp_layer_num): + self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) + + @override + def free_all(self): + super().free_all() + self.mamba_cache_mem_manager.free_all() + return + + @override + def get_cell_size(self): + # Only full attention layers and MTP layers have KV cache + kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num + return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) + + def get_mamba_cache(self, layer_idx: int): + layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) + return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 000000000..93cae72d0 --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,179 @@ +import torch +from typing import Optional +from typing_extensions import override +import triton +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextFullAttentionTransformerLayerWeight, + Qwen3NextGatedDeltaNetTransformerLayerWeight, +) +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextGatedDeltaNetTransformerLayerInfer, +) +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.common.req_manager import ReqManagerForMamba +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + + post_layer_infer_class = Qwen3NextPostLayerInfer + infer_state_class = Qwen3NextInferStateInfo + + is_hybrid_attention = True # Indicates model uses hybrid (full + linear) attention + use_buffer_manager = True # Indicates model needs per-request buffer management for linear attention states + + _triton_allocator_initialized = False # Class-level flag to ensure allocator is set only once + + @classmethod + def get_radix_cache_class(cls): + """Return HybridRadixCache for hybrid attention models.""" + from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + + return HybridRadixCache + + @classmethod + def _init_triton_allocator(cls): + """Initialize Triton allocator once for all instances.""" + if cls._triton_allocator_initialized: + return + + def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + # Set Triton allocator for TMA descriptors + # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py + triton.set_allocator(_triton_allocator) + cls._triton_allocator_initialized = True + logger.info("Triton allocator set for Qwen3Next model") + + def __init__(self, kvargs) -> None: + self.mem_manager: Qwen3NextHybridMemManager = None + self._init_triton_allocator() + super().__init__(kvargs) + + @override + def autotune_layers(self): + return self.config["full_attention_interval"] + + @override + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + @override + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + @override + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + start_args: StartArgs = get_env_start_args() + mamba_cache_size = start_args.mamba_cache_size + if mamba_cache_size is not None: + assert ( + mamba_cache_size >= start_args.running_max_req_size + ), "mamba_cache_size must be greater than running_max_req_size" + + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + if start_args.mamba_ssm_data_type not in ssm_dtype_dict: + raise ValueError( + f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}." + f" Must be one of {list(ssm_dtype_dict.keys())}" + ) + + self.mem_manager = Qwen3NextHybridMemManager( + full_attn_cache_size=self.max_total_token_num, + linear_attn_cache_size=mamba_cache_size, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + mtp_layer_num=start_args.mtp_step, + full_attention_interval=self.config["full_attention_interval"], + conv_state_dtype=self.data_type, + conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1), + ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], + ssm_state_shape=( + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + max_req_num=self.max_req_num, + mem_fraction=self.mem_fraction, + ) + + @override + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) + + @override + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.trans_layers_weight = [ + Qwen3NextFullAttentionTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerWeight( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(self.config["n_layer"]) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + num_full_attention_layers = self.config["full_attention_interval"] + self.layers_infer = [ + Qwen3NextFullAttentionTransformerLayerInfer(i, network_config=self.config) + if (i + 1) % num_full_attention_layers == 0 + else Qwen3NextGatedDeltaNetTransformerLayerInfer(i, network_config=self.config) + for i in range(self.config["n_layer"]) + ] diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 000000000..c6d099a2d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,122 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py + +from typing import Optional + +import torch + +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = -1, + **kwargs, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + causal_conv1d_fwd( + x, + weight, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + pad_slot_id, + ) + return x + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + causal_conv1d_update_kernel( + x, + conv_state, + weight, + bias, + activation_val, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 000000000..2bde70bb9 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from +# https://github.com/vllm-project/vllm diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 000000000..cd3b0962a --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 000000000..7b3067bbf --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 000000000..97933b2ac --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp, safe_exp +from lightllm.common.triton_utils.autotuner import autotune + +NUM_WARPS = [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_delta_h_configs(): + return [ + {"BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ] + + +def _get_chunk_delta_h_static_key(k, u, chunk_size): + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + return {"H": H, "K": K, "V": V, "BT": chunk_size} + + +def _get_chunk_delta_h_run_key(k, u): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_gated_delta_rule_fwd_h", + configs_gen_func=_get_chunk_delta_h_configs, + static_key_func=_get_chunk_delta_h_static_key, + run_key_func=_get_chunk_delta_h_run_key, +) +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + run_config=None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + # Extract config parameters + if run_config is None: + run_config = {"BV": 64, "num_warps": 2, "num_stages": 2} + + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 000000000..fc49763ec --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp, safe_exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from lightllm.common.triton_utils.autotuner import autotune + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_o_configs(): + return [ + {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_o_static_key(q, v, chunk_size): + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_o_run_key(q, v): + # Return batch * heads as run key + return q.shape[0] * q.shape[2] + + +@autotune( + kernel_name="chunk_fwd_o", + configs_gen_func=_get_chunk_o_configs, + static_key_func=_get_chunk_o_static_key, + run_key_func=_get_chunk_o_run_key, +) +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + run_config=None, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 000000000..60a594c07 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import safe_exp +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_scaled_dot_kkt_configs(): + return [ + {"BK": BK, "num_warps": num_warps, "num_stages": num_stages} + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None): + B, T, Hg, K = k.shape + H = beta.shape[-1] + IS_VARLEN = cu_seqlens is not None + return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} + + +def _get_chunk_scaled_dot_kkt_run_key(k, beta): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_scaled_dot_kkt_fwd", + configs_gen_func=_get_chunk_scaled_dot_kkt_configs, + static_key_func=_get_chunk_scaled_dot_kkt_static_key, + run_key_func=_get_chunk_scaled_dot_kkt_run_key, +) +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, + run_config=None, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=BK, + num_warps=num_warps, + num_stages=num_stages, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 000000000..6331e1602 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard +from lightllm.common.triton_utils.autotuner import autotune + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_cumsum_scalar_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] + + +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_scalar_run_key(g): + # Return total number of elements as run key + return g.shape[0] * g.shape[1] + + +@autotune( + kernel_name="chunk_local_cumsum_scalar", + configs_gen_func=_get_cumsum_scalar_configs, + static_key_func=_get_cumsum_scalar_static_key, + run_key_func=_get_cumsum_scalar_run_key, +) +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"num_warps": 2} + + num_warps = run_config.get("num_warps", 2) + + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +def _get_cumsum_vector_configs(): + return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] + + +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_vector_run_key(g): + # Return batch * heads as run key + return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] + + +@autotune( + kernel_name="chunk_local_cumsum_vector", + configs_gen_func=_get_cumsum_vector_configs, + static_key_func=_get_cumsum_vector_static_key, + run_key_func=_get_cumsum_vector_run_key, +) +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"BS": 32, "num_warps": 2} + + BS = run_config.get("BS", 32) + num_warps = run_config.get("num_warps", 2) + + grid = (triton.cdiv(S, BS), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 000000000..fa68d7bfe --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,409 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "HAS_SEPARATE_WRITE_INDICES": lambda args: args["ssm_state_write_indices"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, # NEW: separate write indices for state propagation optimization + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + stride_write_indices_seq: tl.constexpr, # NEW: stride for write indices + stride_write_indices_tok: tl.constexpr, # NEW: stride for write indices + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, + HAS_SEPARATE_WRITE_INDICES: tl.constexpr, # NEW: whether to use separate write indices +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Use separate write indices if provided (for state propagation optimization) + # Otherwise fall back to read indices + if HAS_SEPARATE_WRITE_INDICES: + write_idx = tl.load(ssm_state_write_indices + i_n * stride_write_indices_seq + i_t).to(tl.int64) + else: + write_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) + p_ht = ht + write_idx * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + # Strides for read indices + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # Strides for write indices (if provided) + if ssm_state_write_indices is None: + stride_write_indices_seq, stride_write_indices_tok = 1, 1 + elif ssm_state_write_indices.ndim == 1: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 + else: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + stride_write_indices_seq=stride_write_indices_seq, + stride_write_indices_tok=stride_write_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices for state propagation + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 000000000..8b1d59fc6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 000000000..29f892ef2 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import torch + +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def _get_l2norm_kernel1_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] + + +def _get_l2norm_kernel1_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel1_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel1", + configs_gen_func=_get_l2norm_kernel1_configs, + static_key_func=_get_l2norm_kernel1_static_key, + run_key_func=_get_l2norm_kernel1_run_key, +) +def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None): + if run_config is None: + run_config = {"num_warps": 4} + + num_warps = run_config.get("num_warps", 4) + T = x.shape[0] + + l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps) + + +def _get_l2norm_kernel_configs(): + return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] + + +def _get_l2norm_kernel_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel", + configs_gen_func=_get_l2norm_kernel_configs, + static_key_func=_get_l2norm_kernel_static_key, + run_key_func=_get_l2norm_kernel_run_key, +) +def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None): + if run_config is None: + run_config = {"BT": 32, "num_warps": 4} + + BT = run_config.get("BT", 32) + num_warps = run_config.get("num_warps", 4) + + grid = (triton.cdiv(T, BT),) + l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB) + else: + _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 000000000..2f69aa981 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + +from .utils import is_gather_supported + +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + """ + Numerically stable exponential function. + Only applies exp to non-positive values, returns 0 for positive values. + This prevents numerical overflow and improves stability. + """ + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None + +else: + gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 000000000..b5b6cfc36 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + + +def _ensure_triton_allocator(): + """Ensure Triton has an allocator set for kernels requiring scratch memory.""" + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert ( + FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS +), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) + + for i in range(2, min(16, T - i_t * 16)): + # [16] + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, + ) + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, + ) + + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + + # Ensure Triton allocator is set for TMA kernels that require scratch memory + if is_tma_supported: + _ensure_triton_allocator() + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 000000000..cd7c2e3ae --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from collections.abc import Callable +from enum import Enum +from typing import Any, Literal + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = "cuda" +device_torch_lib = getattr(torch, device, None) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = True +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 000000000..08bb00e64 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: torch.LongTensor | None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 000000000..574573781 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,90 @@ +# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if run_config is None: + run_config = {"BLK_HEADS": 8, "num_warps": 1} + + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, run_config["BLK_HEADS"])) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + run_config["BLK_HEADS"], + num_warps=run_config["num_warps"], + ) + return g, beta_output diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py new file mode 100644 index 000000000..f37d4911a --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py @@ -0,0 +1,163 @@ +""" +Fused QKV projection and GDN gating computation. + +This kernel fuses: +1. Linear projection (matmul with weight) +2. Output reorganization (split and reshape) +3. Gating computation (g and beta from a, b) + +This reduces kernel launches from 3 to 1 for the QKV+gating path. +""" + +import torch +import triton +import triton.language as tl +from typing import Tuple, Optional +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_gdn_gating_only_kernel( + # Output pointers + g_ptr, + beta_ptr, + # Input pointers + a_ptr, + b_ptr, + A_log_ptr, + dt_bias_ptr, + # Dimensions + batch_size, + num_heads, + # Constants + beta_const: tl.constexpr, + threshold: tl.constexpr, + BLOCK_BATCH: tl.constexpr, + BLOCK_HEADS: tl.constexpr, +): + """ + Fused kernel for GDN gating computation with better memory access patterns. + + Computes: + - g = -exp(A_log) * softplus(a + dt_bias) + - beta = sigmoid(b) + """ + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + + batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH) + head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) + + batch_mask = batch_offs < batch_size + head_mask = head_offs < num_heads + mask = batch_mask[:, None] & head_mask[None, :] + + # Load A_log and dt_bias (broadcast across batch) + A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0) + dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0) + + # Load a and b + offs = batch_offs[:, None] * num_heads + head_offs[None, :] + a = tl.load(a_ptr + offs, mask=mask, other=0.0) + b = tl.load(b_ptr + offs, mask=mask, other=0.0) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = a.to(tl.float32) + dt_bias.to(tl.float32) + softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x) + g = -tl.exp(A_log.to(tl.float32)) * softplus_x + + # Compute beta = sigmoid(b) + beta_out = tl.sigmoid(b.to(tl.float32)) + + # Store outputs with layout [1, batch, num_heads] + out_offs = batch_offs[:, None] * num_heads + head_offs[None, :] + tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask) + tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_gating_configs(): + """Generate autotuning configurations.""" + configs = [] + for block_batch in [1, 4, 8, 16]: + for block_heads in [8, 16, 32]: + for num_warps in [2, 4, 8]: + configs.append( + { + "BLOCK_BATCH": block_batch, + "BLOCK_HEADS": block_heads, + "num_warps": num_warps, + } + ) + return configs + + +def _get_fused_gating_static_key(a: torch.Tensor): + return {"dtype": str(a.dtype), "num_heads": a.shape[1]} + + +def _get_fused_gating_run_key(a: torch.Tensor): + return a.shape[0] + + +@autotune( + kernel_name="fused_gdn_gating_v2:v1", + configs_gen_func=_get_fused_gating_configs, + static_key_func=_get_fused_gating_static_key, + run_key_func=_get_fused_gating_run_key, + mutates_args=["g", "beta"], +) +def fused_gdn_gating_v2( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + beta_const: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Optimized GDN gating with pre-allocated output tensors. + + Args: + a: Input tensor [batch, num_heads] + b: Input tensor [batch, num_heads] + A_log: Log of A parameter [num_heads] + dt_bias: Bias for dt [num_heads] + g: Output tensor [1, batch, num_heads] (pre-allocated) + beta: Output tensor [1, batch, num_heads] (pre-allocated) + beta_const: Beta constant for softplus (default: 1.0) + threshold: Threshold for softplus approximation (default: 20.0) + run_config: Optional autotuning configuration + + Returns: + Tuple of (g, beta) - same tensors passed in, now filled + """ + batch_size, num_heads = a.shape + + if run_config is None: + run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4} + + grid = ( + triton.cdiv(batch_size, run_config["BLOCK_BATCH"]), + triton.cdiv(num_heads, run_config["BLOCK_HEADS"]), + ) + + _fused_gdn_gating_only_kernel[grid]( + g, + beta, + a, + b, + A_log, + dt_bias, + batch_size, + num_heads, + beta_const, + threshold, + run_config["BLOCK_BATCH"], + run_config["BLOCK_HEADS"], + num_warps=run_config["num_warps"], + ) + + return g, beta diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 000000000..89db5e00c --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py new file mode 100644 index 000000000..5a39debaa --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gdn_decode_mtp.py @@ -0,0 +1,1333 @@ +""" +Optimized GDN Decode MTP (Multi-Token Prediction) Kernel + +This module provides an optimized Triton kernel for GDN decode with MTP support, +eliminating the need for sequential Python loops and reducing memory operations. + +Key optimizations: +1. Fused data reorganization from interleaved to batched layout +2. Parallel processing of all batch items with proper state indexing +3. Auto-tuned configurations for different batch sizes and model dimensions +""" + +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _reorganize_mtp_data_kernel( + # Input pointers (interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...]) + src_ptr, + # Output pointers (batched layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...]) + dst_ptr, + # Dimensions + batch_size, + mtp_size, + dim_size, + # Strides + src_stride_token, + src_stride_dim, + dst_stride_token, + dst_stride_dim, + # Block sizes + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Input layout: [step0_batch0, step0_batch1, ..., step0_batchN, step1_batch0, ...] + Output layout: [batch0_step0, batch0_step1, ..., batch0_stepM, batch1_step0, ...] + + This enables efficient processing with the recurrent kernel. + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_dim_idx = tl.program_id(2) + + # Calculate source and destination token indices + src_token_idx = step_idx * batch_size + batch_idx + dst_token_idx = batch_idx * mtp_size + step_idx + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + mask = dim_offsets < dim_size + + # Load from source (interleaved layout) + src_offset = src_token_idx * src_stride_token + dim_offsets * src_stride_dim + data = tl.load(src_ptr + src_offset, mask=mask, other=0.0) + + # Store to destination (batched layout) + dst_offset = dst_token_idx * dst_stride_token + dim_offsets * dst_stride_dim + tl.store(dst_ptr + dst_offset, data, mask=mask) + + +@triton.jit +def _reorganize_mtp_data_back_kernel( + # Input pointers (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_ptr, + # Output pointers (interleaved layout): [total_tokens, 1, num_heads, head_dim] + dst_ptr, + # Dimensions + batch_size, + mtp_size, + num_heads, + head_dim, + # Strides for src: [batch_size, mtp_size, num_heads, head_dim] + src_stride_batch, + src_stride_mtp, + src_stride_head, + src_stride_dim, + # Strides for dst: [total_tokens, 1, num_heads, head_dim] + dst_stride_token, + dst_stride_seq, + dst_stride_head, + dst_stride_dim, + # Block sizes + BLOCK_HEAD: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """ + Reorganize output data from batched layout back to interleaved layout. + + Input shape: [batch_size, mtp_size, num_heads, head_dim] + Output shape: [batch_size * mtp_size, 1, num_heads, head_dim] (interleaved) + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Decompose block_idx into head and dim blocks + num_dim_blocks = tl.cdiv(head_dim, BLOCK_DIM) + block_head_idx = block_idx // num_dim_blocks + block_dim_idx = block_idx % num_dim_blocks + + # Calculate destination token index (interleaved) + dst_token_idx = step_idx * batch_size + batch_idx + + # Calculate offsets + head_start = block_head_idx * BLOCK_HEAD + dim_start = block_dim_idx * BLOCK_DIM + + head_offsets = head_start + tl.arange(0, BLOCK_HEAD) + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + mask = head_mask[:, None] & dim_mask[None, :] + + # Load from source (batched layout): [batch_size, mtp_size, num_heads, head_dim] + src_base = src_ptr + batch_idx * src_stride_batch + step_idx * src_stride_mtp + src_offset = head_offsets[:, None] * src_stride_head + dim_offsets[None, :] * src_stride_dim + data = tl.load(src_base + src_offset, mask=mask, other=0.0) + + # Store to destination (interleaved layout): [total_tokens, 1, num_heads, head_dim] + # The seq dimension (1) is skipped since it's always 0 + dst_base = dst_ptr + dst_token_idx * dst_stride_token + dst_offset = head_offsets[:, None] * dst_stride_head + dim_offsets[None, :] * dst_stride_dim + tl.store(dst_base + dst_offset, data, mask=mask) + + +def _get_reorganize_mtp_configs(): + """Generate candidate configurations for MTP data reorganization.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_static_key(src: torch.Tensor, mtp_size: int): + """Static key based on tensor properties.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + } + + +def _get_reorganize_run_key(src: torch.Tensor, mtp_size: int): + """Run key based on batch size and dimension.""" + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + return f"{batch_size}_{dim_size}" + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize:v1", + configs_gen_func=_get_reorganize_mtp_configs, + static_key_func=_get_reorganize_static_key, + run_key_func=_get_reorganize_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_to_batched( + src: torch.Tensor, + dst: torch.Tensor, + mtp_size: int, + run_config: dict = None, +): + """ + Reorganize data from interleaved MTP layout to batched layout. + + Args: + src: Input tensor with interleaved layout [total_tokens, dim] + Layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] + dst: Output tensor with batched layout [total_tokens, dim] + Layout: [batch0_step0, batch0_step1, ..., batch1_step0, ...] + mtp_size: Number of MTP steps + run_config: Auto-tuned configuration + """ + total_tokens = src.shape[0] + batch_size = total_tokens // mtp_size + dim_size = src.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + grid = (batch_size, mtp_size, num_blocks_dim) + + _reorganize_mtp_data_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + dim_size, + src.stride(0), + src.stride(-1) if src.ndim > 1 else 1, + dst.stride(0), + dst.stride(-1) if dst.ndim > 1 else 1, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_reorganize_back_configs(): + """Generate candidate configurations for MTP output reorganization.""" + configs = [] + for block_head in [4, 8, 16, 32]: + for block_dim in [32, 64, 128]: + for num_warps in [2, 4, 8]: + for num_stages in [2, 3]: + if block_head * block_dim <= 4096: # Limit shared memory + configs.append( + { + "BLOCK_HEAD": block_head, + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_reorganize_back_static_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Static key for output reorganization.""" + return { + "dtype": str(src.dtype), + "mtp_size": mtp_size, + "num_heads": num_heads, + "head_dim": head_dim, + } + + +def _get_reorganize_back_run_key( + src: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, +): + """Run key for output reorganization.""" + return batch_size + + +@autotune( + kernel_name="gdn_decode_mtp_reorganize_back:v1", + configs_gen_func=_get_reorganize_back_configs, + static_key_func=_get_reorganize_back_static_key, + run_key_func=_get_reorganize_back_run_key, + mutates_args=["dst"], +) +def reorganize_mtp_output_to_interleaved( + src: torch.Tensor, + dst: torch.Tensor, + batch_size: int, + mtp_size: int, + num_heads: int, + head_dim: int, + run_config: dict = None, +): + """ + Reorganize output from batched layout back to interleaved layout. + + Args: + src: Input tensor [batch_size, mtp_size, num_heads, head_dim] (4D) + dst: Output tensor [batch_size * mtp_size, 1, num_heads, head_dim] (4D) + batch_size: Number of batch items + mtp_size: Number of MTP steps + num_heads: Number of attention heads + head_dim: Head dimension + run_config: Auto-tuned configuration + + Mapping: src[b, s, h, d] -> dst[s * batch_size + b, 0, h, d] + """ + if run_config is None: + BLOCK_HEAD = min(triton.next_power_of_2(num_heads), 16) + BLOCK_DIM = min(triton.next_power_of_2(head_dim), 64) + num_warps = 4 + num_stages = 2 + else: + BLOCK_HEAD = run_config["BLOCK_HEAD"] + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_head_blocks = triton.cdiv(num_heads, BLOCK_HEAD) + num_dim_blocks = triton.cdiv(head_dim, BLOCK_DIM) + num_blocks_total = num_head_blocks * num_dim_blocks + + grid = (batch_size, mtp_size, num_blocks_total) + + # src is 4D: [batch_size, mtp_size, num_heads, head_dim] + # dst is 4D: [total_tokens, 1, num_heads, head_dim] + _reorganize_mtp_data_back_kernel[grid]( + src, + dst, + batch_size, + mtp_size, + num_heads, + head_dim, + src.stride(0), # batch stride + src.stride(1), # mtp stride + src.stride(2), # head stride + src.stride(3), # dim stride + dst.stride(0), # token stride + dst.stride(1), # seq stride (=1) + dst.stride(2), # head stride + dst.stride(3), # dim stride + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_DIM=BLOCK_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _prepare_mtp_indices_kernel( + # Input indices (per-step buffer indices) + buffer_idx_ptr, + # Output 2D indices for recurrent kernel + output_idx_ptr, + # Dimensions + batch_size, + mtp_size, + # Strides + input_stride, + output_stride_batch, + output_stride_step, +): + """ + Prepare 2D indices for the fused recurrent kernel. + + Input: mtp_size tensors of shape [batch_size] (buffer indices for each step) + Output: 2D tensor [batch_size, mtp_size] for ssm_state_indices + """ + batch_idx = tl.program_id(0) + step_idx = tl.program_id(1) + + # Load the buffer index for this batch and step + buffer_idx = tl.load(buffer_idx_ptr + step_idx * input_stride + batch_idx) + + # Store to the 2D output + output_offset = batch_idx * output_stride_batch + step_idx * output_stride_step + tl.store(output_idx_ptr + output_offset, buffer_idx) + + +def prepare_mtp_state_indices( + mtp_buffer_idx_list: list, + batch_size: int, + device: torch.device, +) -> torch.Tensor: + """ + Prepare 2D state indices for the fused recurrent kernel. + + Args: + mtp_buffer_idx_list: List of buffer index tensors, one per MTP step + batch_size: Number of batch items + device: Target device + + Returns: + 2D tensor of shape [batch_size, mtp_size] for ssm_state_indices + """ + + # Stack indices to create [mtp_size, batch_size] tensor + stacked_indices = torch.stack(mtp_buffer_idx_list, dim=0) + + # Transpose to get [batch_size, mtp_size] + return stacked_indices.T.contiguous() + + +@triton.jit +def _fused_conv1d_mtp_step_kernel( + # Input/output data + mixed_qkv_ptr, + # Conv state buffer + conv_states_ptr, + # Conv weight and bias + conv_weight_ptr, + conv_bias_ptr, + # Buffer indices (one per MTP step, each [batch_size]) + buffer_indices_ptr, + next_buffer_indices_ptr, + # Dimensions + batch_size, + dim_size, + conv_width, + # Step info + step_idx, + mtp_size, + is_last_step: tl.constexpr, + # Strides + qkv_stride_token, + qkv_stride_dim, + state_stride_buffer, + state_stride_dim, + state_stride_width, + weight_stride_dim, + weight_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, +): + """ + Fused kernel for conv1d update in MTP decode. + + Handles one MTP step for all batch items: + 1. Reads current conv state + 2. Updates with new input + 3. Computes conv1d output + 4. Optionally copies state to next MTP step + """ + batch_idx = tl.program_id(0) + block_dim_idx = tl.program_id(1) + + # Calculate token index in interleaved layout + token_idx = step_idx * batch_size + batch_idx + + # Load buffer indices + cur_buffer_idx = tl.load(buffer_indices_ptr + batch_idx).to(tl.int64) + + # Calculate dimension offsets + dim_start = block_dim_idx * BLOCK_DIM + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + dim_mask = dim_offsets < dim_size + + # Load input value + input_offset = token_idx * qkv_stride_token + dim_offsets * qkv_stride_dim + input_val = tl.load(mixed_qkv_ptr + input_offset, mask=dim_mask, other=0.0) + + # Load conv bias + bias_val = tl.load(conv_bias_ptr + dim_offsets, mask=dim_mask, other=0.0) + + # Compute conv1d output and update state + output_val = bias_val + state_base = conv_states_ptr + cur_buffer_idx * state_stride_buffer + + # Process each position in the conv window + for w in range(conv_width): + # Load weight for this position + weight_offset = dim_offsets * weight_stride_dim + w * weight_stride_width + weight_val = tl.load(conv_weight_ptr + weight_offset, mask=dim_mask, other=0.0) + + if w < conv_width - 1: + # Load from state buffer + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + state_val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + output_val += state_val * weight_val + else: + # Use current input for the last position + output_val += input_val * weight_val + + # Update conv state (shift and insert new value) + for w in range(conv_width - 2, -1, -1): + if w == conv_width - 2: + # Insert new input at the end + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + tl.store(state_base + state_offset, input_val, mask=dim_mask) + else: + # Shift state + src_offset = dim_offsets * state_stride_dim + (w + 1) * state_stride_width + dst_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + src_offset, mask=dim_mask, other=0.0) + tl.store(state_base + dst_offset, val, mask=dim_mask) + + # Apply activation (SiLU) + if ACTIVATION_SILU: + output_val = output_val * tl.sigmoid(output_val) + + # Store output + tl.store(mixed_qkv_ptr + input_offset, output_val, mask=dim_mask) + + # Copy state to next step if not last + if not is_last_step: + next_buffer_idx = tl.load(next_buffer_indices_ptr + batch_idx).to(tl.int64) + next_state_base = conv_states_ptr + next_buffer_idx * state_stride_buffer + + for w in range(conv_width - 1): + state_offset = dim_offsets * state_stride_dim + w * state_stride_width + val = tl.load(state_base + state_offset, mask=dim_mask, other=0.0) + tl.store(next_state_base + state_offset, val, mask=dim_mask) + + +def _get_conv1d_mtp_configs(): + """Generate candidate configurations for conv1d MTP kernel.""" + configs = [] + for block_dim in [64, 128, 256, 512]: + for num_warps in [2, 4, 8]: + for num_stages in [1, 2, 3]: + configs.append( + { + "BLOCK_DIM": block_dim, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def _get_conv1d_mtp_static_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Static key for conv1d MTP kernel.""" + return { + "dtype": str(mixed_qkv.dtype), + "dim_size": mixed_qkv.shape[-1], + "conv_width": conv_weight.shape[-1], + "mtp_size": mtp_size, + } + + +def _get_conv1d_mtp_run_key( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + mtp_size: int, +): + """Run key for conv1d MTP kernel.""" + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + return batch_size + + +@autotune( + kernel_name="gdn_conv1d_mtp:v1", + configs_gen_func=_get_conv1d_mtp_configs, + static_key_func=_get_conv1d_mtp_static_key, + run_key_func=_get_conv1d_mtp_run_key, + mutates_args=["mixed_qkv", "conv_states"], +) +def fused_conv1d_mtp_update( + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + mtp_buffer_idx_list: list, + mtp_size: int, + activation_silu: bool = True, + run_config: dict = None, +): + """ + Fused conv1d update for all MTP steps. + + Args: + mixed_qkv: Input tensor [batch_size * mtp_size, dim] (interleaved) + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + conv_weight: Conv weights [dim, conv_width] + conv_bias: Conv bias [dim] + mtp_buffer_idx_list: List of buffer index tensors per step + mtp_size: Number of MTP steps + activation_silu: Whether to apply SiLU activation + run_config: Auto-tuned configuration + """ + total_tokens = mixed_qkv.shape[0] + batch_size = total_tokens // mtp_size + dim_size = mixed_qkv.shape[-1] + conv_width = conv_weight.shape[-1] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 256)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks_dim = triton.cdiv(dim_size, BLOCK_DIM) + + for step_idx in range(mtp_size): + is_last_step = step_idx == mtp_size - 1 + cur_indices = mtp_buffer_idx_list[step_idx] + next_indices = mtp_buffer_idx_list[step_idx + 1] if not is_last_step else cur_indices + + grid = (batch_size, num_blocks_dim) + + _fused_conv1d_mtp_step_kernel[grid]( + mixed_qkv, + conv_states, + conv_weight, + conv_bias, + cur_indices, + next_indices, + batch_size, + dim_size, + conv_width, + step_idx, + mtp_size, + is_last_step, + mixed_qkv.stride(0), + mixed_qkv.stride(-1) if mixed_qkv.ndim > 1 else 1, + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + conv_weight.stride(0), + conv_weight.stride(1), + BLOCK_DIM=BLOCK_DIM, + ACTIVATION_SILU=activation_silu, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton.jit +def _copy_ssm_state_kernel( + # SSM state buffer + ssm_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + num_heads, + key_dim, + value_dim, + # Strides + state_stride_buffer, + state_stride_head, + state_stride_key, + state_stride_value, + # Block sizes + BLOCK_KEY: tl.constexpr, + BLOCK_VALUE: tl.constexpr, +): + """ + Copy SSM states from source indices to destination indices. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + block_idx = tl.program_id(2) + + # Calculate block positions + num_value_blocks = tl.cdiv(value_dim, BLOCK_VALUE) + block_key_idx = block_idx // num_value_blocks + block_value_idx = block_idx % num_value_blocks + + key_start = block_key_idx * BLOCK_KEY + value_start = block_value_idx * BLOCK_VALUE + + key_offsets = key_start + tl.arange(0, BLOCK_KEY) + value_offsets = value_start + tl.arange(0, BLOCK_VALUE) + + key_mask = key_offsets < key_dim + value_mask = value_offsets < value_dim + mask = key_mask[:, None] & value_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = ssm_states_ptr + src_idx * state_stride_buffer + head_idx * state_stride_head + dst_base = ssm_states_ptr + dst_idx * state_stride_buffer + head_idx * state_stride_head + + offsets = key_offsets[:, None] * state_stride_key + value_offsets[None, :] * state_stride_value + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +@triton.jit +def _copy_conv_state_kernel( + # Conv state buffer [num_buffers, dim, conv_width-1] + conv_states_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + dim_size, + width_size, + num_width_blocks, # Precomputed to avoid runtime division + # Strides + state_stride_buffer, + state_stride_dim, + state_stride_width, + # Block sizes + BLOCK_DIM: tl.constexpr, + BLOCK_WIDTH: tl.constexpr, +): + """ + Copy conv states from source indices to destination indices. + + Conv state shape: [num_buffers, dim, conv_width-1] + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate block positions using precomputed num_width_blocks + block_dim_idx = block_idx // num_width_blocks + block_width_idx = block_idx % num_width_blocks + + dim_start = block_dim_idx * BLOCK_DIM + width_start = block_width_idx * BLOCK_WIDTH + + dim_offsets = dim_start + tl.arange(0, BLOCK_DIM) + width_offsets = width_start + tl.arange(0, BLOCK_WIDTH) + + dim_mask = dim_offsets < dim_size + width_mask = width_offsets < width_size + mask = dim_mask[:, None] & width_mask[None, :] + + # Load indices + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate offsets + src_base = conv_states_ptr + src_idx * state_stride_buffer + dst_base = conv_states_ptr + dst_idx * state_stride_buffer + + offsets = dim_offsets[:, None] * state_stride_dim + width_offsets[None, :] * state_stride_width + + # Copy data + data = tl.load(src_base + offsets, mask=mask, other=0.0) + tl.store(dst_base + offsets, data, mask=mask) + + +def _get_conv_copy_configs(): + """Generate candidate configurations for conv state copy.""" + configs = [] + for block_dim in [64, 128, 256]: + for block_width in [2, 4, 8]: + for num_warps in [2, 4]: + configs.append( + { + "BLOCK_DIM": block_dim, + "BLOCK_WIDTH": block_width, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv copy.""" + return { + "dtype": str(conv_states.dtype), + "dim_size": conv_states.shape[1], + "width_size": conv_states.shape[2], + } + + +def _get_conv_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_copy:v1", + configs_gen_func=_get_conv_copy_configs, + static_key_func=_get_conv_copy_static_key, + run_key_func=_get_conv_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy conv states from source indices to destination indices. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + dim_size = conv_states.shape[1] + width_size = conv_states.shape[2] + + if run_config is None: + BLOCK_DIM = triton.next_power_of_2(min(dim_size, 128)) + BLOCK_WIDTH = triton.next_power_of_2(min(width_size, 4)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_DIM = run_config["BLOCK_DIM"] + BLOCK_WIDTH = run_config["BLOCK_WIDTH"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_dim_blocks = triton.cdiv(dim_size, BLOCK_DIM) + num_width_blocks = triton.cdiv(width_size, BLOCK_WIDTH) + num_blocks_total = num_dim_blocks * num_width_blocks + + grid = (batch_size, num_blocks_total) + + _copy_conv_state_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + dim_size, + width_size, + num_width_blocks, # Pass precomputed value + conv_states.stride(0), + conv_states.stride(1), + conv_states.stride(2), + BLOCK_DIM=BLOCK_DIM, + BLOCK_WIDTH=BLOCK_WIDTH, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_copy_configs(): + """Generate candidate configurations for SSM state copy.""" + configs = [] + for block_key in [16, 32, 64]: + for block_value in [16, 32, 64, 128]: + for num_warps in [2, 4, 8]: + if block_key * block_value <= 4096: + configs.append( + { + "BLOCK_KEY": block_key, + "BLOCK_VALUE": block_value, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_ssm_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for SSM copy.""" + return { + "dtype": str(ssm_states.dtype), + "num_heads": ssm_states.shape[1], + "key_dim": ssm_states.shape[2], + "value_dim": ssm_states.shape[3], + } + + +def _get_ssm_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for SSM copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_copy:v1", + configs_gen_func=_get_ssm_copy_configs, + static_key_func=_get_ssm_copy_static_key, + run_key_func=_get_ssm_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Copy SSM states from source indices to destination indices. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + batch_size = src_indices.shape[0] + num_heads = ssm_states.shape[1] + key_dim = ssm_states.shape[2] + value_dim = ssm_states.shape[3] + + if run_config is None: + BLOCK_KEY = triton.next_power_of_2(min(key_dim, 32)) + BLOCK_VALUE = triton.next_power_of_2(min(value_dim, 64)) + num_warps = 4 + num_stages = 2 + else: + BLOCK_KEY = run_config["BLOCK_KEY"] + BLOCK_VALUE = run_config["BLOCK_VALUE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_key_blocks = triton.cdiv(key_dim, BLOCK_KEY) + num_value_blocks = triton.cdiv(value_dim, BLOCK_VALUE) + num_blocks_total = num_key_blocks * num_value_blocks + + grid = (batch_size, num_heads, num_blocks_total) + + _copy_ssm_state_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + num_heads, + key_dim, + value_dim, + ssm_states.stride(0), + ssm_states.stride(1), + ssm_states.stride(2), + ssm_states.stride(3), + BLOCK_KEY=BLOCK_KEY, + BLOCK_VALUE=BLOCK_VALUE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# ============================================================================= +# Optimized Flat Copy Kernels (for contiguous memory) +# ============================================================================= +# These kernels leverage the fact that both conv_states and ssm_states are +# contiguous in memory, allowing us to flatten the inner dimensions and use +# efficient 1D vectorized copy patterns. + + +@triton.jit +def _copy_state_flat_kernel( + # State buffer pointer (flattened view) + state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + flat_size, # Total elements per buffer entry (flattened inner dims) + # Strides + stride_buffer, # Stride to next buffer entry (in elements) + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Optimized flat copy kernel for contiguous state buffers. + + Instead of using 2D/3D block patterns with stride calculations, this kernel + treats each buffer entry as a flat 1D array and uses vectorized loads/stores + for efficient memory transfer. + + Grid: (batch_size, num_blocks) where num_blocks = ceil(flat_size / BLOCK_SIZE) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Calculate element range for this block + elem_start = block_idx * BLOCK_SIZE + elem_offsets = elem_start + tl.arange(0, BLOCK_SIZE) + elem_mask = elem_offsets < flat_size + + # Load buffer indices for this batch item + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # Calculate source and destination base pointers + src_base = state_ptr + src_idx * stride_buffer + dst_base = state_ptr + dst_idx * stride_buffer + + # Vectorized copy + data = tl.load(src_base + elem_offsets, mask=elem_mask, other=0.0) + tl.store(dst_base + elem_offsets, data, mask=elem_mask) + + +@triton.jit +def _copy_states_fused_kernel( + # Conv state buffer (flattened view) + conv_state_ptr, + # SSM state buffer (flattened view) + ssm_state_ptr, + # Buffer indices + src_indices_ptr, + dst_indices_ptr, + # Dimensions + batch_size, + conv_flat_size, # Total elements per conv buffer entry + ssm_flat_size, # Total elements per ssm buffer entry + # Strides (in elements) + conv_stride_buffer, + ssm_stride_buffer, + # Block sizes + CONV_BLOCK_SIZE: tl.constexpr, + SSM_BLOCK_SIZE: tl.constexpr, +): + """ + Fused kernel to copy both conv_states and ssm_states in a single launch. + + This reduces kernel launch overhead by processing both state copies together. + Each thread block handles one batch item and copies both states sequentially. + + Grid: (batch_size, max(conv_blocks, ssm_blocks)) + """ + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + # Load buffer indices (same for both conv and ssm) + src_idx = tl.load(src_indices_ptr + batch_idx).to(tl.int64) + dst_idx = tl.load(dst_indices_ptr + batch_idx).to(tl.int64) + + # ========== Copy Conv State ========== + conv_num_blocks = tl.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + if block_idx < conv_num_blocks: + conv_elem_start = block_idx * CONV_BLOCK_SIZE + conv_elem_offsets = conv_elem_start + tl.arange(0, CONV_BLOCK_SIZE) + conv_mask = conv_elem_offsets < conv_flat_size + + conv_src_base = conv_state_ptr + src_idx * conv_stride_buffer + conv_dst_base = conv_state_ptr + dst_idx * conv_stride_buffer + + conv_data = tl.load(conv_src_base + conv_elem_offsets, mask=conv_mask, other=0.0) + tl.store(conv_dst_base + conv_elem_offsets, conv_data, mask=conv_mask) + + # ========== Copy SSM State ========== + ssm_num_blocks = tl.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + if block_idx < ssm_num_blocks: + ssm_elem_start = block_idx * SSM_BLOCK_SIZE + ssm_elem_offsets = ssm_elem_start + tl.arange(0, SSM_BLOCK_SIZE) + ssm_mask = ssm_elem_offsets < ssm_flat_size + + ssm_src_base = ssm_state_ptr + src_idx * ssm_stride_buffer + ssm_dst_base = ssm_state_ptr + dst_idx * ssm_stride_buffer + + ssm_data = tl.load(ssm_src_base + ssm_elem_offsets, mask=ssm_mask, other=0.0) + tl.store(ssm_dst_base + ssm_elem_offsets, ssm_data, mask=ssm_mask) + + +def _get_flat_copy_configs(): + """Generate candidate configurations for flat copy kernel.""" + configs = [] + # Larger block sizes for better memory throughput on contiguous data + for block_size in [256, 512, 1024, 2048]: + for num_warps in [4, 8]: + configs.append( + { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_conv_flat_copy_static_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for conv flat copy.""" + return { + "dtype": str(conv_states.dtype), + "flat_size": conv_states.shape[1] * conv_states.shape[2], + } + + +def _get_conv_flat_copy_run_key( + conv_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for conv flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_conv_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_conv_flat_copy_static_key, + run_key_func=_get_conv_flat_copy_run_key, + mutates_args=["conv_states"], +) +def copy_conv_states_flat( + conv_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for conv states leveraging contiguous memory. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions + flat_size = conv_states.shape[1] * conv_states.shape[2] + stride_buffer = conv_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + conv_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_ssm_flat_copy_static_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for ssm flat copy.""" + return { + "dtype": str(ssm_states.dtype), + "flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_ssm_flat_copy_run_key( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for ssm flat copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_ssm_state_flat_copy:v1", + configs_gen_func=_get_flat_copy_configs, + static_key_func=_get_ssm_flat_copy_static_key, + run_key_func=_get_ssm_flat_copy_run_key, + mutates_args=["ssm_states"], +) +def copy_ssm_states_flat( + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Optimized flat copy for SSM states leveraging contiguous memory. + + Args: + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for flat copy" + + batch_size = src_indices.shape[0] + # Flatten inner dimensions (num_heads * key_dim * value_dim) + flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + stride_buffer = ssm_states.stride(0) + + if run_config is None: + BLOCK_SIZE = 1024 + num_warps = 4 + num_stages = 2 + else: + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + num_blocks = triton.cdiv(flat_size, BLOCK_SIZE) + grid = (batch_size, num_blocks) + + _copy_state_flat_kernel[grid]( + ssm_states, + src_indices, + dst_indices, + batch_size, + flat_size, + stride_buffer, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _get_fused_copy_configs(): + """Generate candidate configurations for fused copy kernel.""" + configs = [] + # Use power-of-2 block sizes for both conv and ssm + for conv_block in [256, 512, 1024]: + for ssm_block in [256, 512, 1024]: + for num_warps in [4, 8]: + configs.append( + { + "CONV_BLOCK_SIZE": conv_block, + "SSM_BLOCK_SIZE": ssm_block, + "num_warps": num_warps, + "num_stages": 2, + } + ) + return configs + + +def _get_fused_copy_static_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Static key for fused copy.""" + return { + "conv_dtype": str(conv_states.dtype), + "ssm_dtype": str(ssm_states.dtype), + "conv_flat_size": conv_states.shape[1] * conv_states.shape[2], + "ssm_flat_size": ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3], + } + + +def _get_fused_copy_run_key( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, +): + """Run key for fused copy.""" + return src_indices.shape[0] + + +@autotune( + kernel_name="gdn_states_fused_copy:v1", + configs_gen_func=_get_fused_copy_configs, + static_key_func=_get_fused_copy_static_key, + run_key_func=_get_fused_copy_run_key, + mutates_args=["conv_states", "ssm_states"], +) +def copy_states_fused( + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + run_config: dict = None, +): + """ + Fused copy for both conv and SSM states in a single kernel launch. + + This reduces kernel launch overhead by processing both state copies together. + + Args: + conv_states: Conv state buffer [num_buffers, dim, conv_width-1] (MUST be contiguous) + ssm_states: SSM state buffer [num_buffers, num_heads, key_dim, value_dim] (MUST be contiguous) + src_indices: Source buffer indices [batch_size] + dst_indices: Destination buffer indices [batch_size] + run_config: Auto-tuned configuration + """ + assert conv_states.is_contiguous(), "conv_states must be contiguous for fused copy" + assert ssm_states.is_contiguous(), "ssm_states must be contiguous for fused copy" + + batch_size = src_indices.shape[0] + + # Flatten inner dimensions + conv_flat_size = conv_states.shape[1] * conv_states.shape[2] + ssm_flat_size = ssm_states.shape[1] * ssm_states.shape[2] * ssm_states.shape[3] + + conv_stride_buffer = conv_states.stride(0) + ssm_stride_buffer = ssm_states.stride(0) + + if run_config is None: + CONV_BLOCK_SIZE = 512 + SSM_BLOCK_SIZE = 512 + num_warps = 4 + num_stages = 2 + else: + CONV_BLOCK_SIZE = run_config["CONV_BLOCK_SIZE"] + SSM_BLOCK_SIZE = run_config["SSM_BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + # Grid covers both conv and ssm blocks + conv_num_blocks = triton.cdiv(conv_flat_size, CONV_BLOCK_SIZE) + ssm_num_blocks = triton.cdiv(ssm_flat_size, SSM_BLOCK_SIZE) + max_blocks = max(conv_num_blocks, ssm_num_blocks) + grid = (batch_size, max_blocks) + + _copy_states_fused_kernel[grid]( + conv_states, + ssm_states, + src_indices, + dst_indices, + batch_size, + conv_flat_size, + ssm_flat_size, + conv_stride_buffer, + ssm_stride_buffer, + CONV_BLOCK_SIZE=CONV_BLOCK_SIZE, + SSM_BLOCK_SIZE=SSM_BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py new file mode 100644 index 000000000..ba1e87f76 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py @@ -0,0 +1,143 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _gemma_rmsnorm_fwd_kernel( + x_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + y_ptr = y_ptr + row * y_stride0 + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) + _sum += x * x + + var = tl.sum(_sum, axis=0) / N + rstd = 1 / tl.sqrt(var + EPS) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + w = w + 1.0 + y = x_hat * w + # Write output + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_gemma_rmsnorm_configs(): + """Generate configurations for autotuning gemma RMSNorm kernel.""" + configs = [] + # Different BLOCK_SIZE values (powers of 2) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + for num_stages in [1, 2, 3, 4, 5]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": num_stages}) + return configs + + +def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="gemma_rmsnorm_forward:v1", + configs_gen_func=_get_gemma_rmsnorm_configs, + static_key_func=_get_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], +) +def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): + # Inplace gemma RMS Norm + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + y_arg = y.view(-1, N) + + M, _ = x_arg.shape + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _gemma_rmsnorm_fwd_kernel[(M,)]( + x_arg, + w, + y_arg, + x_stride0=x.stride(0), + x_stride1=x.stride(1), + y_stride0=y.stride(0), + y_stride1=y.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _gemma_rmsnorm_fwd_torch(x, weight, eps): + original_dtype = x.dtype + x = x.to(torch.float32) + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + x = x * (1.0 + weight.float()) + return x.to(original_dtype) + + +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = gemma_rmsnorm_forward(x, weight, eps) + y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + # Use appropriate tolerance based on dtype + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) + return diff --git a/lightllm/models/qwen3next_mtp/__init__.py b/lightllm/models/qwen3next_mtp/__init__.py new file mode 100644 index 000000000..779237817 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel + +__all__ = ["Qwen3NextMTPModel"] diff --git a/lightllm/models/qwen3next_mtp/layer_infer/__init__.py b/lightllm/models/qwen3next_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..2918fca79 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,16 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPostLayerInfer(LlamaPostLayerInfer): + """ + Qwen3Next MTP Post Layer Inference. + Uses gemma_rmsnorm for normalization (same as Qwen3Next). + """ + + def _norm(self, input, infer_state, layer_weight: Qwen3NextMTPPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_.weight, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..4fc207648 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,68 @@ +import torch + +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextMTPPreLayerInfer(LlamaPreLayerInfer): + """ + Qwen3Next MTP Pre-Layer Inference. + Similar to DeepSeek MTP but with different weight structure. + + MTP forward flow: + 1. Get embedding from input_ids + 2. Get hidden state from main model (passed via infer_state) + 3. Normalize embedding with pre_fc_norm_embedding + 4. Normalize hidden with pre_fc_norm_hidden + 5. Concat normalized embedding and hidden + 6. Project through fc to get hidden_dim output + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + # Normalize embedding + input_embdings_normed = self.alloc_tensor(input_embdings.shape, input_embdings.dtype) + gemma_rmsnorm_forward( + input_embdings, layer_weight.pre_fc_norm_embedding_weight_.weight, self.eps_, out=input_embdings_normed + ) + + # Normalize hidden state + tgt_embdings_normed = self.alloc_tensor(tgt_embdings.shape, tgt_embdings.dtype) + gemma_rmsnorm_forward( + tgt_embdings, layer_weight.pre_fc_norm_hidden_weight_.weight, self.eps_, out=tgt_embdings_normed + ) + + # Concat normalized embedding and hidden + cat_embdings = torch.cat((input_embdings_normed, tgt_embdings_normed), dim=-1) + + # Project to hidden_size + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.fc_weight_.shape[1]), dtype=cat_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.fc_weight_, out=ans_logics) + + return ans_logics + + def context_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextMTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..ab3fe6861 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,83 @@ +import os +from functools import partial +from torch.distributed import ReduceOp + +from lightllm.models.qwen3next_mtp.layer_weights.transformer_layer_weight import Qwen3NextMTPTransformerLayerWeight +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextFullAttentionTransformerLayerInfer, + Qwen3NextFFNMixin, +) +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.distributed import all_reduce +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# Module-level constant for MoE mode +MOE_MODE = os.environ.get("MOE_MODE", "TP") + + +class Qwen3NextMTPTransformerLayerInfer(Qwen3NextFullAttentionTransformerLayerInfer): + """ + Qwen3Next MTP Transformer Layer Inference. + MTP layers use full attention (not linear attention) with MoE FFN and shared expert. + Inherits shared methods from Qwen3NextFullAttentionTransformerLayerInfer. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) + self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) + + def _bind_ffn(self): + """MTP always uses shared expert + MoE.""" + if MOE_MODE == "EP": + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_ep, self) + else: + self._ffn = partial(Qwen3NextFFNMixin._ffn_with_shared_expert_tp, self) + + def context_forward( + self, input_embdings, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextMTPTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def token_forward( + self, input_embdings, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextMTPTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings diff --git a/lightllm/models/qwen3next_mtp/layer_weights/__init__.py b/lightllm/models/qwen3next_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..55fd3f0b9 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,46 @@ +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import NoTpGEMMANormWeight + + +class Qwen3NextMTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + self.wte_weight_ = None + self.lm_head_weight_ = None + + # Use Gemma-style normalization for all MTP norm layers + self.final_norm_weight_ = NoTpGEMMANormWeight( + weight_name="mtp.norm.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.pre_fc_norm_embedding_weight_ = NoTpGEMMANormWeight( + weight_name="mtp.pre_fc_norm_embedding.weight", + data_type=self.data_type_, + bias_name=None, + ) + self.pre_fc_norm_hidden_weight_ = NoTpGEMMANormWeight( + weight_name="mtp.pre_fc_norm_hidden.weight", + data_type=self.data_type_, + bias_name=None, + ) + return + + def load_hf_weights(self, weights): + if "mtp.fc.weight" in weights: + self.fc_weight_ = self._cuda(weights["mtp.fc.weight"]).t() + + # Load weights for norm weight objects + self.final_norm_weight_.load_hf_weights(weights) + self.pre_fc_norm_embedding_weight_.load_hf_weights(weights) + self.pre_fc_norm_hidden_weight_.load_hf_weights(weights) + + return + + def verify_load(self): + # Verify all norm weights loaded correctly + return ( + self.final_norm_weight_.verify_load() + and self.pre_fc_norm_embedding_weight_.verify_load() + and self.pre_fc_norm_hidden_weight_.verify_load() + ) diff --git a/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..cc65ed7a4 --- /dev/null +++ b/lightllm/models/qwen3next_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,99 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + NormWeight, +) +from typing_extensions import override + + +class Qwen3NextMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + + @override + def _init_weight_names(self): + self._q_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.q_proj.weight" + self._q_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.q_norm.weight" + self._q_bias_name = None + self._k_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.k_proj.weight" + self._k_norm_name = f"mtp.layers.{self.layer_num_}.self_attn.k_norm.weight" + self._k_bias_name = None + self._v_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.v_proj.weight" + self._v_bias_name = None + self._kv_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.kv_proj.weight" + self._kv_bias_name = None + self._o_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_proj.weight" + self._o_bias_name = None + self._att_norm_weight_name = f"mtp.layers.{self.layer_num_}.input_layernorm.weight" + self._att_norm_bias_name = None + self._ffn_norm_weight_name = f"mtp.layers.{self.layer_num_}.post_attention_layernorm.weight" + self._ffn_norm_bias_name = None + + @override + def _init_weight(self): + self._init_moe() + self._init_shared_expert_weight() + + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + + self._init_qkv() + self._init_o() + self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self._o_gate_weight_name = f"mtp.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + self.o_gate_proj = ROWMMWeight( + weight_names=self._o_gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="o_gate_proj", + ) + + @override + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _init_shared_expert_weight(self): + prefix = f"mtp.layers.{self.layer_num_}.mlp.shared_expert" + self.shared_expert_gate_up_proj = ROWMMWeight( + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate_up_proj", + ) + self.shared_expert_down_proj = COLMMWeight( + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_down_proj", + ) + self.shared_expert_gate = ROWMMWeight( + weight_names=f"mtp.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate", + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.tp_q_head_num_ * self.tp_world_size_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj diff --git a/lightllm/models/qwen3next_mtp/model.py b/lightllm/models/qwen3next_mtp/model.py new file mode 100644 index 000000000..92e4918be --- /dev/null +++ b/lightllm/models/qwen3next_mtp/model.py @@ -0,0 +1,101 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3next_mtp.layer_infer.pre_layer_infer import Qwen3NextMTPPreLayerInfer +from lightllm.models.qwen3next_mtp.layer_infer.transformer_layer_infer import Qwen3NextMTPTransformerLayerInfer +from lightllm.models.qwen3next_mtp.layer_weights.pre_and_post_layer_weight import Qwen3NextMTPPreAndPostLayerWeight +from lightllm.models.qwen3next_mtp.layer_weights.transformer_layer_weight import Qwen3NextMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights +from lightllm.models.registry import ModelRegistry + + +@ModelRegistry("qwen3next_mtp") +class Qwen3NextMTPModel(Qwen3NextTpPartModel): + + pre_and_post_weight_class = Qwen3NextMTPPreAndPostLayerWeight + pre_layer_infer_class = Qwen3NextMTPPreLayerInfer + transformer_weight_class = Qwen3NextMTPTransformerLayerWeight + transformer_layer_infer_class = Qwen3NextMTPTransformerLayerInfer + + def __init__(self, kvargs: dict): + self.mtp_n_layers = 1 + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + """Extract main model and memory layer start from kwargs.""" + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start") + return + + def autotune_layers(self): + return 1 + + def _init_some_value(self): + self.layers_num = self.mtp_n_layers + + def _init_config(self): + super()._init_config() + self.config["n_layers"] = self.mtp_n_layers + self.config["num_hidden_layers"] = self.mtp_n_layers + return + + def _init_custom(self): + """Initialize custom components, sharing cos/sin cache with main model.""" + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + """Share request manager with main model.""" + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + """Share memory manager with main model.""" + self.mem_manager = self.main_model.mem_manager + return + + def _check_mem_size(self): + """Skip mem size check for MTP models since they share memory with main model.""" + self.max_total_token_num = self.mem_manager.size + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(self.mtp_n_layers) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + self.layers_infer = [ + self.transformer_layer_infer_class( + i * self.config["full_attention_interval"] - 1, # Ensure full attention layer + network_config=self.config, + ) + for i in range(self.mtp_n_layers) + ] + # Ensure full attention layer + for i, layer in enumerate(self.layers_infer): + layer.layer_num_ = i + self.mem_layer_start + return diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index f43907307..939843a3e 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -24,8 +24,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index edebccf17..0c6fa31f4 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -1,5 +1,3 @@ -import json -import numpy as np import unicodedata from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.server.core.objs import SamplingParams @@ -88,7 +86,8 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): input_ids.extend(origin_ids[end:]) if multimodal_params: image_cnt = len(multimodal_params.images) - assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") return input_ids diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 395ed4ba1..f908dbdd3 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -8,8 +8,8 @@ class StablelmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.partial_rotary_factor = self.network_config_.get("partial_rotary_factor", 1) return diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 0ad3e07df..3d044eeb5 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -2,8 +2,8 @@ class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.final_norm_weight_ = NoTpNormWeight( weight_name="model.norm.weight", data_type=self.data_type_, diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index a1a73f674..03ee50feb 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class StablelmTransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/stablelm/model.py b/lightllm/models/stablelm/model.py index 2ed710fd4..a3d295358 100644 --- a/lightllm/models/stablelm/model.py +++ b/lightllm/models/stablelm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.stablelm.layer_infer.transformer_layer_infer import StablelmTransformerLayerInfer from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index 52072a348..6b88c066e 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class StarcoderPreLayerInfer(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.layer_norm_eps_ = network_config["layer_norm_epsilon"] def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): diff --git a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py index 018816fcc..074f3411a 100644 --- a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py @@ -1,17 +1,19 @@ from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from functools import partial class StarcoderTransformerLayerInfer(BloomTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self._bind_func() return def _bind_func(self): - LlamaTransformerLayerInfer._bind_attention(self) + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_attention_kernel, self) return diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index d5bdd79a7..329a0245f 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 2aa9dd9ef..41f24f79c 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class StarcoderTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg, layer_prefix="transformer.h") + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg, layer_prefix="transformer.h") assert network_config["num_attention_heads"] % self.tp_world_size_ == 0 def load_hf_weights(self, weights): diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 796a96bc4..09e3299eb 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Starcoder2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 28a26cb4b..6ee188537 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index 6314fa0e5..53342e221 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class Starcoder2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py index b24fc0f0d..44e18c282 100644 --- a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py @@ -17,8 +17,8 @@ def rename_weight_keys(weights): class Tarsier2Qwen2PreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -28,8 +28,8 @@ def load_hf_weights(self, weights): class Tarsier2LlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/tarsier2/model.py b/lightllm/models/tarsier2/model.py index dad252b97..10a7f368c 100644 --- a/lightllm/models/tarsier2/model.py +++ b/lightllm/models/tarsier2/model.py @@ -78,6 +78,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py index 613aec3fa..fa4a87f15 100644 --- a/lightllm/models/vit/layer_infer/post_layer_infer.py +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -9,11 +9,10 @@ class ViTPostLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode self.llm_hidden_size = network_config["llm_hidden_size"] self.downsample_ratio = network_config["downsample_ratio"] return diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index 896e8e898..306bf9f0e 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -11,11 +11,10 @@ class ViTPreLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0b89dca11..0d55d1b57 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -13,7 +13,7 @@ class ViTTransformerLayerInfer: """ """ - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.eps_ = network_config["layer_norm_eps"] @@ -25,7 +25,6 @@ def __init__(self, layer_num, network_config, mode=[]): self.tp_padding_embed_dim_ = self.tp_padding_head_num * self.head_dim_ self.network_config_ = network_config - self.mode = mode self.layer_num_ = layer_num return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0..e2bed1036 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -7,8 +7,8 @@ class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.embed_dim = self.network_config_["hidden_size"] self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index c6024594e..dffcc16fe 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -14,8 +14,8 @@ class ViTTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..9c2bc4242 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -40,7 +40,6 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") - self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] self.weight_dict = kvargs.get("weight_dict", None) self.data_type = kvargs.get("data_type", "float16") self.quant_type = kvargs.get("quant_type", None) @@ -112,15 +111,12 @@ def _padding_hidden_size(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(self.config["num_hidden_layers"]) @@ -141,10 +137,10 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(self.config["num_hidden_layers"]) ] return @@ -185,7 +181,7 @@ def encode(self, images: List[ImageItem]): else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - cur_num = img_tensors[-1].shape[0] + cur_num = img.token_num valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num @@ -195,7 +191,7 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) pixel_values = imgs.cuda().to(dtype=self.data_type) all_img_embeds = self.forward(pixel_values) - return all_img_embeds, uuids, valid_ids + return all_img_embeds.view(-1, all_img_embeds.shape[-1]), uuids, valid_ids def cuda(self): return self diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d193bab41..724a489ce 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -219,29 +219,6 @@ def make_argument_parser() -> argparse.ArgumentParser: the --nccl_host must equal to the config_server_host, and the --nccl_port must be unique for a config_server, dont use same nccl_port for different inference node, it will be critical error""", ) - - parser.add_argument( - "--mode", - type=str, - default=[], - nargs="+", - help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_int8kv_flashdecoding | ppl_int8kv_flashdecoding_diverse - | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration - triton_flashdecoding mode is for long context, current support llama llama2 qwen; - triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; - triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; - triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2; - offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend, - currently only for llama and qwen model; - export_fp8kv_calibration record and export kv cache quant calibration results to a json file. - It can be used for llama and qwen model. - Calibration need to disable cudagraph and use fa3 or flashinfer backend. - ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; - ppl_fp16 mode use ppl fast fp16 decode attention kernel; - you need to read source code to make sure the supported detail mode for all models""", - ) parser.add_argument( "--trust_remote_code", action="store_true", @@ -337,21 +314,40 @@ def make_argument_parser() -> argparse.ArgumentParser: only deepseekv3 model supported now.""", ) parser.add_argument( - "--enable_flashinfer_prefill", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for prefill, - only deepseekv3 model supported now.""", + "--llm_prefill_att_backend", + type=str, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["triton"], + help="""prefill attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( - "--enable_flashinfer_decode", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for decode, - only deepseekv3 model supported now.""", + "--llm_decode_att_backend", + type=str, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["triton"], + help="""decode attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( - "--enable_fa3", - action="store_true", - help="""inference backend will use the fa3 attention kernel for prefill and decode""", + "--llm_kv_type", + type=str, + choices=["None", "int8kv", "int4kv"], + default="None", + help="""kv type used in llm, None for dtype that llm used in config.json. + fp8kv: not fully supported yet, will support in future""", + ) + parser.add_argument( + "--llm_kv_quant_group_size", + type=int, + default=8, + help="""kv quant group size used in llm kv, when llm_kv_type is quanted type,such as int8kv, + this params will be effective. + """, ) parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" @@ -612,4 +608,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_ssm_data_type", + type=str, + choices=["bfloat16", "float32"], + default="float32", + help="the data type of the model weight", + ) return parser diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7b9cdd501..f30ecc55f 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -24,9 +24,10 @@ class Message(BaseModel): class Function(BaseModel): """Function descriptions.""" - description: Optional[str] = Field(default=None, examples=[None]) name: Optional[str] = None - parameters: Optional[object] = None + description: Optional[str] = Field(default=None, examples=[None]) + parameters: Optional[dict] = None + response: Optional[dict] = None class Tool(BaseModel): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 6a8c232dc..d91bb1d94 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -81,7 +81,7 @@ def _process_tool_call_id( # SGLang sets call_item.tool_index to the *local* position inside that message. # Therefore, the index must be corrected by using # `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered. - tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}" + tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt + call_item.tool_index}" logger.debug( f"Process tool call idx, parser: {tool_call_parser}, \ tool_call_id: {tool_call_id}, \ diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 4ead3cbbf..3ae3789f4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -122,21 +122,6 @@ def normal_or_p_d_start(args): if args.return_all_prompt_logprobs: assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" - if "offline_calibration_fp8kv" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - if "export_fp8kv_calibration" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f073319d7..d955aa6a8 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -7,7 +7,7 @@ _SAMPLING_EPS = 1e-5 DEFAULT_INPUT_PENALTY = os.getenv("INPUT_PENALTY", "False").upper() in ["ON", "TRUE", "1"] -SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "True").upper() in ["ON", "TRUE", "1"] +SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "False").upper() in ["ON", "TRUE", "1"] # 从环境变量获取最大长度限制 STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256)) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5ebadaf16..65f467d16 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -63,7 +63,6 @@ class StartArgs: nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -116,8 +115,14 @@ class StartArgs: quant_cfg: Optional[str] = field(default=None) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) - enable_flashinfer_prefill: bool = field(default=False) - enable_flashinfer_decode: bool = field(default=False) + llm_prefill_att_backend: List[str] = field( + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_decode_att_backend: List[str] = field( + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) + llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} @@ -154,5 +159,6 @@ class StartArgs: enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) - # kernel setting - enable_fa3: bool = field(default=False) + # hybrid attention model (e.g., qwen3next) + mamba_cache_size: int = field(default=2000) + mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 000000000..db9968dd4 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,205 @@ +from typing import Set, Protocol, List, Optional, Tuple + +import torch +from sortedcontainers import SortedSet + +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class HybridRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) + assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") + self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) + + def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): + if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: + need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size + + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem) + self.buffer_mem_manager.free(release_buffers) + if len(release_mems) > 0: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + while need_evict_buffer_num > 0: + node = self.evict_buffer_set.pop(0) + assert node.buffer_idx is not None + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + need_evict_buffer_num -= 1 + # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match, + # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除 + if node.is_leaf() and node.ref_counter == 0: + self.evict_tree_set.discard(node) + evict_token_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + return + + def insert_for_hybrid_radix_cache(self, reqs): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + reqs_to_insert = [req for req in reqs if req.cur_kv_len < req.get_cur_total_len()] + + if len(reqs_to_insert) == 0: + return + + self.free_radix_cache_to_get_enough_buffer(len(reqs_to_insert)) + req_idxes = torch.tensor([req.req_idx for req in reqs_to_insert], dtype=torch.int64, device="cuda") + req_to_buffer_index = g_infer_context.req_manager.req_to_buffer_index + # Make contiguous and convert to int64 for Triton kernel compatibility + cur_buffer_indexes = req_to_buffer_index[req_idxes, 0].contiguous().to(torch.int64) + + new_buffer_indexes = self.buffer_mem_manager.alloc(len(reqs_to_insert)) + # Move to CUDA and convert to int64, ensure contiguous + new_buffer_indexes_cuda = new_buffer_indexes.to(device="cuda", dtype=torch.int64).contiguous() + + self.buffer_mem_manager.copy_buffer_p2p(cur_buffer_indexes, new_buffer_indexes_cuda) + + for i, req in enumerate(reqs_to_insert): + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + prefix_len, new_shared_kv_node = super().insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + self.dec_node_ref_counter(req.shared_kv_node) + self.add_node_ref_counter(new_shared_kv_node) + self.add_buffer_idx_to_node(new_shared_kv_node, new_buffer_indexes[i].item()) + req.extra_need_to_free_token_index.append( + g_infer_context.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] + ) + req.shared_kv_node = new_shared_kv_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + miss_prefix_len = 0 + evict_token_list = [] + while tree_node != self.root_node and tree_node.buffer_idx is None: + if tree_node.is_leaf(): + self.evict_tree_set.discard(tree_node) + + if tree_node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + tree_node.ref_counter -= 1 # 只减少当前节点,不递归 + + if tree_node.is_leaf() and tree_node.ref_counter == 0: + evict_token_list.append(tree_node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + parent_node: TreeNode = tree_node.parent + parent_node.remove_child(tree_node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + tree_node = parent_node + else: + if tree_node.is_leaf(): + self.evict_tree_set.add(tree_node) + tree_node = tree_node.parent + miss_prefix_len += len(ans_value_list.pop()) + + if len(evict_token_list) > 0: + evict_token_value = torch.concat(evict_token_list) + self.mem_manager.free(evict_token_value) + + if tree_node == self.root_node: + self._inc_hit_rate(len(key), 0) + return None, miss_prefix_len, None + + update_node = tree_node + while update_node != self.root_node: + if update_node.buffer_idx is not None: + self.evict_buffer_set.discard(update_node) + update_node.update_buffer_time() + self.evict_buffer_set.add(update_node) + update_node = update_node.parent + + value = torch.concat(ans_value_list) + self._inc_hit_rate(len(key), len(value)) + return tree_node, miss_prefix_len, value + + def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): + """Set buffer_idx for a node and add it to evict_buffer_set.""" + self.evict_buffer_set.discard(node) + if node.is_leaf(): + self.evict_tree_set.discard(node) + if node.buffer_idx is not None: + self.buffer_mem_manager.free([node.buffer_idx]) + node.buffer_idx = buffer_idx + node.update_buffer_time() + self.evict_buffer_set.add(node) + if node.is_leaf(): + self.evict_tree_set.add(node) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self.evict(need_evict_token_num, release_buffer, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + if len(release_buffers) > 0: + self.buffer_mem_manager.free(release_buffers) + return + + def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + if node.buffer_idx is not None: + self.evict_buffer_set.discard(node) + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c51774898..ff90aadf9 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -5,6 +5,9 @@ from typing import Tuple, Dict, Set, List, Optional, Union from sortedcontainers import SortedSet from .shared_arr import SharedArray +from lightllm.utils.log_utils import init_logger, log_time_ready + +logger = init_logger(__name__) class UniqueTimeIdGenerator: @@ -31,6 +34,13 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # 专门用于管理混合注意力模型(例如 Qwen3Next), + # 该类模型每个请求需要管理一个唯一的buffer_idx, + # 放在这里让该类模型能够复用当前的radix_cache代码。 + # 纯注意力模型该 buffer_idx 始终保持为 None + self.buffer_idx = None + self.buffer_time = time_gen.generate_time_id() + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -78,6 +88,9 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() + def update_buffer_time(self): + self.buffer_time = time_gen.generate_time_id() + def is_leaf(self): return len(self.children) == 0 @@ -125,13 +138,36 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 + self.total_query_tokens = SharedArray(f"{unique_name}_total_query_tokens_{rank_in_node}", (1,), dtype=np.int64) + self.total_query_tokens.arr[0] = 0 + self.total_hit_tokens = SharedArray(f"{unique_name}_total_hit_tokens_{rank_in_node}", (1,), dtype=np.int64) + self.total_hit_tokens.arr[0] = 0 + self.last_log_query_tokens = 0 + self.last_log_hit_tokens = 0 + + def _inc_hit_rate(self, query_len, hit_len): + self.total_query_tokens.arr[0] += query_len + self.total_hit_tokens.arr[0] += hit_len + if log_time_ready("radix_cache_hit_rate", time_count=10): + current_total_query = self.total_query_tokens.arr[0] + current_total_hit = self.total_hit_tokens.arr[0] + cumulative_hit_rate = current_total_hit / current_total_query if current_total_query > 0 else 0.0 + + label = self.__class__.__name__ + logger.warning( + f"{label} Hit Rate: " + f"Cumulative {cumulative_hit_rate:.2%} ({current_total_hit}/{current_total_query})" + ) + + self.last_log_query_tokens = current_total_query + self.last_log_hit_tokens = current_total_hit + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key - assert len(key) == len(value) # and len(key) >= 1 - if len(key) == 0: - return 0, None + assert len(key) == len(value) and len(key) >= 1 + return self._insert_helper(self.root_node, key, value) def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: @@ -239,9 +275,13 @@ def match_prefix(self, key, update_refs=False): value = torch.concat(ans_value_list) else: value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) - return tree_node, len(value), value + + matched_len = len(value) + self._inc_hit_rate(len(key), matched_len) + return tree_node, matched_len, value else: self.dec_node_ref_counter(self.root_node) + self._inc_hit_rate(len(key), 0) return None, 0, None def _match_prefix_helper( @@ -358,6 +398,7 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 + or parent_node.buffer_idx is not None ): return None diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 89c46d9ed..ac5c1abee 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -58,7 +58,6 @@ def __init__(self, args: StartArgs): # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 self.is_safe_schedule = args.router_token_ratio == 0.0 self.load_way = args.load_way - self.mode = args.mode self.max_total_token_num = args.max_total_token_num self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 @@ -155,7 +154,6 @@ async def wait_to_model_ready(self): "weight_dir": self.model_weightdir, "load_way": self.load_way, "max_total_token_num": self.max_total_token_num, - "mode": self.mode, "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_host": self.args.nccl_host, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538..57241de96 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager +from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.common.basemodel.infer_lock import g_infer_state_lock @@ -32,10 +33,13 @@ class InferenceContext: infer_req_ids = None vocab_size = None cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None + mtp_step: int = 0 overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + use_mamba_model: bool = False + def register( self, backend, @@ -43,6 +47,7 @@ def register( radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, + use_mamba_model: bool = False, ): self.args = get_env_start_args() from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend @@ -57,6 +62,14 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + + self.use_mamba_model = use_mamba_model + if self.use_mamba_model: + assert self.radix_cache is None or isinstance( + self.radix_cache, HybridRadixCache + ), "Mamba model only support HybridRadixCache" + assert isinstance(self.req_manager, ReqManagerForMamba), "Mamba model only support ReqManagerForMamba" + self.mtp_step = get_env_start_args().mtp_step return def init_cpu_embed_cache_client(self): @@ -73,6 +86,27 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream + def _alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None: + """Allocate and copy buffers for requests. Delegates to req_manager which handles model-specific logic.""" + if not req_objs: + return + + if self.radix_cache is not None and hasattr(self.radix_cache, "free_radix_cache_to_get_enough_buffer"): + self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) + + request_indices_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) + self.req_manager.alloc_buffer_for_req(request_indices_gpu) + + if self.radix_cache is None: + return + + copy_data = [(r.req_idx, r.shared_kv_node.buffer_idx) for r in req_objs if r.shared_kv_node is not None] + if copy_data: + copy_indices, copy_buffers = zip(*copy_data) + copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64) + copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64) + self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor) + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -111,9 +145,15 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req + self._alloc_and_copy_req_buffers(req_objs) + return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): + # If no KV cache has been allocated yet, there's nothing to free + if req.cur_kv_len == 0: + return + if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: @@ -122,7 +162,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, _ = self.radix_cache.insert(key, value) + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -130,6 +171,50 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: + # 返回该请求的 mamba buffer 是否需要手动释放 + if req.cur_kv_len == 0: + return True + + if self.radix_cache is None: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + else: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + + if len(req.extra_need_to_free_token_index) > 0: + free_token_index.extend(req.extra_need_to_free_token_index) + req.extra_need_to_free_token_index = [] + + if node.buffer_idx is None: + req_to_buffer_index = self.req_manager.req_to_buffer_index + buffer_idx = req_to_buffer_index[req.req_idx, 0].item() + self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) + # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放 + return False + return True + + def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): + """释放请求的 KV cache 和 buffer 内存""" + if self.use_mamba_model: + need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) + req_to_buffer_index = self.req_manager.req_to_buffer_index + if need_free_base_buffer: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist()) + elif self.mtp_step > 0: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, 1:].tolist()) + else: + self.free_a_req_mem(free_token_index, req) + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer @@ -151,19 +236,23 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + free_buffer_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) - + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) + if len(free_token_index) != 0: + free_token_index = custom_cat(free_token_index) + self.req_manager.free(free_req_index, free_token_index) + + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -191,12 +280,15 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: g_infer_state_lock.acquire() + pause_req_indices = [] free_token_index = [] + free_buffer_index = [] for req in pause_reqs: + pause_req_indices.append(req.req_idx) if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) req.cur_kv_len = 0 req.shm_req.shm_cur_kv_len = req.cur_kv_len assert req.wait_pause is True @@ -209,13 +301,16 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if self.use_mamba_model and len(free_buffer_index) != 0: + self.req_manager.free_buffer(free_buffer_index) + g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: g_infer_state_lock.acquire() - + revovered_reqs = [] for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() if prefill_need_token_num > can_alloc_token_num: @@ -226,7 +321,9 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo if is_master_in_dp: req.shm_req.is_paused = False can_alloc_token_num -= prefill_need_token_num + revovered_reqs.append(req) + self._alloc_and_copy_req_buffers(revovered_reqs) g_infer_state_lock.release() return @@ -351,6 +448,11 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + # 在开启radix cache的情况下,用于标记命中情况,用于插入算法 + self.mamba_model_match_len = 0 + self.mamba_buffer_insert_len = 0 + self.extra_need_to_free_token_index = [] + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED @@ -402,7 +504,7 @@ def _match_radix_cache(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + share_node, miss_prefix_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -411,6 +513,13 @@ def _match_radix_cache(self): self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + if g_infer_context.use_mamba_model: + MAMBA_PREFILL_BLOCK_SIZE = 128 + MAMBA_MIN_INSERT_LEN = 1024 + miss_prefix_len = miss_prefix_len - miss_prefix_len % MAMBA_PREFILL_BLOCK_SIZE + if miss_prefix_len > MAMBA_MIN_INSERT_LEN: + self.mamba_buffer_insert_len = miss_prefix_len + self.shm_req.shm_cur_kv_len = self.cur_kv_len return @@ -458,13 +567,18 @@ def get_input_token_ids(self): return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] def get_chuncked_input_token_ids(self): - chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + # 复用 get_chuncked_input_token_len 的逻辑,保持一致性 + chunked_end = self.get_chuncked_input_token_len() return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + + if self.mamba_buffer_insert_len > 0: + chunked_end = min(self.get_cur_total_len(), chunked_start + self.mamba_buffer_insert_len) + self.mamba_buffer_insert_len = 0 + return chunked_end def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 92653bc0c..eb519f06c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,7 +9,6 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -37,6 +36,7 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.qwen3next_mtp.model import Qwen3NextMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -88,7 +88,6 @@ def init_model(self, kvargs): # dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, self.dp_size // self.nnodes) self.load_way = kvargs["load_way"] - self.mode = kvargs["mode"] self.disable_chunked_prefill = self.args.disable_chunked_prefill self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs @@ -148,7 +147,6 @@ def init_model(self, kvargs): "weight_dir": self.weight_dir, "max_total_token_num": max_total_token_num, "load_way": self.load_way, - "mode": self.mode, "max_req_num": kvargs.get("max_req_num", 1000), "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": kvargs.get("is_token_healing", False), @@ -168,12 +166,16 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + + self.use_buffer_manager = getattr(self.model, "use_buffer_manager", False) + + radix_cache_class = self.model.get_radix_cache_class() self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, + kv_cache_mem_manager=self.model.mem_manager, ) if self.use_dynamic_prompt_cache else None @@ -185,12 +187,18 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") + # Check if the model uses Mamba (linear attention) layers + from lightllm.common.req_manager import ReqManagerForMamba + + use_mamba_model = isinstance(self.model.req_manager, ReqManagerForMamba) + g_infer_context.register( backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, + use_mamba_model=use_mamba_model, ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 @@ -283,26 +291,37 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): raise NotImplementedError() def init_mtp_draft_model(self, main_kvargs: dict): - # 当前只支持 deepseekv3 模式的 mtp + # Support deepseekv3 and qwen3_next MTP modes self.mtp_step = self.args.mtp_step - self.draft_models: List[Deepseek3MTPModel] = [] + self.draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "deepseekv3_vanilla", "qwen3next_vanilla"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "deepseekv3_eagle", "qwen3next_eagle"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): + # Get MTP model config first to calculate mem_layer_start mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) + + # Calculate mem_layer_start: main model layers + previous MTP model layers + # For models with integrated MTP (like qwen3_next), each MTP module has 1 layer + # For models with separate MTP configs, use the config's num_hidden_layers + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "qwen3_next": + # Qwen3Next has integrated MTP with 1 layer per module + mtp_layers_per_module = 1 + else: + mtp_layers_per_module = mtp_model_cfg["num_hidden_layers"] + mem_layer_start = self.model.config["num_hidden_layers"] + i * mtp_layers_per_module mtp_model_kvargs = { "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, "load_way": main_kvargs["load_way"], - "mode": main_kvargs["mode"], "max_req_num": main_kvargs.get("max_req_num", 1000), "max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": False, @@ -319,20 +338,24 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), + "mem_layer_start": mem_layer_start, + "mtp_index": i, } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - if mtp_model_cfg["model_type"] == "deepseek_v3": - assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + # Select MTP model class based on model type + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "deepseek_v3": self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "qwen3_moe": + elif model_type == "qwen3_moe": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "mistral": + elif model_type == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif model_type == "qwen3_next": + self.draft_models.append(Qwen3NextMTPModel(mtp_model_kvargs)) else: - assert False, f"error mtp mode {mtp_model_cfg['model_type']}" + raise ValueError(f"Unsupported MTP model type: {model_type}") self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index f3450261b..970db7a0c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,6 +24,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) @@ -50,6 +51,14 @@ def __init__(self) -> None: self.classed_req_strict_prefill = False return + def _maybe_insert_hybrid_radix_cache(self, run_reqs: List[InferReq]): + # Insert hybrid radix cache entries if applicable, use for hybrid attention models. + if self.use_buffer_manager and self.radix_cache is not None: + torch.cuda.synchronize() + g_infer_state_lock.acquire() + self.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + g_infer_state_lock.release() + def infer_loop(self): torch.cuda.set_device(get_current_device_id()) try: @@ -136,6 +145,9 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -219,6 +231,8 @@ def prefill_mtp( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + self._maybe_insert_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -258,6 +272,24 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 (accept_len == 1 means buffer[0] is already correct) + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = model_input.b_req_idx[b_req_mtp_start_loc[mask]] + # Source: the accepted buffer (at index accept_len - 1) + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + # Destination: buffer[0] for each request + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + # P2P copy both conv_states and ssm_states + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() @@ -298,7 +330,7 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) - select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") + select_mask = accepted_index_cpu.to(dtype=torch.bool) self._post_handle( run_reqs=verify_ok_reqs, next_token_ids=next_token_ids_cpu[select_mask], @@ -399,7 +431,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] draft_model_input.mem_indexes = torch.cat( [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index df10a6d4e..c5dd76822 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -454,6 +454,20 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): gpu_tensor=mtp_accept_len, ) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() @@ -591,7 +605,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) # update the meta info of the inference draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -767,6 +781,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() @@ -955,7 +983,7 @@ def _draft_decode_eagle_overlap( ) draft_model_input0.b_seq_len += 1 - draft_model_input0.max_len_in_batch += 1 + draft_model_input0.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes0[_step * real_req_num0 : (_step + 1) * real_req_num0] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -969,7 +997,7 @@ def _draft_decode_eagle_overlap( ).view(-1) draft_model_input1.b_seq_len += 1 - draft_model_input1.max_len_in_batch += 1 + draft_model_input1.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes1[_step * real_req_num1 : (_step + 1) * real_req_num1] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 6465995c4..03ac4cfb0 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -73,7 +73,6 @@ def padded_prepare_prefill_inputs( max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") @@ -102,7 +101,6 @@ def padded_prepare_prefill_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -150,6 +148,7 @@ def padded_prepare_decode_inputs( seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1 b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) @@ -160,32 +159,30 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_req_idx.append(req.req_idx) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(step + 1) batch_multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) - # padding fake req for decode for _ in range(padded_req_num): seq_len = 2 total_token_num += seq_len b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(0) batch_multimodal_params.append({"images": [], "audios": []}) for step in range(args_mtp_step): seq_len += 1 total_token_num += seq_len b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_mtp_index.append(step + 1) batch_multimodal_params.append({"images": [], "audios": []}) - b_q_seq_len.append(1 + args_mtp_step) - max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_seq_len) b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") @@ -210,7 +207,6 @@ def padded_prepare_decode_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index bdb36054b..4eb8c7e1e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -50,7 +50,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) - max_len_in_batch = max(b_q_seq_len) max_q_seq_len = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) @@ -72,7 +71,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -95,7 +93,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs: List[InferReq] = [] total_token_num = 0 - max_len_in_batch = 0 b_req_idx = [] b_mtp_index = [] b_seq_len = [] @@ -107,8 +104,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. @@ -118,10 +115,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len += 1 b_seq_len.append(seq_len) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) + b_q_seq_len.append(1) max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) @@ -146,7 +142,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index cd48a355b..09d7a680f 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps): return num_sm +@lru_cache(maxsize=1) +def is_musa(): + return hasattr(torch.version, "musa") and torch.version.musa is not None + + @lru_cache(maxsize=None) def get_current_device_name(): - import torch - - if torch.cuda.is_available(): + if torch.cuda.is_available() or is_musa(): device = torch.cuda.current_device() gpu_name = torch.cuda.get_device_name(device) # 4090 trans to 4090 D @@ -103,8 +106,6 @@ def init_p2p(device_index): """ torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。 """ - import torch - num_gpus = torch.cuda.device_count() tensor = torch.zeros((1,)) tensor = tensor.to(f"cuda:{device_index}") @@ -127,8 +128,26 @@ def has_nvlink(): result = result.decode("utf-8") # Check if the output contains 'NVLink' return any(f"NV{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # nvidia-smi is not installed, assume no NVLink + return False + except subprocess.CalledProcessError: + # If there's an error while executing nvidia-smi, assume no NVLink + return False + + +def has_mtlink(): + try: + # Call mthreads-gmi to get the topology matrix + result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"]) + result = result.decode("utf-8") + # Check if the output contains 'MTLink' + return any(f"MT{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # mthreads-gmi is not installed, assume no MTLink + return False except subprocess.CalledProcessError: - # If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink + # If there's an error while executing mthreads-gmi, assume no MTLink return False diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b5822a342..0a70f1dfa 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,10 +26,16 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + + from lightllm.server.core.objs.start_args_type import StartArgs + + args: StartArgs = args + + if "flashinfer" in args.llm_prefill_att_backend or "flashinfer" in args.llm_decode_att_backend: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + return def set_env_start_args(args): @@ -209,7 +215,7 @@ def get_diverse_max_batch_shared_group_size() -> int: @lru_cache(maxsize=None) def enable_diverse_mode_gqa_decode_fast_kernel() -> bool: - return get_env_start_args().diverse_mode and "ppl_int8kv_flashdecoding_diverse" in get_env_start_args().mode + return get_env_start_args().diverse_mode and "int8kv" == get_env_start_args().llm_kv_type @lru_cache(maxsize=None) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 4875b4eee..3256fdd1f 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -19,13 +19,11 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.kv_cache_mem_manager import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from typing import List, Tuple, Optional @@ -77,17 +75,6 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is Deepseek2FP8KVMemoryManager: - cpu_cache_meta = CpuKVCacheMeta( - page_num=0, - token_page_size=args.cpu_cache_token_page_size, - layer_num=get_layer_num(args.model_dir), - num_heads=1, - head_dim=512 + 64 + 2, - data_type=torch.uint8, - scale_head_dim=0, - scale_data_type=get_llm_data_type(), - ) elif mem_manager_class is MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh index e167303a3..180d2d4e2 100644 --- a/test/acc/test_deepseekr1.sh +++ b/test/acc/test_deepseekr1.sh @@ -1,4 +1,4 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --enable_fa3 +LOADWORKER=18 python -m lightllm.server.api_server --batch_max_tokens 6000 --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh index 046314a72..7eaffd499 100644 --- a/test/acc/test_deepseekr1_mtp.sh +++ b/test/acc/test_deepseekr1_mtp.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 2ea5f7438..0467f76e6 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh index 265d679e8..bb5603b5b 100644 --- a/test/acc/test_qwen2.sh +++ b/test/acc/test_qwen2.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-Math-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh index c0da5ec96..36a3c9680 100644 --- a/test/acc/test_qwen3.sh +++ b/test/acc/test_qwen3.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/benchmark/qwen3next/benchmark_gdn_decode.py b/test/benchmark/qwen3next/benchmark_gdn_decode.py new file mode 100644 index 000000000..e68d79397 --- /dev/null +++ b/test/benchmark/qwen3next/benchmark_gdn_decode.py @@ -0,0 +1,140 @@ +"""Benchmark script for Qwen3Next GDN decode performance.""" + +import torch +import time +from typing import Callable + + +def benchmark_kernel( + fn: Callable, + warmup: int = 10, + iterations: int = 100, + sync: bool = True, +) -> float: + """Benchmark a kernel function.""" + # Warmup + for _ in range(warmup): + fn() + if sync: + torch.cuda.synchronize() + + # Benchmark + if sync: + torch.cuda.synchronize() + start = time.perf_counter() + + for _ in range(iterations): + fn() + if sync: + torch.cuda.synchronize() + + elapsed = time.perf_counter() - start + return elapsed / iterations * 1000 # ms + + +def main(): + """Run benchmarks.""" + if not torch.cuda.is_available(): + print("CUDA not available. This benchmark requires a GPU.") + return + + print("Qwen3Next GDN Decode Benchmarks") + print("=" * 50) + + # Test parameters matching real model + batch_size = 32 + mtp_size = 2 + total_tokens = batch_size * mtp_size + dim = 384 # typical qkv dim + num_heads = 8 + + device = "cuda" + dtype = torch.bfloat16 + + # Test data + mixed_qkv = torch.randn(total_tokens, dim, device=device, dtype=dtype) + + # Benchmark: strided slice + contiguous + def bench_contiguous(): + for step in range(mtp_size): + _ = mixed_qkv[step::mtp_size].contiguous() + + # Benchmark: copy to pre-allocated buffer + work_buffer = torch.empty(batch_size, dim, device=device, dtype=dtype) + + def bench_copy_to_buffer(): + for step in range(mtp_size): + work_buffer.copy_(mixed_qkv[step::mtp_size]) + + time_contiguous = benchmark_kernel(bench_contiguous) + time_copy = benchmark_kernel(bench_copy_to_buffer) + + print(f"\n1. MTP Decode Buffer Strategy:") + print(f" Strided .contiguous(): {time_contiguous:.3f} ms") + print(f" Copy to buffer: {time_copy:.3f} ms") + print(f" Speedup: {time_contiguous / time_copy:.2f}x") + + # Benchmark: torch.cat elimination + print(f"\n2. QKV Concatenation:") + q = torch.randn(batch_size, dim // 3, device=device, dtype=dtype) + k = torch.randn(batch_size, dim // 3, device=device, dtype=dtype) + v = torch.randn(batch_size, dim // 3, device=device, dtype=dtype) + + def bench_torch_cat(): + return torch.cat([q, k, v], dim=-1) + + # Pre-concatenated (simulating the optimization) + qkv_pre = torch.empty(batch_size, dim, device=device, dtype=dtype) + + def bench_pre_concat(): + qkv_pre[:, : dim // 3] = q + qkv_pre[:, dim // 3 : 2 * dim // 3] = k + qkv_pre[:, 2 * dim // 3 :] = v + return qkv_pre + + time_cat = benchmark_kernel(bench_torch_cat) + time_pre = benchmark_kernel(bench_pre_concat) + + print(f" torch.cat(): {time_cat:.3f} ms") + print(f" Pre-allocated: {time_pre:.3f} ms") + print(f" Speedup: {time_cat / time_pre:.2f}x") + + # Benchmark: Fused gating kernel + print(f"\n3. Fused Gating Kernel:") + try: + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import ( + fused_gdn_gating, + ) + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import ( + fused_gdn_gating_v2, + ) + + a = torch.randn(batch_size, num_heads, device=device, dtype=dtype) + b = torch.randn(batch_size, num_heads, device=device, dtype=dtype) + A_log = torch.randn(num_heads, device=device, dtype=torch.float32) + dt_bias = torch.randn(num_heads, device=device, dtype=torch.float32) + + def bench_original_gating(): + return fused_gdn_gating(A_log, a, b, dt_bias) + + g_out = torch.empty(1, batch_size, num_heads, device=device, dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device=device, dtype=torch.float32) + + def bench_v2_gating(): + return fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + time_orig = benchmark_kernel(bench_original_gating) + time_v2 = benchmark_kernel(bench_v2_gating) + + print(f" Original (allocates): {time_orig:.3f} ms") + print(f" V2 (pre-alloc): {time_v2:.3f} ms") + print(f" Speedup: {time_orig / time_v2:.2f}x") + except ImportError as e: + print(f" Skipped (import error): {e}") + + print(f"\n" + "=" * 50) + print("Benchmark complete.") + + +if __name__ == "__main__": + main() diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 3fc7ee4b4..7f1c2b493 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -41,7 +41,10 @@ def test_model_inference(args): "run_mode": "normal", "max_seq_length": args.max_req_total_len, "disable_cudagraph": args.disable_cudagraph, - "mode": args.mode, + "llm_prefill_att_backend": args.llm_prefill_att_backend, + "llm_decode_att_backend": args.llm_decode_att_backend, + "llm_kv_type": args.llm_kv_type, + "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, @@ -73,7 +76,6 @@ def overlap_prefill( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: total_token_num // 2] _0_mem_indexes = mem_indexes[: total_token_num // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -83,7 +85,6 @@ def overlap_prefill( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -96,7 +97,6 @@ def overlap_prefill( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[total_token_num // 2 :] _1_mem_indexes = mem_indexes[total_token_num // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -107,7 +107,6 @@ def overlap_prefill( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -129,7 +128,6 @@ def overlap_decode( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: batch_size // 2] _0_mem_indexes = mem_indexes[: batch_size // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -138,7 +136,6 @@ def overlap_decode( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -149,7 +146,6 @@ def overlap_decode( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[batch_size // 2 :] _1_mem_indexes = mem_indexes[batch_size // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -159,7 +155,6 @@ def overlap_decode( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -191,7 +186,6 @@ def prefill( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_len_in_batch, max_kv_seq_len=max_len_in_batch, max_cache_len=0, @@ -217,7 +211,6 @@ def decode( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 942af0f88..07ad52a13 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -129,7 +129,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=input_len, input_ids=test_data, mem_indexes=mem_indexes, b_req_idx=b_req_idx, @@ -197,7 +196,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, input_ids=decode_input_ids, mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, diff --git a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py index d391d3065..f32b09344 100644 --- a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py +++ b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py @@ -4,7 +4,7 @@ import torch.multiprocessing as mp from typing import List from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( flash_decode_stage1, GQADiverseDecodeStage1KernelConfig, ) diff --git a/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py new file mode 100644 index 000000000..13c8945e5 --- /dev/null +++ b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py @@ -0,0 +1,296 @@ +import torch +import os +import torch.multiprocessing as mp +from typing import List +from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, + GQADiverseDecodeStage2KernelConfig, +) +from lightllm.utils.watchdog_utils import Watchdog + +logger = init_logger(__name__) + + +def set_seed(): + import torch + import random + import numpy as np + + seed = 42 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return + + +@torch.no_grad() +def test_decode_attentions( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int = 20, + **run_config, +): + set_seed() + shared_seq_len = 0 + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + max_len_in_batch = 8192 + quant_group_size = 8 + + args = [] + for _ in range(test_count): + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=dtype, device="cuda") / 10 + kv_shape = (batch_size * seq_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, seq_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), + dtype=q.dtype, + device="cuda", + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), + dtype=q.dtype, + device="cuda", + ) + arg_list, kwargs = ( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + b_seq_len, + b_shared_seq_len, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, + ), dict(run_config=run_config) + args.append((arg_list, kwargs)) + + graph = torch.cuda.CUDAGraph() + arg_list, kwargs = args[0] + flash_decode_stage2(*arg_list, **kwargs) + with torch.cuda.graph(graph): + for index in range(test_count): + arg_list, kwargs = args[index] + flash_decode_stage2(*arg_list, **kwargs) + + graph.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + + cost_time = start_event.elapsed_time(end_event=end_event) + + logger.info(f"bf16 {seq_len} cost time: {cost_time} ms") + return cost_time + + +def worker( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, + test_configs, + queue, +): + dog = Watchdog(timeout=10) + dog.start() + + try: + for index in range(len(test_configs)): + tuning_config = test_configs[index] + cost_time = test_decode_attentions( + block_seq=block_seq, + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + test_count=test_count, + **tuning_config, + ) + dog.heartbeat() + queue.put(cost_time) + except Exception as ex: + logger.error(str(ex) + f" config {tuning_config} batch_size {batch_size} seq_len {seq_len} dtype {dtype}") + import sys + import traceback + + traceback.print_exc() + sys.exit(-1) + pass + + +def get_test_configs(split_id, split_count): + index = 0 + for block_n in [16, 32, 64]: + for num_warps in [ + 2, + 4, + 8, + 16, + ]: + for num_stages in [ + 1, + 2, + 3, + 4, + 5, + 7, + 9, + 10, + 11, + ]: + t_config = { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + if index % split_count == split_id: + yield t_config + index += 1 + else: + index += 1 + + +def tuning_configs( + device_id: int, + device_count: int, + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, +): + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) + best_config, best_cost_time = None, 10000000 + queue = mp.Queue() + test_configs = [] + for t_config in get_test_configs(device_id, device_count): + test_configs.append(t_config) + if len(test_configs) < 64: + continue + + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + while len(test_configs) != 0: + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + logger.info(f"{best_config} best cost: {best_cost_time}") + return best_config, best_cost_time + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + from lightllm.utils.tuning_utils import mp_tuning + import collections + + block_seq = 256 + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + num_heads = 32 + kv_head_num = 8 + q_head_dim = 128 + gqa_group_size = num_heads // kv_head_num + + store_json_ans = collections.defaultdict(dict) + + for seq_len in seq_lens: + for batch_size in batch_sizes: + ans = mp_tuning( + tuning_configs, + { + "block_seq": block_seq, + "batch_size": batch_size, + "seq_len": seq_len, + "dtype": torch.bfloat16, + "test_count": 1, + }, + ) + store_json_ans[seq_len][batch_size] = ans + + GQADiverseDecodeStage2KernelConfig.save_config( + gqa_group_size=gqa_group_size, + q_head_dim=q_head_dim, + block_seq=block_seq, + out_dtype=str(torch.bfloat16), + config_json=store_json_ans, + ) diff --git a/test/models/__init__.py b/test/models/__init__.py new file mode 100644 index 000000000..e47ecaea5 --- /dev/null +++ b/test/models/__init__.py @@ -0,0 +1 @@ +# Models test package diff --git a/test/models/qwen3next/__init__.py b/test/models/qwen3next/__init__.py new file mode 100644 index 000000000..dbdf1f4b8 --- /dev/null +++ b/test/models/qwen3next/__init__.py @@ -0,0 +1 @@ +# Qwen3Next model tests diff --git a/test/models/qwen3next/test_fused_qkv_gating.py b/test/models/qwen3next/test_fused_qkv_gating.py new file mode 100644 index 000000000..c0fcb6f9e --- /dev/null +++ b/test/models/qwen3next/test_fused_qkv_gating.py @@ -0,0 +1,161 @@ +"""Tests for fused QKV gating kernel.""" + +import pytest +import torch + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestFusedGDNGating: + """Test fused GDN gating kernel correctness.""" + + def test_fused_gating_matches_reference(self): + """Verify fused kernel matches reference implementation.""" + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + batch_size = 32 + num_heads = 8 + + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Reference + g_ref, beta_ref = fused_gdn_gating(A_log, a, b, dt_bias) + + # Fused v2 with pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + # Note: Reference kernel has precision loss (converts to bfloat16 before storing) + # so we use slightly relaxed tolerances + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(beta_out, beta_ref, rtol=5e-3, atol=5e-3) + + def test_fused_gating_various_batch_sizes(self): + """Test fused kernel with various batch sizes.""" + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + num_heads = 16 + + for batch_size in [1, 4, 16, 32, 64, 128]: + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Reference + g_ref, beta_ref = fused_gdn_gating(A_log, a, b, dt_bias) + + # Fused v2 with pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(beta_out, beta_ref, rtol=5e-3, atol=5e-3) + + def test_fused_gating_various_head_counts(self): + """Test fused kernel with various head counts.""" + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + batch_size = 32 + + for num_heads in [4, 8, 16, 32, 64]: + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Reference + g_ref, beta_ref = fused_gdn_gating(A_log, a, b, dt_bias) + + # Fused v2 with pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(beta_out, beta_ref, rtol=5e-3, atol=5e-3) + + def test_fused_gating_float16_input(self): + """Test fused kernel with float16 input tensors.""" + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + batch_size = 32 + num_heads = 8 + + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.float16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.float16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Reference + g_ref, beta_ref = fused_gdn_gating(A_log, a, b, dt_bias) + + # Fused v2 with pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(beta_out, beta_ref, rtol=5e-3, atol=5e-3) + + def test_fused_gating_returns_same_tensors(self): + """Test that fused_gdn_gating_v2 returns the same tensors passed in.""" + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + batch_size = 32 + num_heads = 8 + + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + + # Capture data_ptr before call + g_ptr = g_out.data_ptr() + beta_ptr = beta_out.data_ptr() + + g_ret, beta_ret = fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out) + + # Verify same tensors returned (not new allocations) + assert g_ret.data_ptr() == g_ptr, "g should be same tensor" + assert beta_ret.data_ptr() == beta_ptr, "beta should be same tensor" + assert g_ret is g_out, "Should return same tensor object" + assert beta_ret is beta_out, "Should return same tensor object" + + def test_fused_gating_custom_beta_threshold(self): + """Test fused kernel with custom beta and threshold parameters.""" + from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating + from lightllm.models.qwen3next.triton_kernel.fused_qkv_gating import fused_gdn_gating_v2 + + batch_size = 32 + num_heads = 8 + beta_const = 2.0 + threshold = 10.0 + + a = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + b = torch.randn(batch_size, num_heads, device="cuda", dtype=torch.bfloat16) + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float32) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float32) + + # Reference + g_ref, beta_ref = fused_gdn_gating(A_log, a, b, dt_bias, beta=beta_const, threshold=threshold) + + # Fused v2 with pre-allocated outputs + g_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + beta_out = torch.empty(1, batch_size, num_heads, device="cuda", dtype=torch.float32) + fused_gdn_gating_v2(a, b, A_log, dt_bias, g_out, beta_out, beta_const=beta_const, threshold=threshold) + + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(beta_out, beta_ref, rtol=5e-3, atol=5e-3) diff --git a/test/models/qwen3next/test_gdn_mtp_decode.py b/test/models/qwen3next/test_gdn_mtp_decode.py new file mode 100644 index 000000000..1ddd1f5fd --- /dev/null +++ b/test/models/qwen3next/test_gdn_mtp_decode.py @@ -0,0 +1,87 @@ +"""Tests for Qwen3Next GDN MTP decode optimization.""" +import pytest +import torch + + +def create_mock_infer_state(batch_size: int, mtp_step: int, device: str = "cuda"): + """Create a mock infer state for testing.""" + mtp_size = mtp_step + 1 + + class MockMemManager: + def get_mamba_cache(self, layer_num): + # conv_states: [num_buffers, dim, conv_width-1] + # ssm_states: [num_buffers, num_heads, key_dim, value_dim] + conv_states = torch.randn(batch_size * mtp_size * 2, 384, 3, device=device, dtype=torch.bfloat16) + ssm_states = torch.randn(batch_size * mtp_size * 2, 8, 128, 64, device=device, dtype=torch.float32) + return conv_states, ssm_states + + class MockInferState: + def __init__(self): + self.mem_manager = MockMemManager() + self.mtp_buffer_idx_list = torch.stack( + [torch.arange(batch_size, device=device, dtype=torch.int32) + i * batch_size for i in range(mtp_size)] + ) + self.b_buffer_idx = self.mtp_buffer_idx_list.flatten() + + return MockInferState() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestMTPDecodeOptimization: + """Test MTP decode memory optimization.""" + + def test_strided_slice_is_not_contiguous(self): + """Verify that strided slices are not contiguous (baseline understanding).""" + mtp_size = 2 + batch_size = 4 + dim = 128 + + # Interleaved layout: [step0_batch0, step0_batch1, ..., step1_batch0, ...] + mixed_qkv = torch.randn(batch_size * mtp_size, dim, device="cuda") + + # Strided slice for step 0 + slice_step0 = mixed_qkv[0::mtp_size] + assert not slice_step0.is_contiguous(), "Strided slice should not be contiguous" + + def test_contiguous_buffer_reuse(self): + """Test that pre-allocated contiguous buffer can be reused across steps.""" + mtp_size = 2 + batch_size = 4 + dim = 128 + + mixed_qkv = torch.randn(batch_size * mtp_size, dim, device="cuda") + work_buffer = torch.empty(batch_size, dim, device="cuda") + + for step_idx in range(mtp_size): + # Copy strided data to contiguous buffer + work_buffer.copy_(mixed_qkv[step_idx::mtp_size]) + assert work_buffer.is_contiguous() + + # Simulate in-place operation + work_buffer.mul_(2.0) + + # Copy back + mixed_qkv[step_idx::mtp_size].copy_(work_buffer) + + # Verify all data was modified + assert torch.allclose(mixed_qkv, mixed_qkv) # Basic sanity check + + def test_output_direct_write_vs_copy(self): + """Test that direct slice assignment works for output tensor.""" + mtp_size = 2 + batch_size = 4 + num_heads = 8 + head_dim = 64 + total_tokens = batch_size * mtp_size + + # Pre-allocated output + core_attn_out = torch.empty(total_tokens, 1, num_heads, head_dim, device="cuda") + + for step_idx in range(mtp_size): + # Simulate kernel output (batch_size, 1, num_heads, head_dim) + step_output = torch.randn(batch_size, 1, num_heads, head_dim, device="cuda") + + # Direct assignment to strided view (this is what we want to verify works) + core_attn_out[step_idx::mtp_size] = step_output + + assert core_attn_out.shape == (total_tokens, 1, num_heads, head_dim) diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index f5dae19b9..e00af2713 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -108,7 +108,6 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree -- `--enable_fa3`: Enable Flash Attention 3.0 - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/draft.sh b/test/start_scripts/draft.sh index 866f5f2fa..04f573cd6 100644 --- a/test/start_scripts/draft.sh +++ b/test/start_scripts/draft.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 @@ -16,7 +16,7 @@ HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index 3a139968a..68f80b39d 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index b24a59868..10aee8528 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node0.sh b/test/start_scripts/multi_node_tp_node0.sh index b86bdeb35..d750da93c 100644 --- a/test/start_scripts/multi_node_tp_node0.sh +++ b/test/start_scripts/multi_node_tp_node0.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node1.sh b/test/start_scripts/multi_node_tp_node1.sh index 378977ab2..cb495496e 100644 --- a/test/start_scripts/multi_node_tp_node1.sh +++ b/test/start_scripts/multi_node_tp_node1.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh index 4cefef6fb..2b7bb80d7 100644 --- a/test/start_scripts/multi_pd_master/pd_decode.sh +++ b/test/start_scripts/multi_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 12322 \ --tp 8 \ --dp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index b845da435..eaa343ef6 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --tp 8 \ --dp 8 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index cad172d51..7406d9462 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -3,7 +3,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp.sh b/test/start_scripts/single_node_tp.sh index 1fb461bb1..ee10b6c10 100644 --- a/test/start_scripts/single_node_tp.sh +++ b/test/start_scripts/single_node_tp.sh @@ -2,7 +2,7 @@ LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp_cpu_cache_enable.sh b/test/start_scripts/single_node_tp_cpu_cache_enable.sh index 3caabb59b..47da83dbe 100644 --- a/test/start_scripts/single_node_tp_cpu_cache_enable.sh +++ b/test/start_scripts/single_node_tp_cpu_cache_enable.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index ae16b96ad..36804dd11 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 1b43c11cc..5fb34a973 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -18,7 +18,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index 303de2975..5a37df0b1 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -19,7 +19,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index f6e2e4b68..b94a1f8cc 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/test_api/test_generate_api.py b/test/test_api/test_generate_api.py index 05fbda44e..4ea74b7f6 100644 --- a/test/test_api/test_generate_api.py +++ b/test/test_api/test_generate_api.py @@ -19,7 +19,7 @@ def run(self): print("Error:", response.status_code, response.text) -url = "http://localhost:8000/generate" +url = "http://localhost:8089/generate" headers = {"Content-Type": "application/json"} for i in range(1): diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py similarity index 81% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py index 67463fa7a..a01bbf32d 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py @@ -1,4 +1,5 @@ import pytest + import torch from lightllm.utils.light_utils import light_ops @@ -21,7 +22,7 @@ class MockInferState: def __init__( self, batch_size, - max_len_in_batch, + max_kv_seq_len, req_to_tokens, b_req_idx, b_seq_len, @@ -29,7 +30,7 @@ def __init__( b_mark_shared_group=None, ): self.batch_size = batch_size - self.max_len_in_batch = max_len_in_batch + self.max_kv_seq_len = max_kv_seq_len self.req_manager = MockReqManager(req_to_tokens) self.b_req_idx = b_req_idx self.b_seq_len = b_seq_len @@ -39,30 +40,32 @@ def __init__( # @pytest.mark.parametrize("shared_seq_len", [512]) @pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550]) -def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len): +@pytest.mark.parametrize("batch_size", list(range(6, 121, 6))) +def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size): """ - 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding + 测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding 与 ppl_int8kv_flash_decoding (baseline) 的对比。 """ - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( + + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding as baseline_attention, ) - batch_size = 6 num_heads = 32 kv_head_num = 8 mark_shared_group_size = 3 - seq_len = 1024 + seq_len = 3547 head_dim = 128 quant_group_size = 8 + max_len_in_batch = 8192 test_dtype = torch.bfloat16 # 创建测试数据 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim) + kv_scale_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") @@ -73,7 +76,9 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le cache_v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") / 100.0 - req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + req_to_tokens = torch.arange(0, max_len_in_batch * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, max_len_in_batch + ) for i in range(batch_size): if i % mark_shared_group_size != 0: req_to_tokens[i, :shared_seq_len] = req_to_tokens[i - 1, :shared_seq_len] @@ -87,7 +92,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) baseline_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -96,7 +101,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 diverse 的 infer_state diverse_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -108,8 +113,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le baseline_out = baseline_attention( q=q.clone(), infer_state=baseline_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, @@ -120,8 +123,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le diverse_out = diverse_attention( q=q.clone(), infer_state=diverse_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, @@ -129,7 +130,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le alloc_tensor_func=alloc_tensor_func, ) - print(f"\nshared_seq_len={shared_seq_len}") + print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}") print(f"baseline_out: {baseline_out[0, 0, :4]}") print(f"diverse_out: {diverse_out[0, 0, :4]}") print(f"max diff: {(baseline_out - diverse_out).abs().max()}") diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py similarity index 50% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py index 30e83b88b..f3cb8de46 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py @@ -1,39 +1,48 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage1 import ( + flash_decode_stage1, +) -@pytest.fixture -def setup_tensors(): - batch_size = 4 - num_heads = 4 - kv_head_num = 1 - seq_len = 256 +def create_tensors( + batch_size=4, + num_heads=4, + kv_head_num=1, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): head_dim = 128 - max_len_in_batch = seq_len block_seq = 256 - max_batch_group_size = 4 quant_group_size = 8 test_dtype = torch.bfloat16 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") b_shared_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" ) mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" ) return { @@ -54,6 +63,11 @@ def setup_tensors(): } +@pytest.fixture +def setup_tensors(): + return create_tensors() + + def test_flash_decode_stage1_execution(setup_tensors): flash_decode_stage1( q=setup_tensors["q"], @@ -81,7 +95,7 @@ def test_flash_decode_stage1_execution(setup_tensors): new_k = k.to(q.dtype) new_v = v.to(q.dtype) - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( flash_decode_stage1 as gqa_flash_decode_stage1, ) @@ -104,3 +118,71 @@ def test_flash_decode_stage1_execution(setup_tensors): assert torch.allclose( setup_tensors["mid_out_logsumexp"], true_mid_out_logsumexp, atol=1e-2 ), "LogSumExp output does not match expected values" + + +def autotune_and_benchmark(): + import triton + + batch_sizes = [8, 16, 32, 64] + seq_lens = [1024, 2048, 4096] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + batch_size=batch, + num_heads=32, + kv_head_num=8, + seq_len=seq, + max_len_in_batch=8192, + max_batch_group_size=8, + kv_len=seq, + req_to_tokens_len=seq, + ) + + def fn_triton(st=setup_tensors): + return flash_decode_stage1( + q=st["q"], + k=st["k"], + k_scale=st["k_scale"], + v=st["v"], + v_scale=st["v_scale"], + Req_to_tokens=st["Req_to_tokens"], + B_req_idx=st["B_req_idx"], + b_shared_seq_len=st["b_shared_seq_len"], + b_mark_shared_group=st["b_mark_shared_group"], + max_len_in_batch=st["max_len_in_batch"], + mid_out=st["mid_out"], + mid_out_logsumexp=st["mid_out_logsumexp"], + block_seq=st["block_seq"], + max_batch_group_size=st["max_batch_group_size"], + ) + + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12}") + print(f"{'-'*80}") + for r in results: + print(f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f}") + print(f"{'='*80}") + + +if __name__ == "__main__": + autotune_and_benchmark() diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 000000000..c7d444254 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,293 @@ +import pytest +import torch +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, +) + + +def create_tensors( + shared_seq_len, + batch_size=4, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + block_seq = 256 + quant_group_size = 8 + + test_dtype = torch.bfloat16 + + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) + + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" + ) + + return { + "q": q, + "k": k, + "k_scale": k_scale, + "v": v, + "v_scale": v_scale, + "Req_to_tokens": Req_to_tokens, + "B_req_idx": B_req_idx, + "b_seq_len": b_seq_len, + "b_shared_seq_len": b_shared_seq_len, + "b_mark_shared_group": b_mark_shared_group, + "max_len_in_batch": max_len_in_batch, + "mid_out": mid_out, + "mid_out_logsumexp": mid_out_logsumexp, + "block_seq": block_seq, + "max_batch_group_size": max_batch_group_size, + "head_dim": head_dim, + } + + +@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) +def test_flash_decode_stage2_execution(shared_seq_len): + setup_tensors = create_tensors(shared_seq_len) + + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ + "block_seq" + ] + mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] + mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] + + q = setup_tensors["q"] + k = setup_tensors["k"] + v = setup_tensors["v"] + true_mid_out = torch.zeros_like(mid_out) + true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) + new_q = q + new_k = k.to(q.dtype) + new_v = v.to(q.dtype) + + b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] + req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] + + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( + flash_decode_stage1 as gqa_flash_decode_stage1, + ) + + gqa_flash_decode_stage1( + q=new_q, + k=new_k, + v=new_v, + Req_to_tokens=req_to_tokens, + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=b_seq_len, + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=true_mid_out, + mid_out_logsumexp=true_mid_out_logsumexp, + block_seq=setup_tensors["block_seq"], + ) + print(f"\nshared_seq_len={shared_seq_len}") + print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") + print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") + abs_diff = (mid_out - true_mid_out).abs() + max_diff = abs_diff.max() + max_diff_idx = abs_diff.argmax() + max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) + mid_out_value = mid_out[max_diff_idx_unraveled] + true_mid_out_value = true_mid_out[max_diff_idx_unraveled] + print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") + + assert torch.allclose( + mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 + ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" + assert torch.allclose( + mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 + ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" + + +if __name__ == "__main__": + import importlib + import triton + from lightllm.utils.light_utils import light_ops + + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + shared_seq_len=0, + batch_size=batch, + seq_len=seq, + max_len_in_batch=8192, + kv_len=seq, + req_to_tokens_len=seq, + ) + + # Outputs for CUDA implementation + mid_out_cuda = setup_tensors["mid_out"].clone() + mid_out_logsumexp_cuda = setup_tensors["mid_out_logsumexp"].clone() + + # Outputs for Triton implementation + mid_out_triton = setup_tensors["mid_out"].clone() + mid_out_logsumexp_triton = setup_tensors["mid_out_logsumexp"].clone() + + # Run CUDA to get reference + light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + mid_out_cuda, + mid_out_logsumexp_cuda, + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + + # Run Triton + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=mid_out_triton, + mid_out_logsumexp=mid_out_logsumexp_triton, + block_seq=setup_tensors["block_seq"], + ) + + # Compare results + diff_mid_out = torch.abs(mid_out_cuda - mid_out_triton) + diff_logsumexp = torch.abs(mid_out_logsumexp_cuda - mid_out_logsumexp_triton) + max_diff_out = diff_mid_out.max().item() + max_diff_logsumexp = diff_logsumexp.max().item() + mean_diff_out = diff_mid_out.mean().item() + mean_diff_logsumexp = diff_logsumexp.mean().item() + + cos_sim_out = torch.nn.functional.cosine_similarity( + mid_out_cuda.flatten(), mid_out_triton.flatten(), dim=0 + ).item() + cos_sim_logsumexp = torch.nn.functional.cosine_similarity( + mid_out_logsumexp_cuda.flatten(), mid_out_logsumexp_triton.flatten(), dim=0 + ).item() + + print(f"\n[batch={batch}, seq={seq}] Consistency check:") + print(" mid_out:") + print(f" max_diff: {max_diff_out:.6f}, mean_diff: {mean_diff_out:.6f}, cosine_sim: {cos_sim_out:.8f}") + print(" logsumexp:") + print( + f" max_diff: {max_diff_logsumexp:.6f}, " + f"mean_diff: {mean_diff_logsumexp:.6f}, " + f"cosine_sim: {cos_sim_logsumexp:.8f}" + ) + + # Performance + fn_cuda = lambda: light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + setup_tensors["mid_out"], + setup_tensors["mid_out_logsumexp"], + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + ms_cuda = triton.testing.do_bench_cudagraph(fn_cuda, rep=100) + + fn_triton = lambda: flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + "cuda_ms": ms_cuda, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12} {'cuda_ms':<12} {'vs cuda':<10}") + print(f"{'-'*80}") + for r in results: + vs_cuda = f"{r['cuda_ms']/r['triton_ms']:.2f}x" + emoji = "🎉" if r["triton_ms"] < r["cuda_ms"] else "" + print( + f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f} {r['cuda_ms']:<12.3f}" + f"{vs_cuda:<10} {emoji}" + ) + print(f"{'='*80}") diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py similarity index 79% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py index b406e2dcf..c1a0ca1e5 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py @@ -1,6 +1,8 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage3 import ( + flash_diverse_decode_stage3, +) @pytest.mark.parametrize( @@ -23,7 +25,10 @@ def test_flash_diverse_decode_stage3(batch, head_num, seq_len, shared_seq_len, b flash_diverse_decode_stage3(mid_out, mid_out_logexpsum, B_Seqlen, b_shared_seq_len, out, block_seq) true_out = torch.zeros_like(out) - from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding_stage2 import ( + flash_decode_stage2, + ) flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, true_out, block_seq) diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py new file mode 100644 index 000000000..d1a53f873 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py @@ -0,0 +1,104 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( + context_attention_fwd_contiguous_kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len): + batch = b_start_loc.shape[0] + + for i in range(batch): + start_loc = b_start_loc[i] + kv_start_loc = b_kv_start_loc[i] + seq_len = b_seq_len[i] + prompt_cache_len = b_prompt_cache_len[i] + cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] + cur_q = cur_q.clone().to(torch.float32) + cur_k = k[kv_start_loc : (kv_start_loc + seq_len), :, :] + cur_k = cur_k.clone().to(torch.float32) + + cur_v = v[kv_start_loc : (kv_start_loc + seq_len), :, :] + cur_v = cur_v.clone().to(torch.float32) + + dk = cur_q.shape[-1] + cur_q = cur_q.permute(1, 0, 2) + cur_k = cur_k.permute(1, 2, 0) + cur_v = cur_v.permute(1, 0, 2) + dk = cur_q.shape[-1] + + p = torch.matmul(cur_q, cur_k) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) + + q_index = (torch.arange(cur_q.shape[1]).to(p.device) + prompt_cache_len).view(-1, 1) + k_index = torch.arange(seq_len).to(p.device).view(1, -1) + + p[:, (q_index < k_index)] = float("-inf") + + s = torch.nn.functional.softmax(p, dim=-1) + + o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, prompt_cache_len", + [ + (b, H, N_CTX, D_HEAD, prompt_cache_len) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for prompt_cache_len in [0, 56, 200] + ], +) +def test_context_attention_fwd_contiguous_kv(B, H, N_CTX, D_HEAD, prompt_cache_len): + dtype = torch.float16 + prompt_cache_len = 0 + if prompt_cache_len >= N_CTX - 1: + return + + q = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + k = kv[:, :H, :] + v = kv[:, H:, :] + + o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + torch_o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + + max_q_input_len = N_CTX - prompt_cache_len + + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.zeros(B, dtype=torch.int32, device="cuda") + + for i in range(B): + b_seq_len[i] = N_CTX + if i != 0: + b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len + b_prompt_cache_len[i] = prompt_cache_len + + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len) + context_attention_fwd_contiguous_kv( + q=q, + k=k, + v=v, + o=o, + b_start_loc=b_start_loc, + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + max_q_input_len=max_q_input_len, + b_prompt_cache_len=b_prompt_cache_len, + ) + + assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(o.flatten().float(), torch_o.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py similarity index 92% rename from unit_tests/models/llama/test_context_flashattention_nopad.py rename to unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py index f24ab619b..541594306 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py @@ -5,12 +5,11 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( context_attention_fwd, context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -54,14 +53,14 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc + infer_state.b_q_start_loc = q_start_loc context_attention_fwd( q, @@ -69,10 +68,10 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): kv[:, KV_HEADS:, :], o, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, ) @@ -127,7 +126,11 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): "batch, seqlen, q_heads, kv_heads, head_dim", [ (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] + for a in [ + 1, + 16, + 32, + ] for b in [16, 32, 512, 1024] for c in [28] for d in [4] @@ -149,18 +152,18 @@ def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc + infer_state.b_q_start_loc = b_start_loc context_attention_fwd_no_prompt_cache( q, k, v, o, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, ) head_dim = HEAD_DIM diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py similarity index 93% rename from unit_tests/models/deepseek2/test_destindex_copy_kv.py rename to unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py index 1379dc72d..ed0c6e369 100644 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py @@ -1,6 +1,6 @@ import torch import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.common.basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv from lightllm.utils.log_utils import init_logger import torch.nn.functional as F diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py new file mode 100644 index 000000000..83537ec70 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -0,0 +1,62 @@ +import torch +import pytest +import numpy as np +from typing import Tuple +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv, dequantize_int4kv +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def test_quanted_and_dequant(): + """Test quantization followed by dequantization.""" + batch_size = 1 + seq_len = 8 + head_num = 4 + k_head_num = 2 + v_head_num = 2 + assert k_head_num + v_head_num == head_num + head_dim = 64 + quant_group_size = 8 + + # Create original data + original_kv = torch.randn(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).clamp_(-1, 1).cuda() + dest_loc = torch.arange(batch_size * seq_len, dtype=torch.int64).cuda() + + # Quantize + group_count = head_dim // quant_group_size + kv_buffer = torch.zeros(batch_size * seq_len, head_num, head_dim // 2, dtype=torch.int8).cuda() + kv_scale_buffer = torch.zeros(batch_size * seq_len, head_num, group_count, dtype=torch.float32).cuda() + destindex_copy_int4kv(original_kv, dest_loc, kv_buffer, kv_scale_buffer, quant_group_size) + + # Dequantize + req_to_token_indexs = torch.arange(seq_len, dtype=torch.int64).view(1, -1).cuda() + b_seq_len = torch.tensor([seq_len], dtype=torch.int32).cuda() + b_req_idx = torch.tensor([0], dtype=torch.int32).cuda() + b_kv_start_loc = torch.tensor([0], dtype=torch.int32).cuda() + + recovered_kv = torch.zeros(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).cuda() + + dequantize_int4kv( + k=kv_buffer[:, 0:k_head_num, :], + k_scale=kv_scale_buffer[:, 0:k_head_num, :], + v=kv_buffer[:, k_head_num:, :], + v_scale=kv_scale_buffer[:, k_head_num:, :], + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=recovered_kv[:, :k_head_num, :], + v_out=recovered_kv[:, k_head_num:, :], + max_len_in_batch=seq_len, + quant_group_size=quant_group_size, + ) + + logger.info("Round-trip test completed!") + assert torch.allclose(recovered_kv, original_kv, atol=2 / 14, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py new file mode 100644 index 000000000..149c9894a --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py @@ -0,0 +1,86 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import ( + dequantize_int8kv, + destindex_copy_quantize_kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_dequant(kv, kv_scale, b_req_idx, b_seq_len, req_to_token_indexs, odtype, group_quant_size): + batch = b_req_idx.shape[0] + tmp_out = [] + for i in range(batch): + req_idx = b_req_idx[i] + seq_len = b_seq_len[i] + kv_loc = req_to_token_indexs[req_idx, :seq_len] + head_num = kv.shape[1] + cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, group_quant_size).to(odtype) + cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) + out = cur_kv * cur_scale + tmp_out.append(out.reshape(seq_len, head_num, -1)) + return torch.cat(tmp_out, dim=0) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, group_quant_size", + [ + (b, H, N_CTX, D_HEAD, group_quant_size) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for group_quant_size in [8, 16] + ], +) +def test_dequantize_int8kv(B, H, N_CTX, D_HEAD, group_quant_size): + dtype = torch.bfloat16 + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=torch.int8, device="cuda").random_(-10, 10) + kv_scale = torch.randn((B * N_CTX, 2 * H, D_HEAD // group_quant_size), dtype=dtype, device="cuda") + out = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda") + req_to_token_indexs = torch.empty((B, N_CTX), dtype=torch.int32, device="cuda") + max_input_len = N_CTX + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_seq_len.fill_(N_CTX) + b_req_idx = torch.arange(0, B, dtype=torch.int32, device="cuda") + req_to_token_indexs.view(-1)[:] = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + + k = kv[:, :H, :] + v = kv[:, H:, :] + k_scale = kv_scale[:, :H, :] + v_scale = kv_scale[:, H:, :] + + ground_out = torch_dequant( + kv=kv, + kv_scale=kv_scale, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + odtype=out.dtype, + group_quant_size=group_quant_size, + ) + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=out[:, :H, :], + v_out=out[:, H:, :], + max_len_in_batch=max_input_len, + quant_group_size=group_quant_size, + ) + assert torch.allclose(out, ground_out, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(out.flatten().float(), ground_out.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py similarity index 92% rename from unit_tests/models/deepseek2/test_gqa_flash_decoding.py rename to unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py index d0bc670ec..a5ac9708d 100644 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py +++ b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py @@ -5,9 +5,10 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.common.basemodel.triton_kernel.mla_att.decode_att.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, +) from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -53,7 +54,7 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len @@ -67,10 +68,6 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): kv_nope, kv_rope, infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, sm_scale, o, ) diff --git a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py index 536cad90f..0afcd5558 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py +++ b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py @@ -18,10 +18,10 @@ def test_add_in_place(): assert input.item() == 3, "最终值应为 3" -@pytest.mark.timeout(2) -def test_wait_timeout(): - input = torch.zeros((1,), device="cuda", dtype=torch.int32) - wait_value(input, 4) +# @pytest.mark.timeout(2) +# def test_wait_timeout(): +# input = torch.zeros((1,), device="cuda", dtype=torch.int32) +# wait_value(input, 4) if __name__ == "__main__": diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py index 5c3ca89c6..41bc217b9 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py @@ -1,21 +1,9 @@ import torch import pytest -import easydict from lightllm.common.basemodel.triton_kernel.gen_decode_params import gen_decode_params -from lightllm.utils.envs_utils import set_env_start_args def test_gen_decode_params_basic(): - set_env_start_args( - easydict.EasyDict( - { - "mtp_step": 0, - "enable_flashinfer_prefill": False, - "enable_flashinfer_decode": False, - } - ) - ) - b_seq_len = torch.ones((9,), dtype=torch.int64, device="cuda") * 8192 ( b_q_seq_len, diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py index 99971dea2..e9d019327 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py @@ -25,6 +25,7 @@ def test_token_id_counter(): for _ in range(100): token_id_counter(prompt_ids=test_prompt_ids, out_token_id_counter=test_token_id_counter) end_event.record() + end_event.synchronize() logger.info(f"test_token_id_count cost time: {start_event.elapsed_time(end_event)} ms") diff --git a/unit_tests/models/deepseek2/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py similarity index 96% rename from unit_tests/models/deepseek2/test_repack_kv_index.py rename to unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index f9e5928a9..b5184d3ca 100644 --- a/unit_tests/models/deepseek2/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,7 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index logger = init_logger(__name__) diff --git a/unit_tests/common/fused_moe/test_deepep.py b/unit_tests/common/fused_moe/test_deepep.py index c846be096..45778244b 100644 --- a/unit_tests/common/fused_moe/test_deepep.py +++ b/unit_tests/common/fused_moe/test_deepep.py @@ -1,12 +1,13 @@ +import pytest + +pytest.skip(reason="need special env, install deep_ep and deep_gemm", allow_module_level=True) + import os import torch import torch.distributed as dist -import pytest import deep_ep import random import numpy as np -from deep_ep import Buffer, EventOverlap -from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from typing import Tuple @@ -25,6 +26,8 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape + from deep_gemm import ceil_div + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index eba15b2a1..671805a3d 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -1,6 +1,18 @@ import torch -import time import pytest + + +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True) + import random from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd diff --git a/unit_tests/common/fused_moe/test_softmax_topk.py b/unit_tests/common/fused_moe/test_softmax_topk.py index 262c37a0f..6252dfa8c 100755 --- a/unit_tests/common/fused_moe/test_softmax_topk.py +++ b/unit_tests/common/fused_moe/test_softmax_topk.py @@ -9,7 +9,10 @@ def benchmark(M, N, K, renorm, runs): - import sgl_kernel as sgl_ops + try: + import sgl_kernel as sgl_ops + except Exception as e: + pytest.skip(f"no sgl_kernel error: {str(e)}", allow_module_level=True) gating = torch.randn(M, N, device="cuda", dtype=torch.float32) torch.cuda.synchronize() diff --git a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py index 1ddb20b63..2c0b7bf76 100644 --- a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py +++ b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py @@ -4,6 +4,18 @@ from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip("not support fp8 in this gpu card", allow_module_level=True) + + @pytest.mark.parametrize("M", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("N,K", [(2048, 2048), (4096, 5120), (8192, 4096)]) @pytest.mark.parametrize("output_dtype", [torch.bfloat16]) diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py b/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py deleted file mode 100644 index 4f9c0a337..000000000 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.utils.log_utils import init_logger -import torch.nn.functional as F - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head, copy_len", - [ - (a, b, c, d, e, f) - for a in [1, 16, 32, 128, 512] - for b in [1024, 2048] - for c in [1] - for d in [512] - for e in [64] - for f in [10, 20, 100, 1024] - ], -) -def test_destindex_copy_kv_fp8(batch, seqlen, heads, nope_head, rope_head, copy_len): - B, N_CTX, H, NOPE_HEAD, ROPE_HEAD, COPY_LEN = batch, seqlen, heads, nope_head, rope_head, copy_len - dtype = torch.bfloat16 - NUM = COPY_LEN - dest_loc = torch.arange(NUM).cuda() - kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() - out = torch.zeros((B * N_CTX, H, NOPE_HEAD + ROPE_HEAD + 2), dtype=torch.uint8).cuda() - - fp8_type = torch.float8_e4m3fn - kv_nope = kv[:, :, :NOPE_HEAD] - kv_rope = kv[:, :, NOPE_HEAD:] - O_nope = out[:, :, :NOPE_HEAD].view(fp8_type) - O_rope = out[:, :, NOPE_HEAD:-2].view(fp8_type) - O_scale = out[:, :, -2:].view(dtype) - destindex_copy_kv_fp8(kv_nope, kv_rope, dest_loc, O_nope, O_rope, O_scale) - - cos1 = F.cosine_similarity(O_nope[:NUM].to(dtype) * O_scale[:NUM], kv_nope).mean() - cos2 = F.cosine_similarity(O_rope[:NUM].to(dtype) * O_scale[:NUM], kv_rope).mean() - assert cos1 > 0.98 - assert cos2 > 0.98 diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py b/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py deleted file mode 100644 index 72d9d9acc..000000000 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head", - [(a, b, c, d, e) for a in [1, 16, 32, 128] for b in [16, 32, 512, 2048] for c in [16] for d in [512] for e in [64]], -) -def test_gqa_flash_decoding_fp8(batch, seqlen, heads, nope_head, rope_head): - Z, N_CTX, H, D_HEAD, ROPE_HEAD = batch, seqlen, heads, nope_head, rope_head - dtype = torch.bfloat16 - sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5) - q = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda") - q_rope = torch.randn((Z, H, ROPE_HEAD), dtype=dtype, device="cuda") - - kv = torch.randn((Z * N_CTX, 1, D_HEAD + ROPE_HEAD), dtype=dtype, device="cuda") - kv_scale = torch.randn((Z * N_CTX, 1, 1), dtype=dtype, device="cuda") - kv_fp8 = kv.to(torch.float8_e4m3fn) - - req_to_token_indexs = torch.zeros((10, Z * N_CTX), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = N_CTX - b_req_idx[0] = 0 - req_to_token_indexs[0][:N_CTX] = torch.tensor(np.arange(N_CTX), dtype=torch.int32).cuda() - - o = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - o1 = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - - infer_state = Deepseek2InferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - - kv_nope = kv_fp8[:, :, :D_HEAD].to(dtype) * kv_scale - kv_rope = kv_fp8[:, :, D_HEAD:].to(dtype) * kv_scale - gqa_token_decode_attention_flash_decoding( - q, - q_rope, - kv_nope, - kv_rope, - infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, - sm_scale, - o, - ) - - kv_nope_fp8 = kv_fp8[:, :, :D_HEAD] - kv_rope_fp8 = kv_fp8[:, :, D_HEAD:] - gqa_token_decode_attention_flash_decoding_fp8( - q, q_rope, kv_nope_fp8, kv_rope_fp8, kv_scale, infer_state, H, D_HEAD, ROPE_HEAD, D_HEAD, sm_scale, o1 - ) - - cos_sim = F.cosine_similarity(o, o1).mean() - assert cos_sim > 0.99 diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py deleted file mode 100644 index 737bb655b..000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import time -import pytest -import triton as tl -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - - k_cache = kv[:, :KV_HEADS, :] - v_cache = kv[:, KV_HEADS:, :] - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), - # v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), - # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=N_CTX, - # causal=True, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = q_per_head_fp8_quant(q.view(q.shape[0], kv_heads, -1), q_lens, q_starts) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=N_CTX, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1.item() == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5ee2306ad..000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 64 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - kv_start_loc = b_seq_len.cumsum(0) - b_seq_len - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_indptr = q_starts.int() - kv_indptr = kv_starts.int() - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, kv_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, - qo_indptr_buf=q_indptr, - paged_kv_indptr_buf=kv_indptr, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len_buffer, - ) - kv_last_page_len = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - q_heads, - kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=q.dtype, - kv_data_type=torch.float8_e4m3fn, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_flashinfer_fp8(16, 1024, 28, 4, 128) diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py deleted file mode 100644 index 2ba085cc9..000000000 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import torch -from lightllm.utils.light_utils import light_ops - - -def create_tensors(shared_seq_len): - batch_size = 4 - num_heads = 32 - kv_head_num = 8 - seq_len = 256 - head_dim = 128 - max_len_in_batch = seq_len - block_seq = 256 - max_batch_group_size = 4 - quant_group_size = 8 - - test_dtype = torch.bfloat16 - - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) - - q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") - k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) - B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") - b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") - b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") - b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") - mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" - ) - mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" - ) - - return { - "q": q, - "k": k, - "k_scale": k_scale, - "v": v, - "v_scale": v_scale, - "Req_to_tokens": Req_to_tokens, - "B_req_idx": B_req_idx, - "b_seq_len": b_seq_len, - "b_shared_seq_len": b_shared_seq_len, - "b_mark_shared_group": b_mark_shared_group, - "max_len_in_batch": max_len_in_batch, - "mid_out": mid_out, - "mid_out_logsumexp": mid_out_logsumexp, - "block_seq": block_seq, - "max_batch_group_size": max_batch_group_size, - "head_dim": head_dim, - } - - -@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) -def test_flash_decode_stage2_execution(shared_seq_len): - setup_tensors = create_tensors(shared_seq_len) - - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - setup_tensors["mid_out"], - setup_tensors["mid_out_logsumexp"], - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ - "block_seq" - ] - mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] - mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] - - q = setup_tensors["q"] - k = setup_tensors["k"] - v = setup_tensors["v"] - true_mid_out = torch.zeros_like(mid_out) - true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) - new_q = q - new_k = k.to(q.dtype) - new_v = v.to(q.dtype) - - b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] - req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] - - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( - flash_decode_stage1 as gqa_flash_decode_stage1, - ) - - gqa_flash_decode_stage1( - q=new_q, - k=new_k, - v=new_v, - Req_to_tokens=req_to_tokens, - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=b_seq_len, - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=true_mid_out, - mid_out_logsumexp=true_mid_out_logsumexp, - block_seq=setup_tensors["block_seq"], - ) - print(f"\nshared_seq_len={shared_seq_len}") - print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") - print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") - abs_diff = (mid_out - true_mid_out).abs() - max_diff = abs_diff.max() - max_diff_idx = abs_diff.argmax() - max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) - mid_out_value = mid_out[max_diff_idx_unraveled] - true_mid_out_value = true_mid_out[max_diff_idx_unraveled] - print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") - - assert torch.allclose( - mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 - ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" - assert torch.allclose( - mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 - ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" diff --git a/unit_tests/models/llama/test_token_attention_nopad.py b/unit_tests/models/llama/test_token_attention_nopad.py deleted file mode 100644 index 1bbb29166..000000000 --- a/unit_tests/models/llama/test_token_attention_nopad.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX - b_start_loc = torch.arange(Z).cuda().int() * N_CTX - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # infer_state.req_manager.req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - non_blocking=True, - ) - kv = kv.unsqueeze(1) - wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - assert cos_sim1 == 1.0 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py deleted file mode 100644 index a7f48ab89..000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_starts = torch.arange(0, Z + 1).int().cuda() - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # v_cache=v_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=1, - # causal=False, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = scaled_fp8_quant(q.view(batch_size * kv_heads, -1), use_per_token_if_dynamic=True) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5c0e595b9..000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn, - non_blocking=True, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1.0 - - -if __name__ == "__main__": - test_token_attention_nopad_flashinfer_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/qwen3-vl/test_deepstack_emb.py b/unit_tests/models/qwen3-vl/test_deepstack_emb.py index 2f929fe0d..f629a1635 100644 --- a/unit_tests/models/qwen3-vl/test_deepstack_emb.py +++ b/unit_tests/models/qwen3-vl/test_deepstack_emb.py @@ -50,7 +50,7 @@ def test_deepstack_same_image_twice(): deepstack_embs=deepstack_embs, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, - img_start_locs=img_start_locs, + img_start_locs_in_cache=img_start_locs, ) # 7. 看看相同图片两段上的增量 diff --git a/unit_tests/server/core/objs/test_req.py b/unit_tests/server/core/objs/test_req.py index 45fa7967f..1c946531c 100644 --- a/unit_tests/server/core/objs/test_req.py +++ b/unit_tests/server/core/objs/test_req.py @@ -1,6 +1,22 @@ import pytest - +import easydict from lightllm.server.core.objs.req import Req, TokenHealingReq, ChunkedPrefillReq, SamplingParams +from lightllm.utils.envs_utils import set_env_start_args + + +@pytest.fixture(scope="module", autouse=True) +def setup_module_env(): + set_env_start_args( + easydict.EasyDict( + { + "mtp_step": 0, + "llm_prefill_att_backend": ["None"], + "llm_decode_att_backend": ["None"], + "cpu_cache_token_page_size": 256, + "enable_cpu_cache": False, + } + ) + ) @pytest.fixture diff --git a/unit_tests/server/core/objs/test_shm_req_manager.py b/unit_tests/server/core/objs/test_shm_req_manager.py index 1d1ae2ef1..e26f128d5 100644 --- a/unit_tests/server/core/objs/test_shm_req_manager.py +++ b/unit_tests/server/core/objs/test_shm_req_manager.py @@ -14,8 +14,11 @@ def setup_env(): running_max_req_size=10, disable_chunked_prefill=True, token_healing_mode=False, - enable_flashinfer_prefill=False, - enable_flashinfer_decode=False, + mtp_step=0, + llm_prefill_att_backend=["None"], + llm_decode_att_backend=["None"], + cpu_cache_token_page_size=256, + enable_cpu_cache=False, ) ) # clear the lru_cache if used