diff --git a/.buildkite/hardware_tests/cpu.yaml b/.buildkite/hardware_tests/cpu.yaml
index e48fa4869ae..3db49d579e3 100644
--- a/.buildkite/hardware_tests/cpu.yaml
+++ b/.buildkite/hardware_tests/cpu.yaml
@@ -28,18 +28,19 @@ steps:
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
-- label: CPU-Compatibility Tests
- depends_on: []
- device: intel_cpu
- no_plugin: true
- source_file_dependencies:
- - cmake/cpu_extension.cmake
- - setup.py
- - vllm/platforms/cpu.py
- commands:
- - |
- bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
- bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
+# Note: SDE can't be downloaded from CI host because of AWS WAF
+# - label: CPU-Compatibility Tests
+# depends_on: []
+# device: intel_cpu
+# no_plugin: true
+# source_file_dependencies:
+# - cmake/cpu_extension.cmake
+# - setup.py
+# - vllm/platforms/cpu.py
+# commands:
+# - |
+# bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
+# bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
- label: CPU-Language Generation and Pooling Model Tests
depends_on: []
diff --git a/.buildkite/intel_jobs/test-intel.yaml b/.buildkite/intel_jobs/test-intel.yaml
index 805b7e54f12..63ce93c4810 100644
--- a/.buildkite/intel_jobs/test-intel.yaml
+++ b/.buildkite/intel_jobs/test-intel.yaml
@@ -40,7 +40,8 @@ steps:
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel &&
- python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192
+ python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192 &&
+ VLLM_XPU_FUSED_MOE_USE_REF=1 python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --enforce-eager -tp 2
'
- label: "XPU V1 test"
depends_on:
diff --git a/.github/mergify.yml b/.github/mergify.yml
index 6caec515d32..a5d4e609474 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -36,18 +36,6 @@ pull_request_rules:
For future commits, `pre-commit` will run automatically on changed files before each commit.
- > [!TIP]
- >
- > Is mypy failing?
- >
- > mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
- >
- > ```bash
- > # For mypy (substitute "3.10" with the failing version if needed)
- > pre-commit run --hook-stage manual mypy-3.10
- > ```
- >
-
- name: comment-dco-failure
description: Comment on PR when DCO check fails
conditions:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c11a80683f8..dff099e3697 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -148,33 +148,27 @@ repos:
language: python
entry: python tools/pre_commit/generate_nightly_torch_test.py
files: ^requirements/test/cuda\.(in|txt)$
- - id: mypy-local
- name: Run mypy locally for lowest supported Python version
- entry: python tools/pre_commit/mypy.py 0 "3.10"
- stages: [pre-commit] # Don't run in CI
+ - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
+ name: Run mypy for Python 3.10
+ entry: python tools/pre_commit/mypy.py "3.10"
<<: &mypy_common
language: python
types_or: [python, pyi]
require_serial: true
- additional_dependencies: ["mypy[faster-cache]==1.19.1", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
- name: Run mypy for Python 3.10
- entry: python tools/pre_commit/mypy.py 1 "3.10"
- <<: *mypy_common
- stages: [manual] # Only run in CI
+ additional_dependencies: ["mypy==1.20.2", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
- entry: python tools/pre_commit/mypy.py 1 "3.11"
+ entry: python tools/pre_commit/mypy.py "3.11"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
- entry: python tools/pre_commit/mypy.py 1 "3.12"
+ entry: python tools/pre_commit/mypy.py "3.12"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.13
- entry: python tools/pre_commit/mypy.py 1 "3.13"
+ entry: python tools/pre_commit/mypy.py "3.13"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: shellcheck
diff --git a/AGENTS.md b/AGENTS.md
index 6566523f48e..441b8d9fb73 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -98,11 +98,13 @@ pre-commit run --all-files
pre-commit run ruff-check --all-files
# Run mypy as it is in CI:
-pre-commit run mypy-3.10 --all-files --hook-stage manual
+pre-commit run mypy-3.12 --all-files --hook-stage manual
```
The line length limit for Python code is 88 characters. If you are not sure, use pre-commit to check.
+Use [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) (`Args:`/`Returns:`/`Raises:` sections), not reStructuredText/Sphinx fields (`:param:`, `:return:`, `:rtype:`).
+
### Commit messages
Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0652a5f066e..cd4a9c0b590 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -307,10 +307,7 @@ endif()
#
set(VLLM_EXT_SRC
- "csrc/cuda_view.cu"
- "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
- "csrc/cuda_utils_kernels.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
@@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
FetchContent_MakeAvailable(cutlass)
- list(APPEND VLLM_EXT_SRC
- "csrc/cutlass_extensions/common.cpp")
-
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
@@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
+ "csrc/libtorch_stable/cuda_view.cu"
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
@@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
+ "csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
@@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
- "csrc/cuda_utils_kernels.cu"
- "csrc/cutlass_extensions/common.cpp"
+ "csrc/libtorch_stable/cuda_utils_kernels.cu"
+ "csrc/libtorch_stable/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
@@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
- # This ensures we only use C-shim APIs available in PyTorch 2.10.
+ # This ensures we only use C-shim APIs available in PyTorch 2.11.
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
- # which is currently set to 2.10.
+ # which is currently set to 2.11.
target_compile_definitions(_C_stable_libtorch PRIVATE
- TORCH_TARGET_VERSION=0x020A000000000000ULL)
+ TORCH_TARGET_VERSION=0x020B000000000000ULL)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu
deleted file mode 100644
index 73b368cb600..00000000000
--- a/csrc/cuda_view.cu
+++ /dev/null
@@ -1,59 +0,0 @@
-#include
-#include
-#include
-
-// This function assumes that `cpu_tensor` is a CPU tensor,
-// and that UVA (Unified Virtual Addressing) is enabled.
-torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
- TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
-
- // handle empty tensor
- if (cpu_tensor.numel() == 0) {
- return torch::empty(cpu_tensor.sizes(),
- cpu_tensor.options().device(torch::kCUDA));
- }
-
- if (cpu_tensor.is_pinned()) {
- // If CPU tensor is pinned, directly get the device pointer.
- void* host_ptr = const_cast(cpu_tensor.data_ptr());
- void* device_ptr = nullptr;
- cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
- TORCH_CHECK(err == cudaSuccess,
- "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
-
- return torch::from_blob(
- device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
- [base = cpu_tensor](void*) {}, // keep cpu tensor alive
- cpu_tensor.options().device(torch::kCUDA));
- }
-
- // If CPU tensor is not pinned, allocate a new pinned memory buffer.
- torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
- size_t nbytes = contiguous_cpu.nbytes();
-
- void* host_ptr = nullptr;
- cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
- if (err != cudaSuccess) {
- AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
- }
-
- err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
- cudaMemcpyDefault);
- if (err != cudaSuccess) {
- cudaFreeHost(host_ptr);
- AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
- }
-
- void* device_ptr = nullptr;
- err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
- if (err != cudaSuccess) {
- cudaFreeHost(host_ptr);
- AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
- }
-
- auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
-
- return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
- contiguous_cpu.strides(), deleter,
- contiguous_cpu.options().device(torch::kCUDA));
-}
\ No newline at end of file
diff --git a/csrc/cuda_utils_kernels.cu b/csrc/libtorch_stable/cuda_utils_kernels.cu
similarity index 100%
rename from csrc/cuda_utils_kernels.cu
rename to csrc/libtorch_stable/cuda_utils_kernels.cu
diff --git a/csrc/libtorch_stable/cuda_view.cu b/csrc/libtorch_stable/cuda_view.cu
new file mode 100644
index 00000000000..7bf8267470e
--- /dev/null
+++ b/csrc/libtorch_stable/cuda_view.cu
@@ -0,0 +1,76 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+// This function assumes that `cpu_tensor` is a CPU tensor,
+// and that UVA (Unified Virtual Addressing) is enabled.
+torch::stable::Tensor get_cuda_view_from_cpu_tensor(
+ torch::stable::Tensor& cpu_tensor) {
+ STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
+
+ const auto dtype = cpu_tensor.scalar_type();
+ const auto layout = cpu_tensor.layout();
+ const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA);
+
+ // handle empty tensor
+ if (cpu_tensor.numel() == 0) {
+ return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev);
+ }
+
+ std::array is_pinned_stack{
+ torch::stable::detail::from(cpu_tensor),
+ torch::stable::detail::from(std::nullopt)};
+ TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
+ "aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION));
+ if (torch::stable::detail::to(is_pinned_stack[0])) {
+ // If CPU tensor is pinned, directly get the device pointer.
+ void* host_ptr = const_cast(cpu_tensor.mutable_data_ptr());
+ void* device_ptr = nullptr;
+ cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
+ STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ",
+ cudaGetErrorString(err));
+
+ return torch::stable::from_blob(
+ device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype,
+ [base = cpu_tensor](void*) {}); // keep cpu tensor alive
+ }
+
+ // If CPU tensor is not pinned, allocate a new pinned memory buffer.
+ torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor);
+ size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size();
+
+ void* host_ptr = nullptr;
+ cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
+ if (err != cudaSuccess) {
+ STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err));
+ }
+
+ err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes,
+ cudaMemcpyDefault);
+ if (err != cudaSuccess) {
+ cudaFreeHost(host_ptr);
+ STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err));
+ }
+
+ void* device_ptr = nullptr;
+ err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
+ if (err != cudaSuccess) {
+ cudaFreeHost(host_ptr);
+ STD_TORCH_CHECK(
+ false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
+ }
+
+ auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
+
+ return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(),
+ contiguous_cpu.strides(), cuda_dev,
+ contiguous_cpu.scalar_type(), deleter);
+}
diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/libtorch_stable/cutlass_extensions/common.cpp
similarity index 90%
rename from csrc/cutlass_extensions/common.cpp
rename to csrc/libtorch_stable/cutlass_extensions/common.cpp
index 3d2093ab942..5bc9463bfa6 100644
--- a/csrc/cutlass_extensions/common.cpp
+++ b/csrc/libtorch_stable/cutlass_extensions/common.cpp
@@ -1,4 +1,4 @@
-#include "cutlass_extensions/common.hpp"
+#include "common.hpp"
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/libtorch_stable/cutlass_extensions/common.hpp
similarity index 100%
rename from csrc/cutlass_extensions/common.hpp
rename to csrc/libtorch_stable/cutlass_extensions/common.hpp
diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h
index 0a991de76ff..536693fee96 100644
--- a/csrc/libtorch_stable/ops.h
+++ b/csrc/libtorch_stable/ops.h
@@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
#endif
+// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
+torch::stable::Tensor get_cuda_view_from_cpu_tensor(
+ torch::stable::Tensor& cpu_tensor);
+
// Attention kernels (shared CUDA/ROCm)
void merge_attn_states(
torch::stable::Tensor& output,
@@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out,
std::optional residual,
int64_t group_size, bool is_scale_transposed);
+void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
+ torch::stable::Tensor const& input,
+ torch::stable::Tensor& scales,
+ int64_t group_size,
+ std::optional scale_ub,
+ bool is_scale_transposed);
+
// Positional encoding kernels (shared CUDA/ROCm)
void rotary_embedding(torch::stable::Tensor& positions,
torch::stable::Tensor& query,
diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
index 1091d9d1230..53ffe521363 100644
--- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
+++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
@@ -18,7 +18,7 @@
#include
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu
index c2b8c0c00de..502f430b30b 100644
--- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu
@@ -21,7 +21,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include
diff --git a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
index 8a493fdf22c..04e98b6076a 100644
--- a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
+++ b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
@@ -12,7 +12,7 @@
#include
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
index b22308d25ca..88caf03fda3 100644
--- a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
@@ -20,7 +20,7 @@
#include
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
index 8d4ba1accc7..e1e7e7a74da 100644
--- a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
+++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu
index d7b2a18e29c..bfb526fcd40 100644
--- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu
+++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu
index fc83c6e8d34..86355bf7060 100644
--- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu
index 2baa00caa82..7adba6308fa 100644
--- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu
+++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
diff --git a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu
similarity index 63%
rename from csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
rename to csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu
index d5c76232599..b32a7bd271f 100644
--- a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu
+++ b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu
@@ -1,11 +1,10 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-#include
-#include
+#include "../../torch_utils.h"
#include "../../dispatch_utils.h"
-#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh"
+#include "quant_conversions.cuh"
namespace vllm {
@@ -105,64 +104,70 @@ __global__ void silu_and_mul_per_block_quant_kernel(
} // namespace vllm
-void silu_and_mul_per_block_quant(torch::Tensor& out,
- torch::Tensor const& input,
- torch::Tensor& scales, int64_t group_size,
- std::optional scale_ub,
+void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
+ torch::stable::Tensor const& input,
+ torch::stable::Tensor& scales,
+ int64_t group_size,
+ std::optional scale_ub,
bool is_scale_transposed) {
- static c10::ScalarType kFp8Type = is_fp8_ocp()
- ? c10::ScalarType::Float8_e4m3fn
- : c10::ScalarType::Float8_e4m3fnuz;
+ static torch::headeronly::ScalarType kFp8Type =
+ is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn
+ : torch::headeronly::ScalarType::Float8_e4m3fnuz;
- TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
- TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
- TORCH_CHECK(
- input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
+ STD_TORCH_CHECK(out.scalar_type() == kFp8Type ||
+ out.scalar_type() == torch::headeronly::ScalarType::Char);
+ STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
+ STD_TORCH_CHECK(
+ input.scalar_type() == torch::headeronly::ScalarType::Half ||
+ input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Input must be FP16 or BF16");
- TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
- TORCH_CHECK(group_size == 128 || group_size == 64,
- "Unsupported group size: ", group_size);
+ STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float);
+ STD_TORCH_CHECK(group_size == 128 || group_size == 64,
+ "Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
- TORCH_CHECK(out.dtype() == kFp8Type);
+ STD_TORCH_CHECK(out.scalar_type() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
- TORCH_CHECK(input.size(-1) == hidden_size * 2,
- "input last dim must be 2x output hidden_size");
- TORCH_CHECK(hidden_size % group_size == 0,
- "hidden_size must be divisible by group_size");
+ STD_TORCH_CHECK(input.size(-1) == hidden_size * 2,
+ "input last dim must be 2x output hidden_size");
+ STD_TORCH_CHECK(hidden_size % group_size == 0,
+ "hidden_size must be divisible by group_size");
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ const torch::stable::accelerator::DeviceGuard device_guard(
+ input.get_device_index());
+ const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
- VLLM_DISPATCH_FLOATING_TYPES(
+ VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
- VLLM_DISPATCH_QUANT_TYPES(
+ VLLM_STABLE_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
- VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
- VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
- vllm::silu_and_mul_per_block_quant_kernel<
- scalar_in_t, scalar_out_t, transpose_scale, gs>
- <<>>(
- out.data_ptr(),
- scales.data_ptr(),
- input.data_ptr(),
- scale_ub.has_value() ? scale_ub->data_ptr()
- : nullptr,
- hidden_size);
- });
+ VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
+ VLLM_STABLE_DISPATCH_BOOL(
+ is_scale_transposed, transpose_scale, [&] {
+ vllm::silu_and_mul_per_block_quant_kernel<
+ scalar_in_t, scalar_out_t, transpose_scale, gs>
+ <<>>(
+ out.mutable_data_ptr(),
+ scales.mutable_data_ptr(),
+ input.const_data_ptr(),
+ scale_ub.has_value()
+ ? scale_ub->const_data_ptr()
+ : nullptr,
+ hidden_size);
+ });
});
});
});
-}
\ No newline at end of file
+}
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh
index ae40c0989e0..1eed7579924 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh
@@ -20,7 +20,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
namespace vllm::c3x {
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh
index 952931103c6..4cb591be056 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh
@@ -15,7 +15,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
/*
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp
index adb3de50fc1..913436186c3 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp
@@ -1,7 +1,7 @@
#include
#include
#include "cuda_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
template
void dispatch_scaled_mm(torch::stable::Tensor& c,
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh
index 49df3fa4e7f..b523d7baeaa 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh
@@ -8,7 +8,7 @@
#include
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
using namespace cute;
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh
index 6eb2c051d00..7846e609fe7 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh
@@ -23,7 +23,7 @@
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "core/math.hpp"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu
index 2e5bbca4700..0f9873cbf88 100644
--- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu
+++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu
@@ -4,7 +4,7 @@
#include "libtorch_stable/torch_utils.h"
-#include "cutlass_extensions/common.hpp"
+#include "libtorch_stable/cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp
index 511a788eeae..d376fd43aa7 100644
--- a/csrc/libtorch_stable/torch_bindings.cpp
+++ b/csrc/libtorch_stable/torch_bindings.cpp
@@ -1,4 +1,5 @@
#include "ops.h"
+#include "cuda_utils.h"
#include "core/registration.h"
#include
@@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
+ ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
+
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
@@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
+ // Fused SiLU+Mul + per-block quantization
+ ops.def(
+ "silu_and_mul_per_block_quant("
+ "Tensor! out, "
+ "Tensor input, "
+ "Tensor! scales, "
+ "int group_size, "
+ "Tensor? scale_ub=None, "
+ "bool is_scale_transposed=False) -> ()");
+
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
@@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("rms_norm_dynamic_per_token_quant",
TORCH_BOX(&rms_norm_dynamic_per_token_quant));
ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant));
+ ops.impl("silu_and_mul_per_block_quant",
+ TORCH_BOX(&silu_and_mul_per_block_quant));
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
@@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
}
+STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
+ ops.impl("get_cuda_view_from_cpu_tensor",
+ TORCH_BOX(&get_cuda_view_from_cpu_tensor));
+}
+
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
@@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
+STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
+ cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
+ cuda_utils.def(
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
+}
+
+STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
+ cuda_utils) {
+ cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
+ cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
+ TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
+}
+
// Cache ops
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
// Swap in (out) the cache blocks from src to dst.
diff --git a/csrc/ops.h b/csrc/ops.h
index ed2fca26b0d..11b704aff29 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
-void silu_and_mul_per_block_quant(torch::Tensor& out,
- torch::Tensor const& input,
- torch::Tensor& scales, int64_t group_size,
- std::optional scale_ub,
- bool is_scale_transposed);
-
// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable
// ABI for CUDA). It remains here because the CPU build still uses these
// torch::Tensor declarations.
@@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);
-torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
-
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional const& azp);
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 3351638f574..f29b941affd 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -2,7 +2,6 @@
// cache.h, which is no longer included here after cache ops moved to
// _C_stable_libtorch).
#include
-#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include
@@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
- ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
- ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
- &get_cuda_view_from_cpu_tensor);
-
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
- // Fused SiLU+Mul + per-block quantization
- ops.def(
- "silu_and_mul_per_block_quant("
- "Tensor! out, "
- "Tensor input, "
- "Tensor! scales, "
- "int group_size, "
- "Tensor? scale_ub=None, "
- "bool is_scale_transposed=False) -> ()");
- ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
- &silu_and_mul_per_block_quant);
-
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
@@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
#endif
-TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
- // Cuda utils
-
- // Gets the specified device attribute.
- cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
- cuda_utils.impl("get_device_attribute", &get_device_attribute);
-
- // Gets the maximum shared memory per block device attribute.
- cuda_utils.def(
- "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
- cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
- &get_max_shared_memory_per_block_device_attribute);
-}
-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md
index 5bf789a0919..80aec64ee5b 100644
--- a/docs/configuration/optimization.md
+++ b/docs/configuration/optimization.md
@@ -296,7 +296,7 @@ llm = LLM(model="Qwen/Qwen3-8B")
The `fastokens` Python package (>= 0.2.0) must be installed; if it isn't,
vLLM raises a clear `ImportError` at tokenizer load. The override applies to
any `--tokenizer-mode` that ends up loading an HF fast tokenizer (`hf`,
-`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Modes that don't use the HF
+`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Models that don't use the HF
fast tokenizer (`mistral`, `grok2`, `kimi_audio`) ignore the flag.
Tokenizer-bound workloads — long shared prefixes, bursty short prompts,
diff --git a/docs/contributing/README.md b/docs/contributing/README.md
index 9b5e26d0fed..3fc8b6dd52b 100644
--- a/docs/contributing/README.md
+++ b/docs/contributing/README.md
@@ -101,7 +101,7 @@ vLLM's `pre-commit` hooks will now run automatically every time you commit.
Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with:
```bash
- pre-commit run --hook-stage manual mypy-3.10
+ pre-commit run --hook-stage manual mypy-3.11
```
### Documentation
diff --git a/docs/design/nixl_kv_cache_lease.md b/docs/design/nixl_kv_cache_lease.md
index a3fdaafe345..aa7683bb9e1 100644
--- a/docs/design/nixl_kv_cache_lease.md
+++ b/docs/design/nixl_kv_cache_lease.md
@@ -128,7 +128,7 @@ The lease mechanism is controlled through `kv_connector_extra_config` in `--kv-t
vllm serve \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
- "kv_role": "kv_both",
+ "kv_role": "kv_producer",
"kv_connector_extra_config": {"kv_lease_duration": 60}
}'
```
diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md
index cb5a3dca035..0f0cbd55354 100644
--- a/docs/features/nixl_connector_usage.md
+++ b/docs/features/nixl_connector_usage.md
@@ -50,7 +50,7 @@ To select a different backend, set `kv_connector_extra_config.backends` in `--kv
vllm serve \
--kv-transfer-config '{
"kv_connector":"NixlConnector",
- "kv_role":"kv_both",
+ "kv_role":"kv_producer",
"kv_connector_extra_config":{"backends":["LIBFABRIC"]}
}'
```
@@ -60,7 +60,7 @@ You can also pass JSON keys individually using dotted arguments, and you can app
```bash
vllm serve \
--kv-transfer-config.kv_connector NixlConnector \
- --kv-transfer-config.kv_role kv_both \
+ --kv-transfer-config.kv_role kv_producer \
--kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC
```
@@ -81,7 +81,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
vllm serve Qwen/Qwen3-0.6B \
--port 8100 \
--enforce-eager \
- --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
+ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}'
```
### Consumer (Decoder) Configuration
@@ -96,7 +96,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
vllm serve Qwen/Qwen3-0.6B \
--port 8200 \
--enforce-eager \
- --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
+ --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}'
```
### Proxy Server
@@ -212,10 +212,21 @@ sequenceDiagram
Enable bidirectional KV transfer by setting `bidirectional_kv_xfer` in `kv_connector_extra_config` on **both** P and D instances:
```bash
+# Prefill instance
vllm serve \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
- "kv_role": "kv_both",
+ "kv_role": "kv_producer",
+ "kv_connector_extra_config": {
+ "bidirectional_kv_xfer": true
+ }
+ }'
+
+# Decode instance
+vllm serve \
+ --kv-transfer-config '{
+ "kv_connector": "NixlConnector",
+ "kv_role": "kv_consumer",
"kv_connector_extra_config": {
"bidirectional_kv_xfer": true
}
@@ -359,11 +370,10 @@ For multi-host DP deployment, only need to provide the host/port of the head ins
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
-- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
+- **kv_both** (deprecated): Previously used as a catch-all when the role was not predetermined. This value is now deprecated for NixlConnector and will be removed in a future release.
-!!! tip
- NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
- Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
+!!! warning
+ `kv_role="kv_both"` is deprecated for NixlConnector. Please set `kv_role="kv_producer"` for prefill instances and `kv_role="kv_consumer"` for decode instances. See [#33702](https://github.com/vllm-project/vllm/issues/33702) for details.
### KV Load Failure Policy
diff --git a/docs/features/speculative_decoding/README.md b/docs/features/speculative_decoding/README.md
index 768e9f78d40..58d1df9dced 100644
--- a/docs/features/speculative_decoding/README.md
+++ b/docs/features/speculative_decoding/README.md
@@ -169,7 +169,7 @@ speculative decoding, breaking down the guarantees into three key areas:
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
- > provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](/tests/v1/spec_decode).
+ > provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../../tests/v1/spec_decode).
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
3. **vLLM Logprob Stability**
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 1fee824f3b2..970bf963309 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -83,22 +83,22 @@ plugins:
- "re:vllm\\._.*" # Internal modules
- "vllm.third_party"
- "vllm.vllm_flash_attn"
- - "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files
+ - "vllm.transformers_utils.configs"
+ - "vllm.transformers_utils.processors"
- !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default
- mkdocstrings:
handlers:
python:
options:
- show_symbol_type_heading: true
- show_symbol_type_toc: true
- filters:
- - "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs
- summary:
- modules: true
- show_signature_annotations: true
- separate_signature: true
+ filters: []
show_overloads: true
signature_crossrefs: true
+ # Recommendations from api-autonav
+ docstring_section_style: list
+ parameter_headings: true
+ show_symbol_type_heading: true
+ show_symbol_type_toc: true
+ summary: true
inventories:
- https://docs.python.org/3/objects.inv
- https://typing-extensions.readthedocs.io/en/latest/objects.inv
diff --git a/requirements/common.txt b/requirements/common.txt
index 8141dc8ea6b..8b37f3cd30c 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
-mistral_common[image] >= 1.11.2
+mistral_common[image] >= 1.11.3
opencv-python-headless >= 4.13.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
diff --git a/requirements/test/cuda.in b/requirements/test/cuda.in
index 344a58ec1bb..8d7ad7d0aa2 100644
--- a/requirements/test/cuda.in
+++ b/requirements/test/cuda.in
@@ -31,7 +31,7 @@ torchaudio==2.11.0
torchvision==0.26.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
-mistral_common[image,audio] >= 1.11.2 # required for voxtral test
+mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test
diff --git a/requirements/test/cuda.txt b/requirements/test/cuda.txt
index 7d847d10577..a3e1466c763 100644
--- a/requirements/test/cuda.txt
+++ b/requirements/test/cuda.txt
@@ -409,7 +409,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
-mistral-common==1.11.2
+mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/cuda.in
diff --git a/requirements/test/nightly-torch.txt b/requirements/test/nightly-torch.txt
index 89fd4ea9b43..10eb7a62191 100644
--- a/requirements/test/nightly-torch.txt
+++ b/requirements/test/nightly-torch.txt
@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
-mistral_common[image,audio] >= 1.11.2 # required for voxtral test
+mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test
diff --git a/requirements/test/rocm.in b/requirements/test/rocm.in
index 0a615831774..ed10270f565 100644
--- a/requirements/test/rocm.in
+++ b/requirements/test/rocm.in
@@ -30,7 +30,7 @@ tblib # for pickling test exceptions
timm>=1.0.17 # required for internvl and gemma3n-mm test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
-mistral_common[image,audio]>=1.11.2 # required for voxtral test
+mistral_common[image,audio]>=1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless>=4.13.0 # required for video test
diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt
index e0232d8b6d3..eced7117116 100644
--- a/requirements/test/rocm.txt
+++ b/requirements/test/rocm.txt
@@ -512,7 +512,7 @@ mcp==1.27.0
# via -r requirements/test/../common.txt
mdurl==0.1.2
# via markdown-it-py
-mistral-common==1.11.2
+mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/../common.txt
diff --git a/requirements/test/xpu.txt b/requirements/test/xpu.txt
index 5581d0a079c..6242ff420f6 100644
--- a/requirements/test/xpu.txt
+++ b/requirements/test/xpu.txt
@@ -266,7 +266,7 @@ mbstrdecoder==1.1.4
# typepy
mdurl==0.1.2
# via markdown-it-py
-mistral-common==1.11.2
+mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/xpu.in
diff --git a/rust/src/chat/src/backend/hf.rs b/rust/src/chat/src/backend/hf.rs
index 6c3dddc8729..77ed24de854 100644
--- a/rust/src/chat/src/backend/hf.rs
+++ b/rust/src/chat/src/backend/hf.rs
@@ -38,13 +38,17 @@ impl HfChatBackend {
) -> Result {
let model_config = load_model_config(files.config_path.as_deref())?;
let model_type = model_config.model_type().unwrap_or_default();
- let multimodal_model_info = MultimodalModelInfo::from_paths(
- model_id.clone(),
- (!model_type.is_empty()).then_some(model_type.to_string()),
- files.config_path.as_deref(),
- files.preprocessor_config_path.as_deref(),
- tokenizer.clone(),
- )?;
+ let multimodal_model_info = if options.language_model_only {
+ None
+ } else {
+ MultimodalModelInfo::from_paths(
+ model_id.clone(),
+ (!model_type.is_empty()).then_some(model_type.to_string()),
+ files.config_path.as_deref(),
+ files.preprocessor_config_path.as_deref(),
+ tokenizer.clone(),
+ )?
+ };
let multimodal_render_info = resolve_multimodal_render_info(multimodal_model_info.as_ref());
let renderer = options.renderer.resolve(model_type);
@@ -225,6 +229,7 @@ mod tests {
"test-model".to_string(),
LoadModelBackendsOptions {
renderer,
+ language_model_only: false,
chat_template_content_format: Default::default(),
chat_template: None,
default_chat_template_kwargs: HashMap::new(),
@@ -267,6 +272,54 @@ mod tests {
assert_eq!(prompt, "hello");
}
+ #[test]
+ fn language_model_only_skips_multimodal_preprocessor_config() {
+ let mut files = resolved_files(
+ r#"{"model_type":"deepseek_v0_vl"}"#,
+ r#"{"chat_template":"{{ messages[0].content }}"}"#,
+ );
+ let preprocessor_config_path = files
+ .config_path
+ .as_ref()
+ .unwrap()
+ .parent()
+ .unwrap()
+ .join("preprocessor_config.json");
+ write_json(&preprocessor_config_path, r#"{"size":[672,672]}"#);
+ files.preprocessor_config_path = Some(preprocessor_config_path);
+
+ let backend = HfChatBackend::from_resolved_model_files(
+ files.clone(),
+ "test-model".to_string(),
+ LoadModelBackendsOptions {
+ language_model_only: true,
+ chat_template_content_format: Default::default(),
+ chat_template: None,
+ default_chat_template_kwargs: HashMap::new(),
+ ..Default::default()
+ },
+ test_tokenizer(),
+ )
+ .unwrap();
+
+ assert!(backend.multimodal_model_info().is_none());
+
+ let error = HfChatBackend::from_resolved_model_files(
+ files,
+ "test-model".to_string(),
+ LoadModelBackendsOptions {
+ chat_template_content_format: Default::default(),
+ chat_template: None,
+ default_chat_template_kwargs: HashMap::new(),
+ ..Default::default()
+ },
+ test_tokenizer(),
+ )
+ .err()
+ .expect("invalid preprocessor config should fail without language_model_only");
+ assert!(error.to_string().contains("failed to parse preprocessor_config.json"));
+ }
+
#[test]
fn explicit_deepseek_renderer_overrides_generic_model_type() {
let prompt = render_prompt(
diff --git a/rust/src/chat/src/backend/mod.rs b/rust/src/chat/src/backend/mod.rs
index f49ca673704..be609ba5d9e 100644
--- a/rust/src/chat/src/backend/mod.rs
+++ b/rust/src/chat/src/backend/mod.rs
@@ -60,6 +60,9 @@ pub type DynChatTextBackend = Arc;
pub struct LoadModelBackendsOptions {
/// Which chat renderer implementation to use.
pub renderer: RendererSelection,
+ /// Disable frontend-side multimodal preprocessing and render the model as
+ /// language-only.
+ pub language_model_only: bool,
/// How to serialize `message.content` when rendering the chat template.
pub chat_template_content_format: ChatTemplateContentFormatOption,
/// Optional server-default chat template override, provided either as an
diff --git a/rust/src/cmd/src/cli.rs b/rust/src/cmd/src/cli.rs
index ee7848fe0be..312d9365a0f 100644
--- a/rust/src/cmd/src/cli.rs
+++ b/rust/src/cmd/src/cli.rs
@@ -116,6 +116,10 @@ pub struct SharedRuntimeArgs {
#[arg(long = "tokenizer-mode", default_value_t)]
#[serde(default, rename = "tokenizer_mode")]
pub renderer: RendererSelection,
+ /// Disable multimodal inputs and treat the model as language-only.
+ #[arg(long)]
+ #[serde(default)]
+ pub language_model_only: bool,
/// Override the maximum model context length. When set, the frontend uses
/// this value instead of the model's `max_position_embeddings` from
/// `config.json`.
@@ -243,6 +247,7 @@ impl SharedRuntimeArgs {
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
renderer: self.renderer,
+ language_model_only: self.language_model_only,
chat_template: self.chat_template,
default_chat_template_kwargs: self.default_chat_template_kwargs,
chat_template_content_format: self.chat_template_content_format,
@@ -284,6 +289,7 @@ impl SharedRuntimeArgs {
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
renderer: self.renderer,
+ language_model_only: self.language_model_only,
chat_template: self.chat_template,
default_chat_template_kwargs: self.default_chat_template_kwargs,
chat_template_content_format: self.chat_template_content_format,
@@ -419,6 +425,7 @@ impl ServeArgs {
self.managed_engine.clone().into_config(
self.runtime.model.clone(),
self.runtime.max_model_len,
+ self.runtime.language_model_only,
handshake_port,
)
}
diff --git a/rust/src/cmd/src/cli/tests.rs b/rust/src/cmd/src/cli/tests.rs
index ea867e4673a..c56db543355 100644
--- a/rust/src/cmd/src/cli/tests.rs
+++ b/rust/src/cmd/src/cli/tests.rs
@@ -34,6 +34,7 @@ fn serve_args_forward_python_flags_with_separator() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
max_model_len: Some(
512,
),
@@ -263,6 +264,7 @@ fn frontend_args_accept_json() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
max_model_len: None,
grpc_port: None,
shutdown_timeout: 0,
@@ -321,7 +323,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
"--output-address",
"ipc:///tmp/output.sock",
"--args-json",
- r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","max_model_len":8192,"shutdown_timeout":3}"#,
+ r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","language_model_only":true,"max_model_len":8192,"shutdown_timeout":3}"#,
])
.unwrap();
@@ -338,6 +340,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
ParserSelection::Explicit("qwen3_thinking".to_string())
);
assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32);
+ assert!(args.runtime.language_model_only);
assert_eq!(args.runtime.max_model_len, Some(8192));
assert_eq!(args.runtime.shutdown_timeout, 3);
}
@@ -662,6 +665,7 @@ fn serve_args_accept_handshake_aliases() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
max_model_len: None,
grpc_port: None,
shutdown_timeout: 0,
@@ -783,6 +787,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
@@ -846,6 +851,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
@@ -924,6 +930,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
+ language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
diff --git a/rust/src/engine-core-client/src/client/imp.rs b/rust/src/engine-core-client/src/client/imp.rs
index 9a66ad84cc1..44eae7e3e54 100644
--- a/rust/src/engine-core-client/src/client/imp.rs
+++ b/rust/src/engine-core-client/src/client/imp.rs
@@ -1,5 +1,4 @@
use std::collections::BTreeMap;
-use std::slice;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
@@ -253,33 +252,47 @@ pub(crate) async fn run_abort_loop(
inner: Arc,
mut abort_rx: mpsc::UnboundedReceiver,
) {
- // TODO: receive and abort requests in batch
- while let Some(AbortRequest { request_id, cause }) = abort_rx.recv().await {
- let Some(engine_id) = inner.take_auto_abort_target(&request_id) else {
- debug!(request_id, "skip auto-abort for inactive request");
- continue;
- };
+ // Coalesce bursts of auto-aborts into a single Abort message per engine.
+ // A dropped-stream storm (e.g. many clients disconnecting at once under
+ // high concurrency) would otherwise issue one engine round-trip per
+ // request. `recv_many` returns as soon as at least one item is ready, so a
+ // lone abort is still forwarded promptly.
+ const MAX_DRAIN: usize = 1024;
+ let mut batch: Vec = Vec::new();
- match cause {
- AbortCause::DroppedStream => {
- info!(request_id, "auto-aborting request due to dropped stream")
- }
- AbortCause::StopStringMatched => {
- debug!(
- request_id,
- "auto-aborting request due to stop string matched"
- )
+ while abort_rx.recv_many(&mut batch, MAX_DRAIN).await > 0 {
+ let mut by_engine: BTreeMap> = BTreeMap::new();
+
+ for AbortRequest { request_id, cause } in batch.drain(..) {
+ let Some(engine_id) = inner.take_auto_abort_target(&request_id) else {
+ debug!(request_id, "skip auto-abort for inactive request");
+ continue;
+ };
+
+ match cause {
+ AbortCause::DroppedStream => {
+ info!(request_id, "auto-aborting request due to dropped stream")
+ }
+ AbortCause::StopStringMatched => {
+ debug!(
+ request_id,
+ "auto-aborting request due to stop string matched"
+ )
+ }
}
+
+ by_engine.entry(engine_id).or_default().push(request_id);
}
- if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await
- {
- warn!(
- request_id,
- ?engine_id,
- error = %error.as_report(),
- "failed to auto-abort dropped request stream"
- );
+ for (engine_id, request_ids) in by_engine {
+ if let Err(error) = inner.do_abort_requests(&engine_id, &request_ids).await {
+ warn!(
+ ?engine_id,
+ ?request_ids,
+ error = %error.as_report(),
+ "failed to auto-abort request streams"
+ );
+ }
}
}
}
diff --git a/rust/src/engine-core-client/src/tests/client.rs b/rust/src/engine-core-client/src/tests/client.rs
index 9a92ffe447e..32530d6d385 100644
--- a/rust/src/engine-core-client/src/tests/client.rs
+++ b/rust/src/engine-core-client/src/tests/client.rs
@@ -1225,6 +1225,86 @@ async fn dropping_a_live_stream_triggers_abort() {
client.shutdown().await.unwrap();
}
+#[tokio::test]
+async fn dropping_multiple_live_streams_aborts_all_in_a_burst() {
+ init_tracing();
+ let ipc = IpcNamespace::new().unwrap();
+ let handshake_address = ipc.handshake_endpoint();
+ let engine_id = b"engine-burst".to_vec();
+ let request_ids = ["req-1", "req-2", "req-3"];
+
+ let (shutdown_tx, engine_task) = spawn_mock_engine_task(
+ handshake_address.clone(),
+ engine_id.clone(),
+ |dealer, push| {
+ Box::pin(async move {
+ for _ in 0..3 {
+ let add = recv_engine_message(dealer).await;
+ assert_eq!(add[0].as_ref(), &[0x00]);
+ }
+ send_outputs(
+ push,
+ EngineCoreOutputs {
+ outputs: vec![
+ request_output("req-1", vec![99], None),
+ request_output("req-2", vec![99], None),
+ request_output("req-3", vec![99], None),
+ ],
+ ..Default::default()
+ },
+ )
+ .await;
+
+ let abort =
+ timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap();
+ assert_eq!(abort[0].as_ref(), &[0x01]);
+ let ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap();
+ assert_eq!(
+ ids,
+ vec![
+ "req-1".to_string(),
+ "req-2".to_string(),
+ "req-3".to_string()
+ ]
+ );
+ assert!(
+ timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err()
+ );
+ })
+ },
+ );
+
+ let client = connect_client_with_ipc(
+ handshake_test_config(
+ handshake_address,
+ 1,
+ "test-model",
+ Duration::from_secs(2),
+ 0,
+ None,
+ ),
+ &ipc,
+ )
+ .await;
+
+ // Open every request first so all three adds reach the engine before it
+ // emits outputs, then drain the first token from each stream.
+ let mut streams = Vec::new();
+ for id in request_ids {
+ streams.push(client.call(sample_request_with_id(id)).await.unwrap());
+ }
+ for stream in streams.iter_mut() {
+ let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap();
+ assert_eq!(first.new_token_ids, vec![99]);
+ }
+ // Drop the whole burst back-to-back so the abort worker can batch them.
+ drop(streams);
+
+ let _ = shutdown_tx.send(());
+ engine_task.await.unwrap();
+ client.shutdown().await.unwrap();
+}
+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dispatcher_failure_propagates_to_streams_and_future_calls() {
init_tracing();
diff --git a/rust/src/managed-engine/src/cli.rs b/rust/src/managed-engine/src/cli.rs
index 302737dbd88..d70870dc32a 100644
--- a/rust/src/managed-engine/src/cli.rs
+++ b/rust/src/managed-engine/src/cli.rs
@@ -71,6 +71,7 @@ impl ManagedEngineArgs {
self,
model: String,
max_model_len: Option,
+ language_model_only: bool,
handshake_port: u16,
) -> ManagedEngineConfig {
let mut python_args = self.python_args;
@@ -79,6 +80,9 @@ impl ManagedEngineArgs {
python_args.push("--max-model-len".to_string());
python_args.push(max_model_len.to_string());
}
+ if language_model_only {
+ python_args.push("--language-model-only".to_string());
+ }
if let Some(data_parallel_size_local) = self.data_parallel_size_local {
python_args.push("--data-parallel-size-local".to_string());
python_args.push(data_parallel_size_local.to_string());
diff --git a/rust/src/server/examples/external_engine_openai_qwen.rs b/rust/src/server/examples/external_engine_openai_qwen.rs
index 50d6fc1be40..33da868876b 100644
--- a/rust/src/server/examples/external_engine_openai_qwen.rs
+++ b/rust/src/server/examples/external_engine_openai_qwen.rs
@@ -64,6 +64,7 @@ async fn main() -> Result<()> {
tool_call_parser: ParserSelection::Auto,
reasoning_parser: ParserSelection::Auto,
renderer: RendererSelection::Auto,
+ language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: ChatTemplateContentFormatOption::Auto,
diff --git a/rust/src/server/src/config.rs b/rust/src/server/src/config.rs
index f1599d18793..3018f57377a 100644
--- a/rust/src/server/src/config.rs
+++ b/rust/src/server/src/config.rs
@@ -53,6 +53,9 @@ pub struct Config {
pub reasoning_parser: ParserSelection,
/// Chat renderer selection.
pub renderer: RendererSelection,
+ /// Disable frontend-side multimodal preprocessing and render the model as
+ /// language-only.
+ pub language_model_only: bool,
/// Server-default chat template override, as a file path or inline
/// template.
pub chat_template: Option,
diff --git a/rust/src/server/src/lib.rs b/rust/src/server/src/lib.rs
index 8d779da132f..52ac7633292 100644
--- a/rust/src/server/src/lib.rs
+++ b/rust/src/server/src/lib.rs
@@ -42,6 +42,7 @@ async fn build_state(config: &Config) -> Result> {
&config.model,
LoadModelBackendsOptions {
renderer: config.renderer,
+ language_model_only: config.language_model_only,
chat_template: config.chat_template.clone(),
chat_template_content_format: config.chat_template_content_format,
default_chat_template_kwargs: config
diff --git a/rust/src/server/src/routes/openai/chat_completions.rs b/rust/src/server/src/routes/openai/chat_completions.rs
index 543a7e806c4..0cccd8bf4ab 100644
--- a/rust/src/server/src/routes/openai/chat_completions.rs
+++ b/rust/src/server/src/routes/openai/chat_completions.rs
@@ -85,6 +85,7 @@ pub async fn chat_completions(
log_request,
prepared.include_usage,
prepared.requested_logprobs,
+ prepared.include_reasoning,
prepared.echo,
prepared.return_token_ids,
prepared.return_tokens_as_token_ids,
@@ -100,6 +101,7 @@ pub async fn chat_completions(
created,
prepared.requested_logprobs,
prepared.include_prompt_logprobs,
+ prepared.include_reasoning,
prepared.echo,
prepared.return_token_ids,
prepared.return_tokens_as_token_ids,
@@ -134,6 +136,7 @@ async fn collect_chat_completion(
created: u64,
requested_logprobs: bool,
include_prompt_logprobs: bool,
+ include_reasoning: bool,
echo: Option,
return_token_ids: bool,
return_tokens_as_token_ids: bool,
@@ -157,6 +160,11 @@ async fn collect_chat_completion(
} = collected;
let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json);
let saw_tool_calls = message.tool_calls().next().is_some();
+ let reasoning = message.reasoning();
+ // Output logprobs and token IDs cover the complete generated token stream.
+ // When reasoning is hidden, omit them rather than leaking hidden reasoning
+ // tokens through per-token metadata.
+ let include_output_metadata = include_reasoning || reasoning.is_none();
let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?.to_string();
let tool_calls = message
.tool_calls()
@@ -169,7 +177,7 @@ async fn collect_chat_completion(
},
})
.collect::>();
- let logprobs = if requested_logprobs {
+ let logprobs = if requested_logprobs && include_output_metadata {
Some(decoded_logprobs_to_openai_chat(
logprobs.as_ref().ok_or_else(|| {
server_error!("chat response requested logprobs but generation returned none")
@@ -207,12 +215,12 @@ async fn collect_chat_completion(
None => Some(message.text()).filter(|t| !t.is_empty()),
},
tool_calls: Some(tool_calls).filter(|calls| !calls.is_empty()),
- reasoning: message.reasoning(),
+ reasoning: if include_reasoning { reasoning } else { None },
},
logprobs,
finish_reason: Some(finish_reason),
stop_reason,
- token_ids: return_token_ids.then_some(token_ids),
+ token_ids: (return_token_ids && include_output_metadata).then_some(token_ids),
}],
usage: Some(usage),
system_fingerprint: None,
@@ -232,12 +240,18 @@ async fn chat_completion_chunk_stream(
log_request: bool,
include_usage: bool,
requested_logprobs: bool,
+ include_reasoning: bool,
echo: Option,
return_token_ids: bool,
return_tokens_as_token_ids: bool,
mut y: TryYielder,
) -> Result<(), ApiError> {
let mut saw_tool_calls = false;
+ // `LogprobsDelta` is emitted after all chat events for one decoded update.
+ // If that update contains hidden reasoning, including delimiter-only block
+ // starts or ends, omit its token metadata as well as its visible delta.
+ let mut inside_hidden_reasoning = false;
+ let mut suppress_current_update_metadata = false;
// If the client requested logprobs or token_ids, we need to buffer chunks until
// we receive the separate `LogprobsDelta` event, so that we can emit one
@@ -268,29 +282,44 @@ async fn chat_completion_chunk_stream(
}
}
Ok(ChatEvent::BlockDelta { kind, delta, .. }) => {
- if let Some(pending_chunk) = pending_chunk.as_mut() {
- pending_chunk.push_block_delta(kind, delta);
+ let include_delta =
+ include_reasoning || !matches!(kind, AssistantBlockKind::Reasoning);
+ if include_delta {
+ if let Some(pending_chunk) = pending_chunk.as_mut() {
+ pending_chunk.push_block_delta(kind, delta);
+ } else {
+ y.yield_ok(block_delta_chunk(
+ &request_id,
+ &response_model,
+ created,
+ kind,
+ delta,
+ ))
+ .await;
+ }
} else {
- y.yield_ok(block_delta_chunk(
- &request_id,
- &response_model,
- created,
- kind,
- delta,
- ))
- .await;
+ suppress_current_update_metadata = true;
}
}
Ok(ChatEvent::LogprobsDelta {
logprobs,
token_ids,
}) => {
- let openai_logprobs = logprobs
- .as_ref()
- .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids))
- .transpose()?;
- let openai_token_ids =
- return_token_ids.then_some(token_ids).filter(|t| !t.is_empty());
+ let include_metadata =
+ !suppress_current_update_metadata && !inside_hidden_reasoning;
+ suppress_current_update_metadata = false;
+ let openai_logprobs = if include_metadata {
+ logprobs
+ .as_ref()
+ .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids))
+ .transpose()?
+ } else {
+ None
+ };
+ let openai_token_ids = include_metadata
+ .then_some(token_ids)
+ .and_then(|token_ids| return_token_ids.then_some(token_ids))
+ .filter(|t| !t.is_empty());
if let Some(pending_chunk) = pending_chunk.as_mut() {
pending_chunk.logprobs = openai_logprobs;
pending_chunk.token_ids = openai_token_ids;
@@ -311,9 +340,17 @@ async fn chat_completion_chunk_stream(
}
Ok(ChatEvent::BlockStart { kind, .. }) => {
debug!(?kind, "starting new block");
+ if !include_reasoning && matches!(kind, AssistantBlockKind::Reasoning) {
+ inside_hidden_reasoning = true;
+ suppress_current_update_metadata = true;
+ }
}
Ok(ChatEvent::BlockEnd { .. }) => {
debug!("ending current block");
+ if inside_hidden_reasoning {
+ inside_hidden_reasoning = false;
+ suppress_current_update_metadata = true;
+ }
}
Ok(ChatEvent::ToolCallStart { index, id, name }) => {
let tool_index = index as u32;
@@ -763,7 +800,9 @@ fn stop_reason_to_json(stop_reason: &StopReason) -> Value {
mod tests {
use futures::{StreamExt as _, stream};
use serde_json::json;
- use vllm_chat::{AssistantBlockKind, AssistantToolCall, ChatEvent, FinishReason};
+ use vllm_chat::{
+ AssistantBlockKind, AssistantContentBlock, AssistantToolCall, ChatEvent, FinishReason,
+ };
use vllm_engine_core_client::protocol::StopReason;
use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob};
@@ -895,6 +934,7 @@ mod tests {
false,
false,
true,
+ true,
None,
false,
false,
@@ -958,6 +998,7 @@ mod tests {
false,
false,
true,
+ true,
None,
false,
false,
@@ -976,6 +1017,292 @@ mod tests {
assert!(chunks[1].choices[0].logprobs.is_some());
}
+ #[tokio::test]
+ async fn chunk_stream_omits_reasoning_delta_when_disabled() {
+ let stream = stream::iter(vec![
+ Ok(ChatEvent::Start {
+ prompt_token_ids: vec![].into(),
+ prompt_logprobs: None,
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 0,
+ kind: AssistantBlockKind::Reasoning,
+ delta: "think".to_string(),
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 1,
+ kind: AssistantBlockKind::Text,
+ delta: "answer".to_string(),
+ }),
+ Ok(ChatEvent::Done {
+ message: Default::default(),
+ prompt_token_count: 1,
+ output_token_count: 2,
+ finish_reason: FinishReason::stop_eos(),
+ kv_transfer_params: None,
+ }),
+ ]);
+
+ let chunks = chat_completion_chunk_stream(
+ stream,
+ "chatcmpl-1".to_string(),
+ "model".to_string(),
+ 1,
+ false,
+ false,
+ false,
+ false,
+ None,
+ false,
+ false,
+ )
+ .collect::>()
+ .await
+ .into_iter()
+ .collect::, _>>()
+ .expect("stream chunks");
+
+ assert_eq!(chunks.len(), 3);
+ assert_eq!(
+ chunks[1].choices[0].delta.content.as_deref(),
+ Some("answer")
+ );
+ assert!(
+ chunks
+ .iter()
+ .all(|chunk| chunk.choices.iter().all(|choice| choice.delta.reasoning.is_none()))
+ );
+ }
+
+ #[tokio::test]
+ async fn chunk_stream_omits_logprobs_for_suppressed_reasoning() {
+ let stream = stream::iter(vec![
+ Ok(ChatEvent::Start {
+ prompt_token_ids: vec![].into(),
+ prompt_logprobs: None,
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 0,
+ kind: AssistantBlockKind::Reasoning,
+ delta: "think".to_string(),
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 11,
+ token: "think".to_string(),
+ logprob: -0.1,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![11],
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 1,
+ kind: AssistantBlockKind::Text,
+ delta: "answer".to_string(),
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 22,
+ token: "answer".to_string(),
+ logprob: -0.2,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![22],
+ }),
+ Ok(ChatEvent::Done {
+ message: Default::default(),
+ prompt_token_count: 1,
+ output_token_count: 2,
+ finish_reason: FinishReason::stop_eos(),
+ kv_transfer_params: None,
+ }),
+ ]);
+
+ let chunks = chat_completion_chunk_stream(
+ stream,
+ "chatcmpl-1".to_string(),
+ "model".to_string(),
+ 1,
+ false,
+ false,
+ true,
+ false,
+ None,
+ true,
+ false,
+ )
+ .collect::>()
+ .await
+ .into_iter()
+ .collect::, _>>()
+ .expect("stream chunks");
+
+ assert_eq!(chunks.len(), 3);
+ let choice = &chunks[1].choices[0];
+ assert_eq!(choice.delta.content.as_deref(), Some("answer"));
+ assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
+ let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
+ let content = logprobs.content.as_ref().expect("logprobs content");
+ assert_eq!(content[0].token, "answer");
+ assert!(chunks.iter().all(|chunk| {
+ chunk.choices.iter().all(|choice| {
+ choice.delta.reasoning.is_none()
+ && choice.token_ids.as_deref() != Some(&[11][..])
+ && choice
+ .logprobs
+ .as_ref()
+ .and_then(|logprobs| logprobs.content.as_ref())
+ .is_none_or(|content| content.iter().all(|entry| entry.token != "think"))
+ })
+ }));
+ }
+
+ #[tokio::test]
+ async fn chunk_stream_omits_logprobs_for_hidden_reasoning_delimiters() {
+ let stream = stream::iter(vec![
+ Ok(ChatEvent::Start {
+ prompt_token_ids: vec![].into(),
+ prompt_logprobs: None,
+ }),
+ Ok(ChatEvent::BlockStart {
+ index: 0,
+ kind: AssistantBlockKind::Reasoning,
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 11,
+ token: "".to_string(),
+ logprob: -0.1,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![11],
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 0,
+ kind: AssistantBlockKind::Reasoning,
+ delta: "think".to_string(),
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 12,
+ token: "think".to_string(),
+ logprob: -0.2,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![12],
+ }),
+ Ok(ChatEvent::BlockEnd {
+ index: 0,
+ block: AssistantContentBlock::Reasoning {
+ text: "think".to_string(),
+ },
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 13,
+ token: "".to_string(),
+ logprob: -0.3,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![13],
+ }),
+ Ok(ChatEvent::BlockStart {
+ index: 1,
+ kind: AssistantBlockKind::Text,
+ }),
+ Ok(ChatEvent::BlockDelta {
+ index: 1,
+ kind: AssistantBlockKind::Text,
+ delta: "answer".to_string(),
+ }),
+ Ok(ChatEvent::LogprobsDelta {
+ logprobs: Some(DecodedLogprobs {
+ positions: vec![DecodedPositionLogprobs {
+ entries: vec![DecodedTokenLogprob {
+ token_id: 22,
+ token: "answer".to_string(),
+ logprob: -0.4,
+ rank: 1,
+ }],
+ }],
+ }),
+ token_ids: vec![22],
+ }),
+ Ok(ChatEvent::Done {
+ message: Default::default(),
+ prompt_token_count: 1,
+ output_token_count: 4,
+ finish_reason: FinishReason::stop_eos(),
+ kv_transfer_params: None,
+ }),
+ ]);
+
+ let chunks = chat_completion_chunk_stream(
+ stream,
+ "chatcmpl-1".to_string(),
+ "model".to_string(),
+ 1,
+ false,
+ false,
+ true,
+ false,
+ None,
+ true,
+ false,
+ )
+ .collect::>()
+ .await
+ .into_iter()
+ .collect::, _>>()
+ .expect("stream chunks");
+
+ assert_eq!(chunks.len(), 3);
+ let choice = &chunks[1].choices[0];
+ assert_eq!(choice.delta.content.as_deref(), Some("answer"));
+ assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
+ let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
+ let content = logprobs.content.as_ref().expect("logprobs content");
+ assert_eq!(content[0].token, "answer");
+ assert!(chunks.iter().all(|chunk| {
+ chunk.choices.iter().all(|choice| {
+ choice.delta.reasoning.is_none()
+ && !choice
+ .token_ids
+ .as_ref()
+ .is_some_and(|ids| matches!(ids.as_slice(), [11] | [12] | [13]))
+ && choice
+ .logprobs
+ .as_ref()
+ .and_then(|logprobs| logprobs.content.as_ref())
+ .is_none_or(|content| {
+ content.iter().all(|entry| {
+ !matches!(entry.token.as_str(), "" | "think" | "")
+ })
+ })
+ })
+ }));
+ }
+
#[tokio::test]
async fn chunk_stream_preserves_tool_call_index_and_omits_id_from_arguments_delta() {
let stream = stream::iter(vec![
@@ -1017,6 +1344,7 @@ mod tests {
false,
false,
false,
+ true,
None,
false,
false,
diff --git a/rust/src/server/src/routes/openai/chat_completions/convert.rs b/rust/src/server/src/routes/openai/chat_completions/convert.rs
index 2701bef809c..ba75c9c1a8c 100644
--- a/rust/src/server/src/routes/openai/chat_completions/convert.rs
+++ b/rust/src/server/src/routes/openai/chat_completions/convert.rs
@@ -29,6 +29,8 @@ pub struct PreparedRequest {
pub requested_logprobs: bool,
/// Whether the caller requested top-level prompt logprobs.
pub include_prompt_logprobs: bool,
+ /// Whether to include reasoning content in OpenAI responses.
+ pub include_reasoning: bool,
/// Lowered chat request for `vllm-chat`.
pub chat_request: ChatRequest,
/// Last assistant-role message content to echo back when `echo=true`.
@@ -57,6 +59,7 @@ pub(crate) fn prepare_chat_request(
.as_ref()
.map(|request| request.lora_name.clone())
.unwrap_or_else(|| lora_resolution.model_names.first().cloned().unwrap_or_default());
+ let include_reasoning = request.include_reasoning;
let echo = request
.echo
.then(|| extract_last_assistant_content(&request.messages))
@@ -146,6 +149,7 @@ pub(crate) fn prepare_chat_request(
include_usage,
requested_logprobs,
include_prompt_logprobs,
+ include_reasoning,
chat_request,
echo,
return_token_ids: request.return_token_ids.unwrap_or(false),
@@ -480,6 +484,23 @@ mod tests {
assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto);
}
+ #[test]
+ fn prepare_chat_request_preserves_include_reasoning_false() {
+ let request = ChatCompletionRequest {
+ include_reasoning: false,
+ ..base_request()
+ };
+
+ let prepared = prepare_chat_request(
+ request,
+ &served(&["Qwen/Qwen1.5-0.5B-Chat"]),
+ ResolvedRequestContext::default(),
+ )
+ .expect("request is valid");
+
+ assert!(!prepared.include_reasoning);
+ }
+
#[test]
fn prepare_chat_request_preserves_sampling_passthrough_fields() {
let request = ChatCompletionRequest {
diff --git a/rust/src/server/src/routes/openai/chat_completions/validate.rs b/rust/src/server/src/routes/openai/chat_completions/validate.rs
index fbd10eea0cb..379f2c2d39b 100644
--- a/rust/src/server/src/routes/openai/chat_completions/validate.rs
+++ b/rust/src/server/src/routes/openai/chat_completions/validate.rs
@@ -120,12 +120,6 @@ pub(super) fn validate_request_compat(
"thinking_token_budget",
"thinking_token_budget is not supported.",
)?;
- if !request.include_reasoning {
- bail_invalid_request!(
- param = "include_reasoning",
- "include_reasoning is not supported."
- );
- }
reject_non_default(
request.media_io_kwargs.as_ref(),
"media_io_kwargs",
@@ -312,6 +306,17 @@ mod tests {
.expect("reasoning_effort should be accepted");
}
+ #[test]
+ fn validate_request_compat_accepts_include_reasoning_false() {
+ let request = ChatCompletionRequest {
+ include_reasoning: false,
+ ..base_request()
+ };
+
+ validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"]))
+ .expect("include_reasoning=false should be accepted");
+ }
+
#[test]
fn validate_request_compat_rejects_top_logprobs_without_logprobs() {
let request = ChatCompletionRequest {
diff --git a/rust/src/server/src/routes/tests.rs b/rust/src/server/src/routes/tests.rs
index a3e437e0480..0a269f4ce17 100644
--- a/rust/src/server/src/routes/tests.rs
+++ b/rust/src/server/src/routes/tests.rs
@@ -3492,6 +3492,170 @@ async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() {
assert!(text.contains("\"content\":\"answer\""), "{text}");
}
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+#[serial]
+async fn include_reasoning_false_suppresses_reasoning_in_non_stream_chat() {
+ let (app, engine_task) = test_app_with_backend_and_stream_output_specs(
+ Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
+ vec![
+ (bytes_to_token_ids(b"think"), None),
+ (
+ bytes_to_token_ids(b"answer"),
+ Some(EngineCoreFinishReason::Length),
+ ),
+ ],
+ )
+ .await;
+
+ let response = app
+ .clone()
+ .call(
+ Request::builder()
+ .method("POST")
+ .uri("/v1/chat/completions")
+ .header("content-type", "application/json")
+ .body(Body::from(
+ json!({
+ "model": "Qwen/Qwen1.5-0.5B-Chat",
+ "stream": false,
+ "include_reasoning": false,
+ "messages": [{"role": "user", "content": "hello"}]
+ })
+ .to_string(),
+ ))
+ .expect("build request"),
+ )
+ .await
+ .expect("call app");
+
+ assert_eq!(response.status(), StatusCode::OK);
+ let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
+ engine_task.await.expect("mock engine task");
+ let text = String::from_utf8(body.to_vec()).expect("utf8 body");
+ let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
+
+ assert_eq!(json["choices"][0]["message"]["content"], "answer");
+ assert!(
+ json["choices"][0]["message"]
+ .as_object()
+ .is_some_and(|message| !message.contains_key("reasoning")),
+ "{text}"
+ );
+}
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+#[serial]
+async fn include_reasoning_false_suppresses_non_stream_output_metadata() {
+ let ipc = IpcNamespace::new().expect("create ipc namespace");
+ let handshake_address = ipc.handshake_endpoint();
+ let engine_id = b"engine-openai-hidden-reasoning-logprobs".to_vec();
+
+ let engine_task = MockEngineTask::new(spawn_mock_engine_task(
+ handshake_address.clone(),
+ engine_id.clone(),
+ |dealer, push| {
+ boxed_test_future(async move {
+ let add = recv_engine_message(dealer).await;
+ let request: EngineCoreRequest =
+ rmp_serde::from_slice(&add[1]).expect("decode request");
+ let reasoning_token_ids = bytes_to_token_ids(b"think");
+ let answer_token_ids = bytes_to_token_ids(b"answer");
+
+ send_outputs(
+ push,
+ EngineCoreOutputs {
+ engine_index: 0,
+ outputs: vec![
+ request_output_with_logprobs(
+ &request.request_id,
+ reasoning_token_ids.clone(),
+ None,
+ None,
+ Some(sample_logprobs_for_tokens(&reasoning_token_ids)),
+ None,
+ ),
+ request_output_with_logprobs(
+ &request.request_id,
+ answer_token_ids.clone(),
+ Some(EngineCoreFinishReason::Length),
+ None,
+ Some(sample_logprobs_for_tokens(&answer_token_ids)),
+ None,
+ ),
+ ],
+ scheduler_stats: None,
+ timestamp: 0.0,
+ utility_output: None,
+ finished_requests: None,
+ wave_complete: None,
+ start_wave: None,
+ },
+ )
+ .await;
+ })
+ },
+ ));
+
+ let client = EngineCoreClient::connect(
+ EngineCoreClientConfig::new_single(handshake_address)
+ .with_model_name("test-model")
+ .with_local_input_output_addresses(
+ Some(ipc.input_endpoint()),
+ Some(ipc.output_endpoint()),
+ ),
+ )
+ .await
+ .expect("connect client");
+ let chat = ChatLlm::from_shared_backend(
+ test_llm(client),
+ Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
+ );
+ let mut app = build_router(Arc::new(AppState::new(
+ vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()],
+ chat,
+ )));
+
+ let response = app
+ .call(
+ Request::builder()
+ .method("POST")
+ .uri("/v1/chat/completions")
+ .header("content-type", "application/json")
+ .body(Body::from(
+ json!({
+ "model": "Qwen/Qwen1.5-0.5B-Chat",
+ "stream": false,
+ "include_reasoning": false,
+ "logprobs": true,
+ "return_token_ids": true,
+ "messages": [{"role": "user", "content": "hello"}]
+ })
+ .to_string(),
+ ))
+ .expect("build request"),
+ )
+ .await
+ .expect("call app");
+
+ assert_eq!(response.status(), StatusCode::OK);
+ let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
+ engine_task.await.expect("mock engine task");
+ let text = String::from_utf8(body.to_vec()).expect("utf8 body");
+ let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
+ let choice = json["choices"][0].as_object().expect("choice object");
+
+ assert_eq!(json["choices"][0]["message"]["content"], "answer");
+ assert!(
+ json["choices"][0]["message"]
+ .as_object()
+ .is_some_and(|message| !message.contains_key("reasoning")),
+ "{text}"
+ );
+ assert!(!choice.contains_key("logprobs"), "{text}");
+ assert!(!choice.contains_key("token_ids"), "{text}");
+ assert!(json["prompt_token_ids"].is_array(), "{text}");
+}
+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn tool_calls_are_mapped_to_tool_call_sse_chunks() {
diff --git a/rust/src/tokenizer/src/incremental.rs b/rust/src/tokenizer/src/incremental.rs
index 7a025d35e5c..462475fc918 100644
--- a/rust/src/tokenizer/src/incremental.rs
+++ b/rust/src/tokenizer/src/incremental.rs
@@ -115,6 +115,8 @@ impl IncrementalDecoder for DecodeStream<'_, T> {
fn next_chunk(&mut self) -> Option {
let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer);
+ // Ensure we split at a utf-8 char boundary.
+ let cutoff = self.cumulative_output.floor_char_boundary(cutoff);
(cutoff > self.output_index).then(|| {
let chunk = self.cumulative_output[self.output_index..cutoff].to_string();
self.output_index = cutoff;
@@ -356,4 +358,27 @@ mod tests {
assert_eq!(last_chunk.as_deref(), Some("lo!"));
assert_eq!(full_text, "Hello!");
}
+
+ #[test]
+ fn next_chunk_cutoff_respects_char_boundary() {
+ // Regression: next_chunk's cutoff (len - min_bytes_to_buffer) must be
+ // aligned to a UTF-8 char boundary like push_token/flush; otherwise
+ // streaming multi-byte output (CJK/emoji) with a hold-back buffer (set
+ // by a stop string) panics slicing cumulative_output mid-character.
+ let backend = Utf8Backend;
+ let mut decoder = backend.create_decode_stream(&[], false, 2);
+ let mut out = String::new();
+ for byte in "你好A".bytes() {
+ decoder.push_token(u32::from(byte)).unwrap();
+ if let Some(chunk) = decoder.next_chunk() {
+ out.push_str(&chunk);
+ }
+ }
+ let (last_chunk, full_text) = decoder.flush(None).unwrap();
+ if let Some(chunk) = last_chunk {
+ out.push_str(&chunk);
+ }
+ assert_eq!(full_text, "你好A");
+ assert_eq!(out, "你好A");
+ }
}
diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py
index 4cb199b5897..8e20b704fac 100644
--- a/tests/compile/test_graph_partition.py
+++ b/tests/compile/test_graph_partition.py
@@ -565,6 +565,8 @@ def test_size_used_in_multiple_consumer_subgraphs():
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(y, 0)
torch.compile(model_fn, backend=capturing_backend)(x, y)
+ assert captured_graph is not None, "Graph should be captured by backend"
+ assert captured_inputs is not None, "Example inputs should be captured by backend"
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
diff --git a/tests/distributed/test_dcp_a2a.py b/tests/distributed/test_dcp_a2a.py
index d80ed36be65..5ab0f3de97b 100644
--- a/tests/distributed/test_dcp_a2a.py
+++ b/tests/distributed/test_dcp_a2a.py
@@ -15,6 +15,7 @@ import pytest
import torch
import torch.distributed as dist
+import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables
@@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
update_environment_variables(env)
local_rank = int(env["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
- dist.init_process_group(backend="nccl")
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ dist.init_process_group(
+ backend="cpu:gloo,cuda:nccl",
+ device_id=torch.device(f"cuda:{local_rank}"),
+ )
+ else:
+ dist.init_process_group(backend="nccl")
use_workspace = env.get("USE_WORKSPACE") == "1"
if use_workspace:
from vllm.v1.worker.workspace import init_workspace_manager
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index a1d5355d446..d7b04f68091 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -9,6 +9,7 @@ import pytest
import torch
import torch.distributed
+import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
@@ -82,11 +83,18 @@ def test_pynccl():
@worker_fn_wrapper
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
- groups = [
- torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
- torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
- ]
- group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ # Eager-init path: parent PG has bound_device_id + a CPU backend,
+ # so split_group is supported.
+ group = torch.distributed.split_group(
+ split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
+ )
+ else:
+ groups = [
+ torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
+ torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
+ ]
+ group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
@@ -339,11 +347,16 @@ def test_pynccl_send_recv():
@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
- groups = [
- torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
- torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
- ]
- group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ group = torch.distributed.split_group(
+ split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
+ )
+ else:
+ groups = [
+ torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
+ torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
+ ]
+ group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py
index a9591f96a78..86eb82c962e 100644
--- a/tests/distributed/test_quick_all_reduce.py
+++ b/tests/distributed/test_quick_all_reduce.py
@@ -9,6 +9,7 @@ import ray
import torch
import torch.distributed as dist
+import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.quick_all_reduce import (
@@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size):
ranks = []
for i in range(world_size):
ranks.append(i)
- dist.init_process_group(
- backend="nccl",
- init_method="tcp://127.0.0.1:29500",
- rank=rank,
- world_size=world_size,
- )
- cpu_group = torch.distributed.new_group(ranks, backend="nccl")
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ dist.init_process_group(
+ backend="cpu:gloo,cuda:nccl",
+ init_method="tcp://127.0.0.1:29500",
+ rank=rank,
+ world_size=world_size,
+ device_id=device,
+ )
+ else:
+ dist.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:29500",
+ rank=rank,
+ world_size=world_size,
+ )
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ cpu_group = torch.distributed.split_group(
+ split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
+ )
+ else:
+ cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
diff --git a/tests/distributed/test_split_group.py b/tests/distributed/test_split_group.py
new file mode 100644
index 00000000000..54586c9e370
--- /dev/null
+++ b/tests/distributed/test_split_group.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for split_group in GroupCoordinator.
+
+These tests verify that:
+1. split_group is used for both device and CPU group creation.
+2. Multiple subgroups work correctly with split_group.
+3. Both GPU and CPU all-reduce work on split groups.
+"""
+
+import os
+from typing import Any
+
+import multiprocess as mp
+import pytest
+import torch
+import torch.distributed
+
+import vllm.envs as envs
+from vllm.distributed.parallel_state import (
+ GroupCoordinator,
+ init_distributed_environment,
+)
+from vllm.utils.system_utils import update_environment_variables
+
+# The whole module exercises the split_group code path, which is opt-in
+# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
+pytestmark = pytest.mark.skipif(
+ not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
+ reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
+)
+
+mp.set_start_method("spawn", force=True)
+
+
+def distributed_run(fn, world_size):
+ number_of_processes = world_size
+ processes: list[mp.Process] = []
+ for i in range(number_of_processes):
+ env: dict[str, str] = {}
+ env["RANK"] = str(i)
+ env["LOCAL_RANK"] = str(i)
+ env["WORLD_SIZE"] = str(number_of_processes)
+ env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
+ env["MASTER_ADDR"] = "localhost"
+ env["MASTER_PORT"] = "12346"
+ # propagate the opt-in flag to the spawned child workers
+ env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
+ p = mp.Process(target=fn, args=(env,))
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ for p in processes:
+ assert p.exitcode == 0
+
+
+def worker_fn_wrapper(fn):
+ def wrapped_fn(env):
+ update_environment_variables(env)
+ local_rank = os.environ["LOCAL_RANK"]
+ device = torch.device(f"cuda:{local_rank}")
+ torch.accelerator.set_device_index(device)
+ init_distributed_environment()
+ fn()
+
+ return wrapped_fn
+
+
+def _verify_device_group(coordinator: GroupCoordinator):
+ """Verify device group works via all-reduce."""
+ local_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{local_rank}")
+ tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
+ torch.distributed.all_reduce(tensor, group=coordinator.device_group)
+ torch.accelerator.synchronize()
+ expected = coordinator.world_size
+ assert torch.all(tensor == expected).cpu().item(), (
+ f"Device group all-reduce failed: expected {expected}, "
+ f"got {tensor.flatten()[0].item()}"
+ )
+
+
+def _verify_cpu_group(coordinator: GroupCoordinator):
+ """Verify CPU group works via all-reduce."""
+ tensor = torch.ones(16, dtype=torch.float32)
+ torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
+ expected = coordinator.world_size
+ assert torch.all(tensor == expected).cpu().item(), (
+ f"CPU group all-reduce failed: expected {expected}, "
+ f"got {tensor.flatten()[0].item()}"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Test 1: Basic split_group path with 2 GPUs
+# ---------------------------------------------------------------------------
+@worker_fn_wrapper
+def split_group_basic_worker():
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ group_ranks = [list(range(world_size))]
+
+ coordinator = GroupCoordinator(
+ group_ranks=group_ranks,
+ local_rank=rank,
+ torch_distributed_backend="nccl",
+ use_device_communicator=False,
+ group_name="test_split_basic",
+ )
+
+ _verify_device_group(coordinator)
+ _verify_cpu_group(coordinator)
+
+
+@pytest.mark.skipif(
+ torch.accelerator.device_count() < 2,
+ reason="Need at least 2 GPUs to run the test.",
+)
+def test_split_group_basic():
+ """Test basic GroupCoordinator creation with split_group."""
+ distributed_run(split_group_basic_worker, 2)
+
+
+# ---------------------------------------------------------------------------
+# Test 2: Multiple subgroups with split_group (4 GPUs)
+# ---------------------------------------------------------------------------
+@worker_fn_wrapper
+def split_group_multiple_subgroups_worker():
+ rank = torch.distributed.get_rank()
+ group_ranks = [[0, 1], [2, 3]]
+
+ coordinator = GroupCoordinator(
+ group_ranks=group_ranks,
+ local_rank=rank,
+ torch_distributed_backend="nccl",
+ use_device_communicator=False,
+ group_name="test_split_multi",
+ )
+
+ assert coordinator.world_size == 2
+
+ _verify_device_group(coordinator)
+ _verify_cpu_group(coordinator)
+
+ if rank in [0, 1]:
+ assert coordinator.ranks == [0, 1]
+ else:
+ assert coordinator.ranks == [2, 3]
+
+
+@pytest.mark.skipif(
+ torch.accelerator.device_count() < 4,
+ reason="Need at least 4 GPUs to run the test.",
+)
+def test_split_group_multiple_subgroups():
+ """Test GroupCoordinator with multiple independent subgroups."""
+ distributed_run(split_group_multiple_subgroups_worker, 4)
+
+
+# ---------------------------------------------------------------------------
+# Test 3: split_group contract — every parent rank must enter with the same
+# ``split_ranks``. NCCL happens to produce
+# correct subgroups for disjoint partitions because the wrapper hashes
+# ``my_group`` to derive the comm-split color, but the contract violation is
+# real and would break under non-partition / non-NCCL backends. This test
+# captures the actual ``split_ranks`` argument passed on every rank and
+# asserts they match.
+# ---------------------------------------------------------------------------
+@worker_fn_wrapper
+def split_group_contract_worker():
+ rank = torch.distributed.get_rank()
+ group_ranks = [[0, 1], [2, 3]]
+
+ captured: list[list[list[int]]] = []
+ original_split_group = torch.distributed.split_group
+
+ def capturing_split_group(*args, split_ranks=None, **kwargs):
+ captured.append([list(g) for g in split_ranks])
+ return original_split_group(*args, split_ranks=split_ranks, **kwargs)
+
+ torch.distributed.split_group = capturing_split_group
+ try:
+ GroupCoordinator(
+ group_ranks=group_ranks,
+ local_rank=rank,
+ torch_distributed_backend="nccl",
+ use_device_communicator=False,
+ group_name="test_split_contract",
+ )
+ finally:
+ torch.distributed.split_group = original_split_group
+
+ # GroupCoordinator builds two subgroups (device + cpu) per coordinator,
+ # so every rank must have made exactly two split_group calls.
+ if len(captured) != 2:
+ raise AssertionError(
+ f"rank {rank} expected 2 split_group calls (device + cpu), "
+ f"got {len(captured)}: {captured}"
+ )
+
+ world_size = torch.distributed.get_world_size()
+ for call_idx in range(2):
+ gathered: list[Any] = [None] * world_size
+ torch.distributed.all_gather_object(gathered, captured[call_idx])
+ # Normalize for stable comparison (sort each subgroup and the outer list).
+ norm = [
+ sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
+ ]
+ reference = norm[0]
+ for r, args in enumerate(norm):
+ if args != reference:
+ raise AssertionError(
+ f"split_group contract violation on call #{call_idx}: "
+ f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
+ f"passed split_ranks={gathered[0]}. PyTorch requires every "
+ "parent rank to enter split_group with the same split_ranks."
+ )
+
+
+@pytest.mark.skipif(
+ torch.accelerator.device_count() < 4,
+ reason="Need at least 4 GPUs to run the test.",
+)
+def test_split_group_contract_same_split_ranks_on_all_ranks():
+ """All parent ranks must call torch.distributed.split_group with the same
+ ``split_ranks`` argument. This catches the bug where each rank passed
+ only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
+ disjoint partitions but is a documented contract violation.
+ """
+ distributed_run(split_group_contract_worker, 4)
diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py
index e72f00bc91e..670df2759b0 100644
--- a/tests/distributed/test_torchrun_example.py
+++ b/tests/distributed/test_torchrun_example.py
@@ -5,13 +5,26 @@
import os
import random
+import torch
import torch.distributed as dist
+import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group
-# Let PyTorch choose the WORLD backend for the current device type.
-dist.init_process_group()
+# By default, let PyTorch choose the WORLD backend for the current device
+# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
+# use the explicit eager-init pattern required by `split_group` (mixed
+# cpu:gloo,cuda:nccl backend + device_id binding).
+if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.accelerator.set_device_index(local_rank)
+ dist.init_process_group(
+ backend="cpu:gloo,cuda:nccl",
+ device_id=torch.device(f"cuda:{local_rank}"),
+ )
+else:
+ dist.init_process_group()
# Create prompts
prompts = [
diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py
index 969b5e92e3f..6f0957ed026 100644
--- a/tests/distributed/test_torchrun_example_moe.py
+++ b/tests/distributed/test_torchrun_example_moe.py
@@ -5,13 +5,26 @@
import os
import random
+import torch
import torch.distributed as dist
+import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group
-# Let PyTorch choose the WORLD backend for the current device type.
-dist.init_process_group()
+# By default, let PyTorch choose the WORLD backend for the current device
+# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
+# use the explicit eager-init pattern required by `split_group` (mixed
+# cpu:gloo,cuda:nccl backend + device_id binding).
+if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.accelerator.set_device_index(local_rank)
+ dist.init_process_group(
+ backend="cpu:gloo,cuda:nccl",
+ device_id=torch.device(f"cuda:{local_rank}"),
+ )
+else:
+ dist.init_process_group()
# Create prompts
prompts = [
diff --git a/tests/entrypoints/openai/chat_completion/test_chat.py b/tests/entrypoints/openai/chat_completion/test_chat.py
index 6703095aec4..16a3cd857cb 100644
--- a/tests/entrypoints/openai/chat_completion/test_chat.py
+++ b/tests/entrypoints/openai/chat_completion/test_chat.py
@@ -808,14 +808,20 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
"logprobs": False,
}
- chat_completion = await client.chat.completions.create(**request_args)
+ # Use raw HTTP for both endpoints so we compare server responses
+ # directly, without the openai SDK injecting extra fields
+ # (e.g. `moderation` added in newer SDK versions).
+ chat_response = requests.post(
+ server.url_for("v1/chat/completions"), json=request_args
+ )
+ chat_response.raise_for_status()
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
- chat_output = chat_completion.model_dump()
+ chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
diff --git a/tests/kernels/attention/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py
index 46757cc10b6..61e13166808 100644
--- a/tests/kernels/attention/test_lightning_attn.py
+++ b/tests/kernels/attention/test_lightning_attn.py
@@ -5,8 +5,16 @@ import pytest
import torch
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
+from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
+DEVICE = current_platform.device_type
+
+pytestmark = pytest.mark.skipif(
+ not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
+ reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.",
+)
+
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
@@ -121,7 +129,7 @@ def test_linear_decode_forward_triton(
head_size: int,
dtype: torch.dtype,
):
- torch.set_default_device("cuda")
+ torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
@@ -129,16 +137,16 @@ def test_linear_decode_forward_triton(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
- batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
+ batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
- slope_rate = torch.zeros(num_heads, device="cuda")
+ slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
- slot_idx = torch.arange(batch_size, device="cuda")
+ slot_idx = torch.arange(batch_size, device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding(
head_size: int,
dtype: torch.dtype,
):
- torch.set_default_device("cuda")
+ torch.set_default_device(DEVICE)
set_random_seed(42)
batch_size = 4
@@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
- batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
+ batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
- slope_rate = torch.zeros(num_heads, device="cuda")
+ slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
- slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
+ slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -224,7 +232,7 @@ def test_lightning_attention_reference(
seq_len: int,
dtype: torch.dtype,
):
- torch.set_default_device("cuda")
+ torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
@@ -232,12 +240,12 @@ def test_lightning_attention_reference(
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
- ed = torch.zeros(num_heads, device="cuda")
+ ed = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(
- batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
+ batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_history_clone = kv_history.clone()
diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
index 07f244451b4..44e2ee836b9 100644
--- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
+++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py
@@ -10,6 +10,7 @@ import torch
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import ParamSpec
+import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
cleanup_dist_env_and_memory,
@@ -60,7 +61,15 @@ def _set_vllm_config(
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
)
- cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ cpu_group = torch.distributed.split_group(
+ split_ranks=[list(range(world_size))],
+ group_desc="moe_test_cpu",
+ )
+ else:
+ cpu_group = torch.distributed.new_group(
+ list(range(world_size)), backend="gloo"
+ )
return cpu_group
diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py
index 1380281bb2e..e3315142a9b 100644
--- a/tests/kernels/moe/test_cutlass_moe.py
+++ b/tests/kernels/moe/test_cutlass_moe.py
@@ -205,7 +205,10 @@ def run_with_expert_maps(
w2 = kwargs["w2"]
a = kwargs["hidden_states"]
moe_config = make_dummy_moe_config(
- num_experts=w2.shape[0],
+ max_num_tokens=kwargs.get("hidden_states").shape[0],
+ experts_per_token=kwargs.get("topk_ids").shape[1],
+ num_experts=num_experts,
+ num_local_experts=num_local_experts,
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
@@ -258,23 +261,27 @@ def run_8_bit(
a1_scale=None,
)
+ num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
+ with_ep = num_local_experts is not None or num_local_experts == num_experts
+
kwargs = {
"hidden_states": moe_tensors.a,
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
- "global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
+ "global_num_experts": num_experts,
"activation": MoEActivation.SILU,
"expert_map": None,
"apply_router_weight_on_input": False,
}
- num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
- with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
moe_config = make_dummy_moe_config(
- num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
+ max_num_tokens=moe_tensors.a.shape[0],
+ experts_per_token=topk_ids.shape[1],
+ num_experts=num_experts,
+ num_local_experts=num_local_experts,
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
per_out_channel,
False,
topk_weights,
+ None,
)
workspace13.random_()
diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py
index 452bf64ed98..efb1e2f2969 100644
--- a/tests/kernels/moe/test_deepep_deepgemm_moe.py
+++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py
@@ -14,6 +14,7 @@ import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
+import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe(
w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device)
- pg = torch.distributed.new_group(list(range(pgi.world_size)))
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ pg = torch.distributed.split_group(
+ split_ranks=[list(range(pgi.world_size))],
+ group_desc="deepep_deepgemm_test",
+ )
+ else:
+ pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py
index 5e0303c3df7..83cd2f09d1e 100644
--- a/tests/kernels/moe/test_deepep_moe.py
+++ b/tests/kernels/moe/test_deepep_moe.py
@@ -10,6 +10,7 @@ import pytest
import torch.distributed
from torch.distributed import ProcessGroup
+import vllm.envs as envs
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
@@ -375,7 +376,13 @@ def _deep_ep_moe(
w1_scale = w1_scale.to(device=device_idx)
w2_scale = w2_scale.to(device=device_idx)
- pg = torch.distributed.new_group(list(range(pgi.world_size)))
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ pg = torch.distributed.split_group(
+ split_ranks=[list(range(pgi.world_size))],
+ group_desc="deepep_test",
+ )
+ else:
+ pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):
diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py
index 3503ce4cdeb..ebb99576756 100644
--- a/tests/kernels/moe/utils.py
+++ b/tests/kernels/moe/utils.py
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
def make_dummy_moe_config(
num_experts: int = 1,
+ num_local_experts: int | None = None,
experts_per_token: int = 1,
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
+ max_num_tokens: int = 512,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
- num_local_experts=num_experts,
+ num_local_experts=num_local_experts
+ if num_local_experts is not None
+ else num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
- max_num_tokens=512,
+ max_num_tokens=max_num_tokens,
)
diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py
index 337bc177e6d..6572a7efd22 100644
--- a/tests/kernels/quantization/test_awq_triton.py
+++ b/tests/kernels/quantization/test_awq_triton.py
@@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton,
awq_gemm_triton,
)
+from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
-device = "cuda"
+pytestmark = pytest.mark.skipif(
+ not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
+ reason="AWQ Triton kernels require CUDA/ROCm or XPU.",
+)
+
+device = current_platform.device_type
def reverse_awq_order(t: torch.Tensor):
diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py
index 481715da9cd..72c189d8331 100644
--- a/tests/model_executor/test_eagle_quantization.py
+++ b/tests/model_executor/test_eagle_quantization.py
@@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
assert output.shape == (2, output_size)
-def test_kv_cache_scale_name_handling():
- # Mock a quant config that supports cache scales
- mock_quant_config = Mock()
- mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
-
- # Condition check in load_weights
- name = "layers.0.self_attn.k_proj.weight"
- scale_name = mock_quant_config.get_cache_scale(name)
-
- # Check if get_cache_scale is called and returns expected value
- mock_quant_config.get_cache_scale.assert_called_once_with(name)
- assert scale_name == "layers.0.self_attn.kv_scale"
-
-
-def test_kv_cache_scale_name_no_scale():
- # Mock a quant config that returns None for get_cache_scale
- mock_quant_config = Mock()
- mock_quant_config.get_cache_scale = Mock(return_value=None)
-
- name = "layers.0.mlp.gate_proj.weight"
- scale_name = mock_quant_config.get_cache_scale(name)
-
- # Should return None for weights that don't have cache scales
- assert scale_name is None
-
-
def test_maybe_remap_kv_scale_name():
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
@@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config():
assert call_kwargs["quant_config"] is mock_quant_config, (
"ParallelLMHead must receive the draft model's quant_config"
)
-
-
-def test_load_weights_kv_scale_handling():
- kv_scale_param = Mock()
- kv_scale_param.weight_loader = Mock()
-
- params_dict = {
- "layers.0.self_attn.kv_scale": kv_scale_param,
- }
-
- mock_quant_config = Mock()
- mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
-
- # Load_weights logic for KV cache scales
- name = "layers.0.self_attn.k_proj.weight"
- loaded_weight_tensor = torch.tensor([1.0, 2.0])
-
- if mock_quant_config is not None:
- scale_name = mock_quant_config.get_cache_scale(name)
- if scale_name:
- param = params_dict[scale_name]
- assert param is kv_scale_param
- weight_to_load = (
- loaded_weight_tensor
- if loaded_weight_tensor.dim() == 0
- else loaded_weight_tensor[0]
- )
-
- assert scale_name == "layers.0.self_attn.kv_scale"
- assert weight_to_load == loaded_weight_tensor[0]
diff --git a/tests/models/language/pooling/test_colbert.py b/tests/models/language/pooling/test_colbert.py
index 10c229fe063..b6ff9b9ffbd 100644
--- a/tests/models/language/pooling/test_colbert.py
+++ b/tests/models/language/pooling/test_colbert.py
@@ -14,7 +14,7 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
-COLBERT_MODELS = {
+COLBERT_MODELS: dict[str, dict] = {
"bert": {
"model": "answerdotai/answerai-colbert-small-v1",
"colbert_dim": 96,
diff --git a/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py
new file mode 100644
index 00000000000..92017e95cb7
--- /dev/null
+++ b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py
@@ -0,0 +1,406 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import csv
+import importlib
+import importlib.util
+import os
+
+import pytest
+import torch
+
+from tests.utils import TestFP8Layer
+from vllm._aiter_ops import rocm_aiter_ops
+from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
+ AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
+)
+from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
+ FP8ScaledMMLinearLayerConfig,
+)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ kFp8DynamicTokenSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
+from vllm.platforms import current_platform
+
+aiter_available = importlib.util.find_spec("aiter") is not None
+
+pytestmark = [
+ pytest.mark.skipif(
+ not (
+ current_platform.is_rocm()
+ and current_platform.supports_fp8()
+ and aiter_available
+ ),
+ reason="Requires ROCm + FP8 support + aiter",
+ ),
+ pytest.mark.usefixtures("default_vllm_config"),
+]
+
+
+@pytest.fixture
+def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
+ rocm_aiter_ops.refresh_env_variables()
+ yield
+ rocm_aiter_ops.refresh_env_variables()
+
+
+def _make_config(
+ *,
+ weight_quant_key=kFp8StaticChannelSym,
+ out_dtype: torch.dtype = torch.bfloat16,
+ weight_shape: tuple[int, int] = (512, 4096),
+) -> FP8ScaledMMLinearLayerConfig:
+ return FP8ScaledMMLinearLayerConfig(
+ weight_quant_key=weight_quant_key,
+ activation_quant_key=kFp8DynamicTokenSym,
+ weight_shape=weight_shape,
+ input_dtype=torch.bfloat16,
+ out_dtype=out_dtype,
+ )
+
+
+def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None:
+ if not os.path.exists(path):
+ return None
+
+ with open(path, newline="") as f:
+ reader = csv.DictReader(f, skipinitialspace=True)
+ for row in reader:
+ try:
+ if (
+ int(row.get("m", -1)) == m
+ and int(row.get("n", -1)) == n
+ and int(row.get("k", -1)) == k
+ ):
+ return dict(row)
+ except (TypeError, ValueError):
+ continue
+ return None
+
+
+def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None:
+ if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc):
+ pytest.skip(
+ "hipb_mm bpreshuffle path has no valid hipBLASLt solution on "
+ "this ROCm stack."
+ )
+
+
+def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int):
+ import aiter
+ from aiter.ops.shuffle import shuffle_weight
+
+ x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda")
+ w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda")
+
+ aiter.hipb_create_extension()
+ x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype())
+ w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype())
+
+ try:
+ aiter.hipb_mm(
+ x_q,
+ shuffle_weight(w_q, layout=(16, 16)).t(),
+ solution_index=-1,
+ out_dtype=torch.bfloat16,
+ scaleA=x_scale,
+ scaleB=w_scale.t().contiguous(),
+ scaleOut=None,
+ bpreshuffle=True,
+ )
+ except RuntimeError as exc:
+ _skip_if_no_hipb_mm_solution(exc)
+ raise
+
+
+def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch):
+ # The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper
+ # requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises
+ # the rejection branch.
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
+ monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False)
+ rocm_aiter_ops.refresh_env_variables()
+
+ is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported()
+
+ assert not is_supported
+ assert reason == (
+ "requires setting `VLLM_ROCM_USE_AITER=1`, "
+ "`VLLM_ROCM_USE_AITER_LINEAR=1`, "
+ "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`."
+ )
+
+
+def test_hipb_mm_flag_enables_hip_online_tuning(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ import vllm.envs as envs_mod
+ import vllm.platforms.rocm as rocm_mod
+
+ # The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set
+ # HIP_ONLINE_TUNING.
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
+ monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
+
+ try:
+ importlib.reload(envs_mod)
+ importlib.reload(rocm_mod)
+ assert envs_mod.VLLM_ROCM_USE_AITER
+ assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR
+ assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
+ assert os.environ.get("HIP_ONLINE_TUNING") == "1"
+ finally:
+ monkeypatch.undo()
+ os.environ.pop("HIP_ONLINE_TUNING", None)
+ importlib.reload(envs_mod)
+ importlib.reload(rocm_mod)
+ rocm_aiter_ops.refresh_env_variables()
+
+
+def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel):
+ can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
+ _make_config()
+ )
+
+ assert can_implement
+ assert reason is None
+
+
+@pytest.mark.parametrize(
+ ("config", "expected_reason"),
+ [
+ (
+ _make_config(weight_quant_key=kFp8StaticTensorSym),
+ "requires per token activation scales and per channel weight scales.",
+ ),
+ (
+ _make_config(out_dtype=torch.float16),
+ "requires bfloat16 output dtype.",
+ ),
+ (
+ _make_config(weight_shape=(8, 4090)),
+ "requires N >= 16 and both N and K divisible by 16, "
+ "received N=8 and K=4090.",
+ ),
+ ],
+)
+def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs(
+ enable_hipb_mm_kernel,
+ config: FP8ScaledMMLinearLayerConfig,
+ expected_reason: str,
+):
+ can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
+ config
+ )
+
+ assert not can_implement
+ assert reason == expected_reason
+
+
+def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights(
+ enable_hipb_mm_kernel,
+):
+ weight_shape = (512, 4096)
+ kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
+ _make_config(weight_shape=weight_shape),
+ layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
+ )
+
+ layer = torch.nn.Module()
+ layer.weight = torch.nn.Parameter(
+ torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(),
+ requires_grad=False,
+ )
+ layer.weight_scale = torch.nn.Parameter(
+ torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"),
+ requires_grad=False,
+ )
+ layer.input_scale = None
+ layer.input_scale_ub = None
+
+ original_weight = layer.weight.detach().clone()
+ original_weight_scale = layer.weight_scale.detach().clone()
+
+ kernel.process_weights_after_loading(layer)
+
+ # process_weights_after_loading now pre-applies the transposes that used
+ # to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the
+ # shuffled tensor with a trailing `.t()` view, and the stored weight scale
+ # is its transposed-contiguous form.
+ expected_weight = rocm_aiter_ops.shuffle_weight(
+ original_weight.t().contiguous()
+ ).t()
+ torch.testing.assert_close(layer.weight, expected_weight)
+
+ expected_weight_scale = original_weight_scale.t().contiguous()
+ torch.testing.assert_close(layer.weight_scale, expected_weight_scale)
+
+
+def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel):
+ import aiter
+
+ weight_shape = (512, 4096)
+ _check_bpreshuffle_runtime_support(weight_shape, num_tokens=32)
+
+ layer = TestFP8Layer(
+ weight_shape=weight_shape,
+ activation_quant_key=kFp8DynamicTokenSym,
+ weight_quant_key=kFp8StaticChannelSym,
+ input_dtype=torch.bfloat16,
+ out_dtype=torch.bfloat16,
+ device=torch.device("cuda"),
+ force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
+ )
+
+ # hipb_mm uses a transposed-result GEMM internally, so the flattened token
+ # count becomes the effective N dimension passed into hipBLASLt. Keep it
+ # aligned to avoid heuristic failures for tiny N.
+ x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
+ bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda")
+
+ try:
+ out = layer(x, bias)
+ except RuntimeError as exc:
+ _skip_if_no_hipb_mm_solution(exc)
+ raise
+
+ x_2d = x.view(-1, x.shape[-1])
+ x_q, x_scale = layer.kernel.quant_fp8(
+ x_2d,
+ layer.input_scale,
+ layer.input_scale_ub,
+ )
+ try:
+ # process_weights_after_loading already applies the trailing `.t()` on
+ # the shuffled weight and the `.t().contiguous()` on the weight scale,
+ # so the raw aiter call uses them directly.
+ expected = aiter.hipb_mm(
+ x_q,
+ layer.weight,
+ solution_index=-1,
+ bias=bias,
+ out_dtype=torch.bfloat16,
+ scaleA=x_scale,
+ scaleB=layer.weight_scale,
+ scaleOut=None,
+ bpreshuffle=True,
+ ).view(*out.shape)
+ except RuntimeError as exc:
+ _skip_if_no_hipb_mm_solution(exc)
+ raise
+
+ assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel)
+ assert out.shape == (2, 16, weight_shape[0])
+ torch.testing.assert_close(out, expected)
+
+
+def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel):
+ """Kernel output should match a dequantized fp32 reference within
+ fp8 per-token / per-channel quantization noise."""
+ weight_shape = (512, 4096) # (N, K)
+ num_tokens = 32
+ _check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens)
+
+ fp8_dtype = current_platform.fp8_dtype()
+ fp8_max = torch.finfo(fp8_dtype).max
+ device = torch.device("cuda")
+
+ # Build a bf16 weight and quantize per output channel (one scale per row).
+ w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device)
+ w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32)
+ w_scale = (w_amax / fp8_max).clamp(min=1e-12)
+ w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype)
+ w_dequant = w_fp8.to(torch.float32) * w_scale
+
+ bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device)
+
+ layer = torch.nn.Module()
+ # Pre-`process_weights_after_loading` convention: weight stored as the
+ # `[K, N]` view of the fp8 tensor.
+ layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False)
+ layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False)
+ layer.input_scale = None
+ layer.input_scale_ub = None
+
+ kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
+ _make_config(weight_shape=weight_shape),
+ layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
+ )
+ kernel.process_weights_after_loading(layer)
+
+ x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device)
+
+ try:
+ out = kernel.apply_weights(layer, x, bias)
+ except RuntimeError as exc:
+ _skip_if_no_hipb_mm_solution(exc)
+ raise
+
+ # Reference: quantize x per-token the same way the kernel does, then run
+ # the matmul in fp32 against the dequantized weight. This isolates plumbing
+ # / reduction bugs from inherent fp8 quantization noise.
+ x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32)
+ x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12)
+ x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype)
+ x_dequant = x_q.to(torch.float32) * x_scale_ref
+ expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16)
+
+ assert out.shape == (num_tokens, weight_shape[0])
+ # K=4096 fp8 reduction leaves room for accumulation order drift and
+ # catastrophic cancellation on near-zero outputs; tolerances are loose
+ # enough to absorb that but tight enough to catch wrong layouts, missing
+ # bias, swapped scales, etc.
+ torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1)
+
+
+def test_hipb_mm_kernel_online_tuning_writes_csv(
+ enable_hipb_mm_kernel,
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path,
+):
+ weight_shape = (256, 4096)
+ cache_file = tmp_path / "hip_online_tuning_res.csv"
+
+ _check_bpreshuffle_runtime_support(weight_shape, num_tokens=16)
+
+ monkeypatch.setenv("HIP_ONLINE_TUNING", "1")
+ monkeypatch.chdir(tmp_path)
+
+ layer = TestFP8Layer(
+ weight_shape=weight_shape,
+ activation_quant_key=kFp8DynamicTokenSym,
+ weight_quant_key=kFp8StaticChannelSym,
+ input_dtype=torch.bfloat16,
+ out_dtype=torch.bfloat16,
+ device=torch.device("cuda"),
+ force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
+ )
+
+ # The effective heuristic N dimension is the flattened token count.
+ x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
+ try:
+ out = layer(x)
+ except RuntimeError as exc:
+ _skip_if_no_hipb_mm_solution(exc)
+ raise
+ torch.accelerator.synchronize()
+
+ assert out.shape == (16, weight_shape[0])
+ assert cache_file.exists()
+
+ # hipb_mm records the internal GEMM dimensions used by hipBLASLt after its
+ # transposed-result transformation.
+ row = _find_csv_row(
+ str(cache_file),
+ m=weight_shape[0],
+ n=x.shape[0],
+ k=weight_shape[1],
+ )
+ assert row is not None
diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py
index 2023337e857..47abbd81289 100644
--- a/tests/tokenizers_/test_mistral.py
+++ b/tests/tokenizers_/test_mistral.py
@@ -797,11 +797,11 @@ class TestMistralTokenizer:
True,
(
[1, 3, 23325, 2294, 1686, 4, 23325],
- [1, 3, 22177, 4304, 2662, 4, 22177, 2],
+ [1, 3, 22177, 4304, 2662, 4, 22177],
),
(
"[INST]▁Hello▁world▁![/INST]▁Hello",
- ("[INST]Hello world ![/INST]Hello"),
+ "[INST]Hello world ![/INST]Hello",
),
),
],
diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
index bde246c9b66..0e7f6af7e38 100755
--- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
@@ -49,11 +49,13 @@ else
KV_EXTRA_CONFIG=''
fi
-# Build the kv-transfer-config once
+# Build the kv-transfer-config for P and D
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
- KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
+ KV_CONFIG_P='{"kv_connector":"NixlConnector","kv_role":"kv_producer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
+ KV_CONFIG_D='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
else
- KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
+ KV_CONFIG_P="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
+ KV_CONFIG_D="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
fi
# Models to run
@@ -159,7 +161,7 @@ run_tests_for_model() {
--block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
- --kv-transfer-config '$KV_CONFIG'"
+ --kv-transfer-config '$KV_CONFIG_P'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
@@ -207,7 +209,7 @@ run_tests_for_model() {
--enforce-eager \
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
- --kv-transfer-config '$KV_CONFIG'"
+ --kv-transfer-config '$KV_CONFIG_D'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
index bc90680a533..10e119d48a9 100755
--- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
@@ -193,9 +193,11 @@ run_test_for_device() {
local kv_device=$1
if [[ "$kv_device" == "cuda" ]]; then
- local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
+ local kv_config_p='{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
+ local kv_config_d='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
else
- local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
+ local kv_config_p="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"${kv_device}\"}"
+ local kv_config_d="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"${kv_device}\"}"
fi
echo ""
@@ -248,7 +250,7 @@ run_test_for_device() {
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
- --kv-transfer-config "$kv_config" \
+ --kv-transfer-config "$kv_config_p" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
@@ -287,7 +289,7 @@ run_test_for_device() {
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $DECODER_TP_SIZE \
- --kv-transfer-config "$kv_config" \
+ --kv-transfer-config "$kv_config_d" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
diff --git a/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py
new file mode 100644
index 00000000000..4a2ca6d2721
--- /dev/null
+++ b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py
@@ -0,0 +1,232 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+from typing import Any
+
+import pytest
+
+from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorBase_V1,
+ KVConnectorHandshakeMetadata,
+)
+from vllm.v1.engine import core as engine_core_module
+
+pytestmark = pytest.mark.cpu_test
+
+
+class _Metadata(KVConnectorHandshakeMetadata):
+ pass
+
+
+class _FakeExecutor:
+ handshake_metadata_src: (
+ list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
+ )
+ last_instance: "_FakeExecutor | None" = None
+
+ def __init__(
+ self,
+ vllm_config: Any,
+ ) -> None:
+ del vllm_config
+ self.handshake_metadata = self.handshake_metadata_src
+ self.handshake_calls = 0
+ _FakeExecutor.last_instance = self
+
+ def get_kv_connector_handshake_metadata(
+ self,
+ ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None:
+ self.handshake_calls += 1
+ return self.handshake_metadata
+
+ def init_kv_output_aggregator(self, connector: KVConnectorBase_V1) -> None:
+ pass
+
+
+def _run_engine_core_handshake(
+ monkeypatch: pytest.MonkeyPatch,
+ connector: KVConnectorBase_V1,
+ *,
+ handshake_metadata: (
+ list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
+ ),
+) -> _FakeExecutor:
+ class _FakeScheduler:
+ def __init__(self, **kwargs: Any) -> None:
+ self.connector = connector
+
+ def get_kv_connector(self) -> KVConnectorBase_V1:
+ return connector
+
+ _FakeExecutor.handshake_metadata_src = handshake_metadata
+ _FakeExecutor.last_instance = None
+
+ monkeypatch.setattr("vllm.plugins.load_general_plugins", lambda: None)
+ monkeypatch.setattr(
+ engine_core_module.EngineCore,
+ "_initialize_kv_caches",
+ lambda self, vllm_config: SimpleNamespace(kv_cache_groups=[object()]),
+ )
+ monkeypatch.setattr(
+ engine_core_module,
+ "StructuredOutputManager",
+ lambda vllm_config: object(),
+ )
+ monkeypatch.setattr(
+ engine_core_module,
+ "resolve_kv_cache_block_sizes",
+ lambda kv_cache_config, vllm_config: (16, 16),
+ )
+ monkeypatch.setattr(
+ engine_core_module,
+ "MULTIMODAL_REGISTRY",
+ SimpleNamespace(engine_receiver_cache_from_config=lambda vllm_config: None),
+ )
+ monkeypatch.setattr(engine_core_module, "freeze_gc_heap", lambda: None)
+ monkeypatch.setattr(
+ engine_core_module, "maybe_attach_gc_debug_callback", lambda: None
+ )
+ monkeypatch.setattr(engine_core_module, "enable_envs_cache", lambda: None)
+ monkeypatch.setattr(engine_core_module, "get_hash_fn_by_name", lambda name: None)
+ monkeypatch.setattr(engine_core_module, "init_none_hash", lambda hash_fn: None)
+ monkeypatch.setattr(
+ engine_core_module, "get_request_block_hasher", lambda *args: None
+ )
+
+ vllm_config = SimpleNamespace(
+ parallel_config=SimpleNamespace(data_parallel_rank_local=0),
+ scheduler_config=SimpleNamespace(
+ get_scheduler_cls=lambda: _FakeScheduler,
+ enable_chunked_prefill=False,
+ async_scheduling=False,
+ ),
+ speculative_config=None,
+ ec_transfer_config=None,
+ max_concurrent_batches=1,
+ model_config=SimpleNamespace(runner_type="generate"),
+ cache_config=SimpleNamespace(
+ enable_prefix_caching=False,
+ prefix_caching_hash_algo="builtin",
+ ),
+ )
+
+ engine_core_module.EngineCore(vllm_config, _FakeExecutor, log_stats=False)
+ assert _FakeExecutor.last_instance is not None
+ return _FakeExecutor.last_instance
+
+
+class _LegacyConnector(KVConnectorBase_V1):
+ def __init__(self) -> None:
+ self.legacy_metadata: dict[int, KVConnectorHandshakeMetadata] | None = None
+
+ def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None:
+ pass
+
+ def wait_for_layer_load(self, layer_name: str) -> None:
+ pass
+
+ def save_kv_layer(
+ self,
+ layer_name: str,
+ kv_layer: Any,
+ attn_metadata: Any,
+ **kwargs: Any,
+ ) -> None:
+ pass
+
+ def wait_for_save(self) -> None:
+ pass
+
+ def get_num_new_matched_tokens(
+ self, request: Any, num_computed_tokens: int
+ ) -> tuple[int | None, bool]:
+ return 0, False
+
+ def update_state_after_alloc(
+ self, request: Any, blocks: Any, num_external_tokens: int
+ ) -> None:
+ pass
+
+ def build_connector_meta(self, scheduler_output: Any) -> Any:
+ raise NotImplementedError
+
+ def set_xfer_handshake_metadata(
+ self, metadata: dict[int, KVConnectorHandshakeMetadata]
+ ) -> None:
+ self.legacy_metadata = metadata
+
+
+class _PPAwareConnector(_LegacyConnector):
+ def __init__(self) -> None:
+ super().__init__()
+ self.pp_aware_metadata: (
+ dict[tuple[int, int], KVConnectorHandshakeMetadata] | None
+ ) = None
+
+ def set_xfer_handshake_metadata_pp_aware(
+ self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata]
+ ) -> None:
+ self.pp_aware_metadata = metadata
+
+
+def test_engine_unwraps_handshake_metadata_for_legacy_connector(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Engine core always asks workers for `(pp_rank, tp_rank)`-keyed metadata,
+ then unwraps to `{tp_rank: metadata}` for a connector that has not opted
+ into PP-aware handshake (single-PP producer, all `pp_rank == 0`)."""
+ metadata_0 = _Metadata()
+ metadata_1 = _Metadata()
+ connector = _LegacyConnector()
+
+ executor = _run_engine_core_handshake(
+ monkeypatch,
+ connector,
+ handshake_metadata=[
+ {(0, 0): metadata_0},
+ None,
+ {(0, 1): metadata_1},
+ ],
+ )
+
+ assert executor.handshake_calls == 1
+ assert connector.legacy_metadata == {0: metadata_0, 1: metadata_1}
+
+
+def test_engine_rejects_pp_producer_for_legacy_connector(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """A connector that has not opted into PP-aware handshake must not silently
+ drop metadata from `pp_rank > 0`; engine core init raises instead."""
+ connector = _LegacyConnector()
+
+ with pytest.raises(ValueError, match="does not support PP-disaggregated"):
+ _run_engine_core_handshake(
+ monkeypatch,
+ connector,
+ handshake_metadata=[{(0, 0): _Metadata()}, {(1, 0): _Metadata()}],
+ )
+
+
+def test_engine_passes_handshake_metadata_through_for_pp_aware_connector(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """A PP-aware connector receives the full `(pp_rank, tp_rank)`-keyed dict
+ unchanged."""
+ metadata_0 = _Metadata()
+ metadata_1 = _Metadata()
+ connector = _PPAwareConnector()
+
+ executor = _run_engine_core_handshake(
+ monkeypatch,
+ connector,
+ handshake_metadata=[{(0, 0): metadata_0}, {(1, 0): metadata_1}],
+ )
+
+ assert executor.handshake_calls == 1
+ assert connector.legacy_metadata is None
+ assert connector.pp_aware_metadata == {
+ (0, 0): metadata_0,
+ (1, 0): metadata_1,
+ }
diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py
index d1f3a81ca96..f78037a1431 100644
--- a/tests/v1/kv_connector/unit/test_multi_connector.py
+++ b/tests/v1/kv_connector/unit/test_multi_connector.py
@@ -261,11 +261,12 @@ def test_multi_example_connector_consistency():
storage1_scheduler_events = _ignore_event_collection(events["storage1-SCHEDULER"])
storage2_scheduler_events = _ignore_event_collection(events["storage2-SCHEDULER"])
# First event is bind_gpu_block_pool from initialization, then
- # set_xfer_handshake_metadata, then on_new_request when the request is enqueued,
- # then get_num_new_matched_tokens and update_state_after_alloc from generate().
+ # set_xfer_handshake_metadata_pp_aware, then on_new_request when the request is
+ # enqueued, then get_num_new_matched_tokens and update_state_after_alloc from
+ # generate().
assert storage1_scheduler_events[:6] == [
"bind_gpu_block_pool",
- "set_xfer_handshake_metadata",
+ "set_xfer_handshake_metadata_pp_aware",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
@@ -285,7 +286,7 @@ def test_multi_example_connector_consistency():
]
assert storage2_scheduler_events[:6] == [
"bind_gpu_block_pool",
- "set_xfer_handshake_metadata",
+ "set_xfer_handshake_metadata_pp_aware",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
@@ -1057,7 +1058,7 @@ def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch):
"connectors": [
{
"kv_connector": "NixlConnector",
- "kv_role": "kv_both",
+ "kv_role": "kv_consumer",
},
{
"kv_connector": "MockConnector",
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index 1a7c35cacb8..6f6d8b1ca98 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -1363,7 +1363,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
timeout = 6
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
- kv_role="kv_both",
+ kv_role="kv_consumer",
kv_connector_extra_config={"kv_lease_duration": timeout},
)
llm_kwargs = {
@@ -2737,3 +2737,50 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
f"got {notif!r} (expected {expected_notif!r}, "
f"buggy form would be {bad_notif!r})"
)
+
+
+def test_kv_both_deprecation_warning(default_vllm_config, dist_init):
+ """kv_role='kv_both' should emit a deprecation log warning."""
+ from unittest.mock import patch
+
+ from vllm.logger import _print_warning_once
+
+ _print_warning_once.cache_clear()
+
+ vllm_config = create_vllm_config(kv_role="kv_both")
+
+ with patch(
+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
+ ) as mock_logger:
+ mock_logger.warning_once = mock_logger.warning_once
+ NixlConnector(
+ vllm_config,
+ KVConnectorRole.WORKER,
+ make_kv_cache_config(block_size=16),
+ )
+
+ mock_logger.warning_once.assert_called_once()
+ msg = mock_logger.warning_once.call_args[0][0]
+ assert "kv_role='kv_both'" in msg
+ assert "deprecated" in msg
+
+
+def test_explicit_kv_role_no_deprecation_warning(default_vllm_config, dist_init):
+ """kv_role='kv_consumer' or 'kv_producer' should NOT emit a warning."""
+ from unittest.mock import patch
+
+ for role in ("kv_consumer", "kv_producer"):
+ vllm_config = create_vllm_config(kv_role=role)
+ with patch(
+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
+ ) as mock_logger:
+ NixlConnector(
+ vllm_config,
+ KVConnectorRole.WORKER,
+ make_kv_cache_config(block_size=16),
+ )
+
+ (
+ mock_logger.warning_once.assert_not_called(),
+ (f"kv_role={role!r} should not emit deprecation warning"),
+ )
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py
index 80088809469..6e399db7b14 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py
@@ -453,7 +453,7 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
- kv_role="kv_both",
+ kv_role="kv_consumer",
)
block_size = 16
llm_kwargs = {
diff --git a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py
index 3a3ef2a88a6..c3adc05e3ef 100644
--- a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py
+++ b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py
@@ -75,7 +75,7 @@ def test_gpu_memory_rixl_hma(model_name, sw_size):
"gpu_memory_utilization": 0.5,
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector",
- kv_role="kv_both",
+ kv_role="kv_consumer",
),
"max_model_len": 2048,
"disable_hybrid_kv_cache_manager": False,
diff --git a/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py
new file mode 100644
index 00000000000..ac00bb48128
--- /dev/null
+++ b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py
@@ -0,0 +1,146 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+
+from vllm.distributed.kv_transfer.kv_connector.utils import (
+ EngineTransferInfo,
+ TransferTopology,
+)
+
+pytestmark = pytest.mark.cpu_test
+
+
+class _FakeAttentionBackend:
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int,
+ head_size: int,
+ ) -> tuple[int, int, int, int, int]:
+ return (2, num_blocks, num_kv_heads, block_size, head_size)
+
+
+def _make_topology(
+ *,
+ tp_rank: int = 1,
+ tp_size: int = 4,
+ total_num_kv_heads: int = 8,
+) -> TransferTopology:
+ return TransferTopology(
+ tp_rank=tp_rank,
+ tp_size=tp_size,
+ block_size=16,
+ engine_id="local-engine",
+ is_mla=False,
+ is_mamba=False,
+ total_num_kv_heads=total_num_kv_heads,
+ attn_backends=[_FakeAttentionBackend],
+ )
+
+
+def test_legacy_register_remote_engine_uses_pp_rank_zero() -> None:
+ topology = _make_topology()
+ info = EngineTransferInfo(
+ remote_tp_size=2,
+ remote_block_len=1024,
+ remote_block_size=16,
+ remote_physical_blocks_per_logical=1,
+ )
+
+ registered = topology.register_remote_engine("remote-engine", info)
+
+ assert registered == info
+ assert registered.remote_pp_rank == 0
+ assert topology.get_engine_info("remote-engine") == info
+ assert topology._engines[("remote-engine", 0)] == info
+ assert topology.target_remote_ranks("remote-engine") == [0]
+
+
+def test_register_remote_engine_stores_pp_ranks_separately() -> None:
+ topology = _make_topology(tp_rank=0, tp_size=2)
+
+ info_0 = EngineTransferInfo(
+ remote_tp_size=2,
+ remote_block_len=1024,
+ remote_block_size=16,
+ remote_physical_blocks_per_logical=1,
+ remote_pp_rank=0,
+ start_layer=0,
+ end_layer=16,
+ )
+ info_1 = EngineTransferInfo(
+ remote_tp_size=1,
+ remote_block_len=512,
+ remote_block_size=8,
+ remote_physical_blocks_per_logical=2,
+ remote_pp_rank=1,
+ start_layer=16,
+ end_layer=32,
+ )
+
+ registered_0 = topology.register_remote_engine("remote-engine", info_0)
+ registered_1 = topology.register_remote_engine("remote-engine", info_1)
+
+ assert registered_0 == info_0
+ assert registered_1 == info_1
+ assert topology.get_engine_info("remote-engine") == info_0
+ assert topology.get_engine_info("remote-engine", 0) == info_0
+ assert topology.get_engine_info("remote-engine", 1) == info_1
+ assert set(topology._engines) == {
+ ("remote-engine", 0),
+ ("remote-engine", 1),
+ }
+
+
+def test_helpers_use_requested_pp_rank() -> None:
+ topology = _make_topology(tp_rank=1, tp_size=2, total_num_kv_heads=2)
+ topology.register_remote_engine(
+ "remote-engine",
+ EngineTransferInfo(
+ remote_tp_size=1,
+ remote_block_len=1024,
+ remote_block_size=16,
+ remote_physical_blocks_per_logical=1,
+ remote_pp_rank=0,
+ start_layer=0,
+ end_layer=8,
+ ),
+ )
+ topology.register_remote_engine(
+ "remote-engine",
+ EngineTransferInfo(
+ remote_tp_size=4,
+ remote_block_len=1024,
+ remote_block_size=16,
+ remote_physical_blocks_per_logical=1,
+ remote_pp_rank=1,
+ start_layer=8,
+ end_layer=16,
+ ),
+ )
+
+ assert not topology.is_kv_replicated("remote-engine", 0)
+ assert topology.is_kv_replicated("remote-engine", 1)
+ assert topology.replicates_kv_cache("remote-engine", 1)
+ assert topology.target_remote_ranks("remote-engine", 0) == [0]
+ assert topology.target_remote_ranks("remote-engine", 1) == [2, 3]
+ assert "remote_pp=1" in topology.describe("remote-engine", 1)
+
+
+def test_engine_info_fields_have_backward_compatible_defaults() -> None:
+ topology = _make_topology()
+ info = EngineTransferInfo(
+ remote_tp_size=2,
+ remote_block_len=1024,
+ remote_block_size=16,
+ remote_physical_blocks_per_logical=1,
+ )
+
+ registered = topology.register_remote_engine("remote-engine", info)
+
+ assert topology.get_engine_info("remote-engine") == registered
+ assert registered.remote_pp_rank == 0
+ assert registered.start_layer == 0
+ assert registered.end_layer == 0
diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py
index 1b892849d90..9d801f772a6 100644
--- a/tests/v1/kv_connector/unit/utils.py
+++ b/tests/v1/kv_connector/unit/utils.py
@@ -103,7 +103,7 @@ def create_vllm_config(
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
kv_connector: str = "NixlConnector",
kv_connector_module_path: str | None = None,
- kv_role: str = "kv_both",
+ kv_role: str = "kv_consumer",
disable_hybrid_kv_cache_manager: bool | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
diff --git a/tests/v1/kv_offload/tiering/test_obj_tier.py b/tests/v1/kv_offload/tiering/test_obj_tier.py
new file mode 100644
index 00000000000..ed112738f5a
--- /dev/null
+++ b/tests/v1/kv_offload/tiering/test_obj_tier.py
@@ -0,0 +1,365 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Mock-based unit tests for ObjectStoreSecondaryTierManager.
+
+These tests replace the NIXL backend with an in-memory mock so they run
+without S3 credentials or a live object store. They verify the manager's
+state machine: job submission, transfer completion polling, and lookup.
+"""
+
+import uuid
+from collections.abc import Callable
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import torch
+
+from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key
+from vllm.v1.kv_offload.tiering.base import JobMetadata, JobResult
+from vllm.v1.kv_offload.tiering.obj.manager import ObjectStoreSecondaryTierManager
+
+# ---------------------------------------------------------------------------
+# Shared stubs
+# ---------------------------------------------------------------------------
+
+
+def _make_vllm_config():
+ return SimpleNamespace(
+ model_config=SimpleNamespace(model="test/model"),
+ cache_config=SimpleNamespace(block_size=16, cache_dtype="float16"),
+ parallel_config=SimpleNamespace(
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ prefill_context_parallel_size=1,
+ decode_context_parallel_size=1,
+ rank=0,
+ ),
+ )
+
+
+_OFFLOADING_SPEC = SimpleNamespace(
+ vllm_config=_make_vllm_config(),
+ kv_cache_config=SimpleNamespace(kv_cache_groups=[]),
+)
+
+_STORE_CONFIG = {
+ "bucket": "mock-bucket",
+ "endpoint_override": "mock:9000",
+ "access_key": "mock-access",
+ "secret_key": "mock-secret",
+}
+
+_BLOCK_ELEMENTS = 256
+_DTYPE = torch.float32
+_RUN_PREFIX = f"test/{uuid.uuid4().hex[:8]}"
+_CTX = ReqContext(req_id="test-req")
+
+
+def key(n: int) -> OffloadKey:
+ return make_offload_key(n.to_bytes(8, "big"), 0)
+
+
+def make_job(
+ job_id: int,
+ keys: list[OffloadKey],
+ block_ids: list[int] | None = None,
+) -> JobMetadata:
+ if block_ids is None:
+ block_ids = list(range(len(keys)))
+ return JobMetadata(
+ job_id=job_id,
+ keys=keys,
+ block_ids=np.array(block_ids, dtype=np.int64),
+ is_promotion=False,
+ req_context=_CTX,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Mock NIXL agent
+# ---------------------------------------------------------------------------
+
+
+class MockNixlAgent:
+ """In-memory NIXL agent. Tracks stored object keys and simulates async
+ transfers: transfer() returns PROC, check_xfer_state() returns DONE and
+ commits the write to the in-memory key set.
+
+ The four methods overridden by tests (register_memory, make_prepped_xfer,
+ check_xfer_state, query_memory) are stored as Callable instance attributes
+ so mypy allows reassignment in tests.
+ """
+
+ # Callable attributes — tests may reassign these on instances.
+ register_memory: Callable
+ make_prepped_xfer: Callable
+ check_xfer_state: Callable
+ query_memory: Callable
+
+ def __init__(self):
+ self._stored_obj_keys: set[str] = set()
+ # handle_id -> (op, [obj_keys])
+ self._pending: dict[int, tuple[str, list[str]]] = {}
+ self._handle_counter = 0
+ self._last_obj_keys: list[str] = []
+ # Bind default implementations as instance attributes.
+ self.register_memory = self._register_memory
+ self.make_prepped_xfer = self._make_prepped_xfer
+ self.check_xfer_state = self._check_xfer_state
+ self.query_memory = self._query_memory
+
+ def create_backend(self, backend_type, params):
+ pass
+
+ def _register_memory(self, descs, mem_type=None, backends=None):
+ mock = MagicMock()
+ mock.trim.return_value = MagicMock()
+ # Capture obj_keys from OBJ 4-tuples: (addr, len, dev_id, obj_key)
+ if mem_type == "OBJ" and descs:
+ self._last_obj_keys = [d[3] for d in descs if d[3]]
+ return mock
+
+ def deregister_memory(self, desc):
+ pass
+
+ def prep_xfer_dlist(self, agent_name, descs, mem_type=None, backends=None):
+ return MagicMock()
+
+ def _make_prepped_xfer(
+ self,
+ op,
+ local_handle,
+ local_indices,
+ remote_handle,
+ remote_indices,
+ notif_msg=b"",
+ backends=None,
+ skip_desc_merge=False,
+ ):
+ handle = MagicMock()
+ handle._id = self._handle_counter
+ self._pending[self._handle_counter] = (op, list(self._last_obj_keys))
+ self._handle_counter += 1
+ return handle
+
+ def transfer(self, handle):
+ return "PROC"
+
+ def _check_xfer_state(self, handle):
+ entry = self._pending.pop(handle._id, None)
+ if entry:
+ op, obj_keys = entry
+ if op == "WRITE":
+ self._stored_obj_keys.update(obj_keys)
+ return "DONE"
+
+ def release_xfer_handle(self, handle):
+ pass
+
+ def release_dlist_handle(self, handle):
+ pass
+
+ def _query_memory(self, queries, mem_type, agent_name):
+ return [object() if q[3] in self._stored_obj_keys else None for q in queries]
+
+
+# ---------------------------------------------------------------------------
+# Fixture
+# ---------------------------------------------------------------------------
+
+
+def _make_tier(
+ num_blocks: int = 4,
+) -> tuple[ObjectStoreSecondaryTierManager, MockNixlAgent]:
+ """Create a tier backed by a fresh MockNixlAgent."""
+ mock_agent = MockNixlAgent()
+ tensor = torch.zeros((num_blocks, _BLOCK_ELEMENTS), dtype=_DTYPE)
+ view = memoryview(tensor.numpy())
+ with (
+ patch("vllm.v1.kv_offload.tiering.obj.manager.nixl_agent_config"),
+ patch(
+ "vllm.v1.kv_offload.tiering.obj.manager.nixl_agent",
+ return_value=mock_agent,
+ ),
+ ):
+ tier = ObjectStoreSecondaryTierManager(
+ offloading_spec=_OFFLOADING_SPEC,
+ primary_kv_view=view,
+ tier_type="obj",
+ store_config=_STORE_CONFIG,
+ prefix=_RUN_PREFIX,
+ )
+ return tier, mock_agent
+
+
+def drain(
+ tier: ObjectStoreSecondaryTierManager, max_rounds: int = 20
+) -> list[JobResult]:
+ """Poll get_finished_jobs() until all in-flight jobs resolve."""
+ results: list[JobResult] = []
+ for _ in range(max_rounds):
+ results.extend(tier.get_finished_jobs())
+ if not tier._transfers:
+ break
+ return results
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+class TestMockObjTierBasic:
+ def setup_method(self):
+ self.tier, self.agent = _make_tier(num_blocks=4)
+
+ def test_lookup_empty_tier(self):
+ assert self.tier.lookup(key(1), _CTX) is False
+
+ def test_store_and_lookup(self):
+ self.tier.submit_store(make_job(1, [key(1)], [0]))
+ results = drain(self.tier)
+ assert len(results) == 1
+ assert results[0].success
+ assert self.tier.lookup(key(1), _CTX) is True
+
+ def test_lookup_unrelated_key_returns_false(self):
+ self.tier.submit_store(make_job(1, [key(1)], [0]))
+ drain(self.tier)
+ assert self.tier.lookup(key(999), _CTX) is False
+
+ def test_store_then_load_roundtrip(self):
+ self.tier.submit_store(make_job(1, [key(1), key(2)], [0, 1]))
+ results = drain(self.tier)
+ assert results[0].success
+
+ self.tier.submit_load(make_job(2, [key(1), key(2)], [0, 1]))
+ results = drain(self.tier)
+ assert len(results) == 1
+ assert results[0].success
+
+ def test_multiple_jobs_tracked_independently(self):
+ self.tier.submit_store(make_job(1, [key(1)], [0]))
+ self.tier.submit_store(make_job(2, [key(2)], [1]))
+ results = drain(self.tier)
+ assert len(results) == 2
+ assert all(r.success for r in results)
+
+ def test_failed_transfer_reported(self):
+ self.agent.check_xfer_state = lambda h: "ERR"
+ self.tier.submit_store(make_job(1, [key(1)], [0]))
+ results = drain(self.tier)
+ assert len(results) == 1
+ assert not results[0].success
+
+ def test_pending_transfer_not_returned_until_done(self):
+ # First poll returns PROC; second poll returns DONE.
+ call_count = [0]
+ original = self.agent.check_xfer_state
+
+ def delayed(h):
+ call_count[0] += 1
+ return "PROC" if call_count[0] == 1 else original(h)
+
+ self.agent.check_xfer_state = delayed
+
+ self.tier.submit_store(make_job(1, [key(1)], [0]))
+ assert list(self.tier.get_finished_jobs()) == []
+ results = list(self.tier.get_finished_jobs())
+ assert len(results) == 1
+ assert results[0].success
+
+
+class TestMockObjTierMultiBlock:
+ def test_store_multiple_blocks(self):
+ tier, _ = _make_tier(num_blocks=8)
+ keys = [key(i) for i in range(8)]
+ tier.submit_store(make_job(1, keys, list(range(8))))
+ results = drain(tier)
+ assert len(results) == 1
+ assert results[0].success
+ assert all(tier.lookup(k, _CTX) for k in keys)
+
+ def test_partial_block_lookup(self):
+ tier, _ = _make_tier(num_blocks=4)
+ tier.submit_store(make_job(1, [key(0), key(1)], [0, 1]))
+ drain(tier)
+ assert tier.lookup(key(0), _CTX) is True
+ assert tier.lookup(key(1), _CTX) is True
+ assert tier.lookup(key(2), _CTX) is False
+
+
+class TestMockObjTierFailures:
+ def test_lookup_exception_returns_false(self):
+ tier, agent = _make_tier(num_blocks=4)
+ agent.query_memory = lambda *a, **k: (_ for _ in ()).throw(
+ RuntimeError("backend error")
+ )
+ assert tier.lookup(key(1), _CTX) is False
+
+ def test_submit_store_register_memory_failure_reported_in_get_finished(self):
+ tier, agent = _make_tier(num_blocks=4)
+ agent.register_memory = lambda *a, **k: None
+ tier.submit_store(make_job(1, [key(1)], [0]))
+ results = list(tier.get_finished_jobs())
+ assert len(results) == 1
+ assert results[0].job_id == 1
+ assert not results[0].success
+
+ def test_submit_load_register_memory_failure_reported_in_get_finished(self):
+ tier, agent = _make_tier(num_blocks=4)
+ agent.register_memory = lambda *a, **k: None
+ tier.submit_load(make_job(2, [key(1)], [0]))
+ results = list(tier.get_finished_jobs())
+ assert len(results) == 1
+ assert results[0].job_id == 2
+ assert not results[0].success
+
+ def test_submit_store_make_prepped_xfer_failure_reported_in_get_finished(self):
+ tier, agent = _make_tier(num_blocks=4)
+ agent.make_prepped_xfer = lambda *a, **k: None
+ tier.submit_store(make_job(3, [key(1)], [0]))
+ results = list(tier.get_finished_jobs())
+ assert len(results) == 1
+ assert results[0].job_id == 3
+ assert not results[0].success
+
+ def test_failure_and_success_both_returned_by_get_finished(self):
+ # One job fails at submission, another succeeds in flight.
+ tier, agent = _make_tier(num_blocks=4)
+ original_register = agent.register_memory
+ call_count = [0]
+
+ def register_once_fail(*a, **k):
+ call_count[0] += 1
+ return None if call_count[0] == 1 else original_register(*a, **k)
+
+ agent.register_memory = register_once_fail
+
+ tier.submit_store(make_job(1, [key(1)], [0])) # fails immediately
+ tier.submit_store(make_job(2, [key(2)], [1])) # succeeds
+ results = drain(tier)
+ assert len(results) == 2
+ by_id = {r.job_id: r for r in results}
+ assert not by_id[1].success
+ assert by_id[2].success
+
+
+class TestMockObjTierShutdown:
+ def test_shutdown_clears_in_flight_transfers(self):
+ tier, agent = _make_tier(num_blocks=4)
+ # Keep transfer in flight by never completing it
+ agent.check_xfer_state = lambda h: "PROC"
+ tier.submit_store(make_job(1, [key(1)], [0]))
+ assert len(tier._transfers) == 1
+ tier.shutdown()
+ assert len(tier._transfers) == 0
+ assert tier._dram_prepped_handle is None
+ assert tier._primary_reg is None
+
+ def test_shutdown_idempotent(self):
+ tier, _ = _make_tier(num_blocks=4)
+ tier.shutdown()
+ tier.shutdown() # must not raise
diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py
index ae0cbeab53b..e628f903792 100644
--- a/tests/v1/sample/test_rejection_sampler.py
+++ b/tests/v1/sample/test_rejection_sampler.py
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
+ use_fp64_gumbel: bool = False,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
+ q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
- dtype=torch.float32,
+ dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
+def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
+ batch_size = 2
+ vocab_size = 64
+ max_spec_len = 2
+ num_tokens = batch_size * max_spec_len
+
+ draft_probs = torch.rand(
+ num_tokens,
+ vocab_size,
+ dtype=torch.float32,
+ device=DEVICE_TYPE,
+ )
+ draft_probs = F.softmax(draft_probs, dim=-1)
+ target_probs = torch.rand(
+ num_tokens,
+ vocab_size,
+ dtype=torch.float32,
+ device=DEVICE_TYPE,
+ )
+ target_probs = F.softmax(target_probs, dim=-1)
+ draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
+
+ generators = {
+ i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
+ }
+ sampling_metadata = create_sampling_metadata(
+ all_greedy=False,
+ temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
+ generators=generators,
+ )
+ spec_decode_metadata = create_spec_decode_metadata(
+ draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
+ target_probs.log(),
+ )
+
+ expected = native_sample_recovered_tokens(
+ max_spec_len,
+ spec_decode_metadata.num_draft_tokens,
+ spec_decode_metadata.cu_num_draft_tokens,
+ draft_token_ids,
+ draft_probs,
+ target_probs,
+ sampling_metadata,
+ device=torch.device(DEVICE_TYPE),
+ use_fp64_gumbel=True,
+ )
+ actual = sample_recovered_tokens(
+ max_spec_len,
+ spec_decode_metadata.num_draft_tokens,
+ spec_decode_metadata.cu_num_draft_tokens,
+ draft_token_ids,
+ draft_probs,
+ target_probs,
+ sampling_metadata,
+ device=torch.device(DEVICE_TYPE),
+ use_fp64_gumbel=True,
+ )
+
+ assert torch.equal(actual, expected)
+
+
########################### Tests for Synthetic Rejection Sampling #########
diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py
index a80fddc9235..554649b5e19 100644
--- a/tests/v1/sample/test_topk_topp_sampler.py
+++ b/tests/v1/sample/test_topk_topp_sampler.py
@@ -6,7 +6,12 @@ from torch import Generator
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
-from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
+from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.sample.ops.topk_topp_sampler import (
+ apply_top_k_top_p_pytorch,
+ random_sample,
+)
+from vllm.v1.sample.sampler import Sampler
DEVICE_TYPE = current_platform.device_type
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
+def _seed_default_generator(seed: int) -> None:
+ set_random_seed(seed)
+
+
@pytest.fixture(autouse=True)
def reset_default_device():
"""
@@ -49,6 +58,35 @@ def reset_default_device():
torch.set_default_device(original_device)
+def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
+ sampler = Sampler(use_fp64_gumbel=True)
+
+ assert sampler.topk_topp_sampler.use_fp64_gumbel
+
+
+def test_random_sample_uses_fp64_exponential_race_when_requested():
+ torch.set_default_device(DEVICE_TYPE)
+ probs = torch.tensor(
+ [
+ [0.70, 0.20, 0.10],
+ [0.05, 0.15, 0.80],
+ [0.25, 0.25, 0.50],
+ ],
+ dtype=torch.float32,
+ device=DEVICE_TYPE,
+ )
+
+ _seed_default_generator(12345)
+ q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
+ q.exponential_()
+ expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
+
+ _seed_default_generator(12345)
+ actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
+
+ assert torch.equal(actual, expected)
+
+
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py
index c13de6d4f71..32f9dcc86ab 100644
--- a/tests/v1/spec_decode/test_eagle.py
+++ b/tests/v1/spec_decode/test_eagle.py
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
proposer.model = model_mock
proposer._draft_attn_layer_names = {"layer.0"}
- def fake_compute_probs(logits, sampling_metadata):
+ def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
+ assert use_fp64_gumbel == proposer.use_fp64_gumbel
probs = torch.softmax(logits, dim=-1)
return probs.argmax(dim=-1), probs
diff --git a/tests/v1/spec_decode/test_llm_base_proposer_sampling.py b/tests/v1/spec_decode/test_llm_base_proposer_sampling.py
new file mode 100644
index 00000000000..9c7ec760ebb
--- /dev/null
+++ b/tests/v1/spec_decode/test_llm_base_proposer_sampling.py
@@ -0,0 +1,70 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+
+from vllm.platforms import current_platform
+from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.sample.logits_processor import LogitsProcessors
+from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.spec_decode.llm_base_proposer import (
+ compute_probs_and_sample_next_token,
+)
+
+DEVICE_TYPE = current_platform.device_type
+
+
+def _seed_default_generator(seed: int) -> None:
+ set_random_seed(seed)
+
+
+def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
+ return SamplingMetadata(
+ temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
+ all_greedy=False,
+ all_random=True,
+ top_p=None,
+ top_k=None,
+ generators={},
+ max_num_logprobs=None,
+ no_penalties=True,
+ prompt_token_ids=None,
+ frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
+ presence_penalties=torch.empty(0, device=DEVICE_TYPE),
+ repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
+ output_token_ids=[[] for _ in range(batch_size)],
+ spec_token_ids=[[] for _ in range(batch_size)],
+ allowed_token_ids_mask=None,
+ bad_words_token_ids={},
+ logitsprocs=LogitsProcessors(),
+ )
+
+
+def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
+ batch_size = 4
+ vocab_size = 32
+ generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
+ logits = torch.randn(
+ batch_size,
+ vocab_size,
+ dtype=torch.float32,
+ device=DEVICE_TYPE,
+ generator=generator,
+ )
+ metadata = _make_sampling_metadata(batch_size)
+
+ _seed_default_generator(12345)
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
+ q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
+ q.exponential_()
+ expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
+
+ _seed_default_generator(12345)
+ actual_ids, actual_probs = compute_probs_and_sample_next_token(
+ logits.clone(),
+ metadata,
+ use_fp64_gumbel=True,
+ )
+
+ assert torch.equal(actual_ids, expected_ids)
+ assert torch.allclose(actual_probs, probs)
diff --git a/tools/gumbel_precision/prove_exponential_race_precision.py b/tools/gumbel_precision/prove_exponential_race_precision.py
new file mode 100644
index 00000000000..2af8f40fa74
--- /dev/null
+++ b/tools/gumbel_precision/prove_exponential_race_precision.py
@@ -0,0 +1,141 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""CUDA proof for fp32 exponential-race tail truncation.
+
+This script is intentionally not a unit test. It is a reproducible, GPU-only
+statistical proof for the hidden Gumbel-max idiom:
+
+ q.exponential_()
+ sample = (probs / q).argmax()
+
+For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA,
+fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very
+small q values are impossible. The many-tail experiment below chooses a case
+where a correct sampler should select a low-probability tail token dozens of
+times, while fp32 q cannot select one.
+"""
+
+from __future__ import annotations
+
+import argparse
+import math
+import time
+
+import torch
+
+
+def _seed(seed: int) -> None:
+ torch.manual_seed(seed)
+
+
+def measure_exponential_lower_tail(
+ *,
+ device: torch.device,
+ samples: int,
+ chunk_size: int,
+ seed: int,
+) -> None:
+ threshold = 2.0**-24
+ print(f"lower-tail threshold: {threshold:.18e}")
+ for dtype in (torch.float32, torch.float64):
+ _seed(seed)
+ count_below = 0
+ min_q = float("inf")
+ max_q = 0.0
+ start = time.perf_counter()
+ remaining = samples
+ while remaining > 0:
+ n = min(chunk_size, remaining)
+ q = torch.empty((n,), dtype=dtype, device=device)
+ q.exponential_()
+ count_below += int((q < threshold).sum().item())
+ min_q = min(min_q, float(q.min().item()))
+ max_q = max(max_q, float(q.max().item()))
+ remaining -= n
+ torch.accelerator.synchronize()
+ elapsed = time.perf_counter() - start
+ print(
+ f"{dtype}: samples={samples} count(q < 2^-24)={count_below} "
+ f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s"
+ )
+
+
+def run_many_tail_race(
+ *,
+ device: torch.device,
+ trials: int,
+ num_tail_tokens: int,
+ gap: float,
+ chunk_trials: int,
+ seed: int,
+) -> None:
+ p_tail = math.exp(-gap)
+ expected_tail_hits = (
+ trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail)
+ )
+ print(
+ "many-tail race: "
+ f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} "
+ f"expected_tail_hits={expected_tail_hits:.4f}"
+ )
+
+ for dtype in (torch.float32, torch.float64):
+ _seed(seed)
+ hits = 0
+ p0 = torch.tensor(1.0, dtype=dtype, device=device)
+ pt = torch.tensor(p_tail, dtype=dtype, device=device)
+ start = time.perf_counter()
+ remaining = trials
+ while remaining > 0:
+ batch = min(chunk_trials, remaining)
+ q0 = torch.empty((batch,), dtype=dtype, device=device)
+ q0.exponential_()
+ qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device)
+ qt.exponential_()
+ head_score = p0 / q0
+ tail_score = (pt / qt).amax(dim=-1)
+ hits += int((tail_score > head_score).sum().item())
+ remaining -= batch
+ torch.accelerator.synchronize()
+ elapsed = time.perf_counter() - start
+ print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--lower-tail-samples", type=int, default=200_000_000)
+ parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000)
+ parser.add_argument("--race-trials", type=int, default=100_000)
+ parser.add_argument("--race-tail-tokens", type=int, default=262_144)
+ parser.add_argument("--race-gap", type=float, default=20.5)
+ parser.add_argument("--race-chunk-trials", type=int, default=64)
+ parser.add_argument("--seed", type=int, default=2026)
+ args = parser.parse_args()
+
+ if not torch.accelerator.is_available():
+ raise RuntimeError("CUDA is required for this proof.")
+
+ device = torch.accelerator.current_accelerator()
+ if device.type != "cuda":
+ raise RuntimeError("CUDA is required for this proof.")
+
+ print(f"torch={torch.__version__} cuda={torch.version.cuda}")
+ print(f"device={device}")
+ measure_exponential_lower_tail(
+ device=device,
+ samples=args.lower_tail_samples,
+ chunk_size=args.lower_tail_chunk_size,
+ seed=args.seed,
+ )
+ run_many_tail_race(
+ device=device,
+ trials=args.race_trials,
+ num_tail_tokens=args.race_tail_tokens,
+ gap=args.race_gap,
+ chunk_trials=args.race_chunk_trials,
+ seed=args.seed,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py
index 22855080824..a174208da4c 100755
--- a/tools/pre_commit/mypy.py
+++ b/tools/pre_commit/mypy.py
@@ -8,11 +8,9 @@ on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
- python tools/pre_commit/mypy.py
+ python tools/pre_commit/mypy.py
Args:
- ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
- "silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
@@ -98,8 +96,8 @@ def mypy(
def main():
- python_version = sys.argv[2]
- file_groups = group_files(sys.argv[3:])
+ python_version = sys.argv[1]
+ file_groups = group_files(sys.argv[2:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index 5a8b690433c..eb12bedd7bf 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -27,6 +27,16 @@ except ImportError:
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
+_HIPB_MM_INITIALIZED_DEVICES: set[int] = set()
+
+
+def _ensure_hipb_mm_extension_initialized() -> None:
+ import aiter
+
+ device = torch.accelerator.current_device_index()
+ if device not in _HIPB_MM_INITIALIZED_DEVICES:
+ aiter.hipb_create_extension()
+ _HIPB_MM_INITIALIZED_DEVICES.add(device)
def is_aiter_found() -> bool:
@@ -625,6 +635,43 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
return torch.empty(m, n, dtype=output_dtype, device=A.device)
+def _rocm_aiter_hipb_mm_fp8_impl(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+ from aiter import hipb_mm
+
+ _ensure_hipb_mm_extension_initialized()
+ return hipb_mm(
+ A,
+ B,
+ solution_index=-1,
+ bias=bias,
+ out_dtype=output_dtype,
+ scaleA=As,
+ scaleB=Bs,
+ scaleOut=None,
+ bpreshuffle=True,
+ )
+
+
+def _rocm_aiter_hipb_mm_fp8_fake(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+ m = A.shape[0]
+ n = B.shape[1]
+ return torch.empty(m, n, dtype=output_dtype, device=A.device)
+
+
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
@@ -1308,6 +1355,7 @@ class rocm_aiter_ops:
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
+ _LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
@@ -1340,6 +1388,7 @@ class rocm_aiter_ops:
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
+ cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@@ -1512,6 +1561,13 @@ class rocm_aiter_ops:
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()
+ @classmethod
+ @if_aiter_supported
+ def is_linear_hipbmm_enabled(cls) -> bool:
+ from vllm.platforms.rocm import on_mi3xx
+
+ return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED
+
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
@@ -1668,6 +1724,12 @@ class rocm_aiter_ops:
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
)
+ direct_register_custom_op(
+ op_name="rocm_aiter_hipb_mm_fp8",
+ op_func=_rocm_aiter_hipb_mm_fp8_impl,
+ fake_impl=_rocm_aiter_hipb_mm_fp8_fake,
+ )
+
direct_register_custom_op(
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
@@ -1858,6 +1920,17 @@ class rocm_aiter_ops:
A, B, As, Bs, bias, output_dtype
)
+ @staticmethod
+ def hipb_mm_fp8(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.bfloat16,
+ ) -> torch.Tensor:
+ return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype)
+
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index f12d128f083..bd8a19b6d2d 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -868,9 +868,10 @@ def cutlass_scaled_mm_azp(
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
- :param azp_adj: In the per-tensor case, this should include the azp.
- Always per-channel.
- :param azp: Only set in the per-token case. Per-token if set.
+ Args:
+ azp_adj: In the per-tensor case, this should include the azp.
+ Always per-channel.
+ azp: Only set in the per-token case. Per-token if set.
"""
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
@@ -3886,9 +3887,12 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.
- :param x: value to be transformed inplace
- :param inplace: modify value in place
- :return: value after transformation
+ Args:
+ x: value to be transformed inplace
+ inplace: modify value in place
+
+ Returns:
+ value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)
diff --git a/vllm/compilation/passes/inductor_pass.py b/vllm/compilation/passes/inductor_pass.py
index 8a0d5326dd9..29410f960cd 100644
--- a/vllm/compilation/passes/inductor_pass.py
+++ b/vllm/compilation/passes/inductor_pass.py
@@ -82,11 +82,12 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
- :param srcs: strings or objects to add to the hash.
- Objects and functions have their source inspected.
- Results are cached by resolved types to avoid repeated
- inspect.getsource() calls.
- :return:
+
+ Args:
+ srcs: strings or objects to add to the hash.
+ Objects and functions have their source inspected.
+ Results are cached by resolved types to avoid repeated
+ inspect.getsource() calls.
"""
# Resolve instances to their class for a hashable cache key.
cache_key = tuple(
@@ -99,7 +100,9 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
- :return: A sha256 hash of the json rep of the dictionary.
+
+ Returns:
+ A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py
index 2887c19ad4a..c0643a916b3 100644
--- a/vllm/compilation/passes/utility/fix_functionalization.py
+++ b/vllm/compilation/passes/utility/fix_functionalization.py
@@ -276,9 +276,11 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
- :param node: The auto-functionalized node
- :param mutated_args: The mutated arguments, indexed by getitem index.
- If the value of an arg is a string, `node.kwargs[arg]` is used.
+
+ Args:
+ node: The auto-functionalized node
+ mutated_args: The mutated arguments, indexed by getitem index.
+ If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
@@ -317,10 +319,11 @@ class FixFunctionalizationPass(VllmInductorPass):
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
- :param graph: Graph to insert the defunctionalized node into
- :param node: The auto-functionalized node to defunctionalize
- :param args: If we cannot use kwargs, specify args directly.
- If an arg is a string, `node.kwargs[arg]` is used.
+ Args:
+ graph: Graph to insert the defunctionalized node into
+ node: The auto-functionalized node to defunctionalize
+ args: If we cannot use kwargs, specify args directly.
+ If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
diff --git a/vllm/compilation/passes/utility/noop_elimination.py b/vllm/compilation/passes/utility/noop_elimination.py
index 5f7d47ad6f8..80bf8ecc603 100644
--- a/vllm/compilation/passes/utility/noop_elimination.py
+++ b/vllm/compilation/passes/utility/noop_elimination.py
@@ -108,9 +108,13 @@ class NoOpEliminationPass(VllmInductorPass):
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
- :param dim: The dimension arg to reshape/slice
- :param i_dim: The corresponding dimension in the input tensor
- :return: Are the dimensions equivalent?
+
+ Args:
+ dim: The dimension arg to reshape/slice
+ i_dim: The corresponding dimension in the input tensor
+
+ Returns:
+ Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 67040a423b7..f648a69e10e 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -235,9 +235,10 @@ class ModelConfig:
temperature and top_k/top_p.
"""
use_fp64_gumbel: bool = False
- """Whether to use FP64 (instead of FP32) for the Gumbel noise used by the
- sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost
- of significantly lower kernel throughput on most GPUs."""
+ """Whether to use FP64 (instead of FP32) random noise for Gumbel-max and
+ equivalent exponential-race sampling. FP64 preserves lower-tail sampling
+ events that fp32 uniform/exponential draws can truncate, at the cost of
+ significantly lower throughput on most GPUs."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py
index 6edd69a949e..a652c5a6c73 100644
--- a/vllm/device_allocator/cumem.py
+++ b/vllm/device_allocator/cumem.py
@@ -180,8 +180,9 @@ class CuMemAllocator:
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
- :param offload_tags: The tags of the memory allocation that will be
- offloaded. The rest of the memory allocation will be discarded.
+ Args:
+ offload_tags: The tags of the memory allocation that will be
+ offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
@@ -230,9 +231,10 @@ class CuMemAllocator:
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
- :param tags: The tags of the memory allocation that will be loaded
- back to GPU memory. If None, all memory allocation will be loaded
- back to GPU memory.
+ Args:
+ tags: The tags of the memory allocation that will be loaded
+ back to GPU memory. If None, all memory allocation will be loaded
+ back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
@@ -255,8 +257,9 @@ class CuMemAllocator:
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
- :param tag: The tag of the memory allocation. If None, the default tag
- will be used.
+ Args:
+ tag: The tag of the memory allocation. If None, the default tag
+ will be used.
"""
if tag is None:
tag = CuMemAllocator.default_tag
diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py
index ee21185969f..c9bc8909519 100644
--- a/vllm/distributed/kv_events.py
+++ b/vllm/distributed/kv_events.py
@@ -132,7 +132,8 @@ class KVEventAggregator:
"""
Add events from a worker batch.
- :param events: List of KVCacheEvent objects.
+ Args:
+ events: List of KVCacheEvent objects.
"""
if not isinstance(events, list):
raise TypeError("events must be a list of KVCacheEvent.")
@@ -142,7 +143,8 @@ class KVEventAggregator:
"""
Return events that appeared in all workers.
- :return: List of events present in all workers.
+ Returns:
+ List of events present in all workers.
"""
return [
event
@@ -154,7 +156,8 @@ class KVEventAggregator:
"""
Return all events for all workers.
- :return: List of events for all workers.
+ Returns:
+ List of events for all workers.
"""
return list(self._event_counter.elements())
@@ -168,7 +171,8 @@ class KVEventAggregator:
"""
Increment the number of workers contributing events.
- :param count: Number to increment the workers by.
+ Args:
+ count: Number to increment the workers by.
"""
if count <= 0:
raise ValueError("count must be positive.")
@@ -184,7 +188,8 @@ class KVEventAggregator:
"""
Return the number of workers.
- :return: int number of workers.
+ Returns:
+ int number of workers.
"""
return self._num_workers
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index fafc1f45724..0ab694b7e73 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -368,7 +368,7 @@ def get_current_attn_backend(
class EngineTransferInfo:
"""Common per-remote-engine transfer state, computed at handshake.
- Stored per ``engine_id`` inside ``TransferTopology._engines``.
+ Stored per ``(engine_id, pp_rank)`` inside ``TransferTopology._engines``.
"""
remote_tp_size: int
@@ -382,6 +382,15 @@ class EngineTransferInfo:
remote_physical_blocks_per_logical: int
"""Physical blocks per logical block."""
+ remote_pp_rank: int = 0
+ """Remote producer PP rank for this engine."""
+
+ start_layer: int = 0
+ """Global index of the first layer owned by this PP rank."""
+
+ end_layer: int = 0
+ """Exclusive global index after the last layer owned by this PP rank."""
+
# ---- Transfer topology ----
@@ -403,7 +412,7 @@ class TransferTopology:
def __post_init__(self):
self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size)
- self._engines: dict[EngineId, EngineTransferInfo] = {}
+ self._engines: dict[tuple[EngineId, int], EngineTransferInfo] = {}
# Figure out whether the first dimension of the cache is K/V
# or num_blocks.
@@ -461,13 +470,16 @@ class TransferTopology:
f"Cannot register local engine {self.engine_id} as remote. "
f"Local identity is set via __init__ params."
)
- if remote_engine_id in self._engines:
- return self._engines[remote_engine_id]
- self._engines[remote_engine_id] = info
+ engine_key = (remote_engine_id, info.remote_pp_rank)
+ if engine_key in self._engines:
+ return self._engines[engine_key]
+ self._engines[engine_key] = info
return info
- def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo:
- return self._engines[remote_engine_id]
+ def get_engine_info(
+ self, remote_engine_id: EngineId, remote_pp_rank: int = 0
+ ) -> EngineTransferInfo:
+ return self._engines[(remote_engine_id, remote_pp_rank)]
# ============================================================
# Layout properties
@@ -528,15 +540,22 @@ class TransferTopology:
)
return self.block_size // remote_block_size
- def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
+ def is_kv_replicated(
+ self, remote_engine_id: EngineId, remote_pp_rank: int = 0
+ ) -> bool:
"""Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
- return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads
+ return (
+ self._engines[(remote_engine_id, remote_pp_rank)].remote_tp_size
+ > self.total_num_kv_heads
+ )
- def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
+ def replicates_kv_cache(
+ self, remote_engine_id: EngineId, remote_pp_rank: int = 0
+ ) -> bool:
# MLA is always replicated as the hidden dim can't be split.
- return self.is_mla or self.is_kv_replicated(remote_engine_id)
+ return self.is_mla or self.is_kv_replicated(remote_engine_id, remote_pp_rank)
@property
def local_replicates_kv_cache(self) -> bool:
@@ -555,12 +574,14 @@ class TransferTopology:
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
- def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]:
+ def target_remote_ranks(
+ self, remote_engine_id: EngineId, remote_pp_rank: int = 0
+ ) -> list[int]:
"""Get the remote TP rank(s) that the current local TP rank will
read from. When remote tp_size > local tp_size, reads from
multiple remote ranks.
"""
- info = self._engines[remote_engine_id]
+ info = self._engines[(remote_engine_id, remote_pp_rank)]
tp_ratio = self.tp_ratio(info.remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
@@ -593,15 +614,16 @@ class TransferTopology:
# Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache]
- def describe(self, remote_engine_id: EngineId) -> str:
+ def describe(self, remote_engine_id: EngineId, remote_pp_rank: int = 0) -> str:
"""One-line summary of transfer config for logging."""
- info = self._engines[remote_engine_id]
+ info = self._engines[(remote_engine_id, remote_pp_rank)]
return (
f"TransferTopology("
f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, "
f"num_kv_heads={self.total_num_kv_heads if not self.is_mla else 1}, "
f"local_tp={self.tp_size}, "
f"remote_tp={info.remote_tp_size}, "
+ f"remote_pp={remote_pp_rank}, "
f"local_rank={self.tp_rank}, "
f"remote_block_len={info.remote_block_len})"
)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index fb5658da887..71d89f43a79 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -644,6 +644,23 @@ class KVConnectorBase_V1(ABC):
"""
return None
+ def set_xfer_handshake_metadata_pp_aware(
+ self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata]
+ ) -> None:
+ """
+ Set handshake metadata keyed by (pp_rank, tp_rank).
+ - Default implementation assumes pp_rank is always 0
+ - PP-aware connectors override this to consume all PP producer shards.
+ """
+ if any(pp_rank != 0 for pp_rank, _ in metadata):
+ raise ValueError(
+ f"{type(self).__name__} received pp_rank > 0 handshake metadata "
+ "but does not support PP-disaggregated KV transfer."
+ )
+ self.set_xfer_handshake_metadata(
+ {tp_rank: meta for (_, tp_rank), meta in metadata.items()}
+ )
+
@classmethod
def build_prom_metrics(
cls,
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
index 35cd7060691..d16fbee585a 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
@@ -439,13 +439,12 @@ def _init_lmcache_engine(
`LMCACHE_CONFIG_FILE` to load the configuration file. If that environment
variable is not set, this function will return None.
- :param lmcache_config: The LMCache configuration.
- :type lmcache_config: LMCacheEngineConfig
- :param vllm_config: The vLLM configuration.
- :type vllm_config: VllmConfig
+ Args:
+ lmcache_config: The LMCache configuration.
+ vllm_config: The vLLM configuration.
- :return: The initialized LMCache engine
- :rtype: LMCacheEngine
+ Returns:
+ The initialized LMCache engine
"""
if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME):
return curr_engine
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
index 73418104bea..46354337e65 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@@ -471,6 +471,12 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA):
for c in self._connectors:
c.set_xfer_handshake_metadata(metadata)
+ def set_xfer_handshake_metadata_pp_aware(
+ self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata]
+ ) -> None:
+ for c in self._connectors:
+ c.set_xfer_handshake_metadata_pp_aware(metadata)
+
def _aggregate_request_finished(
self,
request: "Request",
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
index 187322b4ae4..dad81e84c45 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py
@@ -94,6 +94,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
+
+ if vllm_config.kv_transfer_config.kv_role == "kv_both":
+ logger.warning_once(
+ "Using kv_role='kv_both' with NixlConnector is deprecated "
+ "and will be removed in a future release. Please set "
+ "kv_role='kv_producer' for prefill instances and "
+ "kv_role='kv_consumer' for decode instances. "
+ )
+
self.kv_cache_config = kv_cache_config
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 8775e519d99..8bd6e92157a 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -227,6 +227,67 @@ def patched_fused_scaled_matmul_reduce_scatter_fake(
return res
+def _platform_device_type() -> str:
+ """Return the device-type string (e.g. ``"cuda"``, ``"xpu"``, ``"cpu"``)
+ for the current platform, in the form expected by
+ ``torch.distributed.init_process_group(backend=...)``.
+ """
+ from vllm.platforms import current_platform
+
+ if current_platform.is_cuda_alike():
+ return "cuda"
+ elif current_platform.is_xpu():
+ return "xpu"
+ elif current_platform.is_out_of_tree():
+ return current_platform.device_name
+ else:
+ return "cpu"
+
+
+def _device_backend_str(torch_distributed_backend: str | Backend) -> str:
+ """Normalize ``torch_distributed_backend`` to the ``":"``
+ format required by ``split_group``'s ``backend`` argument.
+
+ Accepts either a bare backend name (e.g. ``"nccl"``) or an already-prefixed
+ string (e.g. ``"cuda:nccl"``).
+ """
+ backend_str = str(torch_distributed_backend)
+ if ":" in backend_str:
+ return backend_str
+ return f"{_platform_device_type()}:{backend_str}"
+
+
+def _create_subgroups_split_group(
+ group_ranks: list[list[int]],
+ group_name: str,
+ torch_distributed_backend: str | Backend,
+) -> tuple[ProcessGroup, ProcessGroup]:
+ """Create the device + CPU subgroups for ``GroupCoordinator`` via
+ ``torch.distributed.split_group``.
+
+ ``split_group`` is collective on the parent group, so every parent rank
+ must enter with the same ``split_ranks`` definition. Each rank receives
+ the subgroup it belongs to.
+ """
+ device_backend_str = _device_backend_str(torch_distributed_backend)
+ self_device_group = torch.distributed.split_group(
+ split_ranks=group_ranks,
+ group_desc=f"{group_name}:device",
+ backend=device_backend_str,
+ )
+ # CPU subgroup: split_group requires the requested backend filter to
+ # include the parent's default device type (= the device the parent PG
+ # was bound to via ``device_id``), so a cpu-only filter is rejected.
+ # Include the device backend in the filter; only the gloo backend is
+ # actually used for CPU collectives on this group.
+ self_cpu_group = torch.distributed.split_group(
+ split_ranks=group_ranks,
+ group_desc=f"{group_name}:cpu",
+ backend=f"cpu:gloo,{device_backend_str}",
+ )
+ return self_device_group, self_cpu_group
+
+
def patched_fused_scaled_matmul_reduce_scatter(
A: torch.Tensor,
B: torch.Tensor,
@@ -335,26 +396,39 @@ class GroupCoordinator:
self_device_group = None
self_cpu_group = None
- from vllm.distributed.utils import get_cpu_distributed_timeout_or_none
-
- timeout = get_cpu_distributed_timeout_or_none()
-
- for ranks in group_ranks:
- device_group = torch.distributed.new_group(
- ranks, backend=torch_distributed_backend
+ # VLLM_DISTRIBUTED_USE_SPLIT_GROUP gates the new ``split_group``
+ # codepath. Default (False) preserves the legacy ``new_group`` path.
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ self_device_group, self_cpu_group = _create_subgroups_split_group(
+ group_ranks, group_name, torch_distributed_backend
)
- # a group with `gloo` backend, to allow direct coordination between
- # processes through the CPU.
- with suppress_stdout():
- cpu_group = torch.distributed.new_group(
- ranks, backend="gloo", timeout=timeout
+ for ranks in group_ranks:
+ if self.rank in ranks:
+ self.ranks = ranks
+ self.world_size = len(ranks)
+ self.rank_in_group = ranks.index(self.rank)
+ break
+ else:
+ from vllm.distributed.utils import get_cpu_distributed_timeout_or_none
+
+ timeout = get_cpu_distributed_timeout_or_none()
+
+ for ranks in group_ranks:
+ device_group = torch.distributed.new_group(
+ ranks, backend=torch_distributed_backend
)
- if self.rank in ranks:
- self.ranks = ranks
- self.world_size = len(ranks)
- self.rank_in_group = ranks.index(self.rank)
- self_device_group = device_group
- self_cpu_group = cpu_group
+ # a group with `gloo` backend, to allow direct coordination between
+ # processes through the CPU.
+ with suppress_stdout():
+ cpu_group = torch.distributed.new_group(
+ ranks, backend="gloo", timeout=timeout
+ )
+ if self.rank in ranks:
+ self.ranks = ranks
+ self.world_size = len(ranks)
+ self.rank_in_group = ranks.index(self.rank)
+ self_device_group = device_group
+ self_cpu_group = cpu_group
assert self_cpu_group is not None
assert self_device_group is not None
@@ -1348,6 +1422,62 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
+def _init_process_group_for_split_group(
+ *,
+ backend: str,
+ distributed_init_method: str,
+ world_size: int,
+ rank: int,
+ local_rank: int,
+ timeout: timedelta | None,
+) -> None:
+ """Initialize the default PG with both CPU (gloo) and device (e.g. nccl)
+ backends and an eager ``device_id`` binding so that subgroups can be
+ created via ``split_group`` (which requires the parent communicator to
+ be eagerly initialized). Falls back to ``gloo`` on CPU-only systems.
+ """
+ if torch.accelerator.is_available() and backend != "gloo":
+ init_backend = "cpu:gloo,cuda:nccl"
+ device_id: torch.device | None = torch.device(f"cuda:{local_rank}")
+ else:
+ init_backend = "gloo"
+ device_id = None
+ torch.distributed.init_process_group(
+ backend=init_backend,
+ init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank,
+ timeout=timeout,
+ device_id=device_id,
+ )
+
+
+def _validate_default_pg_for_split_group() -> None:
+ """When an external launcher (e.g. ``torchrun``) initialized the default
+ PG, ``GroupCoordinator`` cannot patch in additional backends or change
+ the eager-init behavior — ``split_group`` only selects subsets of an
+ existing parent. Validate that the parent has both ``device_id`` and a
+ CPU (gloo) backend, and emit a descriptive error pointing at the exact
+ init call to update otherwise.
+ """
+ default_pg = torch.distributed.distributed_c10d._get_default_group()
+ assert default_pg.bound_device_id is not None, (
+ "External launcher initialized the default process group "
+ "without device_id. vLLM requires the default PG to be device-"
+ "bound for split_group. Pass device_id=torch.device(f'cuda:"
+ "{local_rank}') to torch.distributed.init_process_group()."
+ )
+ try:
+ default_pg._get_backend(torch.device("cpu"))
+ except RuntimeError as e:
+ raise RuntimeError(
+ "External launcher initialized the default process group "
+ "without a CPU (gloo) backend. vLLM requires both CPU and "
+ "device backends. Pass backend='cpu:gloo,cuda:nccl' to "
+ "torch.distributed.init_process_group()."
+ ) from e
+
+
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
@@ -1456,14 +1586,33 @@ def init_distributed_environment(
"Fallback Gloo backend is not available."
)
backend = "gloo"
- # this backend is used for WORLD
- torch.distributed.init_process_group(
- backend=backend,
- init_method=distributed_init_method,
- world_size=world_size,
- rank=rank,
- timeout=timeout,
- )
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
+ # split_group needs local_rank early to compute device_id for
+ # the eager init. local_rank is not available in torch
+ # ProcessGroup, see https://github.com/pytorch/pytorch/issues/122816
+ if local_rank == -1:
+ local_rank = (
+ int(envs.LOCAL_RANK)
+ if distributed_init_method == "env://"
+ else rank
+ )
+ _init_process_group_for_split_group(
+ backend=backend,
+ distributed_init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank,
+ local_rank=local_rank,
+ timeout=timeout,
+ )
+ else:
+ # this backend is used for WORLD
+ torch.distributed.init_process_group(
+ backend=backend,
+ init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank,
+ timeout=timeout,
+ )
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
@@ -1476,6 +1625,9 @@ def init_distributed_environment(
"Elastic EP is not yet supported with multi-node TP/PP"
)
+ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP and torch.accelerator.is_available():
+ _validate_default_pg_for_split_group()
+
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1903,6 +2055,10 @@ def destroy_distributed_environment():
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
+ logger.debug(
+ "[shutdown] Distributed: cleanup start shutdown_ray=%s",
+ shutdown_ray,
+ )
# Reset environment variable cache
envs.disable_envs_cache()
@@ -1937,6 +2093,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
"torch._C._host_emptyCache() only available in Pytorch >=2.5"
)
+ logger.debug_once("[shutdown] Distributed: cleanup complete")
+
def in_the_same_node_as(
pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py
index a560db87ea2..08a3ab58c78 100644
--- a/vllm/entrypoints/launcher.py
+++ b/vllm/entrypoints/launcher.py
@@ -95,6 +95,9 @@ async def serve_http(
shutdown_event = asyncio.Event()
def signal_handler() -> None:
+ if shutdown_event.is_set():
+ return
+ logger.info_once("[shutdown] API server: shutdown triggered")
shutdown_event.set()
async def dummy_shutdown() -> None:
@@ -108,12 +111,21 @@ async def serve_http(
engine_client = app.state.engine_client
timeout = engine_client.vllm_config.shutdown_timeout
+ mode = "abort" if timeout == 0 else "drain"
+
+ logger.info(
+ "[shutdown] API server: stopping engine client mode=%s timeout=%ss",
+ mode,
+ timeout,
+ )
await loop.run_in_executor(
None, partial(engine_client.shutdown, timeout=timeout)
)
+ logger.info_once("[shutdown] API server: engine client stopped")
server.should_exit = True
+ logger.info_once("[shutdown] API server: signalling HTTP server shutdown")
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
@@ -134,7 +146,7 @@ async def serve_http(
process,
" ".join(process.cmdline()),
)
- logger.info("Shutting down FastAPI HTTP server.")
+ logger.info_once("[shutdown] API server: shutting down FastAPI HTTP server")
return server.shutdown()
finally:
shutdown_task.cancel()
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 7297243f918..892e5035ab6 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -556,7 +556,8 @@ class LLM(BeamSearchOfflineMixin, PoolingOfflineMixin, OfflineInferenceMixin):
and returns their outputs. Use after enqueue() to get results.
Args:
- output_type: The expected output type, defaults to RequestOutput.
+ output_type: The expected output type(s). If not provided, accepts
+ both RequestOutput and PoolingRequestOutput.
use_tqdm: If True, shows a tqdm progress bar.
Returns:
diff --git a/vllm/envs.py b/vllm/envs.py
index bb3bb34284b..a32f055a028 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
+ VLLM_DISTRIBUTED_USE_SPLIT_GROUP: bool = False
VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
@@ -114,6 +115,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
+ VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: bool = False
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_AITER_MOE_DISPATCH_POLICY: int = 0
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
@@ -877,6 +879,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool(
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1"))
),
+ # When True, GroupCoordinator constructs its CPU/device subgroups via
+ # ``torch.distributed.split_group(backend=...)``
+ # and ``init_distributed_environment`` initializes the default PG with
+ # mixed ``cpu:gloo,cuda:nccl`` backend + eager ``device_id`` binding.
+ "VLLM_DISTRIBUTED_USE_SPLIT_GROUP": lambda: bool(
+ int(os.getenv("VLLM_DISTRIBUTED_USE_SPLIT_GROUP", "0"))
+ ),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(
@@ -1109,6 +1118,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_LINEAR": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1")
),
+ "VLLM_ROCM_USE_AITER_LINEAR_HIPBMM": lambda: (
+ os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() in ("true", "1")
+ ),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE": lambda: (
diff --git a/vllm/ir/op.py b/vllm/ir/op.py
index 8e82b5d8c7e..742d3f33ff8 100644
--- a/vllm/ir/op.py
+++ b/vllm/ir/op.py
@@ -113,11 +113,14 @@ def register_op(
"""
Register a new vLLM IR op.
- :param f: the native implementation of the op
- :param name: the name of the op, defaults to the function name
- :param activations: list of activation params, defaults to params starting with 'x'
- :param allow_inplace: add a maybe_inplace overload that allows inplace impls
- :return: the IrOp object if f is provided, otherwise a decorator
+ Args:
+ f: the native implementation of the op
+ name: the name of the op, defaults to the function name
+ activations: list of activation params, defaults to params starting with 'x'
+ allow_inplace: add a maybe_inplace overload that allows inplace impls
+
+ Returns:
+ the IrOp object if f is provided, otherwise a decorator
Example usage:
```python
@@ -245,14 +248,17 @@ class IrOp:
supported: bool = True,
supports_args: Callable[..., bool] | None = None,
inplace: bool = False,
- ):
+ ) -> Callable[[Callable[..., Any]], "IrOpImpl"]:
"""
Register an implementation for this custom op.
- :param provider: The name of the provider, must be unique.
- :param supported: Static support check, use this to check platform support.
- :param supports_args: Dynamic arg support check, used for types and shapes.
- :param inplace: Does this op reuse activation input memory for outputs
- :return: A decorator that registers the implementation.
+ Args:
+ provider: The name of the provider, must be unique.
+ supported: Static support check, use this to check platform support.
+ supports_args: Dynamic arg support check, used for types and shapes.
+ inplace: Does this op reuse activation input memory for outputs
+
+ Returns:
+ A decorator that registers the implementation.
The decorated function must have the same semantics and signature as
the native implementation.
diff --git a/vllm/ir/util.py b/vllm/ir/util.py
index ac8a06155da..e9240f487ac 100644
--- a/vllm/ir/util.py
+++ b/vllm/ir/util.py
@@ -12,9 +12,9 @@ from typing import Any
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
- :param srcs: strings or objects to add to the hash.
- Objects and functions have their source inspected.
- :return:
+ Args:
+ srcs: strings or objects to add to the hash.
+ Objects and functions have their source inspected.
"""
hasher = hashlib.sha256()
for src in srcs:
diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py
index cd2c9eb01be..8162acd5e8d 100644
--- a/vllm/model_executor/kernels/linear/__init__.py
+++ b/vllm/model_executor/kernels/linear/__init__.py
@@ -128,6 +128,7 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
+ AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
AiterInt8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
@@ -285,6 +286,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
+ AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
@@ -1024,6 +1026,7 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
+ "AiterHipbMMPerTokenFp8ScaledMMLinearKernel",
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel",
"AiterPerTokenFp8ScaledMMLinearKernel",
"NvFp4LinearKernel",
diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py
index 5ded5ca798a..1b39491ab34 100644
--- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py
+++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py
@@ -212,6 +212,99 @@ class AiterPreshuffledPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
)
+class AiterHipbMMPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
+ @classmethod
+ def is_supported(
+ cls, compute_capability: int | None = None
+ ) -> tuple[bool, str | None]:
+ if not current_platform.is_rocm():
+ return False, "requires ROCm."
+
+ if not rocm_aiter_ops.is_linear_hipbmm_enabled():
+ return (
+ False,
+ "requires setting `VLLM_ROCM_USE_AITER=1`, "
+ "`VLLM_ROCM_USE_AITER_LINEAR=1`, "
+ "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`.",
+ )
+ try:
+ import aiter # noqa: F401
+ except Exception:
+ return False, "requires aiter library to be installed."
+
+ if not hasattr(aiter, "hipb_mm"):
+ return False, "requires aiter hipb_mm support."
+
+ return True, None
+
+ @classmethod
+ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
+ is_ptpc = (
+ c.activation_quant_key.scale.group_shape.is_per_token()
+ and c.weight_quant_key.scale.group_shape.is_per_channel()
+ )
+ if c.weight_shape is None:
+ return False, "weight_shape is required for Aiter kernels"
+ N, K = c.weight_shape
+
+ if c.out_dtype is not torch.bfloat16:
+ return False, "requires bfloat16 output dtype."
+
+ if not is_ptpc:
+ return (
+ False,
+ "requires per token activation scales and per channel weight scales.",
+ )
+
+ if not (N >= 16 and N % 16 == 0 and K % 16 == 0):
+ return (
+ False,
+ "requires N >= 16 and both N and K divisible by 16, "
+ f"received N={N} and K={K}.",
+ )
+
+ return True, None
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ w_name, w_s_name, *_ = self.layer_param_names
+ w, w_s, *_ = self._get_layer_params(layer)
+
+ # Pre-apply the transposes that used to live in
+ # _rocm_aiter_hipb_mm_fp8_impl so the kernel can consume B/Bs directly.
+ # The `.t()` on the shuffled weight is kept as a non-contiguous view —
+ # materializing it with `.contiguous()` would re-arrange the bytes and
+ # break the `bpreshuffle` layout.
+ shuffled_w = rocm_aiter_ops.shuffle_weight(w.t().contiguous())
+ replace_parameter(
+ layer,
+ w_name,
+ torch.nn.Parameter(shuffled_w.t(), requires_grad=False),
+ )
+
+ if w_s.ndim > 1:
+ replace_parameter(
+ layer,
+ w_s_name,
+ torch.nn.Parameter(w_s.t().contiguous(), requires_grad=False),
+ )
+
+ def apply_scaled_mm(
+ self,
+ *,
+ A: torch.Tensor,
+ B: torch.Tensor,
+ out_dtype: torch.dtype,
+ As: torch.Tensor,
+ Bs: torch.Tensor,
+ bias: torch.Tensor | None,
+ output_shape: list,
+ ) -> torch.Tensor:
+ output_shape[-1] = B.shape[1]
+ return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view(
+ *output_shape
+ )
+
+
class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
diff --git a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py
index d8570049af2..fa91804f35c 100644
--- a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py
@@ -379,8 +379,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
topk_ids,
activation,
global_num_experts,
- # the fp8 cutlass experts use their own expert map.
- None,
+ expert_map,
self.w1_scale,
self.w2_scale,
a1q_scale,
diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py
index a65873aca49..a412a6936d3 100644
--- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py
+++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py
@@ -30,6 +30,8 @@ class TrtLlmMxint4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int | None = None,
+ num_dispatchers: int | None = None,
):
super().__init__(moe_config, quant_config)
self.topk = moe_config.experts_per_token
diff --git a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py
index 82969dd8e25..94208326461 100644
--- a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py
+++ b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py
@@ -14,7 +14,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
+ kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
+ kFp8Static128BlockSym,
kFp8StaticTensorSym,
kInt4Static,
kMxfp4Static,
@@ -62,6 +64,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
self.is_fp8 = False
self.is_int4 = False
self.is_mxfp4 = False
+ self.is_block_fp8 = False
self.is_mxfp8 = False
self.fused_moe_impl: XpuFusedMoe | None = None
@@ -171,6 +174,7 @@ class XPUExperts(mk.FusedMoEExpertsModular):
is_int4=self.is_int4,
is_mxfp4=self.is_mxfp4,
is_mxfp8=self.is_mxfp8,
+ is_block_fp8=self.is_block_fp8,
)
assert self.fused_moe_impl is not None
self.fused_moe_impl.apply(
@@ -238,6 +242,33 @@ class XPUExpertsMxfp8(XPUExpertsFp8):
return (weight_key, activation_key) in SUPPORTED_W_A
+class XPUExpertsBlockFp8(XPUExperts):
+ def __init__(
+ self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int | None = None,
+ num_dispatchers: int | None = None,
+ ):
+ super().__init__(
+ moe_config,
+ quant_config,
+ max_num_tokens,
+ num_dispatchers,
+ )
+ self.is_block_fp8 = True
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+
class XPUExpertsWNA16(XPUExperts):
"""W4A16 INT4-symmetric MoE backed by `xpu_fused_moe(is_int4=True)`.
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 4ff43ce21b8..2f4401563b6 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -599,12 +599,14 @@ class FusedMoE(PluggableLayer):
):
"""
Load grouped weight scales for group quantization or model weights
- :param shard_dim: dimension to shard
- :param expert_data: parameter for a particular expert
- :param shard_id: either w1, w2, or w3
- :param loaded_weight: checkpoint weight to load into the param
- :param tp_rank: tensor parallel rank
- :param load_full_w2: whether or not the w2 loaded should be sharded.
+
+ Args:
+ shard_dim: dimension to shard
+ expert_data: parameter for a particular expert
+ shard_id: either w1, w2, or w3
+ loaded_weight: checkpoint weight to load into the param
+ tp_rank: tensor parallel rank
+ load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index 0a2e3846dd9..cce3245b75c 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -184,11 +184,12 @@ def backend_to_kernel_cls(
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.experts.xpu_moe import (
+ XPUExpertsBlockFp8,
XPUExpertsFp8,
XPUExpertsMxfp8,
)
- return [XPUExpertsFp8, XPUExpertsMxfp8]
+ return [XPUExpertsFp8, XPUExpertsMxfp8, XPUExpertsBlockFp8]
elif backend == Fp8MoeBackend.CPU:
from vllm.model_executor.layers.fused_moe.experts.cpu_moe import (
diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py
index ef7a2745a06..d3ea7fb211a 100644
--- a/vllm/model_executor/layers/lightning_attn.py
+++ b/vllm/model_executor/layers/lightning_attn.py
@@ -4,6 +4,7 @@
import torch
from einops import rearrange
+from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -403,13 +404,16 @@ class _attention(torch.autograd.Function):
v = v.contiguous()
s = s.contiguous()
- # Check CUDA compute capability
- capability = torch.cuda.get_device_capability()
- if capability[0] < 8:
- raise RuntimeError(
- "Flash attention currently only supported",
- "for compute capability >= 80",
- )
+ # Check CUDA compute capability (Ampere+ required for flash attention
+ # path). Other accelerators (ROCm, XPU) rely on their own Triton
+ # backend support and skip this check.
+ if current_platform.is_cuda():
+ capability = torch.cuda.get_device_capability()
+ if capability[0] < 8:
+ raise RuntimeError(
+ "Flash attention currently only supported",
+ "for compute capability >= 80",
+ )
# Get input dimensions
b, h, n, d = q.shape
diff --git a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py
index 59bab27c48c..23d7070cc80 100644
--- a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py
+++ b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py
@@ -85,7 +85,7 @@ direct_register_custom_op(
class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
def get_state_dtype(
self,
- ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
+ ) -> tuple[torch.dtype, torch.dtype]:
if self.model_config is None or self.cache_config is None:
raise ValueError("model_config and cache_config must be set")
return MambaStateDtypeCalculator.kda_state_dtype(
@@ -94,7 +94,7 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
def get_state_shape(
self,
- ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ ) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.kda_state_shape(
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
)
@@ -300,13 +300,13 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
g1 = g1[:, :num_actual_tokens]
beta = beta[:, :num_actual_tokens]
- (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
+ (conv_state, recurrent_state) = constant_caches
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
if not is_conv_state_dim_first():
- conv_state_q = conv_state_q.transpose(-1, -2)
- conv_state_k = conv_state_k.transpose(-1, -2)
- conv_state_v = conv_state_v.transpose(-1, -2)
+ conv_state = conv_state.transpose(-1, -2)
+
+ conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2)
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py
index c1fd81e40e3..0c86c787917 100644
--- a/vllm/model_executor/layers/mamba/mamba_utils.py
+++ b/vllm/model_executor/layers/mamba/mamba_utils.py
@@ -120,9 +120,9 @@ class MambaStateDtypeCalculator:
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
- ):
+ ) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
- return (state_dtype, state_dtype, state_dtype, torch.float32)
+ return (state_dtype, torch.float32)
class MambaStateShapeCalculator:
@@ -243,7 +243,7 @@ class MambaStateShapeCalculator:
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
- ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
+ ) -> tuple[tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
@@ -252,19 +252,12 @@ class MambaStateShapeCalculator:
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
+ conv_dim = proj_size + 2 * proj_k_size
conv_state_shape = cls._orient_conv_shape(
- divide(proj_size, tp_world_size), conv_kernel_size - 1
- )
- conv_state_k_shape = cls._orient_conv_shape(
- divide(proj_k_size, tp_world_size), conv_kernel_size - 1
+ divide(conv_dim, tp_world_size), conv_kernel_size - 1
)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
- return (
- conv_state_shape,
- conv_state_k_shape,
- conv_state_k_shape,
- recurrent_state_shape,
- )
+ return (conv_state_shape, recurrent_state_shape)
@dataclass
@@ -365,9 +358,4 @@ class MambaStateCopyFuncCalculator:
@classmethod
def kda_state_copy_func(cls):
- return (
- get_conv_copy_spec,
- get_conv_copy_spec,
- get_conv_copy_spec,
- get_temporal_copy_spec,
- )
+ return (get_conv_copy_spec, get_temporal_copy_spec)
diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py
index 344ddd8abd2..5b911114d38 100644
--- a/vllm/model_executor/layers/quantization/base_config.py
+++ b/vllm/model_executor/layers/quantization/base_config.py
@@ -162,7 +162,13 @@ class QuantizationConfig(ABC):
"""
raise NotImplementedError
- def get_cache_scale(self, name: str) -> str | None:
+ def get_cache_scale_mapper(self) -> "WeightsMapper | None":
+ """Mapping from checkpoint KV-cache scale names to vLLM scale names.
+
+ Returning a mapper here causes `AutoWeightsLoader` to apply it to the
+ weight stream automatically; individual model `load_weights` methods
+ do not need to know about KV-cache scales.
+ """
return None
def apply_vllm_mapper( # noqa: B027
@@ -172,8 +178,9 @@ class QuantizationConfig(ABC):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
- :param hf_to_vllm_mapper: maps from hf model structure (the assumed
- structure of the qconfig) to vllm model structure
+ Args:
+ hf_to_vllm_mapper: maps from hf model structure (the assumed
+ structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index b59e12e8e1b..f4bd57e10e6 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -267,8 +267,11 @@ class CompressedTensorsConfig(QuantizationConfig):
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
- :param config: The `quantization_config` dictionary from config.json
- :return: A tuple with two elements
+ Args:
+ config: The `quantization_config` dictionary from config.json
+
+ Returns:
+ A tuple with two elements
1. A dictionary mapping target layer names to their corresponding
sparsity_config
2. A list of layer names to ignore for sparsity
@@ -296,8 +299,11 @@ class CompressedTensorsConfig(QuantizationConfig):
cls, config: dict[str, Any]
) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
- :param config: The `quantization_config` dictionary from config.json
- :return: A dictionary mapping target layer names to their corresponding
+ Args:
+ config: The `quantization_config` dictionary from config.json
+
+ Returns:
+ A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
@@ -967,7 +973,9 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
- :param kv_cache_scheme: the compressed-tensors kv cache scheme
+
+ Args:
+ kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
index 731cba1ba2a..78419a0dd98 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
@@ -38,11 +38,11 @@ class CompressedTensorsScheme(ABC):
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
- :param layer: torch.nn.Module with the registered weights and
- other parameters relevant to the particular scheme.
- :param x: input to the layer
- :param bias: bias parameter
-
+ Args:
+ layer: torch.nn.Module with the registered weights and
+ other parameters relevant to the particular scheme.
+ x: input to the layer
+ bias: bias parameter
"""
raise NotImplementedError()
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
index def4797b139..afb899cd6d7 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
@@ -133,12 +133,11 @@ def find_matched_target(
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
- :param layer_name: layer name
- :param module: torch.nn.Module
- :param targets: list of targets to match the layer against
- :param fused_mapping: map from fused layer names to its components
- :param fused_strategy: either "all" or "any". If using "all", fused
- layers match if "all" of its components match
+ Args:
+ layer_name: layer name
+ module: torch.nn.Module
+ targets: list of targets to match the layer against
+ fused_mapping: map from fused layer names to its components
"""
if layer_name is None:
@@ -161,9 +160,10 @@ def _find_first_match(
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
- :param value: string to compare the list of targets against
- :param targets: list of targets to match the layer against
- :param check_contains: whether or not to do a substring match
+ Args:
+ value: string to compare the list of targets against
+ targets: list of targets to match the layer against
+ check_contains: whether or not to do a substring match
"""
for target in targets:
@@ -205,9 +205,10 @@ def _match_fused_layer(
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
- :param layer_name: layer name
- :param target_layers: list of targets to match the layer against
- :param fused_mapping: map from fused layer names to its components
+ Args:
+ layer_name: layer name
+ target_layers: list of targets to match the layer against
+ fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 41e2e19785c..a6461900d13 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -207,25 +207,18 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
- def get_cache_scale(self, name: str) -> str | None:
- """
- Check whether the param name matches the format for k/v cache scales
- in compressed-tensors. If this is the case, return its equivalent
- param name expected by vLLM
+ def get_cache_scale_mapper(self) -> "WeightsMapper":
+ """Map compressed-tensors KV-cache scale names to vLLM names."""
+ from vllm.model_executor.models.utils import WeightsMapper
- :param name: param name
- :return: matching param name for KV cache scale in vLLM
- """
- if name.endswith(".output_scale") and ".k_proj" in name:
- return name.replace(".k_proj.output_scale", ".attn.k_scale")
- if name.endswith(".output_scale") and ".v_proj" in name:
- return name.replace(".v_proj.output_scale", ".attn.v_scale")
- if name.endswith(".output_scale") and ".q_proj" in name:
- return name.replace(".q_proj.output_scale", ".attn.q_scale")
- if name.endswith("self_attn.prob_output_scale"):
- return name.replace(".prob_output_scale", ".attn.prob_scale")
- # If no matches, return None
- return None
+ return WeightsMapper(
+ orig_to_new_suffix={
+ ".k_proj.output_scale": ".attn.k_scale",
+ ".v_proj.output_scale": ".attn.v_scale",
+ ".q_proj.output_scale": ".attn.q_scale",
+ ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
+ }
+ )
class CopyNumelCounter(TorchDispatchMode):
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index dca49d7ed97..7458b70ea81 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -114,8 +114,9 @@ class GGUFConfig(QuantizationConfig):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
- :param hf_to_vllm_mapper: maps from hf model structure (the assumed
- structure of the qconfig) to vllm model structure
+ Args:
+ hf_to_vllm_mapper: maps from hf model structure (the assumed
+ structure of the qconfig) to vllm model structure
"""
if self.unquantized_modules is not None:
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py
index d7fa6cf2633..e8810919c20 100644
--- a/vllm/model_executor/layers/quantization/input_quant_fp8.py
+++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py
@@ -46,16 +46,17 @@ class QuantFP8(CustomOp):
compile_native: bool = True,
):
"""
- :param static: static or dynamic quantization
- :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
- PER_CHANNEL, or arbitrary block size)
- :param num_token_padding: Pad the token dimension of output to this
- size
- :param tma_aligned_scales: For group quantization, output scales in
- TMA-aligned layout
- :param column_major_scales: For group quantization, output scales in
- column major format
- :param compile_native: Manually compile forward_native if compile mode > None
+ Args:
+ static: static or dynamic quantization
+ group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
+ PER_CHANNEL, or arbitrary block size)
+ num_token_padding: Pad the token dimension of output to this
+ size
+ tma_aligned_scales: For group quantization, output scales in
+ TMA-aligned layout
+ column_major_scales: For group quantization, output scales in
+ column major format
+ compile_native: Manually compile forward_native if compile mode > None
"""
super().__init__(compile_native=compile_native)
self.static = static
diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py
index 726ac2232af..100632686b0 100644
--- a/vllm/model_executor/layers/quantization/kv_cache.py
+++ b/vllm/model_executor/layers/quantization/kv_cache.py
@@ -15,6 +15,30 @@ from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
logger = init_logger(__name__)
+class KVCacheScaleParameter(torch.nn.Parameter):
+ """Scalar parameter for KV-cache scales.
+
+ Initialized to -1.0 (an invalid sentinel) so call sites just write
+ `KVCacheScaleParameter()`. The `weight_loader` accepts shape `()` or
+ `(1,)` and rejects anything else — per-head scales go through a separate
+ path (compressed-tensors' `_tp_aware_loader`), not this one. Per-instance
+ overrides still work because instance attribute assignment shadows this
+ class-level loader.
+ """
+
+ def __new__(cls) -> "KVCacheScaleParameter":
+ return super().__new__(cls, torch.tensor(-1.0), requires_grad=False)
+
+ @staticmethod
+ def weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
+ if loaded_weight.numel() != 1:
+ raise ValueError(
+ f"KV-cache scale expects a scalar weight, got shape "
+ f"{tuple(loaded_weight.shape)}"
+ )
+ param.data.copy_(loaded_weight.reshape(()))
+
+
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
@@ -23,7 +47,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
- :param quant_config: the appropriate QuantizationConfig
+ Args:
+ quant_config: the appropriate QuantizationConfig
"""
def __init__(self, quant_config: QuantizationConfig):
@@ -37,11 +62,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the Q and KV cache scales to -1.0, an invalid value.
# If the q and k/v_scales appear in the checkpoint, it will be
# overwritten when loading weights.
- layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
- layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
- layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
+ layer.q_scale = KVCacheScaleParameter()
+ layer.k_scale = KVCacheScaleParameter()
+ layer.v_scale = KVCacheScaleParameter()
# Initialize P = softmax(QK^T) scales
- layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
+ layer.prob_scale = KVCacheScaleParameter()
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py
index d1f7a169ee7..424fdf2fba0 100644
--- a/vllm/model_executor/layers/quantization/quark/quark.py
+++ b/vllm/model_executor/layers/quantization/quark/quark.py
@@ -87,8 +87,9 @@ class QuarkConfig(QuantizationConfig):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
- :param hf_to_vllm_mapper: maps from hf model structure (the assumed
- structure of the qconfig) to vllm model structure
+ Args:
+ hf_to_vllm_mapper: maps from hf model structure (the assumed
+ structure of the qconfig) to vllm model structure
"""
quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {}
@@ -646,26 +647,16 @@ class QuarkConfig(QuantizationConfig):
return scheme
- def get_cache_scale(self, name: str) -> str | None:
- """
- Check whether the param name matches the format for k/v cache scales
- in quark. If this is the case, return its equivalent param name
- expected by vLLM
-
- :param name: param name
- :return: matching param name for KV cache scale in vLLM
- """
- if name.endswith(".output_scale") and ".k_proj" in name:
- return name.replace(".k_proj.output_scale", ".attn.k_scale")
- if name.endswith(".output_scale") and ".v_proj" in name:
- return name.replace(".v_proj.output_scale", ".attn.v_scale")
- if name.endswith(".output_scale") and ".q_proj" in name:
- return name.replace(".q_proj.output_scale", ".attn.q_scale")
- if name.endswith("self_attn.prob_output_scale"):
- return name.replace(".prob_output_scale", ".attn.prob_scale")
-
- # If no matches, return None
- return None
+ def get_cache_scale_mapper(self) -> "WeightsMapper":
+ """Map Quark KV-cache scale names to vLLM names."""
+ return WeightsMapper(
+ orig_to_new_suffix={
+ ".k_proj.output_scale": ".attn.k_scale",
+ ".v_proj.output_scale": ".attn.v_scale",
+ ".q_proj.output_scale": ".attn.q_scale",
+ ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
+ }
+ )
class QuarkLinearMethod(LinearMethodBase):
@@ -734,7 +725,9 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
- :param kv_cache_config: the quark kv cache scheme
+
+ Args:
+ kv_cache_config: the quark kv cache scheme
"""
if kv_cache_config is None:
return
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
index 412a07a85fe..6f8db9ea57d 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@@ -38,11 +38,11 @@ class QuarkScheme(ABC):
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
- :param layer: torch.nn.Module with the registered weights and
- other parameters relevant to the particular scheme.
- :param x: input to the layer
- :param bias: bias parameter
-
+ Args:
+ layer: torch.nn.Module with the registered weights and
+ other parameters relevant to the particular scheme.
+ x: input to the layer
+ bias: bias parameter
"""
raise NotImplementedError
diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
index 7362abcc8fb..bfaf81ad007 100644
--- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
@@ -251,7 +251,7 @@ class DeepseekV4ScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
- device=current_platform.device_type,
+ device=inv_freq.device,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py
index 40dd6dc9f39..6cf1c19cba4 100644
--- a/vllm/model_executor/model_loader/reload/layerwise.py
+++ b/vllm/model_executor/model_loader/reload/layerwise.py
@@ -123,7 +123,8 @@ def initialize_online_processing(layer: torch.nn.Module):
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading
- :param layer: layer whose parameter weight loaders will be wrapped
+ Args:
+ layer: layer whose parameter weight loaders will be wrapped
"""
info = get_layerwise_info(layer)
@@ -222,8 +223,9 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
- :param model: model to finalize processing for
- :param model_config: config needed for applying processing to attention layers
+ Args:
+ model: model to finalize processing for
+ model_config: config needed for applying processing to attention layers
"""
if hasattr(model, "_original_do_torchao_reload"):
model._do_torchao_reload = model._original_do_torchao_reload
diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py
index 397a458cbdd..283a98de284 100644
--- a/vllm/model_executor/model_loader/reload/meta.py
+++ b/vllm/model_executor/model_loader/reload/meta.py
@@ -175,9 +175,12 @@ def get_numel_loaded(
"""
Determine how many elements would be loaded by a weight loader call.
- :param weight loader: used to load weights
- :param args: bound arguments to weight loader
- :return: number of elements loaded by the weight loader, the return value of the
+ Args:
+ weight_loader: used to load weights
+ args: bound arguments to weight loader
+
+ Returns:
+ number of elements loaded by the weight loader, the return value of the
weight loader
"""
with CopyCounter() as counter:
diff --git a/vllm/model_executor/model_loader/reload/sanitize.py b/vllm/model_executor/model_loader/reload/sanitize.py
index 2a6dc7182d0..21c47a2257f 100644
--- a/vllm/model_executor/model_loader/reload/sanitize.py
+++ b/vllm/model_executor/model_loader/reload/sanitize.py
@@ -20,9 +20,12 @@ def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.T
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
- :param tensor: tensor to be sanitized
- :param layer: layer whose references should be removed
- :return: sanitized tensor
+ Args:
+ tensor: tensor to be sanitized
+ layer: layer whose references should be removed
+
+ Returns:
+ sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer:
@@ -38,10 +41,12 @@ def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Te
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
- :param tensor: tensor to be sanitized
- :param layer: layer whose references should be removed
- :return: sanitized tensor
+ Args:
+ tensor: tensor to be sanitized
+ layer: layer whose references should be removed
+ Returns:
+ sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
diff --git a/vllm/model_executor/model_loader/reload/utils.py b/vllm/model_executor/model_loader/reload/utils.py
index 7a3d6873e10..f0078d0f9d8 100644
--- a/vllm/model_executor/model_loader/reload/utils.py
+++ b/vllm/model_executor/model_loader/reload/utils.py
@@ -49,8 +49,11 @@ def has_device_tensors(bound_args: BoundArguments) -> bool:
"""
Return True if the loaded weights exist on an accelerator device
- :param bound_args: args to load weights
- :return: True if weights are on accelerator device
+ Args:
+ bound_args: args to load weights
+
+ Returns:
+ True if weights are on accelerator device
"""
return any(
isinstance(value, torch.Tensor) and value.device.type not in ("meta", "cpu")
@@ -62,8 +65,11 @@ def get_info_size(info: LayerReloadingInfo) -> int:
"""
Calculate the number of bytes used by loaded weights for a given layer
- :param info: layerwise info to get size of
- :return: number of bytes used by loaded weights
+ Args:
+ info: layerwise info to get size of
+
+ Returns:
+ number of bytes used by loaded weights
"""
return sum(
value.nbytes
diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py
index 736b2134604..008abb6fdfe 100644
--- a/vllm/model_executor/model_loader/tensorizer.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -687,7 +687,7 @@ def serialize_vllm_model(
serializer = TensorSerializer(
stream,
encryption=encryption_params,
- **tensorizer_config.serialization_kwargs,
+ **(tensorizer_config.serialization_kwargs or {}),
)
serializer.write_module(model)
serializer.close()
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index cbb191ebb62..76a83898989 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -1541,6 +1541,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
+ # Already in vLLM's expected form (e.g. weights pre-renamed by a
+ # `WeightsMapper` from the quant config). Skip the regex remap, which
+ # would otherwise double-apply the `.attn` prefix and drop the weight.
+ if name in params_dict:
+ return name
if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py
index 5905a198b28..0711fb03f84 100644
--- a/vllm/model_executor/models/apertus.py
+++ b/vllm/model_executor/models/apertus.py
@@ -430,18 +430,6 @@ class ApertusModel(nn.Module, EagleModelMixin):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py
index eb8c3e3f65e..d25c954fc19 100644
--- a/vllm/model_executor/models/arcee.py
+++ b/vllm/model_executor/models/arcee.py
@@ -293,18 +293,6 @@ class ArceeModel(nn.Module, EagleModelMixin):
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if "scale" in name or "zero_point" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py
index 55bc64cd94a..e453e6e6cf6 100644
--- a/vllm/model_executor/models/aria.py
+++ b/vllm/model_executor/models/aria.py
@@ -363,18 +363,6 @@ class AriaTextModel(LlamaModel, SupportsQuant):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/cohere2_moe.py b/vllm/model_executor/models/cohere2_moe.py
index aa8adff188f..16e299ea27a 100644
--- a/vllm/model_executor/models/cohere2_moe.py
+++ b/vllm/model_executor/models/cohere2_moe.py
@@ -464,18 +464,6 @@ class Cohere2MoeModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/cohere_eagle.py b/vllm/model_executor/models/cohere_eagle.py
index 5c22d6e34dd..7b57c739ffe 100644
--- a/vllm/model_executor/models/cohere_eagle.py
+++ b/vllm/model_executor/models/cohere_eagle.py
@@ -150,18 +150,6 @@ class CohereEagleModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index e73dfb1f01e..317269ec3b6 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -352,19 +352,6 @@ class CohereModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 6c798bf2f36..32f4f15c7a3 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -394,19 +394,6 @@ class DbrxModel(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
diff --git a/vllm/model_executor/models/deepseek_eagle3.py b/vllm/model_executor/models/deepseek_eagle3.py
index 9b96cdec830..dc153ac9e0b 100644
--- a/vllm/model_executor/models/deepseek_eagle3.py
+++ b/vllm/model_executor/models/deepseek_eagle3.py
@@ -284,19 +284,6 @@ class DeepseekV2Eagle3Model(nn.Module):
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
- # Handle kv cache quantization scales
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index b633fd28508..dca05f72c69 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -391,18 +391,6 @@ class ExaoneModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py
index 04708de93d3..e38dbb5ee29 100644
--- a/vllm/model_executor/models/exaone4.py
+++ b/vllm/model_executor/models/exaone4.py
@@ -389,18 +389,6 @@ class Exaone4Model(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/exaone_moe.py b/vllm/model_executor/models/exaone_moe.py
index 80b7e0957e8..3373983f5c9 100644
--- a/vllm/model_executor/models/exaone_moe.py
+++ b/vllm/model_executor/models/exaone_moe.py
@@ -374,18 +374,6 @@ class ExaoneMoeModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index 425ecc65195..733eb3ed3c1 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -328,16 +328,6 @@ class Gemma2Model(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache scales for compressed-tensors quantization
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
index f61f7c6f780..7bae2b1a5e7 100644
--- a/vllm/model_executor/models/gemma3.py
+++ b/vllm/model_executor/models/gemma3.py
@@ -386,17 +386,6 @@ class Gemma3Model(nn.Module):
):
loaded_weight -= 1
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache scales for compressed-tensors quantization
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
# Check if this is a scale parameter that needs remapping first
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
index 770424ba0fd..ad8b21d86b4 100644
--- a/vllm/model_executor/models/gemma3n.py
+++ b/vllm/model_executor/models/gemma3n.py
@@ -1056,16 +1056,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
):
name = f"self_decoder.{name}"
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache scales for compressed-tensors quantization
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py
index 5d0e3efe2e1..d14a767df5f 100644
--- a/vllm/model_executor/models/gemma4.py
+++ b/vllm/model_executor/models/gemma4.py
@@ -1409,16 +1409,6 @@ class Gemma4Model(nn.Module, EagleModelMixin):
params_dict.update(dict(self.named_buffers()))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
diff --git a/vllm/model_executor/models/gemma4_unified.py b/vllm/model_executor/models/gemma4_unified.py
index 66fc914dc75..9cc0710c4d0 100644
--- a/vllm/model_executor/models/gemma4_unified.py
+++ b/vllm/model_executor/models/gemma4_unified.py
@@ -80,7 +80,7 @@ class Gemma4UnifiedVisionEmbedder(nn.Module):
Pipeline: raw patches → LN₁ → Dense → LN₂ → +factorized_posemb → LN₃.
"""
- def __init__(self, config, quant_config=None):
+ def __init__(self, config, quant_config=None, prefix=""):
super().__init__()
patch_dim = config.model_patch_size**2 * 3
mm_embed_dim = config.mm_embed_dim
@@ -91,6 +91,7 @@ class Gemma4UnifiedVisionEmbedder(nn.Module):
mm_embed_dim,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.patch_dense",
gather_output=True,
)
self.patch_ln2 = nn.LayerNorm(mm_embed_dim)
@@ -267,6 +268,7 @@ class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration):
Gemma4UnifiedVisionEmbedder(
config.vision_config,
quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "vision_embedder"),
)
if config.vision_config is not None
else None
diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py
index 89447927d5c..4587a692766 100644
--- a/vllm/model_executor/models/glm4.py
+++ b/vllm/model_executor/models/glm4.py
@@ -258,18 +258,6 @@ class Glm4Model(LlamaModel):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/glm_ocr_mtp.py b/vllm/model_executor/models/glm_ocr_mtp.py
index 34e602bb669..3d283c101ca 100644
--- a/vllm/model_executor/models/glm_ocr_mtp.py
+++ b/vllm/model_executor/models/glm_ocr_mtp.py
@@ -166,6 +166,10 @@ class GlmOcrMTP(nn.Module, SupportsPP):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ if self.quant_config is not None and (
+ cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
+ ):
+ weights = cache_scale_mapper.apply(weights)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -189,19 +193,6 @@ class GlmOcrMTP(nn.Module, SupportsPP):
name = self._rewrite_spec_layer_name(spec_layer, name)
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index c29103c6d52..30da9b4dea2 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -254,19 +254,6 @@ class GPTJModel(nn.Module):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index d12db96c5d4..920d43392cf 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -635,52 +635,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)
- def kv_cache_scale_loader(
- quant_config: QuantizationConfig,
- name: str,
- params_dict: dict[str, typing.Any],
- weight: torch.Tensor,
- default_weight_loader: Callable[..., None],
- loaded_params: set[str],
- ) -> tuple[bool, set[str]]:
- """
- Load KV cache output scales.
- Returns:
- Tuple of (bool, set):
- - bool: True if KV-cache scale was loaded into loaded_params
- - set: Updated set of loaded_params if True else the original set
- """
- # load explicit cached KV output scale from quant_config
- if quant_config is not None and (
- scale_name := quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- if weight.numel() != 1:
- raise ValueError(
- f"KV cache scale '{scale_name}' is expected to be a "
- f"scalar, but got a tensor of shape {weight.shape}."
- )
- # Ensure weight is a scalar before passing to loader.
- weight_loader(param, weight.flatten()[0])
- loaded_params.add(scale_name)
- return True, loaded_params
-
- return False, loaded_params
-
- load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
- self.quant_config,
- name,
- params_dict,
- loaded_weight,
- default_weight_loader,
- loaded_params,
- )
- if load_kv_cache_scale_completed:
- continue
-
if (
all(key in name for key in ["input_scale", "mlp.experts"])
and expert_id is not None
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index 4b486ede443..2adc29f8d25 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -334,18 +334,6 @@ class GraniteModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index e3585a6dd74..5909604bd54 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -365,19 +365,6 @@ class GraniteMoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py
index 1ab069e3ba3..4f861fc5016 100644
--- a/vllm/model_executor/models/granitemoehybrid.py
+++ b/vllm/model_executor/models/granitemoehybrid.py
@@ -495,18 +495,6 @@ class GraniteMoeHybridModel(nn.Module):
if "A_log" in n:
n = n.replace("A_log", "A")
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(n)
- ):
- # Loading kv cache quantization scales
- loaded_weight = p
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- _load(scale_name, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if _load_quant_expert(n, p):
continue
diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py
index f06122a7fd1..3fc3d1a2d2c 100644
--- a/vllm/model_executor/models/grok1.py
+++ b/vllm/model_executor/models/grok1.py
@@ -548,18 +548,6 @@ class Grok1Model(nn.Module):
for old_pattern, new_pattern in self.weight_name_remapping.items():
if old_pattern in name:
name = name.replace(old_pattern, new_pattern)
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py
index b900c0ed83e..ec3cfbd017b 100644
--- a/vllm/model_executor/models/hunyuan_v1.py
+++ b/vllm/model_executor/models/hunyuan_v1.py
@@ -771,15 +771,6 @@ class HunYuanModel(nn.Module, EagleModelMixin):
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache scales for compressed-tensors quantization
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
diff --git a/vllm/model_executor/models/hy_v3.py b/vllm/model_executor/models/hy_v3.py
index bfff84b8049..6142f2c086a 100644
--- a/vllm/model_executor/models/hy_v3.py
+++ b/vllm/model_executor/models/hy_v3.py
@@ -531,17 +531,6 @@ class HYV3Model(nn.Module):
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/hy_v3_mtp.py b/vllm/model_executor/models/hy_v3_mtp.py
index 8594a38c3ab..37ace1dada3 100644
--- a/vllm/model_executor/models/hy_v3_mtp.py
+++ b/vllm/model_executor/models/hy_v3_mtp.py
@@ -264,6 +264,10 @@ class HYV3MTP(nn.Module):
return torch.concat((q, k, v))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
+ if self.quant_config is not None and (
+ cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
+ ):
+ weights = cache_scale_mapper.apply(weights)
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -336,14 +340,6 @@ class HYV3MTP(nn.Module):
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
diff --git a/vllm/model_executor/models/hyperclovax.py b/vllm/model_executor/models/hyperclovax.py
index 3176c428413..2f54f78e758 100644
--- a/vllm/model_executor/models/hyperclovax.py
+++ b/vllm/model_executor/models/hyperclovax.py
@@ -395,18 +395,6 @@ class HyperCLOVAXModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/iquest_loopcoder.py b/vllm/model_executor/models/iquest_loopcoder.py
index 24c004ff4c2..3755cba5d1a 100644
--- a/vllm/model_executor/models/iquest_loopcoder.py
+++ b/vllm/model_executor/models/iquest_loopcoder.py
@@ -476,18 +476,6 @@ class IQuestLoopCoderModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "gate_projections" in name:
continue
diff --git a/vllm/model_executor/models/jais2.py b/vllm/model_executor/models/jais2.py
index 4e03eb12ee4..dafa0f03ae9 100644
--- a/vllm/model_executor/models/jais2.py
+++ b/vllm/model_executor/models/jais2.py
@@ -386,16 +386,6 @@ class Jais2Model(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache scales for compressed-tensors quantization
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 86d7edc25f9..8a0b7ceea5f 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -790,21 +790,6 @@ class KeyeSiglipVisionModel(nn.Module):
continue
if "head.mlp" in name or "head.probe" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(
- param,
- "weight_loader",
- default_weight_loader,
- )
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for (
param_name,
weight_name,
diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py
index a891950fa57..307b24ac112 100644
--- a/vllm/model_executor/models/kimi_linear.py
+++ b/vllm/model_executor/models/kimi_linear.py
@@ -600,7 +600,7 @@ class KimiLinearForCausalLM(
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
- ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
+ ) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@@ -608,7 +608,7 @@ class KimiLinearForCausalLM(
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
- ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ ) -> tuple[tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
@@ -628,9 +628,7 @@ class KimiLinearForCausalLM(
@classmethod
def get_mamba_state_copy_func(
cls,
- ) -> tuple[
- MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
- ]:
+ ) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
diff --git a/vllm/model_executor/models/laguna.py b/vllm/model_executor/models/laguna.py
index f79f6097c61..5572481565b 100644
--- a/vllm/model_executor/models/laguna.py
+++ b/vllm/model_executor/models/laguna.py
@@ -724,20 +724,6 @@ class LagunaModel(nn.Module, EagleModelMixin):
loaded_params.add(name)
continue
- # Handle KV cache quantization scales
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- assert loaded_weight.numel() == 1, (
- f"KV scale numel {loaded_weight.numel()} != 1"
- )
- loaded_weight = loaded_weight.squeeze()
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
# Handle stacked params (QKV, gate_up for
# non-expert layers and shared_expert)
for param_name, weight_name, shard_id in stacked_params_mapping:
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 3c797d05e93..cf59ccdf750 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -451,18 +451,6 @@ class LlamaModel(nn.Module, EagleModelMixin):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index bfcb72a6a74..07ca2714ed2 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -588,21 +588,6 @@ class Llama4Model(LlamaModel):
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
- # If kv cache quantization scales exist and the weight name
- # corresponds to one of the kv cache quantization scales, load
- # them.
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
# Iterate over stacked_params_mapping to check if the current weight
# is one of the stacked parameters. If so, load the weight with the
# corresponding shard id. Note that MoE weights are handled
@@ -625,9 +610,9 @@ class Llama4Model(LlamaModel):
if is_pp_missing_parameter(name, self):
continue
- # Remap kv cache scale names for ModelOpt checkpoints.
- # TODO: ModelOpt should implement get_cache_scale() such that
- # kv cache scale name remapping can be done there.
+ # Remap kv cache scale names for any checkpoint format the
+ # quant config's `get_cache_scale_mapper` does not cover
+ # (idempotent for names already renamed by the mapper).
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py
index 585c8f6dbd2..14842a75fea 100644
--- a/vllm/model_executor/models/llama_eagle.py
+++ b/vllm/model_executor/models/llama_eagle.py
@@ -127,19 +127,6 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- # Handle kv cache quantization scales
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 9fd6652aa24..bb1bbb85537 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -267,19 +267,6 @@ class LlamaModel(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
- # Handle kv cache quantization scales
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py
index a7699f0d598..4f67d468ace 100644
--- a/vllm/model_executor/models/mimo.py
+++ b/vllm/model_executor/models/mimo.py
@@ -104,18 +104,6 @@ class MiMoModel(Qwen2Model):
continue
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/mimo_v2.py b/vllm/model_executor/models/mimo_v2.py
index 3f466162649..7c6d5363c0a 100644
--- a/vllm/model_executor/models/mimo_v2.py
+++ b/vllm/model_executor/models/mimo_v2.py
@@ -555,22 +555,6 @@ class MiMoV2Model(nn.Module):
if "mtp" in name:
continue
- if self.quant_config is not None:
- cache_scale_name = self.quant_config.get_cache_scale(name)
- if cache_scale_name is not None and cache_scale_name in params_dict:
- param = params_dict[cache_scale_name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
-
- kv_scale = loaded_weight
- if kv_scale.dim() > 0 and kv_scale.numel() > 1:
- kv_scale = kv_scale.view(-1)[0]
-
- weight_loader(param, kv_scale)
- loaded_params.add(cache_scale_name)
- continue
-
expert_matched = False
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
if weight_name not in name:
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index cbfc254dda3..53c1c87cfce 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -388,19 +388,6 @@ class MixtralModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index 15d43a9ddf9..7b2e6b93b27 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -375,18 +375,6 @@ class NemotronModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py
index f2f3811c064..b974a3eb085 100644
--- a/vllm/model_executor/models/nemotron_nas.py
+++ b/vllm/model_executor/models/nemotron_nas.py
@@ -334,18 +334,6 @@ class DeciModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 4491a6a3ea1..541f60c2c40 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -277,7 +277,8 @@ class OlmoModel(nn.Module):
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
"""
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+ Args:
+ input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py
index 212140fe15e..ad04b258bde 100644
--- a/vllm/model_executor/models/olmo2.py
+++ b/vllm/model_executor/models/olmo2.py
@@ -314,7 +314,8 @@ class Olmo2Model(nn.Module):
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
"""
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+ Args:
+ input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py
index 56505ec7be2..503d4b5c834 100644
--- a/vllm/model_executor/models/ouro.py
+++ b/vllm/model_executor/models/ouro.py
@@ -390,18 +390,6 @@ class OuroModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py
index cd88009c739..0bf10e3ce77 100644
--- a/vllm/model_executor/models/paddleocr_vl.py
+++ b/vllm/model_executor/models/paddleocr_vl.py
@@ -944,21 +944,6 @@ class SiglipVisionModel(nn.Module):
continue
if "packing_position_embedding" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(
- param,
- "weight_loader",
- default_weight_loader,
- )
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for (
param_name,
weight_name,
diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py
index 5770420ce56..a49e8ce2e82 100644
--- a/vllm/model_executor/models/phimoe.py
+++ b/vllm/model_executor/models/phimoe.py
@@ -537,19 +537,6 @@ class PhiMoEModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index b83fedc70db..9c39c649708 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -439,18 +439,6 @@ class Qwen2Model(nn.Module, EagleModelMixin):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py
index 25f139f26cb..820260f795c 100644
--- a/vllm/model_executor/models/qwen3_dflash.py
+++ b/vllm/model_executor/models/qwen3_dflash.py
@@ -469,17 +469,6 @@ class DFlashQwen3Model(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index 4ec1be3367d..6980184cc8a 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -552,19 +552,6 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- assert loaded_weight.numel() == 1, (
- f"KV scale numel {loaded_weight.numel()} != 1"
- )
- loaded_weight = loaded_weight.squeeze()
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
diff --git a/vllm/model_executor/models/rnj1.py b/vllm/model_executor/models/rnj1.py
index f83577b7a39..68c3722e2bc 100644
--- a/vllm/model_executor/models/rnj1.py
+++ b/vllm/model_executor/models/rnj1.py
@@ -350,16 +350,6 @@ class Rnj1Model(nn.Module):
):
loaded_weight -= 1
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = loaded_weight[0]
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
-
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py
index d90174911fb..48147f7334e 100644
--- a/vllm/model_executor/models/seed_oss.py
+++ b/vllm/model_executor/models/seed_oss.py
@@ -376,18 +376,6 @@ class SeedOssModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index bff866d0d0c..454a0e97112 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -360,18 +360,6 @@ class SolarModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if self.quant_config is not None and (
- scale_name := self.quant_config.get_cache_scale(name)
- ):
- # Loading kv cache quantization scales
- param = params_dict[scale_name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- loaded_weight = (
- loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(scale_name)
- continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index 095d0e363d5..83d113415dc 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -52,6 +52,7 @@ class WeightsMapper:
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
"""Combine two `WeightsMapper`s by merging their mappings."""
return WeightsMapper(
+ orig_to_new_regex={**self.orig_to_new_regex, **other.orig_to_new_regex},
orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
@@ -343,6 +344,20 @@ class AutoWeightsLoader:
*,
mapper: WeightsMapper | None = None,
) -> set[str]:
+ # Many models store quant_config in the base model instead of the causal model.
+ # We look at the causal model's direct children for this reason.
+ modules = (self.module, *self.module.children())
+ iterator = (m.quant_config for m in modules if hasattr(m, "quant_config"))
+ quant_config = next(iterator, None)
+ cache_scale_mapper = (
+ quant_config.get_cache_scale_mapper() if quant_config is not None else None
+ )
+ if cache_scale_mapper is not None:
+ mapper = (
+ mapper | cache_scale_mapper
+ if mapper is not None
+ else cache_scale_mapper
+ )
if mapper is not None:
weights = mapper.apply(weights)
# filter out weights with first-prefix/substr to skip in name
diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py
index 4106672d501..7f96ceda09d 100644
--- a/vllm/model_executor/parameter.py
+++ b/vllm/model_executor/parameter.py
@@ -3,6 +3,7 @@
from collections.abc import Callable, Hashable
from fractions import Fraction
+from typing import Any
from weakref import WeakValueDictionary
import torch
@@ -42,10 +43,9 @@ class BasevLLMParameter(Parameter):
"""
Initialize the BasevLLMParameter
- :param data: torch tensor with the parameter data
- :param weight_loader: weight loader callable
-
- :returns: a torch.nn.parameter
+ Args:
+ data: torch tensor with the parameter data
+ weight_loader: weight loader callable
"""
# During weight loading, we often do something like:
@@ -445,15 +445,16 @@ class SharedWeightParameter(BasevLLMParameter):
"currently support tensor parallelism"
)
- def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
+ def add_partition(self, index: int, data_key: Hashable, *args: Any, **kwargs: Any):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
- :param index: index of partition to add
- :param data_key: hashable key used to key shared tensors
- :param *args: arguments for `torch.empty`
- :param **kwargs: keyword arguments for `torch.empty`
+ Args:
+ index: index of partition to add
+ data_key: hashable key used to key shared tensors
+ *args: arguments for `torch.empty`
+ **kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if data_key not in self.tensors_registry:
diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py
index fb724fbe2f1..c32be5768bf 100644
--- a/vllm/models/deepseek_v4/amd/model.py
+++ b/vllm/models/deepseek_v4/amd/model.py
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -45,11 +44,7 @@ from vllm.model_executor.models.utils import (
make_layers,
maybe_prefix,
)
-from vllm.models.deepseek_v4.attention import (
- DeepseekV4Indexer,
- DeepseekV4MLA,
-)
-from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
+from vllm.models.deepseek_v4.amd.rocm import DeepseekV4ROCMAiterMLAAttention
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
@@ -225,158 +220,6 @@ class DeepseekV4MoE(nn.Module):
return final_hidden_states.view(org_shape)
-class DeepseekV4Attention(nn.Module):
- def __init__(
- self,
- vllm_config: VllmConfig,
- prefix: str,
- topk_indices_buffer: torch.Tensor | None = None,
- aux_stream_list: list[torch.cuda.Stream] | None = None,
- ):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- layer_id = extract_layer_index(prefix)
-
- self.layer_id = layer_id
- self.hidden_size = config.hidden_size
- self.n_heads = config.num_attention_heads
- tp_size = get_tensor_model_parallel_world_size()
- assert self.n_heads % tp_size == 0
-
- self.n_local_heads = self.n_heads // tp_size
- self.q_lora_rank = config.q_lora_rank
- self.o_lora_rank = config.o_lora_rank
- self.head_dim = config.head_dim
- self.rope_head_dim = config.qk_rope_head_dim
- self.nope_head_dim = self.head_dim - self.rope_head_dim
- self.n_groups = config.o_groups
- self.n_local_groups = self.n_groups // tp_size
- self.window_size = config.sliding_window
- # NOTE(zyongye) Compress ratio can't be 0
- # we do this for because MTP layer is not included
- # in the compress ratio list
- if layer_id < config.num_hidden_layers:
- self.compress_ratio = max(1, config.compress_ratios[layer_id])
- else:
- self.compress_ratio = 1
- self.eps = config.rms_norm_eps
- self.max_position_embeddings = config.max_position_embeddings
-
- # Padded to min 64 heads for FlashMLA, initialized to -inf
- # (no sink effect). Weight loading fills the first n_local_heads slots.
- padded_heads = max(self.n_local_heads, 64)
- self.attn_sink = nn.Parameter(
- torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
- requires_grad=False,
- )
-
- self.fused_wqa_wkv = MergedColumnParallelLinear(
- self.hidden_size,
- [self.q_lora_rank, self.head_dim],
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.fused_wqa_wkv",
- disable_tp=True, # fused ReplicatedLinear
- )
- self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
- self.wq_b = ColumnParallelLinear(
- self.q_lora_rank,
- self.n_heads * self.head_dim,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wq_b",
- )
-
- self.kv_norm = RMSNorm(self.head_dim, self.eps)
- self.wo_a = ColumnParallelLinear(
- self.n_heads * self.head_dim // self.n_groups,
- self.n_groups * self.o_lora_rank,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wo_a",
- )
- self.wo_a.is_bmm = True
- self.wo_a.bmm_batch_size = self.n_local_groups
- self.wo_b = RowParallelLinear(
- self.n_groups * self.o_lora_rank,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wo_b",
- )
- self.softmax_scale = self.head_dim**-0.5
- self.scale_fmt = config.quantization_config["scale_fmt"]
-
- self.rope_parameters = config.rope_scaling
-
- # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
- self.rotary_emb = build_deepseek_v4_rope(
- config,
- head_dim=self.head_dim,
- rope_head_dim=self.rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- compress_ratio=self.compress_ratio,
- )
-
- self.indexer = None
- if self.compress_ratio == 4:
- # Only C4A uses sparse attention and hence has indexer.
- self.indexer = DeepseekV4Indexer(
- vllm_config,
- config=config,
- hidden_size=self.hidden_size,
- q_lora_rank=self.q_lora_rank,
- quant_config=quant_config,
- cache_config=vllm_config.cache_config,
- topk_indices_buffer=topk_indices_buffer,
- compress_ratio=self.compress_ratio,
- prefix=f"{prefix}.indexer",
- )
-
- self.mla_attn = DeepseekV4MLA(
- hidden_size=self.hidden_size,
- num_heads=self.n_local_heads,
- head_dim=self.head_dim,
- scale=self.softmax_scale,
- qk_nope_head_dim=self.nope_head_dim,
- qk_rope_head_dim=self.rope_head_dim,
- v_head_dim=self.head_dim,
- q_lora_rank=self.q_lora_rank,
- kv_lora_rank=self.head_dim,
- o_lora_rank=self.o_lora_rank,
- vllm_config=vllm_config,
- fused_wqa_wkv=self.fused_wqa_wkv,
- q_norm=self.q_norm,
- wq_b=self.wq_b,
- kv_norm=self.kv_norm,
- wo_a=self.wo_a,
- wo_b=self.wo_b,
- attn_sink=self.attn_sink,
- rotary_emb=self.rotary_emb,
- indexer=self.indexer,
- indexer_rotary_emb=self.rotary_emb,
- topk_indices_buffer=topk_indices_buffer,
- aux_stream_list=aux_stream_list,
- window_size=self.window_size,
- compress_ratio=self.compress_ratio,
- cache_config=vllm_config.cache_config,
- quant_config=quant_config,
- prefix=prefix,
- )
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- llama_4_scaling: torch.Tensor | None,
- ):
- return self.mla_attn(positions, hidden_states, llama_4_scaling)
-
-
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
@@ -395,7 +238,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
- self.attn = DeepseekV4Attention(
+ self.attn = DeepseekV4ROCMAiterMLAAttention(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
@@ -601,7 +444,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
- # DeepseekV4MLA.attn_gemm_parallel_execute
+ # DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
@@ -897,7 +740,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
- ".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py
index 7298f18365d..f7fb409af1b 100644
--- a/vllm/models/deepseek_v4/amd/rocm.py
+++ b/vllm/models/deepseek_v4/amd/rocm.py
@@ -2,15 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
-from typing import TYPE_CHECKING, cast
+from typing import cast
import torch
from vllm.forward_context import get_forward_context
+from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4FlashMLASparseBackend,
- DeepseekV4SparseMLAAttentionImpl,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
@@ -26,16 +26,12 @@ from vllm.v1.attention.backends.mla.sparse_swa import (
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
build_ragged_indices_from_dense,
+ rocm_inv_rope_einsum,
rocm_sparse_attn_decode,
rocm_sparse_attn_prefill,
)
from vllm.v1.worker.workspace import current_workspace_manager
-if TYPE_CHECKING:
- from vllm.models.deepseek_v4.attention import (
- DeepseekV4MLAAttention,
- )
-
def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
lengths = lengths.to(dtype=torch.int32).contiguous()
@@ -582,13 +578,9 @@ class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend):
def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]:
return DeepseekV4ROCMAiterMLASparseMetadataBuilder
- @staticmethod
- def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
- return DeepseekV4ROCMAiterMLASparseImpl
-
-class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
- """ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer."""
+class DeepseekV4ROCMAiterMLAAttention(DeepseekV4Attention):
+ """ROCm sparse MLA attention layer for DeepSeek V4."""
backend_cls = DeepseekV4ROCMAiterMLASparseBackend
@@ -596,10 +588,21 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
def get_padded_num_q_heads(cls, num_heads: int) -> int:
return num_heads
- @classmethod
- def forward_mqa( # type: ignore[override]
- cls,
- layer: "DeepseekV4MLAAttention",
+ def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ # ROCm BF16 reference wo_a path (inverse RoPE + einsum) + wo_b.
+ z = rocm_inv_rope_einsum(
+ self.rotary_emb,
+ o,
+ positions,
+ self.rope_head_dim,
+ self.n_local_groups,
+ self.o_lora_rank,
+ self.wo_a,
+ )
+ return self.wo_b(z.flatten(1))
+
+ def forward_mqa(
+ self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -619,16 +622,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Warmup dummy run: no real metadata. Reserve the same bf16
# gather workspace _forward_prefill would; the dequantize / topk
# / sparse_fwd kernels are skipped this step.
- swa_only = layer.compress_ratio <= 1
+ swa_only = self.compress_ratio <= 1
N = (
0
if swa_only
- else (layer.max_model_len + layer.compress_ratio - 1)
- // layer.compress_ratio
+ else (self.max_model_len + self.compress_ratio - 1)
+ // self.compress_ratio
)
- M = N + layer.window_size + layer.max_num_batched_tokens
+ M = N + self.window_size + self.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
- ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
+ ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return
@@ -636,25 +639,24 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert isinstance(attn_metadata, dict)
rocm_metadata = cast(
DeepseekV4ROCMAiterMLASparseMetadata | None,
- attn_metadata.get(layer.prefix),
+ attn_metadata.get(self.prefix),
)
swa_metadata = cast(
DeepseekV4ROCMAiterSparseSWAMetadata | None,
- attn_metadata.get(layer.swa_cache_layer.prefix),
+ attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
- swa_only = layer.compress_ratio <= 1
- self_kv_cache = layer.kv_cache if not swa_only else None
- swa_kv_cache = layer.swa_cache_layer.kv_cache
+ swa_only = self.compress_ratio <= 1
+ self_kv_cache = self.kv_cache if not swa_only else None
+ swa_kv_cache = self.swa_cache_layer.kv_cache
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
- cls._forward_prefill(
- layer=layer,
+ self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
@@ -664,8 +666,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata=swa_metadata,
)
if num_decodes > 0:
- cls._forward_decode(
- layer=layer,
+ self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
@@ -674,10 +675,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output[:num_decode_tokens],
)
- @classmethod
def _forward_decode(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata,
@@ -695,16 +694,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
- block_size = attn_metadata.block_size // layer.compress_ratio
+ block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
- if layer.compress_ratio == 4:
- assert layer.topk_indices_buffer is not None
+ if self.compress_ratio == 4:
+ assert self.topk_indices_buffer is not None
(
topk_ragged_indices,
topk_ragged_indptr,
topk_lens,
) = compute_global_topk_ragged_indices_and_indptr(
- layer.topk_indices_buffer[:num_decode_tokens],
+ self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
@@ -719,7 +718,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
rocm_sparse_attn_decode(
q=q,
kv_cache=kv_cache,
- swa_k_cache=layer.swa_cache_layer.kv_cache,
+ swa_k_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
@@ -729,18 +728,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_ragged_indptr=swa_metadata.decode_swa_ragged_indptr,
topk_ragged_indices=topk_ragged_indices,
topk_ragged_indptr=topk_ragged_indptr,
- attn_sink=layer.attn_sink,
- scale=layer.scale,
- head_dim=layer.head_dim,
- nope_head_dim=layer.nope_head_dim,
- rope_head_dim=layer.rope_head_dim,
+ attn_sink=self.attn_sink,
+ scale=self.scale,
+ head_dim=self.head_dim,
+ nope_head_dim=self.nope_head_dim,
+ rope_head_dim=self.rope_head_dim,
output=output,
)
- @classmethod
def _forward_prefill(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None,
@@ -768,34 +765,34 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
- if layer.compress_ratio == 4:
- assert layer.topk_indices_buffer is not None
- topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
+ if self.compress_ratio == 4:
+ assert self.topk_indices_buffer is not None
+ topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
assert topk_indices is not None
top_k = topk_indices.shape[-1]
- N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
+ N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
- assert layer.topk_indices_buffer is not None
- topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
+ assert self.topk_indices_buffer is not None
+ topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
- M = N + layer.window_size + layer.max_num_batched_tokens
- num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // (
- cls.PREFILL_CHUNK_SIZE
+ M = N + self.window_size + self.max_num_batched_tokens
+ num_chunks = (num_prefills + self.PREFILL_CHUNK_SIZE - 1) // (
+ self.PREFILL_CHUNK_SIZE
)
workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
- ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
+ ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
- chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE
- chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills)
+ chunk_start = chunk_idx * self.PREFILL_CHUNK_SIZE
+ chunk_end = min(chunk_start + self.PREFILL_CHUNK_SIZE, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
assert attn_metadata is not None
@@ -804,10 +801,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
- seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
+ seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
- block_size=attn_metadata.block_size // layer.compress_ratio,
+ block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
@@ -836,8 +833,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
- layer.window_size,
- layer.compress_ratio,
+ self.window_size,
+ self.compress_ratio,
top_k,
M,
N,
@@ -847,10 +844,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices,
topk_length=combined_lens,
- scale=layer.scale,
- head_dim=layer.head_dim,
- nope_head_dim=layer.nope_head_dim,
- rope_head_dim=layer.rope_head_dim,
- attn_sink=layer.attn_sink,
+ scale=self.scale,
+ head_dim=self.head_dim,
+ nope_head_dim=self.nope_head_dim,
+ rope_head_dim=self.rope_head_dim,
+ attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py
index 5f13d1bd8d0..29302584880 100644
--- a/vllm/models/deepseek_v4/attention.py
+++ b/vllm/models/deepseek_v4/attention.py
@@ -4,8 +4,9 @@
DeepseekV4 MLA Attention Layer
"""
+from abc import ABC, abstractmethod
from collections.abc import Callable
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, ClassVar, cast
import torch
import torch.nn as nn
@@ -15,16 +16,16 @@ from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
ReplicatedLinear,
+ RowParallelLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.models.deepseek_v4.common.ops import (
fused_indexer_q_rope_quant,
- fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
-from vllm.utils.deep_gemm import fp8_einsum
-from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
@@ -42,14 +43,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.quantization.input_quant_fp8 import (
- QuantFP8,
-)
-from vllm.model_executor.layers.quantization.utils.quant_utils import (
- GroupShape,
-)
+from vllm.model_executor.models.utils import extract_layer_index
+from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.compressor import DeepseekCompressor
-from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
maybe_execute_in_parallel,
@@ -62,187 +58,209 @@ from vllm.v1.attention.backends.mla.indexer import (
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
-if TYPE_CHECKING:
- from vllm.models.deepseek_v4.nvidia.flashmla import (
- DeepseekV4SparseMLAAttentionImpl,
- )
-
logger = init_logger(__name__)
-def _resolve_dsv4_backend(vllm_config: VllmConfig | None):
- """Return the explicitly-requested DSv4 sparse backend enum, or None."""
- if vllm_config is None:
- return None
- attn_config = getattr(vllm_config, "attention_config", None)
- return getattr(attn_config, "backend", None) if attn_config is not None else None
-
-
-def _select_v4_sparse_impl(
- vllm_config: VllmConfig | None = None,
-) -> "type[DeepseekV4SparseMLAAttentionImpl]":
- """Pick the V4 sparse MLA impl class.
-
- An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
- FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on
- NVIDIA, ROCm Aiter on AMD) is used.
- """
- from vllm.v1.attention.backends.registry import AttentionBackendEnum
-
- backend = _resolve_dsv4_backend(vllm_config)
- if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
- from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
- DeepseekV4FlashInferMLASparseImpl,
- )
-
- logger.info_once("Using FLASHINFER_MLA_SPARSE_DSV4 backend.")
- return DeepseekV4FlashInferMLASparseImpl
- if current_platform.is_rocm():
- from vllm.models.deepseek_v4.amd.rocm import (
- DeepseekV4ROCMAiterMLASparseImpl,
- )
-
- logger.info_once("Using ROCM_FLASHMLA_SPARSE_DSV4 backend.")
- return DeepseekV4ROCMAiterMLASparseImpl
- from vllm.models.deepseek_v4.nvidia.flashmla import (
- DeepseekV4FlashMLASparseImpl,
- )
-
- logger.info_once("Using FLASHMLA_SPARSE_DSV4 backend.")
- return DeepseekV4FlashMLASparseImpl
-
-
def _resolve_dsv4_kv_cache_dtype(
- backend,
+ use_flashmla_fp8_layout: bool,
kv_cache_dtype: str,
cache_config: CacheConfig | None,
) -> tuple[str, torch.dtype]:
- """Map ``(backend, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
+ """Map ``(layout, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``.
- FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8
- E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 /
- ``fp8_ds_mla``). For FlashMLA the canonical ``fp8_ds_mla`` string is
- written back onto ``cache_config`` so the page-size specs pick the 576B
- layout.
+ Both layouts are paged; they differ in the per-token block format. The
+ FlashMLA fp8 layout (FlashMLA / ROCm Aiter) is the ``fp8_ds_mla`` format:
+ UE8M0 block-scaled fp8 packed as ``uint8`` (the canonical ``fp8_ds_mla``
+ string is written back onto ``cache_config`` so the page-size specs pick
+ the 576B per-token slot). Otherwise (FlashInfer) each token's KV row is
+ stored in its plain element dtype — bf16 or per-tensor FP8 E4M3.
"""
- from vllm.v1.attention.backends.registry import AttentionBackendEnum
+ if use_flashmla_fp8_layout:
+ # fp8_ds_mla block format: UE8M0 block-scaled fp8 packed as uint8.
+ assert kv_cache_dtype.startswith("fp8"), (
+ f"DeepseekV4 FlashMLA fp8 layout only supports fp8 kv-cache, "
+ f"got {kv_cache_dtype}"
+ )
+ if kv_cache_dtype != "fp8_ds_mla":
+ if cache_config is not None:
+ cache_config.cache_dtype = "fp8_ds_mla"
+ kv_cache_dtype = "fp8_ds_mla"
+ logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
+ return kv_cache_dtype, torch.uint8
- if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4:
- if kv_cache_dtype.startswith("fp8"):
- return kv_cache_dtype, torch.float8_e4m3fn
- # auto / bfloat16 -> contiguous BF16 cache.
- return kv_cache_dtype, torch.bfloat16
-
- # FlashMLA (and ROCm Aiter): legacy UE8M0 paged uint8 cache.
- assert kv_cache_dtype.startswith("fp8"), (
- f"DeepseekV4 FlashMLA sparse backend only supports fp8 kv-cache, "
- f"got {kv_cache_dtype}"
- )
- if kv_cache_dtype != "fp8_ds_mla":
- if cache_config is not None:
- cache_config.cache_dtype = "fp8_ds_mla"
- kv_cache_dtype = "fp8_ds_mla"
- logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
- return kv_cache_dtype, torch.uint8
+ # Plain bf16 / per-tensor fp8 KV row (FlashInfer).
+ if kv_cache_dtype.startswith("fp8"):
+ return kv_cache_dtype, torch.float8_e4m3fn
+ # auto / bfloat16 -> plain bf16 KV row.
+ return kv_cache_dtype, torch.bfloat16
-class DeepseekV4MLA(nn.Module):
+class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
+ """DeepseekV4 MLA attention layer.
+
+ The platform-specific sparse-MLA forward (``forward_mqa`` /
+ ``get_padded_num_q_heads`` / ``_o_proj`` / ``backend_cls``) is provided by a
+ subclass — ``DeepseekV4FlashMLAAttention`` / ``DeepseekV4FlashInferMLAAttention``
+ (CUDA) or ``DeepseekV4ROCMAiterMLAAttention`` (ROCm) — selected by the
+ platform-specific deepseek_v4 model module. The base is never instantiated
+ directly.
+ """
+
+ # Provided by the platform subclass.
+ backend_cls: ClassVar[type[AttentionBackend]]
+ # KV-cache per-token block format (both layouts are paged). True (default)
+ # = FlashMLA / ROCm fp8_ds_mla (UE8M0 block-scaled fp8 packed as uint8);
+ # False = FlashInfer plain bf16 / per-tensor fp8 KV row.
+ use_flashmla_fp8_layout: ClassVar[bool] = True
+ # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
+ # workspace allocated in _forward_prefill and is also read by the dummy-run
+ # path to pre-reserve that workspace.
+ PREFILL_CHUNK_SIZE: ClassVar[int] = 4
+
+ @classmethod
+ @abstractmethod
+ def get_padded_num_q_heads(cls, num_heads: int) -> int:
+ """Q head count the q/output buffers are allocated at.
+
+ The layer allocates the q/output buffers at
+ ``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must satisfy
+ ``result >= num_heads``. Backends with no padding constraint return
+ ``num_heads``.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def forward_mqa(
+ self,
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ positions: torch.Tensor,
+ output: torch.Tensor,
+ ) -> None:
+ """Platform-specific sparse MLA forward; writes attention into ``output``."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Inverse-RoPE + wo_a + wo_b output projection (platform-specific)."""
+ raise NotImplementedError
+
def __init__(
self,
- hidden_size: int,
- num_heads: int,
- head_dim: int,
- scale: float,
- qk_nope_head_dim: int,
- qk_rope_head_dim: int,
- v_head_dim: int,
- q_lora_rank: int | None,
- kv_lora_rank: int,
- o_lora_rank: int | None,
vllm_config: VllmConfig,
- fused_wqa_wkv: torch.nn.Module,
- q_norm: torch.nn.Module,
- wq_b: torch.nn.Module,
- kv_norm: torch.nn.Module,
- wo_a: torch.nn.Module,
- wo_b: torch.nn.Module,
- attn_sink: torch.nn.Module,
- rotary_emb: torch.nn.Module,
- indexer: torch.nn.Module | None,
- indexer_rotary_emb: torch.nn.Module,
- topk_indices_buffer: torch.Tensor | None,
- aux_stream_list: list[torch.cuda.Stream] | None,
- window_size: int,
- compress_ratio: int | None,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
+ prefix: str,
+ topk_indices_buffer: torch.Tensor | None = None,
+ aux_stream_list: list[torch.cuda.Stream] | None = None,
) -> None:
super().__init__()
- self.hidden_size = hidden_size
- self.n_local_heads = num_heads
- self.head_dim = head_dim
- self.scale = scale
-
- self.q_lora_rank = q_lora_rank
- self.kv_lora_rank = kv_lora_rank
- self.window_size = window_size
- self.compress_ratio = compress_ratio if compress_ratio is not None else 1
- self.prefix = prefix
-
- # Extract config from vllm_config
config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ cache_config = vllm_config.cache_config
tp_size = get_tensor_model_parallel_world_size()
+ layer_id = extract_layer_index(prefix)
- # DeepseekV4-specific attributes (num_heads is already TP-adjusted)
- self.eps = config.rms_norm_eps
- self.rope_head_dim = config.qk_rope_head_dim
- self.nope_head_dim = head_dim - self.rope_head_dim
- self.n_local_groups = config.o_groups // tp_size
+ self.prefix = prefix # Alias for compatibility with compressor
+ self.hidden_size = config.hidden_size
+ self.n_heads = config.num_attention_heads
+ assert self.n_heads % tp_size == 0
+ self.n_local_heads = self.n_heads // tp_size
+ self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
+ self.head_dim = config.head_dim
+ self.rope_head_dim = config.qk_rope_head_dim
+ self.nope_head_dim = self.head_dim - self.rope_head_dim
+ self.n_groups = config.o_groups
+ self.n_local_groups = self.n_groups // tp_size
+ self.window_size = config.sliding_window
+ # NOTE(zyongye) Compress ratio can't be 0
+ # we do this for because MTP layer is not included
+ # in the compress ratio list
+ if layer_id < config.num_hidden_layers:
+ self.compress_ratio = max(1, config.compress_ratios[layer_id])
+ else:
+ self.compress_ratio = 1
+ self.eps = config.rms_norm_eps
+ self.scale = self.head_dim**-0.5
- # Store projection modules
- self.fused_wqa_wkv = fused_wqa_wkv
- self.q_norm = q_norm
- self.wq_b = wq_b
-
- self.kv_norm = kv_norm
- self.wo_a = wo_a
-
- self._wo_a_act_quant = QuantFP8(
- static=False,
- group_shape=GroupShape(1, 128),
- use_ue8m0=True,
+ # Padded Q head count is dictated by the platform subclass.
+ self.padded_heads = self.get_padded_num_q_heads(self.n_local_heads)
+ # Sink padded to the same head count, initialized to -inf (no sink
+ # effect). Weight loading fills the first n_local_heads slots.
+ self.attn_sink = nn.Parameter(
+ torch.full((self.padded_heads,), -float("inf"), dtype=torch.float32),
+ requires_grad=False,
)
- # Bypass packed-for-deepgemm path — we need FP32 scales (not packed
- # INT32) so fp8_einsum can handle layout transform internally.
- self._wo_a_act_quant.use_deep_gemm_supported = False
- self.wo_b = wo_b
- # Pick fp8_einsum recipe based on GPU arch:
- # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
- # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
- cap = current_platform.get_device_capability()
- assert cap is not None, "DeepseekV4 attention requires a CUDA device"
- self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
- self._tma_aligned_scales = cap.major >= 10
+ self.fused_wqa_wkv = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.q_lora_rank, self.head_dim],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fused_wqa_wkv",
+ disable_tp=True, # fused ReplicatedLinear
+ )
+ self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
+ self.wq_b = ColumnParallelLinear(
+ self.q_lora_rank,
+ self.n_heads * self.head_dim,
+ bias=False,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.wq_b",
+ )
- self.rotary_emb = rotary_emb
- self.indexer_rotary_emb = indexer_rotary_emb
+ self.kv_norm = RMSNorm(self.head_dim, self.eps)
+ self.wo_a = ColumnParallelLinear(
+ self.n_heads * self.head_dim // self.n_groups,
+ self.n_groups * self.o_lora_rank,
+ bias=False,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.wo_a",
+ )
+ self.wo_a.is_bmm = True
+ self.wo_a.bmm_batch_size = self.n_local_groups
+ self.wo_b = RowParallelLinear(
+ self.n_groups * self.o_lora_rank,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.wo_b",
+ )
+
+ # Initialize rotary embedding before the indexer/compressor consume it.
+ self.rotary_emb = build_deepseek_v4_rope(
+ config,
+ head_dim=self.head_dim,
+ rope_head_dim=self.rope_head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ compress_ratio=self.compress_ratio,
+ )
+ self.indexer_rotary_emb = self.rotary_emb
self.topk_indices_buffer = topk_indices_buffer
- self.indexer = indexer
-
- # Per-head RMS normalization for Q (no learnable weights)
- self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
-
- # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
- head_bytes = (
- self.nope_head_dim # 448 fp8 NoPE
- + self.rope_head_dim * 2 # 64 bf16 RoPE
- + self.nope_head_dim // 64 # 7B scale factors
- + 1 # 1B pad
- )
+ self.indexer = None
+ if self.compress_ratio == 4:
+ # Only C4A uses sparse attention and hence has indexer.
+ # aux_stream_list[2] is free here (outer GEMMs joined) for the inner
+ # overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
+ # ROCm, where aux_stream_list is None.
+ indexer_aux_stream = (
+ aux_stream_list[2] if aux_stream_list is not None else None
+ )
+ self.indexer = DeepseekV4Indexer(
+ vllm_config,
+ config=config,
+ hidden_size=self.hidden_size,
+ q_lora_rank=self.q_lora_rank,
+ quant_config=quant_config,
+ cache_config=cache_config,
+ topk_indices_buffer=topk_indices_buffer,
+ compress_ratio=self.compress_ratio,
+ prefix=f"{prefix}.indexer",
+ aux_stream=indexer_aux_stream,
+ )
# Will be None on ROCm for now.
self.aux_stream_list = aux_stream_list
@@ -252,45 +270,39 @@ class DeepseekV4MLA(nn.Module):
self.ln_events = [torch.cuda.Event() for _ in range(4)]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
- # Resolve the SWA cache tensor dtype from the selected backend: FlashMLA
- # uses the legacy UE8M0 paged uint8 layout; FlashInfer uses a contiguous
- # bf16 / per-tensor fp8 row.
- backend = _resolve_dsv4_backend(vllm_config)
- _, swa_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
- backend, cache_config.cache_dtype, cache_config
+ # ---- Attention / KV-cache setup ----
+ self.max_num_batched_tokens = (
+ vllm_config.scheduler_config.max_num_batched_tokens
)
+ self.max_model_len = vllm_config.model_config.max_model_len
+
+ # Resolve the kv-cache dtype from this backend's block format (a
+ # ClassVar set by the subclass): fp8_ds_mla (UE8M0 block-scaled fp8 as
+ # uint8) for FlashMLA / ROCm, vs a plain bf16 / per-tensor fp8 row for
+ # FlashInfer. The same resolution drives the SWA cache tensor dtype
+ # below.
+ self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
+ self.use_flashmla_fp8_layout, cache_config.cache_dtype, cache_config
+ )
+
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
- dtype=swa_cache_torch_dtype,
+ dtype=self.kv_cache_torch_dtype,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
- self.mla_attn = DeepseekV4MLAAttention(
- num_heads=self.n_local_heads,
- head_dim=self.head_dim,
- scale=self.scale,
- qk_nope_head_dim=self.nope_head_dim,
- qk_rope_head_dim=self.rope_head_dim,
- q_lora_rank=self.q_lora_rank,
- kv_lora_rank=self.kv_lora_rank,
- compress_ratio=self.compress_ratio,
- window_size=self.window_size,
- head_bytes=head_bytes,
- swa_cache_layer=self.swa_cache_layer,
- attn_sink=attn_sink, # already padded with -inf
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=prefix,
- indexer=self.indexer,
- topk_indices_buffer=self.topk_indices_buffer,
- )
- # Mirror the inner layer's padded head count (single source of truth).
- self.padded_heads = self.mla_attn.padded_heads
+ # Register with compilation context for metadata lookup.
+ compilation_config = vllm_config.compilation_config
+ if prefix and prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ if prefix:
+ compilation_config.static_forward_context[prefix] = self
+ self.kv_cache = torch.tensor([])
- # Create the compressor for layers with compress_ratio > 1; after
- # creating the DeepseekV4MLAAttention layer to get its cache.
+ # Create the compressor for layers with compress_ratio > 1; after the
+ # attention setup above so its KV-cache prefix (self.prefix) is set.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
@@ -300,7 +312,7 @@ class DeepseekV4MLA(nn.Module):
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
- k_cache_prefix=self.mla_attn.prefix,
+ k_cache_prefix=self.prefix,
)
def forward(
@@ -318,54 +330,38 @@ class DeepseekV4MLA(nn.Module):
device=hidden_states.device,
)
+ # Metadata-independent input GEMMs + RMSNorm stay in the captured
+ # graph; the metadata-dependent rest (q up-proj + kv-insert, indexer,
+ # compressor, MLA attention) runs in the eager break.
+ qr_kv, kv_score, indexer_kv_score, indexer_weights = (
+ self.attn_gemm_parallel_execute(hidden_states)
+ )
+ qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
+ qr, kv = fused_q_kv_rmsnorm(
+ qr,
+ kv,
+ self.q_norm.weight.data,
+ self.kv_norm.weight.data,
+ self.eps,
+ )
+
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
- self.attention_impl(hidden_states, positions, o_padded)
+ self.attention_impl(
+ hidden_states,
+ qr,
+ kv,
+ kv_score,
+ indexer_kv_score,
+ indexer_weights,
+ positions,
+ o_padded,
+ )
o = o_padded[:, : self.n_local_heads, :]
- # Keep ROCm on the BF16 reference wo_a path util kernel ready.
- if current_platform.is_rocm():
- z = rocm_inv_rope_einsum(
- self.rotary_emb,
- o,
- positions,
- self.rope_head_dim,
- self.n_local_groups,
- self.o_lora_rank,
- self.wo_a,
- )
- return self.wo_b(z.flatten(1))
-
- # O projection: inverse RoPE + FP8 quant + einsum + wo_b
- o_fp8, o_scale = fused_inv_rope_fp8_quant(
- o,
- positions,
- self.rotary_emb.cos_sin_cache,
- n_groups=self.n_local_groups,
- heads_per_group=self.n_local_heads // self.n_local_groups,
- nope_dim=self.nope_head_dim,
- rope_dim=self.rope_head_dim,
- tma_aligned_scales=self._tma_aligned_scales,
- )
-
- wo_a_fp8 = self.wo_a.weight
- wo_a_scale = self.wo_a.weight_scale_inv
-
- z = torch.empty(
- (num_tokens, self.n_local_groups, self.o_lora_rank),
- device=o.device,
- dtype=torch.bfloat16,
- )
- fp8_einsum(
- "bhr,hdr->bhd",
- (o_fp8, o_scale),
- (wo_a_fp8, wo_a_scale),
- z,
- recipe=self._einsum_recipe,
- )
-
- return self.wo_b(z.flatten(1))
+ # Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
+ return self._o_proj(o, positions)
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
aux_streams = self.aux_stream_list
@@ -431,27 +427,19 @@ class DeepseekV4MLA(nn.Module):
def attention_impl(
self,
hidden_states: torch.Tensor,
+ qr: torch.Tensor,
+ kv: torch.Tensor,
+ kv_score: torch.Tensor,
+ indexer_kv_score: torch.Tensor,
+ indexer_weights: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
- qr_kv, kv_score, indexer_kv_score, indexer_weights = (
- self.attn_gemm_parallel_execute(hidden_states)
- )
-
- qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
- qr, kv = fused_q_kv_rmsnorm(
- qr,
- kv,
- self.q_norm.weight.data,
- self.kv_norm.weight.data,
- self.eps,
- )
-
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
- # on the default stream so q stays on its consumer stream (mla_attn
+ # on the default stream so q stays on its consumer stream (forward_mqa
# downstream reads q on default). Indexer/compressor go on aux for
# overlap with default's GEMM + cache write.
if self.indexer is not None:
@@ -514,7 +502,7 @@ class DeepseekV4MLA(nn.Module):
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
- self.mla_attn(q, kv, positions, output=out)
+ self.forward_mqa(q, kv, positions, out)
def _fused_qnorm_rope_kv_insert(
self,
@@ -549,7 +537,7 @@ class DeepseekV4MLA(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache
cache_dtype = swa_kv_cache.dtype
- # kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
+ # kv is unchanged; attention reads kv solely via swa_kv_cache.
if cache_dtype == torch.uint8:
# Legacy FlashMLA UE8M0 paged path. Horizontally fused:
# Q side: per-head RMSNorm (no weight) + GPT-J RoPE, zero-filling
@@ -569,9 +557,10 @@ class DeepseekV4MLA(nn.Module):
swa_metadata.block_size,
)
- # FlashInfer full-cache path: contiguous [num_blocks, block_size, 512]
- # cache (no Q padding). bf16 rewrites q in place; per-tensor fp8 writes a
- # separately-allocated fp8 q and quantizes the KV row.
+ # FlashInfer full-cache path: the [num_blocks, block_size, 512] cache
+ # stores the KV row in its plain dtype (no Q padding). bf16 rewrites q
+ # in place; per-tensor fp8 writes a separately-allocated fp8 q and
+ # quantizes the KV row.
block_size = swa_metadata.block_size
swa_kv_cache_3d = swa_kv_cache.view(-1, block_size, self.head_dim)
if cache_dtype == torch.bfloat16:
@@ -597,99 +586,13 @@ class DeepseekV4MLA(nn.Module):
swa_metadata.slot_mapping,
positions,
cos_sin_cache,
- self.mla_attn._flashinfer_fp8_kv_scale,
- self.mla_attn._flashinfer_fp8_q_scale_inv,
+ self._flashinfer_fp8_kv_scale,
+ self._flashinfer_fp8_q_scale_inv,
self.eps,
block_size,
)
return q_fp8
-
-class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
- def __init__(
- self,
- num_heads: int,
- head_dim: int,
- scale: float,
- qk_nope_head_dim: int,
- qk_rope_head_dim: int,
- q_lora_rank: int | None,
- kv_lora_rank: int,
- compress_ratio: int,
- window_size: int,
- head_bytes: int,
- swa_cache_layer: DeepseekV4SWACache,
- attn_sink: torch.Tensor,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- # Sparse MLA Args
- indexer: object | None = None,
- topk_indices_buffer: torch.Tensor | None = None,
- aux_stream: torch.cuda.Stream | None = None,
- **extra_impl_args,
- ) -> None:
- super().__init__()
- vllm_config = get_current_vllm_config()
- self.impl_cls = _select_v4_sparse_impl(vllm_config)
- self.backend_cls = self.impl_cls.backend_cls
- self.num_heads = num_heads
- self.num_kv_heads = 1
- self.head_dim = head_dim
- self.scale = scale
- self.window_size = window_size
- self.head_bytes = head_bytes
- self.compress_ratio = compress_ratio
- self.q_lora_rank = q_lora_rank
- self.kv_lora_rank = kv_lora_rank
- self.nope_head_dim = qk_nope_head_dim
- self.rope_head_dim = qk_rope_head_dim
- self.indexer = indexer
- self.topk_indices_buffer = topk_indices_buffer
-
- self.prefix = prefix # Alias for compatibility with compressor
-
- self.aux_stream = aux_stream
- self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
-
- # Padded Q head count is dictated by the selected impl.
- self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)
-
- # Store attention sink
- assert attn_sink is not None
- self.attn_sink: torch.Tensor = attn_sink
- # Store SWA cache
- assert swa_cache_layer is not None
- self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer
-
- # Get vllm config for cache setup
- self.max_num_batched_tokens = (
- vllm_config.scheduler_config.max_num_batched_tokens
- )
- self.max_model_len = vllm_config.model_config.max_model_len
-
- # Resolve the kv-cache dtype from the selected backend. FlashMLA uses
- # the legacy UE8M0 paged uint8 (fp8_ds_mla) layout; FlashInfer uses a
- # contiguous bf16 / per-tensor fp8 row.
- backend = _resolve_dsv4_backend(vllm_config)
- kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
- self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
- backend, kv_cache_dtype, cache_config
- )
-
- # Per-impl layer buffers (e.g. FlashInfer FP8 scale buffers). No-op for
- # the FlashMLA / ROCm impls.
- self.impl_cls.init_layer_buffers(self)
-
- # Register with compilation context for metadata lookup
- compilation_config = vllm_config.compilation_config
- if prefix and prefix in compilation_config.static_forward_context:
- raise ValueError(f"Duplicate layer name: {prefix}")
- if prefix:
- compilation_config.static_forward_context[prefix] = self
-
- self.kv_cache = torch.tensor([])
-
def get_attn_backend(self) -> type[AttentionBackend]:
return self.backend_cls
@@ -698,8 +601,9 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
self.compress_ratio <= 1
): # SWA part. Allocated separately as DeepseekV4SWACache.
return None
- # FlashMLA uses the UE8M0 paged uint8 layout (576B aligned); FlashInfer
- # uses a contiguous bf16 / per-tensor fp8 cache with no extra alignment.
+ # FlashMLA uses the fp8_ds_mla block format (UE8M0 block-scaled fp8 as
+ # uint8, 576B aligned); FlashInfer stores a plain bf16 / per-tensor fp8
+ # row with no extra alignment.
is_flashmla = self.kv_cache_dtype == "fp8_ds_mla"
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
@@ -712,15 +616,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
model_version="deepseek_v4",
)
- def forward(
- self,
- q: torch.Tensor,
- kv: torch.Tensor,
- positions: torch.Tensor,
- output: torch.Tensor,
- ) -> None:
- self.impl_cls.forward_mqa(self, q, kv, positions, output)
-
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(
diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
index 71ea4fe506e..1e119614b0d 100644
--- a/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
+++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
@@ -3,28 +3,30 @@
"""DeepSeek V4 FlashInfer TRTLLM-gen sparse MLA backend.
Uses FlashInfer's public ``trtllm_batch_decode_sparse_mla_dsv4`` launcher with a
-contiguous bf16 / per-tensor FP8 KV cache. Shares the V4 sparse-index pipeline
-(SWA cache + compressor + indexer, 256-token blocks, head_size 512) with the
-FlashMLA V4 backend; only the attention forward differs.
+plain bf16 / per-tensor FP8 KV row (vs FlashMLA's packed ``fp8_ds_mla`` block
+format). Shares the V4 sparse-index pipeline (SWA cache + compressor + indexer,
+256-token blocks, head_size 512) with the FlashMLA V4 backend; only the
+attention forward differs.
"""
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING, ClassVar, cast
import torch
from vllm.forward_context import get_forward_context
+from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import (
build_flashinfer_mixed_sparse_indices,
)
-from vllm.models.deepseek_v4.nvidia.flashmla import (
- DeepseekV4FlashMLASparseBackend,
- DeepseekV4SparseMLAAttentionImpl,
+from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLASparseBackend
+from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
+ compute_fp8_einsum_recipe,
+ deep_gemm_fp8_o_proj,
)
from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseMetadata
if TYPE_CHECKING:
- from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
# 128 MB TRTLLM-gen workspace, allocated once per device and zero-initialized
@@ -51,23 +53,21 @@ class DeepseekV4FlashInferMLASparseBackend(DeepseekV4FlashMLASparseBackend):
Inheriting from the FlashMLA V4 backend reuses its ``FlashMLASparseMetadata``
builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer
builder lacks the ``c128a_*`` fields), 256-token blocks, head_size 512, and
- the contiguous (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla``
- dtypes.
+ the (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla`` dtypes.
"""
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA_SPARSE_DSV4"
- @staticmethod
- def get_impl_cls() -> type["DeepseekV4FlashInferMLASparseImpl"]:
- return DeepseekV4FlashInferMLASparseImpl
-
-class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
- """FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4."""
+class DeepseekV4FlashInferMLAAttention(DeepseekV4Attention):
+ """FlashInfer TRTLLM-gen sparse MLA attention layer for DeepSeek V4."""
backend_cls = DeepseekV4FlashInferMLASparseBackend
+ # FlashInfer stores a plain bf16 / per-tensor fp8 KV row, not the FlashMLA
+ # packed fp8_ds_mla block format (UE8M0 block-scaled fp8 as uint8).
+ use_flashmla_fp8_layout: ClassVar[bool] = False
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
@@ -79,27 +79,44 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
)
return 64 if num_heads <= 64 else 128
- @classmethod
- def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
+ def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ return deep_gemm_fp8_o_proj(
+ o,
+ positions,
+ self.rotary_emb.cos_sin_cache,
+ self.wo_a,
+ self.wo_b,
+ n_groups=self.n_local_groups,
+ heads_per_group=self.n_local_heads // self.n_local_groups,
+ nope_dim=self.nope_head_dim,
+ rope_dim=self.rope_head_dim,
+ o_lora_rank=self.o_lora_rank,
+ einsum_recipe=self._einsum_recipe,
+ tma_aligned_scales=self._tma_aligned_scales,
+ )
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
# Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the
- # per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``.
- if layer.kv_cache_torch_dtype != torch.float8_e4m3fn:
+ # per-tensor FP8 cache path consumes these; bf16 reads ``self.scale``.
+ if self.kv_cache_torch_dtype != torch.float8_e4m3fn:
return
# TODO: load real per-tensor Q/KV scales from the checkpoint; unit
# scales until the scale tensor names are wired.
fp8_q_scale = 1.0
fp8_kv_scale = 1.0
- layer.register_buffer(
+ self.register_buffer(
"_flashinfer_fp8_q_scale",
torch.tensor([fp8_q_scale], dtype=torch.float32),
persistent=False,
)
- layer.register_buffer(
+ self.register_buffer(
"_flashinfer_fp8_q_scale_inv",
torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32),
persistent=False,
)
- layer.register_buffer(
+ self.register_buffer(
"_flashinfer_fp8_kv_scale",
torch.tensor([fp8_kv_scale], dtype=torch.float32),
persistent=False,
@@ -107,13 +124,11 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path
# vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax
# scale and the Q/KV per-tensor scales; bmm2 is the KV scale.
- layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale
- layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale
+ self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale
+ self._flashinfer_fp8_bmm2_scale = fp8_kv_scale
- @classmethod
- def forward_mqa( # type: ignore[override]
- cls,
- layer: "DeepseekV4MLAAttention",
+ def forward_mqa(
+ self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -147,21 +162,20 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
- FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
+ FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
- attn_metadata.get(layer.swa_cache_layer.prefix),
+ attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
- swa_only = layer.compress_ratio <= 1
+ swa_only = self.compress_ratio <= 1
# SWA-only layers don't allocate their own compressed KV cache.
- self_kv_cache = layer.kv_cache if not swa_only else None
- swa_kv_cache = layer.swa_cache_layer.kv_cache
+ self_kv_cache = self.kv_cache if not swa_only else None
+ swa_kv_cache = self.swa_cache_layer.kv_cache
- cls._forward(
- layer=layer,
+ self._forward(
q=q,
kv_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
@@ -171,10 +185,8 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output,
)
- @classmethod
def _build_sparse_index_metadata(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
swa_metadata: "DeepseekSparseSWAMetadata",
@@ -200,17 +212,17 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert swa_metadata.block_table is not None
decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
- num_decode_tokens, layer.window_size
+ num_decode_tokens, self.window_size
)
decode_compressed_topk_lens = None
decode_compressed_indices_are_local = False
decode_is_valid_token = None
if swa_only:
- assert layer.topk_indices_buffer is not None
+ assert self.topk_indices_buffer is not None
compressed_kv_cache = swa_k_cache
decode_compressed_indices = None
- prefill_topk_indices = layer.topk_indices_buffer[
+ prefill_topk_indices = self.topk_indices_buffer[
num_decode_tokens:num_tokens, :0
]
compressed_block_table = None
@@ -221,24 +233,24 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
assert attn_metadata is not None
compressed_kv_cache = kv_cache
compressed_block_table = attn_metadata.block_table[:num_reqs]
- compressed_block_size = attn_metadata.block_size // layer.compress_ratio
+ compressed_block_size = attn_metadata.block_size // self.compress_ratio
- if layer.compress_ratio == 4:
- assert layer.topk_indices_buffer is not None
+ if self.compress_ratio == 4:
+ assert self.topk_indices_buffer is not None
if num_prefill_tokens > 0:
- prefill_topk_indices = layer.topk_indices_buffer[
+ prefill_topk_indices = self.topk_indices_buffer[
num_decode_tokens:num_tokens
]
top_k = prefill_topk_indices.shape[-1]
else:
- prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
+ prefill_topk_indices = self.topk_indices_buffer[:0, :0]
top_k = 0
decode_compressed_indices_are_local = True
assert swa_metadata.is_valid_token is not None
decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
if num_decode_tokens > 0:
- decode_compressed_indices = layer.topk_indices_buffer[
+ decode_compressed_indices = self.topk_indices_buffer[
:num_decode_tokens
]
else:
@@ -284,18 +296,16 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata.block_size,
compressed_block_table,
compressed_block_size,
- layer.window_size,
- layer.compress_ratio,
+ self.window_size,
+ self.compress_ratio,
top_k,
decode_compressed_indices_are_local=decode_compressed_indices_are_local,
decode_is_valid_token=decode_is_valid_token,
)
return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens
- @classmethod
def _forward(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
@@ -304,7 +314,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_only: bool,
output: torch.Tensor,
) -> None:
- assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
+ assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
@@ -319,8 +329,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
seq_lens,
sparse_indices,
sparse_topk_lens,
- ) = cls._build_sparse_index_metadata(
- layer=layer,
+ ) = self._build_sparse_index_metadata(
kv_cache=kv_cache,
swa_k_cache=swa_k_cache,
swa_metadata=swa_metadata,
@@ -332,12 +341,12 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# restrict to the real tokens (the launcher validates sparse indices).
query = q[:num_tokens]
output = output[:num_tokens]
- bmm1_scale: float | torch.Tensor = layer.scale
+ bmm1_scale: float | torch.Tensor = self.scale
bmm2_scale: float | torch.Tensor = 1.0
- if layer.kv_cache_torch_dtype == torch.float8_e4m3fn:
+ if self.kv_cache_torch_dtype == torch.float8_e4m3fn:
assert query.dtype == torch.float8_e4m3fn
- bmm1_scale = layer._flashinfer_fp8_bmm1_scale
- bmm2_scale = layer._flashinfer_fp8_bmm2_scale
+ bmm1_scale = self._flashinfer_fp8_bmm1_scale
+ bmm2_scale = self._flashinfer_fp8_bmm2_scale
else:
assert query.dtype == torch.bfloat16
query = query.contiguous()
@@ -376,7 +385,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
out=output[:num_decode_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
- sinks=layer.attn_sink,
+ sinks=self.attn_sink,
cum_seq_lens_q=decode_cu,
max_q_len=int(decode_lens_cpu.max().item()),
)
@@ -401,7 +410,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
out=output[num_decode_tokens:num_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
- sinks=layer.attn_sink,
+ sinks=self.attn_sink,
cum_seq_lens_q=prefill_cu,
max_q_len=int(prefill_lens_cpu.max().item()),
)
diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py
index e9b9c678306..5c3969deb80 100644
--- a/vllm/models/deepseek_v4/nvidia/flashmla.py
+++ b/vllm/models/deepseek_v4/nvidia/flashmla.py
@@ -1,22 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from abc import abstractmethod
-from typing import TYPE_CHECKING, ClassVar, cast
+from typing import TYPE_CHECKING, cast
import torch
from vllm.forward_context import get_forward_context
+from vllm.models.deepseek_v4.attention import DeepseekV4Attention
from vllm.models.deepseek_v4.common.ops import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
)
-from vllm.v1.attention.backend import (
- AttentionBackend,
- MultipleOf,
- SparseMLAAttentionImpl,
+from vllm.models.deepseek_v4.nvidia.ops.o_proj import (
+ compute_fp8_einsum_recipe,
+ deep_gemm_fp8_o_proj,
)
+from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
FlashMLASparseMetadata,
@@ -28,63 +28,9 @@ from vllm.v1.attention.ops.flashmla import (
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
- from vllm.models.deepseek_v4.attention import (
- DeepseekV4MLAAttention,
- )
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata
-class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
- """Abstract parent for DeepseekV4 sparse MLA impls.
-
- V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``)
- rather than the v1 framework, so ``forward_mqa`` is overridden with a
- classmethod that takes the layer as its first argument. This Liskov-broken
- override is intentional: the grandparent's instance-method ``forward_mqa``
- is never called on V4 layers.
- """
-
- backend_cls: ClassVar[type[AttentionBackend]]
-
- # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
- # workspace allocated in _forward_prefill and is also read by the V4 layer's
- # dummy-run path to pre-reserve that workspace.
- PREFILL_CHUNK_SIZE: ClassVar[int] = 4
-
- @classmethod
- @abstractmethod
- def forward_mqa( # type: ignore[override]
- cls,
- layer: "DeepseekV4MLAAttention",
- q: torch.Tensor,
- kv: torch.Tensor,
- positions: torch.Tensor,
- output: torch.Tensor,
- ) -> None:
- raise NotImplementedError
-
- @classmethod
- @abstractmethod
- def get_padded_num_q_heads(cls, num_heads: int) -> int:
- """Q head count the backend wants q allocated at.
-
- The MLA wrapper allocates the q/output buffers at
- ``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must
- satisfy ``result >= num_heads``. Backends with no padding constraint
- return ``num_heads``.
- """
- raise NotImplementedError
-
- @classmethod
- def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
- """Register impl-specific buffers on the layer at construction.
-
- No-op by default; FlashInfer overrides this to register its per-tensor
- FP8 scale buffers.
- """
- return None
-
-
class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
@@ -94,10 +40,6 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
def get_name() -> str:
return "FLASHMLA_SPARSE_DSV4"
- @staticmethod
- def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
- return DeepseekV4FlashMLASparseImpl
-
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# DeepSeek V4 layout: 448 NoPE + 64 RoPE = 512 (overrides the
@@ -120,11 +62,31 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend):
return (num_blocks, block_size, head_size)
-class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
- """FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer."""
+class DeepseekV4FlashMLAAttention(DeepseekV4Attention):
+ """FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA)."""
backend_cls = DeepseekV4FlashMLASparseBackend
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe()
+
+ def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ return deep_gemm_fp8_o_proj(
+ o,
+ positions,
+ self.rotary_emb.cos_sin_cache,
+ self.wo_a,
+ self.wo_b,
+ n_groups=self.n_local_groups,
+ heads_per_group=self.n_local_heads // self.n_local_groups,
+ nope_dim=self.nope_head_dim,
+ rope_dim=self.rope_head_dim,
+ o_lora_rank=self.o_lora_rank,
+ einsum_recipe=self._einsum_recipe,
+ tma_aligned_scales=self._tma_aligned_scales,
+ )
+
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128.
@@ -135,10 +97,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
)
return 64 if num_heads <= 64 else 128
- @classmethod
- def forward_mqa( # type: ignore[override]
- cls,
- layer: "DeepseekV4MLAAttention",
+ def forward_mqa(
+ self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
@@ -159,35 +119,35 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Warmup dummy run: no real metadata. Reserve the same bf16
# gather workspace _forward_prefill would; the dequantize / topk
# / sparse_fwd kernels are skipped this step.
- swa_only = layer.compress_ratio <= 1
+ swa_only = self.compress_ratio <= 1
N = (
0
if swa_only
- else (layer.max_model_len + layer.compress_ratio - 1)
- // layer.compress_ratio
+ else (self.max_model_len + self.compress_ratio - 1)
+ // self.compress_ratio
)
- M = N + layer.window_size + layer.max_num_batched_tokens
+ M = N + self.window_size + self.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
- ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
+ ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
- FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
+ FlashMLASparseMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
- attn_metadata.get(layer.swa_cache_layer.prefix),
+ attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
- swa_only = layer.compress_ratio <= 1
+ swa_only = self.compress_ratio <= 1
# SWA-only layers (compress_ratio <= 1) don't have their own KV cache
- # allocation, so layer.kv_cache may be empty after profiling cleanup.
- self_kv_cache = layer.kv_cache if not swa_only else None
- swa_kv_cache = layer.swa_cache_layer.kv_cache
+ # allocation, so self.kv_cache may be empty after profiling cleanup.
+ self_kv_cache = self.kv_cache if not swa_only else None
+ swa_kv_cache = self.swa_cache_layer.kv_cache
# Split prefill and decode
num_decodes = swa_metadata.num_decodes
@@ -195,8 +155,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
- cls._forward_prefill(
- layer=layer,
+ self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
@@ -206,8 +165,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
swa_metadata=swa_metadata,
)
if num_decodes > 0:
- cls._forward_decode(
- layer=layer,
+ self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
@@ -216,10 +174,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
output=output[:num_decode_tokens],
)
- @classmethod
def _forward_decode(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
q: torch.Tensor,
kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1
swa_metadata: "DeepseekSparseSWAMetadata",
@@ -235,13 +191,13 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
- block_size = attn_metadata.block_size // layer.compress_ratio
+ block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
- if layer.compress_ratio == 4:
+ if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
- assert layer.topk_indices_buffer is not None
+ assert self.topk_indices_buffer is not None
global_indices, topk_lens = compute_global_topk_indices_and_lens(
- layer.topk_indices_buffer[:num_decode_tokens],
+ self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
@@ -258,12 +214,12 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
- # q arrives pre-padded to layer.padded_heads by the outer wrapper.
+ # q arrives pre-padded to self.padded_heads by the outer wrapper.
q = q.unsqueeze(1)
# Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes)
# Use unsqueeze to preserve strides (handles padded blocks correctly)
- swa_cache = layer.swa_cache_layer.kv_cache.unsqueeze(-2)
+ swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
if kv_cache is not None:
kv_cache = kv_cache.unsqueeze(-2)
@@ -274,20 +230,20 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# and num_splits via PyTorch's graph-aware allocator so CUDA graph
# capture reuses the same addresses on replay); subsequent same-type
# layers see have_initialized=True and skip the planner.
- if layer.compress_ratio <= 1:
+ if self.compress_ratio <= 1:
tile_metadata = swa_metadata.tile_sched_swaonly
- elif layer.compress_ratio == 4:
+ elif self.compress_ratio == 4:
tile_metadata = swa_metadata.tile_sched_c4a
- elif layer.compress_ratio == 128:
+ elif self.compress_ratio == 128:
tile_metadata = swa_metadata.tile_sched_c128a
else:
raise ValueError(
- f"Unsupported compress_ratio={layer.compress_ratio}; "
+ f"Unsupported compress_ratio={self.compress_ratio}; "
"expected 1, 4, or 128."
)
assert tile_metadata is not None, (
"swa_metadata missing tile_sched entry for "
- f"compress_ratio={layer.compress_ratio}; "
+ f"compress_ratio={self.compress_ratio}; "
"DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not "
"allocate one for this layer type."
)
@@ -302,18 +258,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
is_fp8_kvcache=True,
indices=swa_indices,
topk_length=swa_lens,
- softmax_scale=layer.scale,
- attn_sink=layer.attn_sink,
+ softmax_scale=self.scale,
+ attn_sink=self.attn_sink,
extra_k_cache=kv_cache if not swa_only else None,
extra_indices_in_kvcache=topk_indices,
extra_topk_length=topk_lens,
out=output.unsqueeze(1),
)
- @classmethod
def _forward_prefill(
- cls,
- layer: "DeepseekV4MLAAttention",
+ self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1
@@ -343,9 +297,9 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
- if layer.compress_ratio == 4:
- assert layer.topk_indices_buffer is not None
- topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
+ if self.compress_ratio == 4:
+ assert self.topk_indices_buffer is not None
+ topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
# C128A: pre-computed during metadata build.
@@ -355,16 +309,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
# Compressed region must fit the full compressed pool (seq_len //
# compress_ratio), not just top_k. top_k bounds how many indices
# the indexer selects, not the pool size it indexes into.
- N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio
+ N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
- assert layer.topk_indices_buffer is not None
- topk_indices = layer.topk_indices_buffer[num_decode_tokens:]
+ assert self.topk_indices_buffer is not None
+ topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
- M = N + layer.window_size + layer.max_num_batched_tokens
- chunk_size_const = cls.PREFILL_CHUNK_SIZE
+ M = N + self.window_size + self.max_num_batched_tokens
+ chunk_size_const = self.PREFILL_CHUNK_SIZE
num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const
workspace_manager = current_workspace_manager()
@@ -382,10 +336,10 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
- seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio,
+ seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
- block_size=attn_metadata.block_size // layer.compress_ratio,
+ block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
@@ -416,8 +370,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
- layer.window_size,
- layer.compress_ratio,
+ self.window_size,
+ self.compress_ratio,
top_k,
M,
N,
@@ -426,8 +380,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
- sm_scale=layer.scale,
- attn_sink=layer.attn_sink,
+ sm_scale=self.scale,
+ attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py
index 547048ab58f..00866e6941b 100644
--- a/vllm/models/deepseek_v4/nvidia/model.py
+++ b/vllm/models/deepseek_v4/nvidia/model.py
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
@@ -55,13 +54,14 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.model_executor.utils import set_weight_attrs
-from vllm.models.deepseek_v4.attention import (
- DeepseekV4Indexer,
- DeepseekV4MLA,
+from vllm.models.deepseek_v4.attention import DeepseekV4Attention
+from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import (
+ DeepseekV4FlashInferMLAAttention,
)
-from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
+from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors
+from vllm.v1.attention.backends.registry import AttentionBackendEnum
class DeepseekV4MLP(nn.Module):
@@ -713,163 +713,18 @@ class DeepseekV4MoE(nn.Module):
self.experts.finalize_weights()
-class DeepseekV4Attention(nn.Module):
- def __init__(
- self,
- vllm_config: VllmConfig,
- prefix: str,
- topk_indices_buffer: torch.Tensor | None = None,
- aux_stream_list: list[torch.cuda.Stream] | None = None,
+def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]:
+ """Pick the CUDA sparse-MLA attention class for the configured backend.
+
+ An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the
+ FlashInfer TRTLLM-gen path; otherwise the FlashMLA path is used.
+ """
+ if (
+ vllm_config.attention_config.backend
+ == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4
):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- layer_id = extract_layer_index(prefix)
-
- self.layer_id = layer_id
- self.hidden_size = config.hidden_size
- self.n_heads = config.num_attention_heads
- tp_size = get_tensor_model_parallel_world_size()
- assert self.n_heads % tp_size == 0
-
- self.n_local_heads = self.n_heads // tp_size
- self.q_lora_rank = config.q_lora_rank
- self.o_lora_rank = config.o_lora_rank
- self.head_dim = config.head_dim
- self.rope_head_dim = config.qk_rope_head_dim
- self.nope_head_dim = self.head_dim - self.rope_head_dim
- self.n_groups = config.o_groups
- self.n_local_groups = self.n_groups // tp_size
- self.window_size = config.sliding_window
- # NOTE(zyongye) Compress ratio can't be 0
- # we do this for because MTP layer is not included
- # in the compress ratio list
- if layer_id < config.num_hidden_layers:
- self.compress_ratio = max(1, config.compress_ratios[layer_id])
- else:
- self.compress_ratio = 1
- self.eps = config.rms_norm_eps
- self.max_position_embeddings = config.max_position_embeddings
-
- # Padded to min 64 heads for FlashMLA, initialized to -inf
- # (no sink effect). Weight loading fills the first n_local_heads slots.
- padded_heads = max(self.n_local_heads, 64)
- self.attn_sink = nn.Parameter(
- torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
- requires_grad=False,
- )
-
- self.fused_wqa_wkv = MergedColumnParallelLinear(
- self.hidden_size,
- [self.q_lora_rank, self.head_dim],
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.fused_wqa_wkv",
- disable_tp=True, # fused ReplicatedLinear
- )
- self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
- self.wq_b = ColumnParallelLinear(
- self.q_lora_rank,
- self.n_heads * self.head_dim,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wq_b",
- )
-
- self.kv_norm = RMSNorm(self.head_dim, self.eps)
- self.wo_a = ColumnParallelLinear(
- self.n_heads * self.head_dim // self.n_groups,
- self.n_groups * self.o_lora_rank,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wo_a",
- )
- self.wo_a.is_bmm = True
- self.wo_a.bmm_batch_size = self.n_local_groups
- self.wo_b = RowParallelLinear(
- self.n_groups * self.o_lora_rank,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.wo_b",
- )
- self.softmax_scale = self.head_dim**-0.5
- self.scale_fmt = config.quantization_config["scale_fmt"]
-
- self.rope_parameters = config.rope_scaling
-
- # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
- self.rotary_emb = build_deepseek_v4_rope(
- config,
- head_dim=self.head_dim,
- rope_head_dim=self.rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- compress_ratio=self.compress_ratio,
- )
-
- self.indexer = None
- if self.compress_ratio == 4:
- # Only C4A uses sparse attention and hence has indexer.
- # aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is
- # free here (outer GEMMs joined) for the inner overlap of
- # wq_b+fused_indexer_q_rope_quant vs compressor.
- indexer_aux_stream = (
- aux_stream_list[2] if aux_stream_list is not None else None
- )
- self.indexer = DeepseekV4Indexer(
- vllm_config,
- config=config,
- hidden_size=self.hidden_size,
- q_lora_rank=self.q_lora_rank,
- quant_config=quant_config,
- cache_config=vllm_config.cache_config,
- topk_indices_buffer=topk_indices_buffer,
- compress_ratio=self.compress_ratio,
- prefix=f"{prefix}.indexer",
- aux_stream=indexer_aux_stream,
- )
-
- self.mla_attn = DeepseekV4MLA(
- hidden_size=self.hidden_size,
- num_heads=self.n_local_heads,
- head_dim=self.head_dim,
- scale=self.softmax_scale,
- qk_nope_head_dim=self.nope_head_dim,
- qk_rope_head_dim=self.rope_head_dim,
- v_head_dim=self.head_dim,
- q_lora_rank=self.q_lora_rank,
- kv_lora_rank=self.head_dim,
- o_lora_rank=self.o_lora_rank,
- vllm_config=vllm_config,
- fused_wqa_wkv=self.fused_wqa_wkv,
- q_norm=self.q_norm,
- wq_b=self.wq_b,
- kv_norm=self.kv_norm,
- wo_a=self.wo_a,
- wo_b=self.wo_b,
- attn_sink=self.attn_sink,
- rotary_emb=self.rotary_emb,
- indexer=self.indexer,
- indexer_rotary_emb=self.rotary_emb,
- topk_indices_buffer=topk_indices_buffer,
- aux_stream_list=aux_stream_list,
- window_size=self.window_size,
- compress_ratio=self.compress_ratio,
- cache_config=vllm_config.cache_config,
- quant_config=quant_config,
- prefix=prefix,
- )
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- llama_4_scaling: torch.Tensor | None,
- ):
- return self.mla_attn(positions, hidden_states, llama_4_scaling)
+ return DeepseekV4FlashInferMLAAttention
+ return DeepseekV4FlashMLAAttention
class DeepseekV4DecoderLayer(nn.Module):
@@ -886,7 +741,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
- self.attn = DeepseekV4Attention(
+ self.attn = _select_dsv4_attn_cls(vllm_config)(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
@@ -1043,7 +898,7 @@ class DeepseekV4Model(nn.Module):
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
- # DeepseekV4MLA.attn_gemm_parallel_execute
+ # DeepseekV4Attention.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
@@ -1341,7 +1196,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
- ".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
diff --git a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py
new file mode 100644
index 00000000000..a0b4e2c678e
--- /dev/null
+++ b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py
@@ -0,0 +1,68 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+import torch.nn as nn
+
+from vllm.models.deepseek_v4.common.ops import fused_inv_rope_fp8_quant
+from vllm.platforms import current_platform
+from vllm.utils.deep_gemm import fp8_einsum
+
+
+def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]:
+ """fp8_einsum recipe + scale layout for the current GPU arch.
+
+ SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128.
+ SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1.
+
+ Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``.
+ """
+ cap = current_platform.get_device_capability()
+ assert cap is not None, "DeepseekV4 attention requires a CUDA device"
+ einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
+ tma_aligned_scales = cap.major >= 10
+ return einsum_recipe, tma_aligned_scales
+
+
+def deep_gemm_fp8_o_proj(
+ o: torch.Tensor,
+ positions: torch.Tensor,
+ cos_sin_cache: torch.Tensor,
+ wo_a: nn.Module,
+ wo_b: nn.Module,
+ *,
+ n_groups: int,
+ heads_per_group: int,
+ nope_dim: int,
+ rope_dim: int,
+ o_lora_rank: int,
+ einsum_recipe: tuple[int, int, int],
+ tma_aligned_scales: bool,
+) -> torch.Tensor:
+ """O projection: inverse RoPE + FP8 quant + einsum + wo_b.
+
+ Shared by the FlashMLA and FlashInfer CUDA backends. ``einsum_recipe`` /
+ ``tma_aligned_scales`` come from ``compute_fp8_einsum_recipe``.
+ """
+ o_fp8, o_scale = fused_inv_rope_fp8_quant(
+ o,
+ positions,
+ cos_sin_cache,
+ n_groups=n_groups,
+ heads_per_group=heads_per_group,
+ nope_dim=nope_dim,
+ rope_dim=rope_dim,
+ tma_aligned_scales=tma_aligned_scales,
+ )
+ z = torch.empty(
+ (o.shape[0], n_groups, o_lora_rank),
+ device=o.device,
+ dtype=torch.bfloat16,
+ )
+ fp8_einsum(
+ "bhr,hdr->bhd",
+ (o_fp8, o_scale),
+ (wo_a.weight, wo_a.weight_scale_inv),
+ z,
+ recipe=einsum_recipe,
+ )
+ return wo_b(z.flatten(1))
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index cf4319ac722..b1414665869 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -92,6 +92,11 @@ class CpuPlatform(Platform):
return meminfo.total_memory
+ @classmethod
+ def mem_get_info(cls) -> tuple[int, int]:
+ meminfo = get_memory_node_info()
+ return meminfo.available_memory, meminfo.total_memory
+
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index 89471e844d8..7f6d8794c28 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -134,6 +134,7 @@ def _sync_hip_cuda_env_vars():
# Sync at import time - catches misconfigurations from process start.
_sync_hip_cuda_env_vars()
+
# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
@@ -312,6 +313,17 @@ def on_gfx950() -> bool:
return _ON_GFX950
+# Enable HIP online tuning early, before hipBLASLt initializes.
+# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM.
+if (
+ envs.VLLM_ROCM_USE_AITER
+ and envs.VLLM_ROCM_USE_AITER_LINEAR
+ and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
+ and on_mi3xx()
+):
+ os.environ["HIP_ONLINE_TUNING"] = "1"
+
+
@cache
def use_rocm_custom_paged_attention(
qtype: torch.dtype,
diff --git a/vllm/reasoning/cohere_command_reasoning_parser.py b/vllm/reasoning/cohere_command_reasoning_parser.py
index b28a59089e7..949c9ff5d99 100644
--- a/vllm/reasoning/cohere_command_reasoning_parser.py
+++ b/vllm/reasoning/cohere_command_reasoning_parser.py
@@ -90,7 +90,7 @@ MODEL_TO_TAG_STYLE: dict[str, CohereTagStyle] = {
tools=COMMAND_A_TOOLS_TAG,
),
"Cohere2MoeForCausalLM": CohereTagStyle(
- json_tags=(COMMAND_A_JSON_TAG,),
+ json_tags=(COMMAND_A_JSON_TAG, COMMAND_A_PLUS_JSON_TAG),
tools=COMMAND_A_TOOLS_TAG,
),
}
diff --git a/vllm/transformers_utils/processors/minicpmv.py b/vllm/transformers_utils/processors/minicpmv.py
index cc0dee8dacd..03649234eab 100644
--- a/vllm/transformers_utils/processors/minicpmv.py
+++ b/vllm/transformers_utils/processors/minicpmv.py
@@ -72,8 +72,8 @@ class MiniCPMVProcessor(ProcessorMixin):
) -> MiniCPMVBatchFeature:
"""Run the vendored MiniCPMV processor on a (text, images) pair.
- Only single-sample input is currently supported; batched input is
- coming soon. ``images`` is forwarded to the underlying image
+ Batched inputs are supported following the upstream MiniCPM-V
+ processor flow. ``images`` is forwarded to the underlying image
processor and ``text`` is tokenized with image placeholders
replaced by the appropriate slice tokens. Returns a
``MiniCPMVBatchFeature`` with at minimum ``input_ids`` and (when
@@ -194,7 +194,7 @@ class MiniCPMVProcessor(ProcessorMixin):
image_end_tokens.unsqueeze(-1),
]
)
- return input_ids.unsqueeze(0), image_bounds
+ return input_ids, image_bounds
def _convert_images_texts_to_inputs(
self,
@@ -220,23 +220,41 @@ class MiniCPMVProcessor(ProcessorMixin):
image_sizes = images["image_sizes"]
tgt_sizes = images["tgt_sizes"]
- image_tags = regex.findall(pattern, texts)
- assert len(image_tags) == len(image_sizes[0])
- text_chunks = texts.split(pattern)
- final_texts = ""
- for i in range(len(image_tags)):
- placeholder = self.image_processor.get_slice_image_placeholder(
- image_sizes[0][i]
- )
- final_texts = final_texts + text_chunks[i] + placeholder
- final_texts += text_chunks[-1]
- input_ids, image_bounds = self._convert(final_texts, max_length)
+ if isinstance(texts, str):
+ texts = [texts]
+
+ input_ids_list = []
+ image_bounds_list = []
+
+ for index, text in enumerate(texts):
+ image_tags = regex.findall(pattern, text)
+ assert len(image_tags) == len(image_sizes[index])
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ placeholder = self.image_processor.get_slice_image_placeholder(
+ image_sizes[index][i]
+ )
+ final_text = final_text + text_chunks[i] + placeholder
+ final_text += text_chunks[-1]
+ input_ids, image_bounds = self._convert(final_text, max_length)
+ input_ids_list.append(input_ids)
+ image_bounds_list.append(image_bounds)
+
+ padded_input_ids, padding_lengths = self.pad(
+ input_ids_list,
+ padding_side="left",
+ )
+ for i, length in enumerate(padding_lengths):
+ image_bounds_list[i] = image_bounds_list[i] + length
+
return MiniCPMVBatchFeature(
data={
- "input_ids": input_ids,
+ "input_ids": padded_input_ids,
+ "attention_mask": padded_input_ids.ne(0),
"pixel_values": images_val,
"image_sizes": image_sizes,
- "image_bound": [image_bounds],
+ "image_bound": image_bounds_list,
"tgt_sizes": tgt_sizes,
}
)
@@ -249,42 +267,36 @@ class MiniCPMVProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
- def pad(
- self,
- orig_items,
- key,
- max_length=None,
- padding_value=0,
- padding_side="left",
- ):
- if not orig_items:
- return torch.empty(0)
+ # Copied from openbmb/MiniCPM-V-4_5 processing_minicpmv.py.
+ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
+ if not inputs:
+ return torch.empty(0), []
items = []
- if isinstance(orig_items[0][key], list):
- assert isinstance(orig_items[0][key][0], torch.Tensor)
- for it in orig_items:
- for tr in it[key]:
- items.append({key: tr})
+ if isinstance(inputs[0], list):
+ assert isinstance(inputs[0][0], torch.Tensor)
+ for it in inputs:
+ for tr in it:
+ items.append(tr)
else:
- assert isinstance(orig_items[0][key], torch.Tensor)
- items = orig_items
+ assert isinstance(inputs[0], torch.Tensor)
+ items = inputs
batch_size = len(items)
- shape = items[0][key].shape
+ shape = items[0].shape
dim = len(shape)
- assert dim <= 3
+ assert dim <= 2
if max_length is None:
max_length = 0
- max_length = max(max_length, max(item[key].shape[-1] for item in items))
- min_length = min(item[key].shape[-1] for item in items)
- dtype = items[0][key].dtype
+ max_length = max(max_length, max(item.shape[-1] for item in items))
+ min_length = min(item.shape[-1] for item in items)
+ dtype = items[0].dtype
- if dim == 1:
- return torch.cat([item[key] for item in items], dim=0)
- elif dim == 2:
+ if dim == 0:
+ return torch.stack([item for item in items], dim=0), [0]
+ elif dim == 1:
if max_length == min_length:
- return torch.cat([item[key] for item in items], dim=0)
+ return torch.stack([item for item in items], dim=0), [0] * batch_size
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = (
@@ -292,23 +304,18 @@ class MiniCPMVProcessor(ProcessorMixin):
+ padding_value
)
+ padding_lengths = []
for i, item in enumerate(items):
- tensor_to_pad = item[key]
- if tensor_to_pad.shape[0] != 1:
- raise ValueError(
- f"Expected leading batch size of 1 for padding, "
- f"but got shape {tensor_to_pad.shape}"
- )
- squeezed = tensor_to_pad.squeeze(0)
- if dim == 2:
+ if dim == 1:
if padding_side == "left":
- tensor[i, -squeezed.shape[0] :] = squeezed.clone()
+ tensor[i, -len(item) :] = item.clone()
else:
- tensor[i, : squeezed.shape[0]] = squeezed.clone()
- elif dim == 3:
+ tensor[i, : len(item)] = item.clone()
+ elif dim == 2:
if padding_side == "left":
- tensor[i, -squeezed.shape[0] :, :] = squeezed.clone()
+ tensor[i, -len(item) :, :] = item.clone()
else:
- tensor[i, : squeezed.shape[0], :] = squeezed.clone()
+ tensor[i, : len(item), :] = item.clone()
+ padding_lengths.append(tensor.shape[-1] - len(item))
- return tensor
+ return tensor, padding_lengths
diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py
index 04def3e3769..cd215421a98 100644
--- a/vllm/transformers_utils/utils.py
+++ b/vllm/transformers_utils/utils.py
@@ -84,8 +84,11 @@ def maybe_model_redirect(model: str) -> str:
"""
Use model_redirect to redirect the model name to a local folder.
- :param model: hf model name
- :return: maybe redirect to a local folder
+ Args:
+ model: hf model name
+
+ Returns:
+ maybe redirect to a local folder
"""
model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH
diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py
index af58bfd31a5..4b4a4435b31 100644
--- a/vllm/v1/attention/backend.py
+++ b/vllm/v1/attention/backend.py
@@ -808,14 +808,17 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
- def fused_output_quant_supported(self, quant_key: "QuantKey"):
+ def fused_output_quant_supported(self, quant_key: "QuantKey") -> bool:
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
- :param quant_key: QuantKey object that describes the quantization op
- :return: is fusion supported for this type of quantization
+ Args:
+ quant_key: QuantKey object that describes the quantization op
+
+ Returns:
+ is fusion supported for this type of quantization
"""
return False
diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py
index 9140a6fccd5..e3173949bf1 100644
--- a/vllm/v1/attention/backends/mla/flashmla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py
@@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
fp8_use_mixed_batch = (
self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
)
- # DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
+ # DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not
# consume fp8_extra_metadata. Skipping the build here avoids a
# forced D2H sync on seq_lens that would otherwise fire on every
# prefill-bearing step, lifting GPU utilization on long-prefill
diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
index 332350d8380..12fd3a17421 100644
--- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
@@ -8,6 +8,7 @@ from importlib.util import find_spec
import torch
import torch.nn.functional as F
+import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
@@ -15,6 +16,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import LayerNameType
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
+from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_rocm():
from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950
@@ -408,8 +410,8 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module = None
# if rocm_aiter_ops.is_enabled():
- batch_size, next_n, heads, head_dim = q_fp8.shape
- num_blocks, block_size, _, _ = kv_cache_fp8.shape
+ batch_size, next_n = q_fp8.shape[:2]
+ block_size = kv_cache_fp8.shape[1]
if rocm_aiter_ops.is_enabled():
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
@@ -420,12 +422,10 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits
)
batch_size, next_n, heads, _ = q_fp8.shape
- out_logits = torch.full(
- [batch_size * next_n, max_model_len],
- float("-inf"),
- device="cuda",
- dtype=torch.float32,
+ (out_logits,) = current_workspace_manager().get_simultaneous(
+ ((batch_size * next_n, max_model_len), torch.float32),
)
+ out_logits.fill_(float("-inf"))
deepgemm_fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
@@ -444,12 +444,10 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size, next_n, heads, _ = q_fp8.shape
- out_qk = torch.full(
- (heads, batch_size * next_n, max_model_len),
- float("-inf"),
- device="cuda",
- dtype=torch.float32,
+ (out_qk,) = current_workspace_manager().get_simultaneous(
+ ((heads, batch_size * next_n, max_model_len), torch.float32),
)
+ out_qk.fill_(float("-inf"))
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
@@ -647,6 +645,43 @@ def rocm_aiter_sparse_attn_indexer(
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
+ # Profiling early-exit: reserve memory to account for runtime
+ # allocations. Must be in the real impl, not the fake impl —
+ # torch.compile calls the fake impl under FakeTensor mode where
+ # workspace manager operations on the locked real workspace
+ # would corrupt PyTorch's dispatch state.
+ workspace_manager = current_workspace_manager()
+
+ # Prefill k_fp8 and k_scale buffers, used by
+ # rocm_aiter_sparse_attn_indexer's prefill path
+ workspace_manager.get_simultaneous(
+ ((total_seq_lens, head_dim), fp8_dtype),
+ ((total_seq_lens, 4), torch.uint8),
+ )
+
+ # Decode logits buffer, used by rocm_fp8_paged_mqa_logits.
+ # batch_size * next_n <= hidden_states.shape[0] == max_num_batched_tokens
+ if _ON_GFX942 or _ON_GFX950:
+ workspace_manager.get_simultaneous(
+ ((hidden_states.shape[0], max_model_len), torch.float32),
+ )
+ else:
+ workspace_manager.get_simultaneous(
+ (
+ (q_fp8.shape[1], hidden_states.shape[0], max_model_len),
+ torch.float32,
+ ),
+ )
+ # Transient logits tensor peak memory, produced by
+ # rocm_fp8_mqa_logits (prefill) and rocm_fp8_paged_mqa_logits
+ # (decode). Prefill logits are bounded by
+ # VLLM_SPARSE_INDEXER_MAX_LOGITS_MB via chunking in
+ # split_indexer_prefill_chunks; decode logits are smaller.
+ max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
+ _ = torch.empty(
+ max_logits_elems, dtype=torch.uint8, device=hidden_states.device
+ )
+
return rocm_aiter_sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
@@ -671,7 +706,6 @@ def rocm_aiter_sparse_attn_indexer(
has_decode = layer_attn_metadata.num_decodes > 0
has_prefill = layer_attn_metadata.num_prefills > 0
num_decode_tokens = layer_attn_metadata.num_decode_tokens
- device = hidden_states.device if k is None else k.device
# during speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens.
@@ -703,17 +737,15 @@ def rocm_aiter_sparse_attn_indexer(
if has_prefill:
prefill_metadata = layer_attn_metadata.prefill
assert prefill_metadata is not None
+
+ workspace_manager = current_workspace_manager()
+ k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
+ ((total_seq_lens, head_dim), fp8_dtype),
+ ((total_seq_lens, 4), torch.uint8),
+ )
for chunk in prefill_metadata.chunks:
- k_fp8 = torch.empty(
- [chunk.total_seq_lens, head_dim],
- device=device,
- dtype=fp8_dtype,
- )
- k_scale = torch.empty(
- [chunk.total_seq_lens, 4],
- device=device,
- dtype=torch.uint8,
- )
+ k_fp8 = k_fp8_full[: chunk.total_seq_lens]
+ k_scale = k_scale_full[: chunk.total_seq_lens]
if _ON_GFX942:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
@@ -731,7 +763,6 @@ def rocm_aiter_sparse_attn_indexer(
chunk.cu_seq_lens,
token_to_seq=chunk.token_to_seq,
)
-
logits = rocm_fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 897d063e830..8384def4664 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -2056,13 +2056,17 @@ class Scheduler(SchedulerInterface):
return spec_decoding_stats
def shutdown(self) -> None:
+ logger.debug_once("[shutdown] Scheduler: start")
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()
if self.connector is not None:
self.connector.shutdown()
+
if self.ec_connector is not None:
self.ec_connector.shutdown()
+ logger.debug_once("[shutdown] Scheduler: complete")
+
########################################################################
# KV Connector Related Methods
########################################################################
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index b12aa9d0505..08c814ab34e 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -177,13 +177,13 @@ class EngineCore:
if xfer_handshake_metadata:
# xfer_handshake_metadata is list of dicts from workers
- # Each dict already has structure {tp_rank: metadata}
+ # Each dict already has structure {(pp_rank, tp_rank): metadata}
# Merge all worker dicts into a single dict
- content: dict[int, Any] = {}
+ content: dict[tuple[int, int], Any] = {}
for worker_dict in xfer_handshake_metadata:
if worker_dict is not None:
content.update(worker_dict)
- kv_connector.set_xfer_handshake_metadata(content)
+ kv_connector.set_xfer_handshake_metadata_pp_aware(content)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
@@ -608,6 +608,7 @@ class EngineCore:
self.abort_requests(request_ids)
def shutdown(self):
+ logger.debug_once("[shutdown] EngineCore: tearing down local resources")
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
@@ -622,6 +623,7 @@ class EngineCore:
# Tear down distributed state initialized in this EngineCore process
# before it exits and release cached memory.
cleanup_dist_env_and_memory()
+ logger.debug_once("[shutdown] EngineCore: local resource teardown complete")
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
self.model_executor.profile(is_start, profile_prefix)
@@ -1172,6 +1174,11 @@ class EngineCoreProc(EngineCore):
signal_callback = SignalCallback(wakeup_engine)
def signal_handler(signum, frame):
+ signal_name = signal.Signals(signum).name
+ logger.info(
+ "[shutdown] EngineCore: trigger received signal=%s",
+ signal_name,
+ )
engine_core.shutdown_state = EngineShutdownState.REQUESTED
signal_callback.trigger()
@@ -1181,7 +1188,7 @@ class EngineCoreProc(EngineCore):
engine_core.run_busy_loop()
except SystemExit:
- logger.debug("EngineCore exiting.")
+ logger.info_once("[shutdown] EngineCore: exiting busy loop")
raise
except Exception as e:
if engine_core is None:
@@ -1285,13 +1292,21 @@ class EngineCoreProc(EngineCore):
if self.shutdown_state == EngineShutdownState.REQUESTED:
shutdown_timeout = self.vllm_config.shutdown_timeout
+ mode = "abort" if shutdown_timeout == 0 else "drain"
- logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout)
+ logger.info(
+ "[shutdown] EngineCore: start mode=%s timeout=%ds",
+ mode,
+ shutdown_timeout,
+ )
if shutdown_timeout == 0:
num_requests = self.scheduler.get_num_unfinished_requests()
if num_requests > 0:
- logger.info("Aborting %d requests", num_requests)
+ logger.info(
+ "[shutdown] EngineCore: aborting in-flight requests count=%d",
+ num_requests,
+ )
aborted_reqs = self.scheduler.finish_requests(
None, RequestStatus.FINISHED_ABORTED
)
@@ -1300,7 +1315,8 @@ class EngineCoreProc(EngineCore):
num_requests = self.scheduler.get_num_unfinished_requests()
if num_requests > 0:
logger.info(
- "Draining %d in-flight requests (timeout=%ds)",
+ "[shutdown] EngineCore: draining in-flight requests "
+ "count=%d timeout=%ds",
num_requests,
shutdown_timeout,
)
@@ -1309,7 +1325,10 @@ class EngineCoreProc(EngineCore):
# Exit when no work remaining
if not self.has_work():
- logger.info("Shutdown complete")
+ logger.info(
+ "[shutdown] EngineCore: request processing complete; "
+ "starting resource teardown"
+ )
return False
return True
@@ -1353,7 +1372,10 @@ class EngineCoreProc(EngineCore):
if self.shutdown_state == EngineShutdownState.RUNNING:
return False
- logger.info("Rejecting request %s (server shutting down)", request.request_id)
+ logger.debug(
+ "[shutdown] EngineCore: rejecting new request request_id=%s",
+ request.request_id,
+ )
self._send_abort_outputs_to_client([request.request_id], request.client_index)
return True
@@ -1363,7 +1385,10 @@ class EngineCoreProc(EngineCore):
if self.shutdown_state == EngineShutdownState.RUNNING:
return False
- logger.warning("Rejecting utility call %s (server shutting down)", method_name)
+ logger.warning(
+ "[shutdown] EngineCore: rejecting utility call method=%s",
+ method_name,
+ )
output = UtilityOutput(call_id, failure_message="Server shutting down")
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output))
diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py
index 14257b020ee..32f2d091eb3 100644
--- a/vllm/v1/engine/core_client.py
+++ b/vllm/v1/engine/core_client.py
@@ -391,6 +391,7 @@ class BackgroundResources:
def __call__(self):
"""Clean up background resources."""
+ logger.debug_once("[shutdown] MPClient: background resource cleanup start")
self.engine_dead = True
if self.engine_manager is not None:
self.engine_manager.shutdown()
@@ -445,6 +446,8 @@ class BackgroundResources:
# Send shutdown signal.
shutdown_sender.send(b"")
+ logger.debug_once("[shutdown] MPClient: background resource cleanup complete")
+
def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD):
self.engine_dead = True
@@ -645,9 +648,15 @@ class MPClient(EngineCoreClient):
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine manager under timeout and clean up resources."""
if self._finalizer.detach() is not None:
+ timeout_str = "default" if timeout is None else f"{timeout}s"
+ logger.info("[shutdown] MPClient: start timeout=%s", timeout_str)
if self.resources.engine_manager is not None:
+ logger.info_once("[shutdown] MPClient: stopping engine manager")
self.resources.engine_manager.shutdown(timeout=timeout)
+ logger.info_once("[shutdown] MPClient: engine manager stopped")
+ logger.info_once("[shutdown] MPClient: cleaning up background resources")
self.resources()
+ logger.info_once("[shutdown] MPClient: complete")
def _format_exception(self, e: Exception) -> Exception:
"""If errored, use EngineDeadError so root cause is clear."""
@@ -687,6 +696,9 @@ class MPClient(EngineCoreClient):
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
return
_self.resources.engine_dead = True
+ logger.warning_once(
+ "[shutdown] MPClient: engine core exited unexpectedly; starting cleanup"
+ )
_self.shutdown()
# Note: For MPClient, we don't have a failure callback mechanism
# like MultiprocExecutor, but we set engine_dead flag which will
diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py
index 7beef598e27..4063844d469 100644
--- a/vllm/v1/executor/abstract.py
+++ b/vllm/v1/executor/abstract.py
@@ -203,7 +203,7 @@ class Executor(ABC):
def get_kv_connector_handshake_metadata(
self,
- ) -> list[dict[int, KVConnectorHandshakeMetadata]]:
+ ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata]]:
return self.collective_rpc("get_kv_connector_handshake_metadata")
@overload
diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
index c5766c923c8..66564bebdb6 100644
--- a/vllm/v1/executor/multiproc_executor.py
+++ b/vllm/v1/executor/multiproc_executor.py
@@ -422,27 +422,45 @@ class MultiprocExecutor(Executor):
return False
active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()]
+ initial_count = len(active_procs())
+
# Give processes time to clean themselves up properly first
- logger.debug("Worker Termination: allow workers to gracefully shutdown")
+ logger.info(
+ "[shutdown] Executor: waiting for worker exit count=%d",
+ initial_count,
+ )
if wait_for_termination(active_procs(), 4):
+ logger.info_once("[shutdown] Executor: all workers exited gracefully")
return
# Send SIGTERM if still running
- logger.debug("Worker Termination: workers still running sending SIGTERM")
- for p in active_procs():
+ remaining = active_procs()
+ logger.warning(
+ "[shutdown] Executor: workers still running after grace period; "
+ "sending SIGTERM count=%d",
+ len(remaining),
+ )
+ for p in remaining:
p.terminate()
if not wait_for_termination(active_procs(), 4):
# Send SIGKILL if still running
- logger.debug(
- "Worker Termination: resorting to SIGKILL to take down workers"
+ remaining = active_procs()
+ logger.warning(
+ "[shutdown] Executor: workers still running after SIGTERM; "
+ "sending SIGKILL count=%d",
+ len(remaining),
)
- for p in active_procs():
+ for p in remaining:
p.kill()
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, "shutting_down", False):
- logger.debug("Triggering shutdown of workers")
+ worker_count = len(getattr(self, "workers", None) or [])
+ logger.debug(
+ "[shutdown] Executor: start worker_count=%d",
+ worker_count,
+ )
self.shutting_down = True
# Make sure all the worker processes are terminated first.
@@ -468,6 +486,8 @@ class MultiprocExecutor(Executor):
mq.shutdown()
self.response_mqs = []
+ logger.debug_once("[shutdown] Executor: complete")
+
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@@ -867,7 +887,9 @@ class WorkerProc:
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
elif shutdown_requested.is_set():
- logger.info("WorkerProc shutting down.")
+ logger.debug_once(
+ "[shutdown] WorkerProc: exiting after shutdown request"
+ )
else:
logger.exception("WorkerProc failed.")
@@ -879,7 +901,12 @@ class WorkerProc:
except SystemExit as e:
# SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that
# the graceful shutdown process did not succeed
- logger.warning("WorkerProc was terminated")
+ if shutdown_requested.is_set():
+ logger.debug_once(
+ "[shutdown] WorkerProc: terminated by shutdown signal"
+ )
+ else:
+ logger.warning("WorkerProc was terminated")
# SystemExit must never be ignored
raise e
diff --git a/vllm/v1/kv_offload/tiering/factory.py b/vllm/v1/kv_offload/tiering/factory.py
index cbde45dfcf8..be703a03b3d 100644
--- a/vllm/v1/kv_offload/tiering/factory.py
+++ b/vllm/v1/kv_offload/tiering/factory.py
@@ -63,3 +63,9 @@ SecondaryTierFactory.register_tier(
"vllm.v1.kv_offload.tiering.fs.manager",
"FileSystemTierManager",
)
+
+SecondaryTierFactory.register_tier(
+ "obj",
+ "vllm.v1.kv_offload.tiering.obj.manager",
+ "ObjectStoreSecondaryTierManager",
+)
diff --git a/vllm/v1/kv_offload/tiering/obj/__init__.py b/vllm/v1/kv_offload/tiering/obj/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/vllm/v1/kv_offload/tiering/obj/config.py b/vllm/v1/kv_offload/tiering/obj/config.py
new file mode 100644
index 00000000000..5507c6a198e
--- /dev/null
+++ b/vllm/v1/kv_offload/tiering/obj/config.py
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Connection configuration for the object store secondary tier."""
+
+from dataclasses import dataclass
+
+
+@dataclass
+class ObjStoreConfig:
+ """Connection parameters for an object store backend."""
+
+ bucket: str
+ endpoint_override: str
+ access_key: str
+ secret_key: str
+ scheme: str = "http"
+ ca_bundle: str = ""
+
+ def to_nixl_params(self) -> dict[str, str]:
+ """Build the NIXL backend params dict."""
+ params: dict[str, str] = {
+ "bucket": self.bucket,
+ "endpoint_override": self.endpoint_override,
+ "scheme": self.scheme,
+ "access_key": self.access_key,
+ "secret_key": self.secret_key,
+ }
+ if self.ca_bundle:
+ params["ca_bundle"] = self.ca_bundle
+ return params
diff --git a/vllm/v1/kv_offload/tiering/obj/manager.py b/vllm/v1/kv_offload/tiering/obj/manager.py
new file mode 100644
index 00000000000..2d7280ae379
--- /dev/null
+++ b/vllm/v1/kv_offload/tiering/obj/manager.py
@@ -0,0 +1,256 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Object store secondary tier implementation."""
+
+import ctypes
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, NamedTuple
+
+from vllm.distributed.nixl_utils import NixlWrapper as nixl_agent
+from vllm.distributed.nixl_utils import nixl_agent_config
+from vllm.logger import init_logger
+from vllm.v1.kv_offload.base import OffloadKey, ReqContext
+from vllm.v1.kv_offload.file_mapper import FileMapper
+from vllm.v1.kv_offload.tiering.base import (
+ JobMetadata,
+ JobResult,
+ RequestOffloadingContext,
+ SecondaryTierManager,
+)
+from vllm.v1.kv_offload.tiering.obj.config import ObjStoreConfig
+
+if TYPE_CHECKING:
+ from nixl._api import nixl_prepped_dlist_handle, nixl_xfer_handle
+
+ from vllm.v1.kv_offload.base import OffloadingSpec
+
+logger = init_logger(__name__)
+
+NIXL_WRITE = "WRITE"
+NIXL_READ = "READ"
+NIXL_PROC = "PROC"
+NIXL_DONE = "DONE"
+
+# Device ID for CPU DRAM descriptors. DRAM is not a multi-device resource so
+# the device ID is always 0.
+NIXL_DEV_ID: int = 0
+
+# Fields for NIXL OBJ descriptors: (addr, len, dev_id, obj_key).
+# For existence probes addr and len are placeholders — no data is read.
+# dev_id=0 is reserved for probes; transfers start from 1.
+_PROBE_ADDR: int = 0
+_PROBE_LEN: int = 1
+_PROBE_DEV_ID: int = 0
+
+
+class TransferEntry(NamedTuple):
+ xfer_handle: "nixl_xfer_handle"
+ files_desc: object
+ obj_handle: "nixl_prepped_dlist_handle"
+
+
+class ObjectStoreSecondaryTierManager(SecondaryTierManager):
+ """Secondary tier that offloads KV cache blocks to an S3-compatible store.
+
+ Handles CPU DRAM <-> S3 transfers only. GPU <-> CPU is managed by the
+ primary tier. Object keys are formed as ``{prefix}/{hash_shard}/{hash}.bin``.
+ """
+
+ def __init__(
+ self,
+ offloading_spec: "OffloadingSpec",
+ primary_kv_view: memoryview,
+ tier_type: str,
+ store_config: dict,
+ prefix: str = "",
+ io_threads: int = 4,
+ ):
+ super().__init__(offloading_spec, primary_kv_view, tier_type)
+ agent_config = nixl_agent_config(backends=[])
+ self._agent = nixl_agent("ObjAgent", agent_config)
+ obj_config = ObjStoreConfig(**store_config)
+ params = {**obj_config.to_nixl_params(), "num_threads": str(io_threads)}
+ self._agent.create_backend("OBJ", params)
+ self._transfers: dict[int, TransferEntry] = {}
+ self._failed_jobs: list[JobResult] = []
+ self._primary_reg = None
+ self._block_size_bytes: int = 0
+ root_dir = f"{prefix}/" if prefix else ""
+ self._file_mapper = FileMapper.from_offloading_spec(root_dir, offloading_spec)
+ self._next_obj_dev_id: int = 1 # dev_id=0 is reserved for _exists() probes
+
+ self._probe_connectivity()
+
+ base_addr = ctypes.addressof(ctypes.c_char.from_buffer(primary_kv_view))
+ assert primary_kv_view.strides is not None
+ stride = primary_kv_view.strides[0]
+ self._primary_reg = self._agent.register_memory(
+ [(base_addr, primary_kv_view.nbytes, NIXL_DEV_ID, "")], "DRAM"
+ )
+ self._block_size_bytes = stride
+ all_blocks = [
+ (base_addr + i * stride, stride, NIXL_DEV_ID)
+ for i in range(len(primary_kv_view))
+ ]
+ # NIXL_INIT_AGENT marks this as the local side; make_prepped_xfer requires
+ # local_xfer_side tagged with NIXL_INIT_AGENT and remote_xfer_side tagged
+ # with the peer agent name ("ObjAgent").
+ self._dram_prepped_handle: nixl_prepped_dlist_handle = (
+ self._agent.prep_xfer_dlist("NIXL_INIT_AGENT", all_blocks, "DRAM")
+ )
+
+ def _probe_connectivity(self) -> None:
+ """Verify object store connectivity at startup via a NIXL lookup probe.
+
+ Performs a single exists() check against a synthetic key that will
+ never exist. A True/False result confirms the bucket is reachable;
+ an exception indicates misconfigured obj store params and raises RuntimeError.
+ """
+ probe_key = "__nixl_probe__/connectivity_test"
+ try:
+ self._exists(probe_key)
+ logger.info("Object store tier connectivity probe succeeded")
+ except Exception as e:
+ raise RuntimeError(
+ f"Object store tier connectivity probe failed — check bucket, "
+ f"endpoint_override, access_key, secret_key, and scheme. "
+ f"Error: {e}"
+ ) from e
+
+ def _exists(self, obj_key: str) -> bool:
+ results = self._agent.query_memory(
+ [(_PROBE_ADDR, _PROBE_LEN, _PROBE_DEV_ID, obj_key)], "OBJ", "OBJ"
+ )
+ return results[0] is not None
+
+ def _submit_transfer(
+ self,
+ job_id: int,
+ block_ids: Iterable[int],
+ obj_keys: Iterable[str],
+ op: str,
+ ) -> None:
+ """Submit an async transfer. op is 'WRITE' (store) or 'READ' (load)."""
+ block_ids_list = [int(bid) for bid in block_ids]
+ # The OBJ backend maps devId -> obj_key. All descriptors must have
+ # unique devIds or later registrations overwrite earlier ones.
+ nixl_files = [
+ (0, self._block_size_bytes, dev_id, key)
+ for dev_id, key in enumerate(obj_keys, self._next_obj_dev_id)
+ ]
+ self._next_obj_dev_id += len(nixl_files)
+
+ files_desc = self._agent.register_memory(nixl_files, "OBJ")
+ if files_desc is None:
+ logger.warning("register_memory (OBJ) failed for job %d", job_id)
+ self._failed_jobs.append(JobResult(job_id=job_id, success=False))
+ return
+
+ obj_handle = self._agent.prep_xfer_dlist("ObjAgent", files_desc.trim())
+ if not obj_handle:
+ logger.warning("prep_xfer_dlist (OBJ) failed for job %d", job_id)
+ self._agent.deregister_memory(files_desc)
+ self._failed_jobs.append(JobResult(job_id=job_id, success=False))
+ return
+
+ xfer_handle = self._agent.make_prepped_xfer(
+ op,
+ self._dram_prepped_handle,
+ block_ids_list,
+ obj_handle,
+ list(range(len(nixl_files))),
+ )
+ if not xfer_handle:
+ logger.warning("make_prepped_xfer failed for job %d", job_id)
+ self._agent.release_dlist_handle(obj_handle)
+ self._agent.deregister_memory(files_desc)
+ self._failed_jobs.append(JobResult(job_id=job_id, success=False))
+ return
+
+ state = self._agent.transfer(xfer_handle)
+ if state == "ERR":
+ logger.warning("agent.transfer failed for job %d", job_id)
+ self._agent.release_dlist_handle(obj_handle)
+ self._agent.deregister_memory(files_desc)
+ self._agent.release_xfer_handle(xfer_handle)
+ self._failed_jobs.append(JobResult(job_id=job_id, success=False))
+ return
+
+ self._transfers[job_id] = TransferEntry(xfer_handle, files_desc, obj_handle)
+
+ def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
+ try:
+ return self._exists(self._file_mapper.get_file_name(key))
+ except Exception as e:
+ logger.warning("lookup failed for key %s: %s", key, e)
+ return False
+
+ def submit_store(self, job_metadata: JobMetadata) -> None:
+ obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys)
+ self._submit_transfer(
+ job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_WRITE
+ )
+
+ def submit_load(self, job_metadata: JobMetadata) -> None:
+ obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys)
+ self._submit_transfer(
+ job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_READ
+ )
+
+ def on_new_request(self, req_context: ReqContext) -> RequestOffloadingContext:
+ return RequestOffloadingContext()
+
+ def get_finished_jobs(self) -> Iterable[JobResult]:
+ """Poll in-flight transfers; return completed (job_id, success) pairs."""
+ results: list[JobResult] = self._failed_jobs
+ self._failed_jobs = []
+ for job_id, entry in list(self._transfers.items()):
+ try:
+ state = self._agent.check_xfer_state(entry.xfer_handle)
+ except Exception as exc:
+ success = False
+ logger.warning("check_xfer_state raised for job %d: %s", job_id, exc)
+ else:
+ if state == NIXL_PROC:
+ continue
+ elif state == NIXL_DONE:
+ success = True
+ else:
+ success = False
+ logger.warning("transfer failed job=%d state=%s", job_id, state)
+ del self._transfers[job_id]
+ self._agent.release_xfer_handle(entry.xfer_handle)
+ self._agent.release_dlist_handle(entry.obj_handle)
+ self._agent.deregister_memory(entry.files_desc)
+ results.append(JobResult(job_id=job_id, success=success))
+ return results
+
+ def shutdown(self) -> None:
+ for job_id, entry in self._transfers.items():
+ try:
+ self._agent.release_xfer_handle(entry.xfer_handle)
+ except Exception as exc:
+ logger.warning("release_xfer_handle failed for job %d: %s", job_id, exc)
+ try:
+ self._agent.release_dlist_handle(entry.obj_handle)
+ except Exception as exc:
+ logger.warning(
+ "release_dlist_handle failed for job %d: %s", job_id, exc
+ )
+ try:
+ self._agent.deregister_memory(entry.files_desc)
+ except Exception as exc:
+ logger.warning("deregister_memory failed for job %d: %s", job_id, exc)
+ self._transfers.clear()
+ if self._dram_prepped_handle is not None:
+ try:
+ self._agent.release_dlist_handle(self._dram_prepped_handle)
+ except Exception as exc:
+ logger.warning("failed to release DRAM prepped handle: %s", exc)
+ self._dram_prepped_handle = None
+ if self._primary_reg is not None:
+ try:
+ self._agent.deregister_memory(self._primary_reg)
+ except Exception as exc:
+ logger.warning("failed to deregister primary buffer: %s", exc)
+ self._primary_reg = None
diff --git a/vllm/v1/kv_offload/tiering/spec.py b/vllm/v1/kv_offload/tiering/spec.py
index a4ea46e08eb..f223d81aa5e 100644
--- a/vllm/v1/kv_offload/tiering/spec.py
+++ b/vllm/v1/kv_offload/tiering/spec.py
@@ -131,8 +131,8 @@ class TieringOffloadingSpec(CPUOffloadingSpec):
)
except Exception as e:
logger.error(
- "Failed to create secondary tier from config %s: %s",
- tier_config,
+ "Failed to create secondary tier from config index %i: %s",
+ i,
e,
)
raise
diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py
index baa0e77119b..6f324ce9850 100644
--- a/vllm/v1/sample/ops/topk_topp_sampler.py
+++ b/vllm/v1/sample/ops/topk_topp_sampler.py
@@ -75,9 +75,14 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place.
"""
- def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
+ def __init__(
+ self,
+ logprobs_mode: LogprobsMode = "raw_logprobs",
+ use_fp64_gumbel: bool = False,
+ ) -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
+ self.use_fp64_gumbel = use_fp64_gumbel
if current_platform.is_cuda():
# FlashInfer doesn't expose post-top-k/top-p logits/logprobs,
# so it can't be used when the configured mode requires them.
@@ -142,7 +147,10 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
- return random_sample(probs, generators), logits_to_return
+ return (
+ random_sample(probs, generators, self.use_fp64_gumbel),
+ logits_to_return,
+ )
def forward_cuda(
self,
@@ -163,6 +171,8 @@ class TopKTopPSampler(nn.Module):
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
+ if self.use_fp64_gumbel:
+ return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
"FlashInfer does not support returning logits/logprobs"
)
@@ -190,16 +200,16 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
- if len(generators) != logits.shape[0]:
+ if len(generators) != logits.shape[0] and not self.use_fp64_gumbel:
return compiled_random_sample(logits), logits_to_return
probs = logits.softmax(dim=-1, dtype=torch.float32)
- q = torch.empty_like(probs)
+ q = empty_exponential_noise_like(probs, self.use_fp64_gumbel)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
- return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
+ return sample_with_exponential_noise(probs, q), logits_to_return
def forward_hip(
self,
@@ -216,6 +226,8 @@ class TopKTopPSampler(nn.Module):
"falling back to PyTorch-native."
)
return self.forward_native(logits, generators, k, p)
+ if self.use_fp64_gumbel:
+ return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
@@ -404,16 +416,33 @@ def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
return logits.masked_fill_(logits < top_k_mask, -float("inf"))
+def empty_exponential_noise_like(
+ probs: torch.Tensor, use_fp64_gumbel: bool
+) -> torch.Tensor:
+ dtype = torch.float64 if use_fp64_gumbel else probs.dtype
+ return torch.empty(probs.shape, dtype=dtype, device=probs.device)
+
+
+def sample_with_exponential_noise(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
+ if q.dtype == probs.dtype:
+ scores = probs.div_(q)
+ else:
+ scores = q.reciprocal_()
+ scores.mul_(probs)
+ return scores.argmax(dim=-1).view(-1)
+
+
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
+ use_fp64_gumbel: bool = False,
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
- q = torch.empty_like(probs)
+ q = empty_exponential_noise_like(probs, use_fp64_gumbel)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
@@ -425,7 +454,7 @@ def random_sample(
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
- return probs.div_(q).argmax(dim=-1).view(-1)
+ return sample_with_exponential_noise(probs, q)
def flashinfer_sample(
diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py
index 678654cb78a..153677e35fa 100644
--- a/vllm/v1/sample/rejection_sampler.py
+++ b/vllm/v1/sample/rejection_sampler.py
@@ -65,6 +65,7 @@ class RejectionSampler(nn.Module):
):
super().__init__()
self.sampler = sampler
+ self.use_fp64_gumbel = getattr(sampler, "use_fp64_gumbel", False)
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
@@ -176,6 +177,7 @@ class RejectionSampler(nn.Module):
sampling_metadata,
synthetic_mode=self.synthetic_mode,
synthetic_conditional_rates=self.synthetic_conditional_rates,
+ use_fp64_gumbel=self.use_fp64_gumbel,
)
logprobs_tensors = None
@@ -406,6 +408,7 @@ def rejection_sample(
sampling_metadata: SamplingMetadata,
synthetic_mode: bool = False,
synthetic_conditional_rates: torch.Tensor | None = None,
+ use_fp64_gumbel: bool = False,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
@@ -480,6 +483,7 @@ def rejection_sample(
target_probs,
sampling_metadata,
device,
+ use_fp64_gumbel,
)
# Rejection sampling for random sampling requests.
@@ -669,13 +673,15 @@ def sample_recovered_tokens(
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
+ use_fp64_gumbel: bool = False,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
+ q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
- dtype=torch.float32,
+ dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -699,6 +705,7 @@ def sample_recovered_tokens(
vocab_size,
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
+ USE_FP64_GUMBEL=use_fp64_gumbel,
)
return recovered_token_ids
@@ -861,6 +868,7 @@ def sample_recovered_tokens_kernel(
vocab_size,
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
+ USE_FP64_GUMBEL: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
@@ -877,7 +885,10 @@ def sample_recovered_tokens_kernel(
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
- max_val = float("-inf")
+ if USE_FP64_GUMBEL:
+ max_val = tl.full((), float("-inf"), tl.float64)
+ else:
+ max_val = tl.full((), float("-inf"), tl.float32)
recovered_id = 0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py
index 9ac3821a326..eadc009c254 100644
--- a/vllm/v1/sample/sampler.py
+++ b/vllm/v1/sample/sampler.py
@@ -58,11 +58,16 @@ class Sampler(nn.Module):
9. Return the final `SamplerOutput`.
"""
- def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
+ def __init__(
+ self,
+ logprobs_mode: LogprobsMode = "raw_logprobs",
+ use_fp64_gumbel: bool = False,
+ ):
super().__init__()
- self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
+ self.topk_topp_sampler = TopKTopPSampler(logprobs_mode, use_fp64_gumbel)
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
+ self.use_fp64_gumbel = use_fp64_gumbel
def forward(
self,
diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py
index aa1bf270c1c..38db5483937 100644
--- a/vllm/v1/spec_decode/llm_base_proposer.py
+++ b/vllm/v1/spec_decode/llm_base_proposer.py
@@ -32,6 +32,10 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.ops.topk_topp_sampler import (
+ empty_exponential_noise_like,
+ sample_with_exponential_noise,
+)
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
@@ -113,6 +117,7 @@ class SpecDecodeBaseProposer:
self.use_local_argmax_reduction: bool = (
self.speculative_config.use_local_argmax_reduction
)
+ self.use_fp64_gumbel = vllm_config.model_config.use_fp64_gumbel
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
@@ -409,7 +414,9 @@ class SpecDecodeBaseProposer:
return logits.argmax(dim=-1), None
if sampling_metadata.all_greedy:
return logits.argmax(dim=-1), None
- return compute_probs_and_sample_next_token(logits, sampling_metadata)
+ return compute_probs_and_sample_next_token(
+ logits, sampling_metadata, self.use_fp64_gumbel
+ )
def _sample_draft_tokens(
self,
@@ -1656,6 +1663,7 @@ class SpecDecodeBaseProposer:
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
+ use_fp64_gumbel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
@@ -1682,11 +1690,11 @@ def compute_probs_and_sample_next_token(
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
- q = torch.empty_like(probs)
+ q = empty_exponential_noise_like(probs, use_fp64_gumbel)
q.exponential_()
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
# will be used later for rejection sampling.
- next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
+ next_token_ids = sample_with_exponential_noise(probs.clone(), q)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py
index f11c92a805d..ffdf43f54c3 100644
--- a/vllm/v1/utils.py
+++ b/vllm/v1/utils.py
@@ -444,6 +444,12 @@ def _shutdown_subprocesses(
timeout = 0.0
timeout = max(timeout, 5.0)
+ logger.debug(
+ "[shutdown] Subprocess manager: start process_count=%d timeout=%ss",
+ len(procs),
+ timeout,
+ )
+
for proc in procs:
if proc.is_alive():
proc.terminate()
@@ -456,9 +462,18 @@ def _shutdown_subprocesses(
if proc.is_alive():
proc.join(remaining)
- for proc in procs:
- if proc.is_alive() and (pid := proc.pid) is not None:
- kill_process_tree(pid)
+ remaining_pids = [
+ proc.pid for proc in procs if proc.is_alive() and proc.pid is not None
+ ]
+ if remaining_pids:
+ logger.warning(
+ "[shutdown] Subprocess manager: force killing remaining processes count=%d",
+ len(remaining_pids),
+ )
+ for pid in remaining_pids:
+ kill_process_tree(pid)
+
+ logger.debug_once("[shutdown] Subprocess manager: complete")
def run_api_server_worker_proc(
@@ -565,6 +580,12 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
# have a user-configured shutdown timeout.
timeout = 5.0
+ logger.debug(
+ "[shutdown] Process manager: start process_count=%d timeout=%ss",
+ len(procs),
+ timeout,
+ )
+
# Shutdown the process.
for proc in procs:
if proc.is_alive():
@@ -579,9 +600,18 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
if proc.is_alive():
proc.join(remaining)
- for proc in procs:
- if proc.is_alive() and (pid := proc.pid) is not None:
- kill_process_tree(pid)
+ remaining_pids = [
+ proc.pid for proc in procs if proc.is_alive() and proc.pid is not None
+ ]
+ if remaining_pids:
+ logger.warning(
+ "[shutdown] Process manager: force killing remaining processes count=%d",
+ len(remaining_pids),
+ )
+ for pid in remaining_pids:
+ kill_process_tree(pid)
+
+ logger.debug_once("[shutdown] Process manager: complete")
def copy_slice(
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 261995f4b01..801a8574ac7 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -504,7 +504,10 @@ class GPUModelRunner(
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Sampler
- self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
+ self.sampler = Sampler(
+ logprobs_mode=self.model_config.logprobs_mode,
+ use_fp64_gumbel=self.model_config.use_fp64_gumbel,
+ )
self.eplb_state: EplbState | None = None
self._moe_model: MixtureOfExperts | None = None
@@ -1875,9 +1878,8 @@ class GPUModelRunner(
SpecDecodeMetadata | None,
]:
"""
- :return: tuple[
- logits_indices, spec_decode_metadata,
- ]
+ Returns:
+ tuple[logits_indices, spec_decode_metadata]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -2202,7 +2204,8 @@ class GPUModelRunner(
slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
- :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
+ Returns:
+ tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0:
@@ -2500,9 +2503,11 @@ class GPUModelRunner(
num_common_prefix_blocks: list[int],
) -> list[list[int]] | None:
"""
- :return: Optional[cascade_attn_prefix_lens]
- cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
- None if we should not use cascade attention
+ Returns:
+ Optional[cascade_attn_prefix_lens]
+ cascade_attn_prefix_lens is 2D:
+ ``[kv_cache_group_id][attn_group_idx]``,
+ None if we should not use cascade attention
"""
use_cascade_attn = False
@@ -5321,11 +5326,12 @@ class GPUModelRunner(
"""
Reload weights from a weights iterator or from disk
- :param weights_iterator: weights to load into model
- :param weights_path: path to load weights from if weights_iterator is not
- provided. Use path of original model if neither is provided.
- :param is_checkpoint_format: set to False if weights have already been processed
- into kernel format (repacking, renaming, etc.)
+ Args:
+ weights_iterator: weights to load into model
+ weights_path: path to load weights from if weights_iterator is not
+ provided. Use path of original model if neither is provided.
+ is_checkpoint_format: set to False if weights have already been
+ processed into kernel format (repacking, renaming, etc.)
"""
# TODO(@kylesayrs): generalize to all runners and loaders
# argument validation
@@ -5787,6 +5793,9 @@ class GPUModelRunner(
num_scheduled_tokens, self.query_pos.np
)
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
+ self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1].fill(
+ cum_num_tokens[-1]
+ )
self.query_start_loc.copy_to_gpu()
# Sync block table CPU->GPU so cleared rows from
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 1b30a981e21..bf6b287d8de 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -34,6 +34,9 @@ from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
)
+from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorHandshakeMetadata,
+)
from vllm.distributed.parallel_state import (
Handle,
get_pp_group,
@@ -513,8 +516,13 @@ class Worker(WorkerBase):
return int(self.available_kv_cache_memory_bytes)
- def get_kv_connector_handshake_metadata(self) -> dict | None:
- """Get KV connector metadata from this worker if available."""
+ def get_kv_connector_handshake_metadata(
+ self,
+ ) -> dict[tuple[int, int], KVConnectorHandshakeMetadata] | None:
+ """Get KV connector metadata from this worker if available.
+
+ Returned dict is keyed by `(pp_rank, tp_rank)`.
+ """
if not has_kv_transfer_group():
return None
@@ -525,8 +533,9 @@ class Worker(WorkerBase):
if (metadata := connector.get_handshake_metadata()) is None:
return None
+ pp_rank = get_pp_group().rank_in_group
tp_rank = get_tp_group().rank_in_group
- return {tp_rank: metadata}
+ return {(pp_rank, tp_rank): metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()