diff --git a/.buildkite/hardware_tests/cpu.yaml b/.buildkite/hardware_tests/cpu.yaml index e48fa4869ae..3db49d579e3 100644 --- a/.buildkite/hardware_tests/cpu.yaml +++ b/.buildkite/hardware_tests/cpu.yaml @@ -28,18 +28,19 @@ steps: pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py" -- label: CPU-Compatibility Tests - depends_on: [] - device: intel_cpu - no_plugin: true - source_file_dependencies: - - cmake/cpu_extension.cmake - - setup.py - - vllm/platforms/cpu.py - commands: - - | - bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m " - bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh" +# Note: SDE can't be downloaded from CI host because of AWS WAF +# - label: CPU-Compatibility Tests +# depends_on: [] +# device: intel_cpu +# no_plugin: true +# source_file_dependencies: +# - cmake/cpu_extension.cmake +# - setup.py +# - vllm/platforms/cpu.py +# commands: +# - | +# bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m " +# bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh" - label: CPU-Language Generation and Pooling Model Tests depends_on: [] diff --git a/.buildkite/intel_jobs/test-intel.yaml b/.buildkite/intel_jobs/test-intel.yaml index 805b7e54f12..63ce93c4810 100644 --- a/.buildkite/intel_jobs/test-intel.yaml +++ b/.buildkite/intel_jobs/test-intel.yaml @@ -40,7 +40,8 @@ steps: python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 && python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 && python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel && - python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192 + python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192 && + VLLM_XPU_FUSED_MOE_USE_REF=1 python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --enforce-eager -tp 2 ' - label: "XPU V1 test" depends_on: diff --git a/.github/mergify.yml b/.github/mergify.yml index 6caec515d32..a5d4e609474 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -36,18 +36,6 @@ pull_request_rules: For future commits, `pre-commit` will run automatically on changed files before each commit. - > [!TIP] - >
- > Is mypy failing? - >
- > mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally: - > - > ```bash - > # For mypy (substitute "3.10" with the failing version if needed) - > pre-commit run --hook-stage manual mypy-3.10 - > ``` - >
- - name: comment-dco-failure description: Comment on PR when DCO check fails conditions: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c11a80683f8..dff099e3697 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -148,33 +148,27 @@ repos: language: python entry: python tools/pre_commit/generate_nightly_torch_test.py files: ^requirements/test/cuda\.(in|txt)$ - - id: mypy-local - name: Run mypy locally for lowest supported Python version - entry: python tools/pre_commit/mypy.py 0 "3.10" - stages: [pre-commit] # Don't run in CI + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: python tools/pre_commit/mypy.py "3.10" <<: &mypy_common language: python types_or: [python, pyi] require_serial: true - additional_dependencies: ["mypy[faster-cache]==1.19.1", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.10 - entry: python tools/pre_commit/mypy.py 1 "3.10" - <<: *mypy_common - stages: [manual] # Only run in CI + additional_dependencies: ["mypy==1.20.2", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 - entry: python tools/pre_commit/mypy.py 1 "3.11" + entry: python tools/pre_commit/mypy.py "3.11" <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 - entry: python tools/pre_commit/mypy.py 1 "3.12" + entry: python tools/pre_commit/mypy.py "3.12" <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.13 - entry: python tools/pre_commit/mypy.py 1 "3.13" + entry: python tools/pre_commit/mypy.py "3.13" <<: *mypy_common stages: [manual] # Only run in CI - id: shellcheck diff --git a/AGENTS.md b/AGENTS.md index 6566523f48e..441b8d9fb73 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -98,11 +98,13 @@ pre-commit run --all-files pre-commit run ruff-check --all-files # Run mypy as it is in CI: -pre-commit run mypy-3.10 --all-files --hook-stage manual +pre-commit run mypy-3.12 --all-files --hook-stage manual ``` The line length limit for Python code is 88 characters. If you are not sure, use pre-commit to check. +Use [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) (`Args:`/`Returns:`/`Raises:` sections), not reStructuredText/Sphinx fields (`:param:`, `:return:`, `:rtype:`). + ### Commit messages Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0652a5f066e..cd4a9c0b590 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,10 +307,7 @@ endif() # set(VLLM_EXT_SRC - "csrc/cuda_view.cu" - "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/activation_kernels.cu" - "csrc/cuda_utils_kernels.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() FetchContent_MakeAvailable(cutlass) - list(APPEND VLLM_EXT_SRC - "csrc/cutlass_extensions/common.cpp") - set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") @@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/torch_bindings.cpp" + "csrc/libtorch_stable/cuda_view.cu" "csrc/libtorch_stable/activation_kernels.cu" "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" @@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") "csrc/libtorch_stable/layernorm_kernels.cu" "csrc/libtorch_stable/layernorm_quant_kernels.cu" "csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/libtorch_stable/attention/merge_attn_states.cu" "csrc/libtorch_stable/sampler.cu" "csrc/libtorch_stable/topk.cu" @@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC - "csrc/cuda_utils_kernels.cu" - "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/cuda_utils_kernels.cu" + "csrc/libtorch_stable/cutlass_extensions/common.cpp" "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" @@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) # Set TORCH_TARGET_VERSION for stable ABI compatibility. - # This ensures we only use C-shim APIs available in PyTorch 2.10. + # This ensures we only use C-shim APIs available in PyTorch 2.11. # _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION - # which is currently set to 2.10. + # which is currently set to 2.11. target_compile_definitions(_C_stable_libtorch PRIVATE - TORCH_TARGET_VERSION=0x020A000000000000ULL) + TORCH_TARGET_VERSION=0x020B000000000000ULL) # Needed to use cuda/hip APIs from C-shim if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu deleted file mode 100644 index 73b368cb600..00000000000 --- a/csrc/cuda_view.cu +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include - -// This function assumes that `cpu_tensor` is a CPU tensor, -// and that UVA (Unified Virtual Addressing) is enabled. -torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { - TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); - - // handle empty tensor - if (cpu_tensor.numel() == 0) { - return torch::empty(cpu_tensor.sizes(), - cpu_tensor.options().device(torch::kCUDA)); - } - - if (cpu_tensor.is_pinned()) { - // If CPU tensor is pinned, directly get the device pointer. - void* host_ptr = const_cast(cpu_tensor.data_ptr()); - void* device_ptr = nullptr; - cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); - TORCH_CHECK(err == cudaSuccess, - "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); - - return torch::from_blob( - device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), - [base = cpu_tensor](void*) {}, // keep cpu tensor alive - cpu_tensor.options().device(torch::kCUDA)); - } - - // If CPU tensor is not pinned, allocate a new pinned memory buffer. - torch::Tensor contiguous_cpu = cpu_tensor.contiguous(); - size_t nbytes = contiguous_cpu.nbytes(); - - void* host_ptr = nullptr; - cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped); - if (err != cudaSuccess) { - AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err)); - } - - err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes, - cudaMemcpyDefault); - if (err != cudaSuccess) { - cudaFreeHost(host_ptr); - AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err)); - } - - void* device_ptr = nullptr; - err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); - if (err != cudaSuccess) { - cudaFreeHost(host_ptr); - AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); - } - - auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); }; - - return torch::from_blob(device_ptr, contiguous_cpu.sizes(), - contiguous_cpu.strides(), deleter, - contiguous_cpu.options().device(torch::kCUDA)); -} \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/libtorch_stable/cuda_utils_kernels.cu similarity index 100% rename from csrc/cuda_utils_kernels.cu rename to csrc/libtorch_stable/cuda_utils_kernels.cu diff --git a/csrc/libtorch_stable/cuda_view.cu b/csrc/libtorch_stable/cuda_view.cu new file mode 100644 index 00000000000..7bf8267470e --- /dev/null +++ b/csrc/libtorch_stable/cuda_view.cu @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// This function assumes that `cpu_tensor` is a CPU tensor, +// and that UVA (Unified Virtual Addressing) is enabled. +torch::stable::Tensor get_cuda_view_from_cpu_tensor( + torch::stable::Tensor& cpu_tensor) { + STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); + + const auto dtype = cpu_tensor.scalar_type(); + const auto layout = cpu_tensor.layout(); + const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA); + + // handle empty tensor + if (cpu_tensor.numel() == 0) { + return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev); + } + + std::array is_pinned_stack{ + torch::stable::detail::from(cpu_tensor), + torch::stable::detail::from(std::nullopt)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION)); + if (torch::stable::detail::to(is_pinned_stack[0])) { + // If CPU tensor is pinned, directly get the device pointer. + void* host_ptr = const_cast(cpu_tensor.mutable_data_ptr()); + void* device_ptr = nullptr; + cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ", + cudaGetErrorString(err)); + + return torch::stable::from_blob( + device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype, + [base = cpu_tensor](void*) {}); // keep cpu tensor alive + } + + // If CPU tensor is not pinned, allocate a new pinned memory buffer. + torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor); + size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size(); + + void* host_ptr = nullptr; + cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err)); + } + + err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes, + cudaMemcpyDefault); + if (err != cudaSuccess) { + cudaFreeHost(host_ptr); + STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err)); + } + + void* device_ptr = nullptr; + err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + if (err != cudaSuccess) { + cudaFreeHost(host_ptr); + STD_TORCH_CHECK( + false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + } + + auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); }; + + return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(), + contiguous_cpu.strides(), cuda_dev, + contiguous_cpu.scalar_type(), deleter); +} diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/libtorch_stable/cutlass_extensions/common.cpp similarity index 90% rename from csrc/cutlass_extensions/common.cpp rename to csrc/libtorch_stable/cutlass_extensions/common.cpp index 3d2093ab942..5bc9463bfa6 100644 --- a/csrc/cutlass_extensions/common.cpp +++ b/csrc/libtorch_stable/cutlass_extensions/common.cpp @@ -1,4 +1,4 @@ -#include "cutlass_extensions/common.hpp" +#include "common.hpp" int32_t get_sm_version_num() { int32_t major_capability, minor_capability; diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/libtorch_stable/cutlass_extensions/common.hpp similarity index 100% rename from csrc/cutlass_extensions/common.hpp rename to csrc/libtorch_stable/cutlass_extensions/common.hpp diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 0a991de76ff..536693fee96 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel, #endif +// CPU tensor -> CUDA UVA view (shared CUDA/ROCm) +torch::stable::Tensor get_cuda_view_from_cpu_tensor( + torch::stable::Tensor& cpu_tensor); + // Attention kernels (shared CUDA/ROCm) void merge_attn_states( torch::stable::Tensor& output, @@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out, std::optional residual, int64_t group_size, bool is_scale_transposed); +void silu_and_mul_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + int64_t group_size, + std::optional scale_ub, + bool is_scale_transposed); + // Positional encoding kernels (shared CUDA/ROCm) void rotary_embedding(torch::stable::Tensor& positions, torch::stable::Tensor& query, diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu index 1091d9d1230..53ffe521363 100644 --- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -18,7 +18,7 @@ #include #include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/torch_utils.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "get_group_starts.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" diff --git a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu index c2b8c0c00de..502f430b30b 100644 --- a/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -21,7 +21,7 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/mixed_dtype_utils.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include diff --git a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu index 8a493fdf22c..04e98b6076a 100644 --- a/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu +++ b/csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu @@ -12,7 +12,7 @@ #include -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index b22308d25ca..88caf03fda3 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -20,7 +20,7 @@ #include -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu index 8d4ba1accc7..e1e7e7a74da 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "nvfp4_utils.cuh" #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu index d7b2a18e29c..bfb526fcd40 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D, diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu index fc83c6e8d34..86355bf7060 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass/cutlass.h" diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index 2baa00caa82..7adba6308fa 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -18,7 +18,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "cutlass/cutlass.h" diff --git a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu similarity index 63% rename from csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu rename to csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu index d5c76232599..b32a7bd271f 100644 --- a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu +++ b/csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM project -#include -#include +#include "../../torch_utils.h" #include "../../dispatch_utils.h" -#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh" +#include "quant_conversions.cuh" namespace vllm { @@ -105,64 +104,70 @@ __global__ void silu_and_mul_per_block_quant_kernel( } // namespace vllm -void silu_and_mul_per_block_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scales, int64_t group_size, - std::optional scale_ub, +void silu_and_mul_per_block_quant(torch::stable::Tensor& out, + torch::stable::Tensor const& input, + torch::stable::Tensor& scales, + int64_t group_size, + std::optional scale_ub, bool is_scale_transposed) { - static c10::ScalarType kFp8Type = is_fp8_ocp() - ? c10::ScalarType::Float8_e4m3fn - : c10::ScalarType::Float8_e4m3fnuz; + static torch::headeronly::ScalarType kFp8Type = + is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn + : torch::headeronly::ScalarType::Float8_e4m3fnuz; - TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); - TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); - TORCH_CHECK( - input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16, + STD_TORCH_CHECK(out.scalar_type() == kFp8Type || + out.scalar_type() == torch::headeronly::ScalarType::Char); + STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + STD_TORCH_CHECK( + input.scalar_type() == torch::headeronly::ScalarType::Half || + input.scalar_type() == torch::headeronly::ScalarType::BFloat16, "Input must be FP16 or BF16"); - TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32"); - TORCH_CHECK(group_size == 128 || group_size == 64, - "Unsupported group size: ", group_size); + STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == kFp8Type); + STD_TORCH_CHECK(out.scalar_type() == kFp8Type); } int32_t hidden_size = out.size(-1); auto num_tokens = input.size(0); int32_t num_groups = hidden_size / group_size; - TORCH_CHECK(input.size(-1) == hidden_size * 2, - "input last dim must be 2x output hidden_size"); - TORCH_CHECK(hidden_size % group_size == 0, - "hidden_size must be divisible by group_size"); + STD_TORCH_CHECK(input.size(-1) == hidden_size * 2, + "input last dim must be 2x output hidden_size"); + STD_TORCH_CHECK(hidden_size % group_size == 0, + "hidden_size must be divisible by group_size"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); dim3 grid(num_tokens, num_groups); dim3 block(group_size); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_in_t = scalar_t; - VLLM_DISPATCH_QUANT_TYPES( + VLLM_STABLE_DISPATCH_QUANT_TYPES( out.scalar_type(), "silu_and_mul_per_block_quant", [&] { using scalar_out_t = scalar_t; - VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { - VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { - vllm::silu_and_mul_per_block_quant_kernel< - scalar_in_t, scalar_out_t, transpose_scale, gs> - <<>>( - out.data_ptr(), - scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); - }); + VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_STABLE_DISPATCH_BOOL( + is_scale_transposed, transpose_scale, [&] { + vllm::silu_and_mul_per_block_quant_kernel< + scalar_in_t, scalar_out_t, transpose_scale, gs> + <<>>( + out.mutable_data_ptr(), + scales.mutable_data_ptr(), + input.const_data_ptr(), + scale_ub.has_value() + ? scale_ub->const_data_ptr() + : nullptr, + hidden_size); + }); }); }); }); -} \ No newline at end of file +} diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh index ae40c0989e0..1eed7579924 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh @@ -20,7 +20,7 @@ #include "cutlass/util/packed_stride.hpp" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on namespace vllm::c3x { diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh index 952931103c6..4cb591be056 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh @@ -15,7 +15,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on /* diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index adb3de50fc1..913436186c3 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -1,7 +1,7 @@ #include #include #include "cuda_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" template void dispatch_scaled_mm(torch::stable::Tensor& c, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh index 49df3fa4e7f..b523d7baeaa 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh @@ -8,7 +8,7 @@ #include #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" #include "get_group_starts.cuh" using namespace cute; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh index 6eb2c051d00..7846e609fe7 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cuh @@ -23,7 +23,7 @@ #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "core/math.hpp" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu index 2e5bbca4700..0f9873cbf88 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -4,7 +4,7 @@ #include "libtorch_stable/torch_utils.h" -#include "cutlass_extensions/common.hpp" +#include "libtorch_stable/cutlass_extensions/common.hpp" void cutlass_scaled_mm_sm75(torch::stable::Tensor& c, torch::stable::Tensor const& a, diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 511a788eeae..d376fd43aa7 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -1,4 +1,5 @@ #include "ops.h" +#include "cuda_utils.h" #include "core/registration.h" #include @@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "output_s, int group_size, float eps, float int8_min, float int8_max) -> " "()"); + ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); + #ifndef USE_ROCM ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); #endif @@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { "Tensor? scale_ub, Tensor!? residual, int group_size, " "bool is_scale_transposed) -> ()"); + // Fused SiLU+Mul + per-block quantization + ops.def( + "silu_and_mul_per_block_quant(" + "Tensor! out, " + "Tensor input, " + "Tensor! scales, " + "int group_size, " + "Tensor? scale_ub=None, " + "bool is_scale_transposed=False) -> ()"); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( @@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ops.impl("rms_norm_dynamic_per_token_quant", TORCH_BOX(&rms_norm_dynamic_per_token_quant)); ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant)); + ops.impl("silu_and_mul_per_block_quant", + TORCH_BOX(&silu_and_mul_per_block_quant)); // Positional encoding kernels (shared CUDA/ROCm) ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding)); @@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2)); } +STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) { + ops.impl("get_cuda_view_from_cpu_tensor", + TORCH_BOX(&get_cuda_view_from_cpu_tensor)); +} + // These capability-check functions take only primitive args (no tensors), so // there is no device to dispatch on. CompositeExplicitAutograd makes them // available for all backends. This is the stable ABI equivalent of calling @@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size)); } +STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) { + cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); + cuda_utils.def( + "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd, + cuda_utils) { + cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute)); + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", + TORCH_BOX(&get_max_shared_memory_per_block_device_attribute)); +} + // Cache ops STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) { // Swap in (out) the cache blocks from src to dst. diff --git a/csrc/ops.h b/csrc/ops.h index ed2fca26b0d..11b704aff29 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -void silu_and_mul_per_block_quant(torch::Tensor& out, - torch::Tensor const& input, - torch::Tensor& scales, int64_t group_size, - std::optional scale_ub, - bool is_scale_transposed); - // rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable // ABI for CUDA). It remains here because the CPU build still uses these // torch::Tensor declarations. @@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); -torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); - void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, std::optional const& azp); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3351638f574..f29b941affd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -2,7 +2,6 @@ // cache.h, which is no longer included here after cache ops moved to // _C_stable_libtorch). #include -#include "cuda_utils.h" #include "ops.h" #include "core/registration.h" #include @@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); - ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); - ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU, - &get_cuda_view_from_cpu_tensor); - // Activation ops (quantized only — basic ops moved to _C_stable_libtorch) ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); - // Fused SiLU+Mul + per-block quantization - ops.def( - "silu_and_mul_per_block_quant(" - "Tensor! out, " - "Tensor input, " - "Tensor! scales, " - "int group_size, " - "Tensor? scale_ub=None, " - "bool is_scale_transposed=False) -> ()"); - ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, - &silu_and_mul_per_block_quant); - // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4 @@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { } #endif -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { - // Cuda utils - - // Gets the specified device attribute. - cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); - cuda_utils.impl("get_device_attribute", &get_device_attribute); - - // Gets the maximum shared memory per block device attribute. - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); - cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute); -} - REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 5bf789a0919..80aec64ee5b 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -296,7 +296,7 @@ llm = LLM(model="Qwen/Qwen3-8B") The `fastokens` Python package (>= 0.2.0) must be installed; if it isn't, vLLM raises a clear `ImportError` at tokenizer load. The override applies to any `--tokenizer-mode` that ends up loading an HF fast tokenizer (`hf`, -`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Modes that don't use the HF +`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Models that don't use the HF fast tokenizer (`mistral`, `grok2`, `kimi_audio`) ignore the flag. Tokenizer-bound workloads — long shared prefixes, bursty short prompts, diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 9b5e26d0fed..3fc8b6dd52b 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -101,7 +101,7 @@ vLLM's `pre-commit` hooks will now run automatically every time you commit. Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with: ```bash - pre-commit run --hook-stage manual mypy-3.10 + pre-commit run --hook-stage manual mypy-3.11 ``` ### Documentation diff --git a/docs/design/nixl_kv_cache_lease.md b/docs/design/nixl_kv_cache_lease.md index a3fdaafe345..aa7683bb9e1 100644 --- a/docs/design/nixl_kv_cache_lease.md +++ b/docs/design/nixl_kv_cache_lease.md @@ -128,7 +128,7 @@ The lease mechanism is controlled through `kv_connector_extra_config` in `--kv-t vllm serve \ --kv-transfer-config '{ "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_producer", "kv_connector_extra_config": {"kv_lease_duration": 60} }' ``` diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index cb5a3dca035..0f0cbd55354 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -50,7 +50,7 @@ To select a different backend, set `kv_connector_extra_config.backends` in `--kv vllm serve \ --kv-transfer-config '{ "kv_connector":"NixlConnector", - "kv_role":"kv_both", + "kv_role":"kv_producer", "kv_connector_extra_config":{"backends":["LIBFABRIC"]} }' ``` @@ -60,7 +60,7 @@ You can also pass JSON keys individually using dotted arguments, and you can app ```bash vllm serve \ --kv-transfer-config.kv_connector NixlConnector \ - --kv-transfer-config.kv_role kv_both \ + --kv-transfer-config.kv_role kv_producer \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC ``` @@ -81,7 +81,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ vllm serve Qwen/Qwen3-0.6B \ --port 8100 \ --enforce-eager \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}' + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}' ``` ### Consumer (Decoder) Configuration @@ -96,7 +96,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \ vllm serve Qwen/Qwen3-0.6B \ --port 8200 \ --enforce-eager \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}' + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}' ``` ### Proxy Server @@ -212,10 +212,21 @@ sequenceDiagram Enable bidirectional KV transfer by setting `bidirectional_kv_xfer` in `kv_connector_extra_config` on **both** P and D instances: ```bash +# Prefill instance vllm serve \ --kv-transfer-config '{ "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "bidirectional_kv_xfer": true + } + }' + +# Decode instance +vllm serve \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer", "kv_connector_extra_config": { "bidirectional_kv_xfer": true } @@ -359,11 +370,10 @@ For multi-host DP deployment, only need to provide the host/port of the head ins - **kv_producer**: For prefiller instances that generate KV caches - **kv_consumer**: For decoder instances that consume KV caches from prefiller -- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. +- **kv_both** (deprecated): Previously used as a catch-all when the role was not predetermined. This value is now deprecated for NixlConnector and will be removed in a future release. -!!! tip - NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). - Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. +!!! warning + `kv_role="kv_both"` is deprecated for NixlConnector. Please set `kv_role="kv_producer"` for prefill instances and `kv_role="kv_consumer"` for decode instances. See [#33702](https://github.com/vllm-project/vllm/issues/33702) for details. ### KV Load Failure Policy diff --git a/docs/features/speculative_decoding/README.md b/docs/features/speculative_decoding/README.md index 768e9f78d40..58d1df9dced 100644 --- a/docs/features/speculative_decoding/README.md +++ b/docs/features/speculative_decoding/README.md @@ -169,7 +169,7 @@ speculative decoding, breaking down the guarantees into three key areas: > distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252) > - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling > without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, - > provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](/tests/v1/spec_decode). + > provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../../tests/v1/spec_decode). > verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291) 3. **vLLM Logprob Stability** diff --git a/mkdocs.yaml b/mkdocs.yaml index 1fee824f3b2..970bf963309 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -83,22 +83,22 @@ plugins: - "re:vllm\\._.*" # Internal modules - "vllm.third_party" - "vllm.vllm_flash_attn" - - "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files + - "vllm.transformers_utils.configs" + - "vllm.transformers_utils.processors" - !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default - mkdocstrings: handlers: python: options: - show_symbol_type_heading: true - show_symbol_type_toc: true - filters: - - "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs - summary: - modules: true - show_signature_annotations: true - separate_signature: true + filters: [] show_overloads: true signature_crossrefs: true + # Recommendations from api-autonav + docstring_section_style: list + parameter_headings: true + show_symbol_type_heading: true + show_symbol_type_toc: true + summary: true inventories: - https://docs.python.org/3/objects.inv - https://typing-extensions.readthedocs.io/en/latest/objects.inv diff --git a/requirements/common.txt b/requirements/common.txt index 8141dc8ea6b..8b37f3cd30c 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.17.0 -mistral_common[image] >= 1.11.2 +mistral_common[image] >= 1.11.3 opencv-python-headless >= 4.13.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/test/cuda.in b/requirements/test/cuda.in index 344a58ec1bb..8d7ad7d0aa2 100644 --- a/requirements/test/cuda.in +++ b/requirements/test/cuda.in @@ -31,7 +31,7 @@ torchaudio==2.11.0 torchvision==0.26.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.11.2 # required for voxtral test +mistral_common[image,audio] >= 1.11.3 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py opencv-python-headless >= 4.13.0 # required for video test diff --git a/requirements/test/cuda.txt b/requirements/test/cuda.txt index 7d847d10577..a3e1466c763 100644 --- a/requirements/test/cuda.txt +++ b/requirements/test/cuda.txt @@ -409,7 +409,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.11.2 +mistral-common==1.11.3 # via # -c requirements/common.txt # -r requirements/test/cuda.in diff --git a/requirements/test/nightly-torch.txt b/requirements/test/nightly-torch.txt index 89fd4ea9b43..10eb7a62191 100644 --- a/requirements/test/nightly-torch.txt +++ b/requirements/test/nightly-torch.txt @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.11.2 # required for voxtral test +mistral_common[image,audio] >= 1.11.3 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.13.0 # required for video test datamodel_code_generator # required for minicpm3 test diff --git a/requirements/test/rocm.in b/requirements/test/rocm.in index 0a615831774..ed10270f565 100644 --- a/requirements/test/rocm.in +++ b/requirements/test/rocm.in @@ -30,7 +30,7 @@ tblib # for pickling test exceptions timm>=1.0.17 # required for internvl and gemma3n-mm test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio]>=1.11.2 # required for voxtral test +mistral_common[image,audio]>=1.11.3 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py opencv-python-headless>=4.13.0 # required for video test diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt index e0232d8b6d3..eced7117116 100644 --- a/requirements/test/rocm.txt +++ b/requirements/test/rocm.txt @@ -512,7 +512,7 @@ mcp==1.27.0 # via -r requirements/test/../common.txt mdurl==0.1.2 # via markdown-it-py -mistral-common==1.11.2 +mistral-common==1.11.3 # via # -c requirements/common.txt # -r requirements/test/../common.txt diff --git a/requirements/test/xpu.txt b/requirements/test/xpu.txt index 5581d0a079c..6242ff420f6 100644 --- a/requirements/test/xpu.txt +++ b/requirements/test/xpu.txt @@ -266,7 +266,7 @@ mbstrdecoder==1.1.4 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.11.2 +mistral-common==1.11.3 # via # -c requirements/common.txt # -r requirements/test/xpu.in diff --git a/rust/src/chat/src/backend/hf.rs b/rust/src/chat/src/backend/hf.rs index 6c3dddc8729..77ed24de854 100644 --- a/rust/src/chat/src/backend/hf.rs +++ b/rust/src/chat/src/backend/hf.rs @@ -38,13 +38,17 @@ impl HfChatBackend { ) -> Result { let model_config = load_model_config(files.config_path.as_deref())?; let model_type = model_config.model_type().unwrap_or_default(); - let multimodal_model_info = MultimodalModelInfo::from_paths( - model_id.clone(), - (!model_type.is_empty()).then_some(model_type.to_string()), - files.config_path.as_deref(), - files.preprocessor_config_path.as_deref(), - tokenizer.clone(), - )?; + let multimodal_model_info = if options.language_model_only { + None + } else { + MultimodalModelInfo::from_paths( + model_id.clone(), + (!model_type.is_empty()).then_some(model_type.to_string()), + files.config_path.as_deref(), + files.preprocessor_config_path.as_deref(), + tokenizer.clone(), + )? + }; let multimodal_render_info = resolve_multimodal_render_info(multimodal_model_info.as_ref()); let renderer = options.renderer.resolve(model_type); @@ -225,6 +229,7 @@ mod tests { "test-model".to_string(), LoadModelBackendsOptions { renderer, + language_model_only: false, chat_template_content_format: Default::default(), chat_template: None, default_chat_template_kwargs: HashMap::new(), @@ -267,6 +272,54 @@ mod tests { assert_eq!(prompt, "hello"); } + #[test] + fn language_model_only_skips_multimodal_preprocessor_config() { + let mut files = resolved_files( + r#"{"model_type":"deepseek_v0_vl"}"#, + r#"{"chat_template":"{{ messages[0].content }}"}"#, + ); + let preprocessor_config_path = files + .config_path + .as_ref() + .unwrap() + .parent() + .unwrap() + .join("preprocessor_config.json"); + write_json(&preprocessor_config_path, r#"{"size":[672,672]}"#); + files.preprocessor_config_path = Some(preprocessor_config_path); + + let backend = HfChatBackend::from_resolved_model_files( + files.clone(), + "test-model".to_string(), + LoadModelBackendsOptions { + language_model_only: true, + chat_template_content_format: Default::default(), + chat_template: None, + default_chat_template_kwargs: HashMap::new(), + ..Default::default() + }, + test_tokenizer(), + ) + .unwrap(); + + assert!(backend.multimodal_model_info().is_none()); + + let error = HfChatBackend::from_resolved_model_files( + files, + "test-model".to_string(), + LoadModelBackendsOptions { + chat_template_content_format: Default::default(), + chat_template: None, + default_chat_template_kwargs: HashMap::new(), + ..Default::default() + }, + test_tokenizer(), + ) + .err() + .expect("invalid preprocessor config should fail without language_model_only"); + assert!(error.to_string().contains("failed to parse preprocessor_config.json")); + } + #[test] fn explicit_deepseek_renderer_overrides_generic_model_type() { let prompt = render_prompt( diff --git a/rust/src/chat/src/backend/mod.rs b/rust/src/chat/src/backend/mod.rs index f49ca673704..be609ba5d9e 100644 --- a/rust/src/chat/src/backend/mod.rs +++ b/rust/src/chat/src/backend/mod.rs @@ -60,6 +60,9 @@ pub type DynChatTextBackend = Arc; pub struct LoadModelBackendsOptions { /// Which chat renderer implementation to use. pub renderer: RendererSelection, + /// Disable frontend-side multimodal preprocessing and render the model as + /// language-only. + pub language_model_only: bool, /// How to serialize `message.content` when rendering the chat template. pub chat_template_content_format: ChatTemplateContentFormatOption, /// Optional server-default chat template override, provided either as an diff --git a/rust/src/cmd/src/cli.rs b/rust/src/cmd/src/cli.rs index ee7848fe0be..312d9365a0f 100644 --- a/rust/src/cmd/src/cli.rs +++ b/rust/src/cmd/src/cli.rs @@ -116,6 +116,10 @@ pub struct SharedRuntimeArgs { #[arg(long = "tokenizer-mode", default_value_t)] #[serde(default, rename = "tokenizer_mode")] pub renderer: RendererSelection, + /// Disable multimodal inputs and treat the model as language-only. + #[arg(long)] + #[serde(default)] + pub language_model_only: bool, /// Override the maximum model context length. When set, the frontend uses /// this value instead of the model's `max_position_embeddings` from /// `config.json`. @@ -243,6 +247,7 @@ impl SharedRuntimeArgs { tool_call_parser: self.tool_call_parser, reasoning_parser: self.reasoning_parser, renderer: self.renderer, + language_model_only: self.language_model_only, chat_template: self.chat_template, default_chat_template_kwargs: self.default_chat_template_kwargs, chat_template_content_format: self.chat_template_content_format, @@ -284,6 +289,7 @@ impl SharedRuntimeArgs { tool_call_parser: self.tool_call_parser, reasoning_parser: self.reasoning_parser, renderer: self.renderer, + language_model_only: self.language_model_only, chat_template: self.chat_template, default_chat_template_kwargs: self.default_chat_template_kwargs, chat_template_content_format: self.chat_template_content_format, @@ -419,6 +425,7 @@ impl ServeArgs { self.managed_engine.clone().into_config( self.runtime.model.clone(), self.runtime.max_model_len, + self.runtime.language_model_only, handshake_port, ) } diff --git a/rust/src/cmd/src/cli/tests.rs b/rust/src/cmd/src/cli/tests.rs index ea867e4673a..c56db543355 100644 --- a/rust/src/cmd/src/cli/tests.rs +++ b/rust/src/cmd/src/cli/tests.rs @@ -34,6 +34,7 @@ fn serve_args_forward_python_flags_with_separator() { tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, max_model_len: Some( 512, ), @@ -263,6 +264,7 @@ fn frontend_args_accept_json() { tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, max_model_len: None, grpc_port: None, shutdown_timeout: 0, @@ -321,7 +323,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() { "--output-address", "ipc:///tmp/output.sock", "--args-json", - r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","max_model_len":8192,"shutdown_timeout":3}"#, + r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","language_model_only":true,"max_model_len":8192,"shutdown_timeout":3}"#, ]) .unwrap(); @@ -338,6 +340,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() { ParserSelection::Explicit("qwen3_thinking".to_string()) ); assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32); + assert!(args.runtime.language_model_only); assert_eq!(args.runtime.max_model_len, Some(8192)); assert_eq!(args.runtime.shutdown_timeout, 3); } @@ -662,6 +665,7 @@ fn serve_args_accept_handshake_aliases() { tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, max_model_len: None, grpc_port: None, shutdown_timeout: 0, @@ -783,6 +787,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() { tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, chat_template: None, default_chat_template_kwargs: None, chat_template_content_format: Auto, @@ -846,6 +851,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() { tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, chat_template: None, default_chat_template_kwargs: None, chat_template_content_format: Auto, @@ -924,6 +930,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present tool_call_parser: Auto, reasoning_parser: Auto, renderer: Auto, + language_model_only: false, chat_template: None, default_chat_template_kwargs: None, chat_template_content_format: Auto, diff --git a/rust/src/engine-core-client/src/client/imp.rs b/rust/src/engine-core-client/src/client/imp.rs index 9a66ad84cc1..44eae7e3e54 100644 --- a/rust/src/engine-core-client/src/client/imp.rs +++ b/rust/src/engine-core-client/src/client/imp.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::slice; use std::sync::Arc; use arc_swap::ArcSwapOption; @@ -253,33 +252,47 @@ pub(crate) async fn run_abort_loop( inner: Arc, mut abort_rx: mpsc::UnboundedReceiver, ) { - // TODO: receive and abort requests in batch - while let Some(AbortRequest { request_id, cause }) = abort_rx.recv().await { - let Some(engine_id) = inner.take_auto_abort_target(&request_id) else { - debug!(request_id, "skip auto-abort for inactive request"); - continue; - }; + // Coalesce bursts of auto-aborts into a single Abort message per engine. + // A dropped-stream storm (e.g. many clients disconnecting at once under + // high concurrency) would otherwise issue one engine round-trip per + // request. `recv_many` returns as soon as at least one item is ready, so a + // lone abort is still forwarded promptly. + const MAX_DRAIN: usize = 1024; + let mut batch: Vec = Vec::new(); - match cause { - AbortCause::DroppedStream => { - info!(request_id, "auto-aborting request due to dropped stream") - } - AbortCause::StopStringMatched => { - debug!( - request_id, - "auto-aborting request due to stop string matched" - ) + while abort_rx.recv_many(&mut batch, MAX_DRAIN).await > 0 { + let mut by_engine: BTreeMap> = BTreeMap::new(); + + for AbortRequest { request_id, cause } in batch.drain(..) { + let Some(engine_id) = inner.take_auto_abort_target(&request_id) else { + debug!(request_id, "skip auto-abort for inactive request"); + continue; + }; + + match cause { + AbortCause::DroppedStream => { + info!(request_id, "auto-aborting request due to dropped stream") + } + AbortCause::StopStringMatched => { + debug!( + request_id, + "auto-aborting request due to stop string matched" + ) + } } + + by_engine.entry(engine_id).or_default().push(request_id); } - if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await - { - warn!( - request_id, - ?engine_id, - error = %error.as_report(), - "failed to auto-abort dropped request stream" - ); + for (engine_id, request_ids) in by_engine { + if let Err(error) = inner.do_abort_requests(&engine_id, &request_ids).await { + warn!( + ?engine_id, + ?request_ids, + error = %error.as_report(), + "failed to auto-abort request streams" + ); + } } } } diff --git a/rust/src/engine-core-client/src/tests/client.rs b/rust/src/engine-core-client/src/tests/client.rs index 9a92ffe447e..32530d6d385 100644 --- a/rust/src/engine-core-client/src/tests/client.rs +++ b/rust/src/engine-core-client/src/tests/client.rs @@ -1225,6 +1225,86 @@ async fn dropping_a_live_stream_triggers_abort() { client.shutdown().await.unwrap(); } +#[tokio::test] +async fn dropping_multiple_live_streams_aborts_all_in_a_burst() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-burst".to_vec(); + let request_ids = ["req-1", "req-2", "req-3"]; + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + for _ in 0..3 { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + } + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output("req-1", vec![99], None), + request_output("req-2", vec![99], None), + request_output("req-3", vec![99], None), + ], + ..Default::default() + }, + ) + .await; + + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); + assert_eq!(abort[0].as_ref(), &[0x01]); + let ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!( + ids, + vec![ + "req-1".to_string(), + "req-2".to_string(), + "req-3".to_string() + ] + ); + assert!( + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err() + ); + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + // Open every request first so all three adds reach the engine before it + // emits outputs, then drain the first token from each stream. + let mut streams = Vec::new(); + for id in request_ids { + streams.push(client.call(sample_request_with_id(id)).await.unwrap()); + } + for stream in streams.iter_mut() { + let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); + assert_eq!(first.new_token_ids, vec![99]); + } + // Drop the whole burst back-to-back so the abort worker can batch them. + drop(streams); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn dispatcher_failure_propagates_to_streams_and_future_calls() { init_tracing(); diff --git a/rust/src/managed-engine/src/cli.rs b/rust/src/managed-engine/src/cli.rs index 302737dbd88..d70870dc32a 100644 --- a/rust/src/managed-engine/src/cli.rs +++ b/rust/src/managed-engine/src/cli.rs @@ -71,6 +71,7 @@ impl ManagedEngineArgs { self, model: String, max_model_len: Option, + language_model_only: bool, handshake_port: u16, ) -> ManagedEngineConfig { let mut python_args = self.python_args; @@ -79,6 +80,9 @@ impl ManagedEngineArgs { python_args.push("--max-model-len".to_string()); python_args.push(max_model_len.to_string()); } + if language_model_only { + python_args.push("--language-model-only".to_string()); + } if let Some(data_parallel_size_local) = self.data_parallel_size_local { python_args.push("--data-parallel-size-local".to_string()); python_args.push(data_parallel_size_local.to_string()); diff --git a/rust/src/server/examples/external_engine_openai_qwen.rs b/rust/src/server/examples/external_engine_openai_qwen.rs index 50d6fc1be40..33da868876b 100644 --- a/rust/src/server/examples/external_engine_openai_qwen.rs +++ b/rust/src/server/examples/external_engine_openai_qwen.rs @@ -64,6 +64,7 @@ async fn main() -> Result<()> { tool_call_parser: ParserSelection::Auto, reasoning_parser: ParserSelection::Auto, renderer: RendererSelection::Auto, + language_model_only: false, chat_template: None, default_chat_template_kwargs: None, chat_template_content_format: ChatTemplateContentFormatOption::Auto, diff --git a/rust/src/server/src/config.rs b/rust/src/server/src/config.rs index f1599d18793..3018f57377a 100644 --- a/rust/src/server/src/config.rs +++ b/rust/src/server/src/config.rs @@ -53,6 +53,9 @@ pub struct Config { pub reasoning_parser: ParserSelection, /// Chat renderer selection. pub renderer: RendererSelection, + /// Disable frontend-side multimodal preprocessing and render the model as + /// language-only. + pub language_model_only: bool, /// Server-default chat template override, as a file path or inline /// template. pub chat_template: Option, diff --git a/rust/src/server/src/lib.rs b/rust/src/server/src/lib.rs index 8d779da132f..52ac7633292 100644 --- a/rust/src/server/src/lib.rs +++ b/rust/src/server/src/lib.rs @@ -42,6 +42,7 @@ async fn build_state(config: &Config) -> Result> { &config.model, LoadModelBackendsOptions { renderer: config.renderer, + language_model_only: config.language_model_only, chat_template: config.chat_template.clone(), chat_template_content_format: config.chat_template_content_format, default_chat_template_kwargs: config diff --git a/rust/src/server/src/routes/openai/chat_completions.rs b/rust/src/server/src/routes/openai/chat_completions.rs index 543a7e806c4..0cccd8bf4ab 100644 --- a/rust/src/server/src/routes/openai/chat_completions.rs +++ b/rust/src/server/src/routes/openai/chat_completions.rs @@ -85,6 +85,7 @@ pub async fn chat_completions( log_request, prepared.include_usage, prepared.requested_logprobs, + prepared.include_reasoning, prepared.echo, prepared.return_token_ids, prepared.return_tokens_as_token_ids, @@ -100,6 +101,7 @@ pub async fn chat_completions( created, prepared.requested_logprobs, prepared.include_prompt_logprobs, + prepared.include_reasoning, prepared.echo, prepared.return_token_ids, prepared.return_tokens_as_token_ids, @@ -134,6 +136,7 @@ async fn collect_chat_completion( created: u64, requested_logprobs: bool, include_prompt_logprobs: bool, + include_reasoning: bool, echo: Option, return_token_ids: bool, return_tokens_as_token_ids: bool, @@ -157,6 +160,11 @@ async fn collect_chat_completion( } = collected; let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json); let saw_tool_calls = message.tool_calls().next().is_some(); + let reasoning = message.reasoning(); + // Output logprobs and token IDs cover the complete generated token stream. + // When reasoning is hidden, omit them rather than leaking hidden reasoning + // tokens through per-token metadata. + let include_output_metadata = include_reasoning || reasoning.is_none(); let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?.to_string(); let tool_calls = message .tool_calls() @@ -169,7 +177,7 @@ async fn collect_chat_completion( }, }) .collect::>(); - let logprobs = if requested_logprobs { + let logprobs = if requested_logprobs && include_output_metadata { Some(decoded_logprobs_to_openai_chat( logprobs.as_ref().ok_or_else(|| { server_error!("chat response requested logprobs but generation returned none") @@ -207,12 +215,12 @@ async fn collect_chat_completion( None => Some(message.text()).filter(|t| !t.is_empty()), }, tool_calls: Some(tool_calls).filter(|calls| !calls.is_empty()), - reasoning: message.reasoning(), + reasoning: if include_reasoning { reasoning } else { None }, }, logprobs, finish_reason: Some(finish_reason), stop_reason, - token_ids: return_token_ids.then_some(token_ids), + token_ids: (return_token_ids && include_output_metadata).then_some(token_ids), }], usage: Some(usage), system_fingerprint: None, @@ -232,12 +240,18 @@ async fn chat_completion_chunk_stream( log_request: bool, include_usage: bool, requested_logprobs: bool, + include_reasoning: bool, echo: Option, return_token_ids: bool, return_tokens_as_token_ids: bool, mut y: TryYielder, ) -> Result<(), ApiError> { let mut saw_tool_calls = false; + // `LogprobsDelta` is emitted after all chat events for one decoded update. + // If that update contains hidden reasoning, including delimiter-only block + // starts or ends, omit its token metadata as well as its visible delta. + let mut inside_hidden_reasoning = false; + let mut suppress_current_update_metadata = false; // If the client requested logprobs or token_ids, we need to buffer chunks until // we receive the separate `LogprobsDelta` event, so that we can emit one @@ -268,29 +282,44 @@ async fn chat_completion_chunk_stream( } } Ok(ChatEvent::BlockDelta { kind, delta, .. }) => { - if let Some(pending_chunk) = pending_chunk.as_mut() { - pending_chunk.push_block_delta(kind, delta); + let include_delta = + include_reasoning || !matches!(kind, AssistantBlockKind::Reasoning); + if include_delta { + if let Some(pending_chunk) = pending_chunk.as_mut() { + pending_chunk.push_block_delta(kind, delta); + } else { + y.yield_ok(block_delta_chunk( + &request_id, + &response_model, + created, + kind, + delta, + )) + .await; + } } else { - y.yield_ok(block_delta_chunk( - &request_id, - &response_model, - created, - kind, - delta, - )) - .await; + suppress_current_update_metadata = true; } } Ok(ChatEvent::LogprobsDelta { logprobs, token_ids, }) => { - let openai_logprobs = logprobs - .as_ref() - .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids)) - .transpose()?; - let openai_token_ids = - return_token_ids.then_some(token_ids).filter(|t| !t.is_empty()); + let include_metadata = + !suppress_current_update_metadata && !inside_hidden_reasoning; + suppress_current_update_metadata = false; + let openai_logprobs = if include_metadata { + logprobs + .as_ref() + .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids)) + .transpose()? + } else { + None + }; + let openai_token_ids = include_metadata + .then_some(token_ids) + .and_then(|token_ids| return_token_ids.then_some(token_ids)) + .filter(|t| !t.is_empty()); if let Some(pending_chunk) = pending_chunk.as_mut() { pending_chunk.logprobs = openai_logprobs; pending_chunk.token_ids = openai_token_ids; @@ -311,9 +340,17 @@ async fn chat_completion_chunk_stream( } Ok(ChatEvent::BlockStart { kind, .. }) => { debug!(?kind, "starting new block"); + if !include_reasoning && matches!(kind, AssistantBlockKind::Reasoning) { + inside_hidden_reasoning = true; + suppress_current_update_metadata = true; + } } Ok(ChatEvent::BlockEnd { .. }) => { debug!("ending current block"); + if inside_hidden_reasoning { + inside_hidden_reasoning = false; + suppress_current_update_metadata = true; + } } Ok(ChatEvent::ToolCallStart { index, id, name }) => { let tool_index = index as u32; @@ -763,7 +800,9 @@ fn stop_reason_to_json(stop_reason: &StopReason) -> Value { mod tests { use futures::{StreamExt as _, stream}; use serde_json::json; - use vllm_chat::{AssistantBlockKind, AssistantToolCall, ChatEvent, FinishReason}; + use vllm_chat::{ + AssistantBlockKind, AssistantContentBlock, AssistantToolCall, ChatEvent, FinishReason, + }; use vllm_engine_core_client::protocol::StopReason; use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob}; @@ -895,6 +934,7 @@ mod tests { false, false, true, + true, None, false, false, @@ -958,6 +998,7 @@ mod tests { false, false, true, + true, None, false, false, @@ -976,6 +1017,292 @@ mod tests { assert!(chunks[1].choices[0].logprobs.is_some()); } + #[tokio::test] + async fn chunk_stream_omits_reasoning_delta_when_disabled() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "think".to_string(), + }), + Ok(ChatEvent::BlockDelta { + index: 1, + kind: AssistantBlockKind::Text, + delta: "answer".to_string(), + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 2, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + false, + false, + None, + false, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!(chunks.len(), 3); + assert_eq!( + chunks[1].choices[0].delta.content.as_deref(), + Some("answer") + ); + assert!( + chunks + .iter() + .all(|chunk| chunk.choices.iter().all(|choice| choice.delta.reasoning.is_none())) + ); + } + + #[tokio::test] + async fn chunk_stream_omits_logprobs_for_suppressed_reasoning() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "think".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 11, + token: "think".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + token_ids: vec![11], + }), + Ok(ChatEvent::BlockDelta { + index: 1, + kind: AssistantBlockKind::Text, + delta: "answer".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 22, + token: "answer".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![22], + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 2, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + true, + false, + None, + true, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!(chunks.len(), 3); + let choice = &chunks[1].choices[0]; + assert_eq!(choice.delta.content.as_deref(), Some("answer")); + assert_eq!(choice.token_ids.as_deref(), Some(&[22][..])); + let logprobs = choice.logprobs.as_ref().expect("answer logprobs"); + let content = logprobs.content.as_ref().expect("logprobs content"); + assert_eq!(content[0].token, "answer"); + assert!(chunks.iter().all(|chunk| { + chunk.choices.iter().all(|choice| { + choice.delta.reasoning.is_none() + && choice.token_ids.as_deref() != Some(&[11][..]) + && choice + .logprobs + .as_ref() + .and_then(|logprobs| logprobs.content.as_ref()) + .is_none_or(|content| content.iter().all(|entry| entry.token != "think")) + }) + })); + } + + #[tokio::test] + async fn chunk_stream_omits_logprobs_for_hidden_reasoning_delimiters() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Reasoning, + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 11, + token: "".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + token_ids: vec![11], + }), + Ok(ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "think".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 12, + token: "think".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![12], + }), + Ok(ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Reasoning { + text: "think".to_string(), + }, + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 13, + token: "".to_string(), + logprob: -0.3, + rank: 1, + }], + }], + }), + token_ids: vec![13], + }), + Ok(ChatEvent::BlockStart { + index: 1, + kind: AssistantBlockKind::Text, + }), + Ok(ChatEvent::BlockDelta { + index: 1, + kind: AssistantBlockKind::Text, + delta: "answer".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 22, + token: "answer".to_string(), + logprob: -0.4, + rank: 1, + }], + }], + }), + token_ids: vec![22], + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 4, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + true, + false, + None, + true, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!(chunks.len(), 3); + let choice = &chunks[1].choices[0]; + assert_eq!(choice.delta.content.as_deref(), Some("answer")); + assert_eq!(choice.token_ids.as_deref(), Some(&[22][..])); + let logprobs = choice.logprobs.as_ref().expect("answer logprobs"); + let content = logprobs.content.as_ref().expect("logprobs content"); + assert_eq!(content[0].token, "answer"); + assert!(chunks.iter().all(|chunk| { + chunk.choices.iter().all(|choice| { + choice.delta.reasoning.is_none() + && !choice + .token_ids + .as_ref() + .is_some_and(|ids| matches!(ids.as_slice(), [11] | [12] | [13])) + && choice + .logprobs + .as_ref() + .and_then(|logprobs| logprobs.content.as_ref()) + .is_none_or(|content| { + content.iter().all(|entry| { + !matches!(entry.token.as_str(), "" | "think" | "") + }) + }) + }) + })); + } + #[tokio::test] async fn chunk_stream_preserves_tool_call_index_and_omits_id_from_arguments_delta() { let stream = stream::iter(vec![ @@ -1017,6 +1344,7 @@ mod tests { false, false, false, + true, None, false, false, diff --git a/rust/src/server/src/routes/openai/chat_completions/convert.rs b/rust/src/server/src/routes/openai/chat_completions/convert.rs index 2701bef809c..ba75c9c1a8c 100644 --- a/rust/src/server/src/routes/openai/chat_completions/convert.rs +++ b/rust/src/server/src/routes/openai/chat_completions/convert.rs @@ -29,6 +29,8 @@ pub struct PreparedRequest { pub requested_logprobs: bool, /// Whether the caller requested top-level prompt logprobs. pub include_prompt_logprobs: bool, + /// Whether to include reasoning content in OpenAI responses. + pub include_reasoning: bool, /// Lowered chat request for `vllm-chat`. pub chat_request: ChatRequest, /// Last assistant-role message content to echo back when `echo=true`. @@ -57,6 +59,7 @@ pub(crate) fn prepare_chat_request( .as_ref() .map(|request| request.lora_name.clone()) .unwrap_or_else(|| lora_resolution.model_names.first().cloned().unwrap_or_default()); + let include_reasoning = request.include_reasoning; let echo = request .echo .then(|| extract_last_assistant_content(&request.messages)) @@ -146,6 +149,7 @@ pub(crate) fn prepare_chat_request( include_usage, requested_logprobs, include_prompt_logprobs, + include_reasoning, chat_request, echo, return_token_ids: request.return_token_ids.unwrap_or(false), @@ -480,6 +484,23 @@ mod tests { assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto); } + #[test] + fn prepare_chat_request_preserves_include_reasoning_false() { + let request = ChatCompletionRequest { + include_reasoning: false, + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert!(!prepared.include_reasoning); + } + #[test] fn prepare_chat_request_preserves_sampling_passthrough_fields() { let request = ChatCompletionRequest { diff --git a/rust/src/server/src/routes/openai/chat_completions/validate.rs b/rust/src/server/src/routes/openai/chat_completions/validate.rs index fbd10eea0cb..379f2c2d39b 100644 --- a/rust/src/server/src/routes/openai/chat_completions/validate.rs +++ b/rust/src/server/src/routes/openai/chat_completions/validate.rs @@ -120,12 +120,6 @@ pub(super) fn validate_request_compat( "thinking_token_budget", "thinking_token_budget is not supported.", )?; - if !request.include_reasoning { - bail_invalid_request!( - param = "include_reasoning", - "include_reasoning is not supported." - ); - } reject_non_default( request.media_io_kwargs.as_ref(), "media_io_kwargs", @@ -312,6 +306,17 @@ mod tests { .expect("reasoning_effort should be accepted"); } + #[test] + fn validate_request_compat_accepts_include_reasoning_false() { + let request = ChatCompletionRequest { + include_reasoning: false, + ..base_request() + }; + + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("include_reasoning=false should be accepted"); + } + #[test] fn validate_request_compat_rejects_top_logprobs_without_logprobs() { let request = ChatCompletionRequest { diff --git a/rust/src/server/src/routes/tests.rs b/rust/src/server/src/routes/tests.rs index a3e437e0480..0a269f4ce17 100644 --- a/rust/src/server/src/routes/tests.rs +++ b/rust/src/server/src/routes/tests.rs @@ -3492,6 +3492,170 @@ async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() { assert!(text.contains("\"content\":\"answer\""), "{text}"); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn include_reasoning_false_suppresses_reasoning_in_non_stream_chat() { + let (app, engine_task) = test_app_with_backend_and_stream_output_specs( + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), + vec![ + (bytes_to_token_ids(b"think"), None), + ( + bytes_to_token_ids(b"answer"), + Some(EngineCoreFinishReason::Length), + ), + ], + ) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "include_reasoning": false, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let json: serde_json::Value = serde_json::from_str(&text).expect("decode json"); + + assert_eq!(json["choices"][0]["message"]["content"], "answer"); + assert!( + json["choices"][0]["message"] + .as_object() + .is_some_and(|message| !message.contains_key("reasoning")), + "{text}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn include_reasoning_false_suppresses_non_stream_output_metadata() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-hidden-reasoning-logprobs".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + let reasoning_token_ids = bytes_to_token_ids(b"think"); + let answer_token_ids = bytes_to_token_ids(b"answer"); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + reasoning_token_ids.clone(), + None, + None, + Some(sample_logprobs_for_tokens(&reasoning_token_ids)), + None, + ), + request_output_with_logprobs( + &request.request_id, + answer_token_ids.clone(), + Some(EngineCoreFinishReason::Length), + None, + Some(sample_logprobs_for_tokens(&answer_token_ids)), + None, + ), + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend( + test_llm(client), + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), + ); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "include_reasoning": false, + "logprobs": true, + "return_token_ids": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let json: serde_json::Value = serde_json::from_str(&text).expect("decode json"); + let choice = json["choices"][0].as_object().expect("choice object"); + + assert_eq!(json["choices"][0]["message"]["content"], "answer"); + assert!( + json["choices"][0]["message"] + .as_object() + .is_some_and(|message| !message.contains_key("reasoning")), + "{text}" + ); + assert!(!choice.contains_key("logprobs"), "{text}"); + assert!(!choice.contains_key("token_ids"), "{text}"); + assert!(json["prompt_token_ids"].is_array(), "{text}"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn tool_calls_are_mapped_to_tool_call_sse_chunks() { diff --git a/rust/src/tokenizer/src/incremental.rs b/rust/src/tokenizer/src/incremental.rs index 7a025d35e5c..462475fc918 100644 --- a/rust/src/tokenizer/src/incremental.rs +++ b/rust/src/tokenizer/src/incremental.rs @@ -115,6 +115,8 @@ impl IncrementalDecoder for DecodeStream<'_, T> { fn next_chunk(&mut self) -> Option { let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer); + // Ensure we split at a utf-8 char boundary. + let cutoff = self.cumulative_output.floor_char_boundary(cutoff); (cutoff > self.output_index).then(|| { let chunk = self.cumulative_output[self.output_index..cutoff].to_string(); self.output_index = cutoff; @@ -356,4 +358,27 @@ mod tests { assert_eq!(last_chunk.as_deref(), Some("lo!")); assert_eq!(full_text, "Hello!"); } + + #[test] + fn next_chunk_cutoff_respects_char_boundary() { + // Regression: next_chunk's cutoff (len - min_bytes_to_buffer) must be + // aligned to a UTF-8 char boundary like push_token/flush; otherwise + // streaming multi-byte output (CJK/emoji) with a hold-back buffer (set + // by a stop string) panics slicing cumulative_output mid-character. + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 2); + let mut out = String::new(); + for byte in "你好A".bytes() { + decoder.push_token(u32::from(byte)).unwrap(); + if let Some(chunk) = decoder.next_chunk() { + out.push_str(&chunk); + } + } + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + if let Some(chunk) = last_chunk { + out.push_str(&chunk); + } + assert_eq!(full_text, "你好A"); + assert_eq!(out, "你好A"); + } } diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 4cb199b5897..8e20b704fac 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -565,6 +565,8 @@ def test_size_used_in_multiple_consumer_subgraphs(): torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(y, 0) torch.compile(model_fn, backend=capturing_backend)(x, y) + assert captured_graph is not None, "Graph should be captured by backend" + assert captured_inputs is not None, "Example inputs should be captured by backend" split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) diff --git a/tests/distributed/test_dcp_a2a.py b/tests/distributed/test_dcp_a2a.py index d80ed36be65..5ab0f3de97b 100644 --- a/tests/distributed/test_dcp_a2a.py +++ b/tests/distributed/test_dcp_a2a.py @@ -15,6 +15,7 @@ import pytest import torch import torch.distributed as dist +import vllm.envs as envs from vllm.config.parallel import ParallelConfig from vllm.utils.network_utils import get_open_port from vllm.utils.system_utils import update_environment_variables @@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None: update_environment_variables(env) local_rank = int(env["LOCAL_RANK"]) torch.accelerator.set_device_index(local_rank) - dist.init_process_group(backend="nccl") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) + else: + dist.init_process_group(backend="nccl") use_workspace = env.get("USE_WORKSPACE") == "1" if use_workspace: from vllm.v1.worker.workspace import init_workspace_manager diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a1d5355d446..d7b04f68091 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,6 +9,7 @@ import pytest import torch import torch.distributed +import vllm.envs as envs from tests.utils import ensure_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -82,11 +83,18 @@ def test_pynccl(): @worker_fn_wrapper def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - groups = [ - torch.distributed.new_group(ranks=[0, 1], backend="gloo"), - torch.distributed.new_group(ranks=[2, 3], backend="gloo"), - ] - group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + # Eager-init path: parent PG has bound_device_id + a CPU backend, + # so split_group is supported. + group = torch.distributed.split_group( + split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl" + ) + else: + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo"), + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) # two groups can communicate independently @@ -339,11 +347,16 @@ def test_pynccl_send_recv(): @worker_fn_wrapper def multiple_send_recv_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - groups = [ - torch.distributed.new_group(ranks=[0, 2], backend="gloo"), - torch.distributed.new_group(ranks=[1, 3], backend="gloo"), - ] - group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + group = torch.distributed.split_group( + split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl" + ) + else: + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo"), + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) if torch.distributed.get_rank() == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index a9591f96a78..86eb82c962e 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -9,6 +9,7 @@ import ray import torch import torch.distributed as dist +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.quick_all_reduce import ( @@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size): ranks = [] for i in range(world_size): ranks.append(i) - dist.init_process_group( - backend="nccl", - init_method="tcp://127.0.0.1:29500", - rank=rank, - world_size=world_size, - ) - cpu_group = torch.distributed.new_group(ranks, backend="nccl") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + device_id=device, + ) + else: + dist.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + cpu_group = torch.distributed.split_group( + split_ranks=[ranks], backend="cpu:gloo,cuda:nccl" + ) + else: + cpu_group = torch.distributed.new_group(ranks, backend="nccl") handle = ops.qr_get_handle(_ptr) world_size = dist.get_world_size(group=cpu_group) diff --git a/tests/distributed/test_split_group.py b/tests/distributed/test_split_group.py new file mode 100644 index 00000000000..54586c9e370 --- /dev/null +++ b/tests/distributed/test_split_group.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for split_group in GroupCoordinator. + +These tests verify that: +1. split_group is used for both device and CPU group creation. +2. Multiple subgroups work correctly with split_group. +3. Both GPU and CPU all-reduce work on split groups. +""" + +import os +from typing import Any + +import multiprocess as mp +import pytest +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.distributed.parallel_state import ( + GroupCoordinator, + init_distributed_environment, +) +from vllm.utils.system_utils import update_environment_variables + +# The whole module exercises the split_group code path, which is opt-in +# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1. +pytestmark = pytest.mark.skipif( + not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP, + reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."), +) + +mp.set_start_method("spawn", force=True) + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[mp.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12346" + # propagate the opt-in flag to the spawned child workers + env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1" + p = mp.Process(target=fn, args=(env,)) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ["LOCAL_RANK"] + device = torch.device(f"cuda:{local_rank}") + torch.accelerator.set_device_index(device) + init_distributed_environment() + fn() + + return wrapped_fn + + +def _verify_device_group(coordinator: GroupCoordinator): + """Verify device group works via all-reduce.""" + local_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{local_rank}") + tensor = torch.ones(16, 16, dtype=torch.float32, device=device) + torch.distributed.all_reduce(tensor, group=coordinator.device_group) + torch.accelerator.synchronize() + expected = coordinator.world_size + assert torch.all(tensor == expected).cpu().item(), ( + f"Device group all-reduce failed: expected {expected}, " + f"got {tensor.flatten()[0].item()}" + ) + + +def _verify_cpu_group(coordinator: GroupCoordinator): + """Verify CPU group works via all-reduce.""" + tensor = torch.ones(16, dtype=torch.float32) + torch.distributed.all_reduce(tensor, group=coordinator.cpu_group) + expected = coordinator.world_size + assert torch.all(tensor == expected).cpu().item(), ( + f"CPU group all-reduce failed: expected {expected}, " + f"got {tensor.flatten()[0].item()}" + ) + + +# --------------------------------------------------------------------------- +# Test 1: Basic split_group path with 2 GPUs +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_basic_worker(): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + group_ranks = [list(range(world_size))] + + coordinator = GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_basic", + ) + + _verify_device_group(coordinator) + _verify_cpu_group(coordinator) + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 2, + reason="Need at least 2 GPUs to run the test.", +) +def test_split_group_basic(): + """Test basic GroupCoordinator creation with split_group.""" + distributed_run(split_group_basic_worker, 2) + + +# --------------------------------------------------------------------------- +# Test 2: Multiple subgroups with split_group (4 GPUs) +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_multiple_subgroups_worker(): + rank = torch.distributed.get_rank() + group_ranks = [[0, 1], [2, 3]] + + coordinator = GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_multi", + ) + + assert coordinator.world_size == 2 + + _verify_device_group(coordinator) + _verify_cpu_group(coordinator) + + if rank in [0, 1]: + assert coordinator.ranks == [0, 1] + else: + assert coordinator.ranks == [2, 3] + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 4, + reason="Need at least 4 GPUs to run the test.", +) +def test_split_group_multiple_subgroups(): + """Test GroupCoordinator with multiple independent subgroups.""" + distributed_run(split_group_multiple_subgroups_worker, 4) + + +# --------------------------------------------------------------------------- +# Test 3: split_group contract — every parent rank must enter with the same +# ``split_ranks``. NCCL happens to produce +# correct subgroups for disjoint partitions because the wrapper hashes +# ``my_group`` to derive the comm-split color, but the contract violation is +# real and would break under non-partition / non-NCCL backends. This test +# captures the actual ``split_ranks`` argument passed on every rank and +# asserts they match. +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_contract_worker(): + rank = torch.distributed.get_rank() + group_ranks = [[0, 1], [2, 3]] + + captured: list[list[list[int]]] = [] + original_split_group = torch.distributed.split_group + + def capturing_split_group(*args, split_ranks=None, **kwargs): + captured.append([list(g) for g in split_ranks]) + return original_split_group(*args, split_ranks=split_ranks, **kwargs) + + torch.distributed.split_group = capturing_split_group + try: + GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_contract", + ) + finally: + torch.distributed.split_group = original_split_group + + # GroupCoordinator builds two subgroups (device + cpu) per coordinator, + # so every rank must have made exactly two split_group calls. + if len(captured) != 2: + raise AssertionError( + f"rank {rank} expected 2 split_group calls (device + cpu), " + f"got {len(captured)}: {captured}" + ) + + world_size = torch.distributed.get_world_size() + for call_idx in range(2): + gathered: list[Any] = [None] * world_size + torch.distributed.all_gather_object(gathered, captured[call_idx]) + # Normalize for stable comparison (sort each subgroup and the outer list). + norm = [ + sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered + ] + reference = norm[0] + for r, args in enumerate(norm): + if args != reference: + raise AssertionError( + f"split_group contract violation on call #{call_idx}: " + f"rank {r} passed split_ranks={gathered[r]}, but rank 0 " + f"passed split_ranks={gathered[0]}. PyTorch requires every " + "parent rank to enter split_group with the same split_ranks." + ) + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 4, + reason="Need at least 4 GPUs to run the test.", +) +def test_split_group_contract_same_split_ranks_on_all_ranks(): + """All parent ranks must call torch.distributed.split_group with the same + ``split_ranks`` argument. This catches the bug where each rank passed + only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for + disjoint partitions but is a documented contract violation. + """ + distributed_run(split_group_contract_worker, 4) diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index e72f00bc91e..670df2759b0 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -5,13 +5,26 @@ import os import random +import torch import torch.distributed as dist +import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import get_world_group -# Let PyTorch choose the WORLD backend for the current device type. -dist.init_process_group() +# By default, let PyTorch choose the WORLD backend for the current device +# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1, +# use the explicit eager-init pattern required by `split_group` (mixed +# cpu:gloo,cuda:nccl backend + device_id binding). +if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.accelerator.set_device_index(local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) +else: + dist.init_process_group() # Create prompts prompts = [ diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py index 969b5e92e3f..6f0957ed026 100644 --- a/tests/distributed/test_torchrun_example_moe.py +++ b/tests/distributed/test_torchrun_example_moe.py @@ -5,13 +5,26 @@ import os import random +import torch import torch.distributed as dist +import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import get_tp_group, get_world_group -# Let PyTorch choose the WORLD backend for the current device type. -dist.init_process_group() +# By default, let PyTorch choose the WORLD backend for the current device +# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1, +# use the explicit eager-init pattern required by `split_group` (mixed +# cpu:gloo,cuda:nccl backend + device_id binding). +if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.accelerator.set_device_index(local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) +else: + dist.init_process_group() # Create prompts prompts = [ diff --git a/tests/entrypoints/openai/chat_completion/test_chat.py b/tests/entrypoints/openai/chat_completion/test_chat.py index 6703095aec4..16a3cd857cb 100644 --- a/tests/entrypoints/openai/chat_completion/test_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_chat.py @@ -808,14 +808,20 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA "logprobs": False, } - chat_completion = await client.chat.completions.create(**request_args) + # Use raw HTTP for both endpoints so we compare server responses + # directly, without the openai SDK injecting extra fields + # (e.g. `moderation` added in newer SDK versions). + chat_response = requests.post( + server.url_for("v1/chat/completions"), json=request_args + ) + chat_response.raise_for_status() invocation_response = requests.post( server.url_for("invocations"), json=request_args ) invocation_response.raise_for_status() - chat_output = chat_completion.model_dump() + chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() diff --git a/tests/kernels/attention/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py index 46757cc10b6..61e13166808 100644 --- a/tests/kernels/attention/test_lightning_attn.py +++ b/tests/kernels/attention/test_lightning_attn.py @@ -5,8 +5,16 @@ import pytest import torch from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed +DEVICE = current_platform.device_type + +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_xpu()), + reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.", +) + NUM_HEADS = [4, 8] HEAD_SIZES = [64] BATCH_SIZES = [1, 2] @@ -121,7 +129,7 @@ def test_linear_decode_forward_triton( head_size: int, dtype: torch.dtype, ): - torch.set_default_device("cuda") + torch.set_default_device(DEVICE) set_random_seed(42) base = 0.01 q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) @@ -129,16 +137,16 @@ def test_linear_decode_forward_triton( v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) kv_caches = base * torch.randn( - batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE ) kv_caches_copy = kv_caches.clone() - slope_rate = torch.zeros(num_heads, device="cuda") + slope_rate = torch.zeros(num_heads, device=DEVICE) for h in range(num_heads): slope_rate[h] = 0.1 * (h + 1) - slot_idx = torch.arange(batch_size, device="cuda") + slot_idx = torch.arange(batch_size, device=DEVICE) triton_output = linear_decode_forward_triton( q, k, v, kv_caches, slope_rate, slot_idx @@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding( head_size: int, dtype: torch.dtype, ): - torch.set_default_device("cuda") + torch.set_default_device(DEVICE) set_random_seed(42) batch_size = 4 @@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding( v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) kv_caches = base * torch.randn( - batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE ) kv_caches_copy = kv_caches.clone() - slope_rate = torch.zeros(num_heads, device="cuda") + slope_rate = torch.zeros(num_heads, device=DEVICE) for h in range(num_heads): slope_rate[h] = 0.1 * (h + 1) - slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") + slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE) triton_output = linear_decode_forward_triton( q, k, v, kv_caches, slope_rate, slot_idx @@ -224,7 +232,7 @@ def test_lightning_attention_reference( seq_len: int, dtype: torch.dtype, ): - torch.set_default_device("cuda") + torch.set_default_device(DEVICE) set_random_seed(42) base = 0.01 @@ -232,12 +240,12 @@ def test_lightning_attention_reference( k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) - ed = torch.zeros(num_heads, device="cuda") + ed = torch.zeros(num_heads, device=DEVICE) for h in range(num_heads): ed[h] = 0.1 * (h + 1) kv_history = base * torch.randn( - batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE ) kv_history_clone = kv_history.clone() diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 07f244451b4..44e2ee836b9 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -10,6 +10,7 @@ import torch from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import ParamSpec +import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import ( cleanup_dist_env_and_memory, @@ -60,7 +61,15 @@ def _set_vllm_config( tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, ) - cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + cpu_group = torch.distributed.split_group( + split_ranks=[list(range(world_size))], + group_desc="moe_test_cpu", + ) + else: + cpu_group = torch.distributed.new_group( + list(range(world_size)), backend="gloo" + ) return cpu_group diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 1380281bb2e..e3315142a9b 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -205,7 +205,10 @@ def run_with_expert_maps( w2 = kwargs["w2"] a = kwargs["hidden_states"] moe_config = make_dummy_moe_config( - num_experts=w2.shape[0], + max_num_tokens=kwargs.get("hidden_states").shape[0], + experts_per_token=kwargs.get("topk_ids").shape[1], + num_experts=num_experts, + num_local_experts=num_local_experts, hidden_dim=w2.shape[1], intermediate_size_per_partition=w2.shape[2], in_dtype=a.dtype, @@ -258,23 +261,27 @@ def run_8_bit( a1_scale=None, ) + num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined] + with_ep = num_local_experts is not None or num_local_experts == num_experts + kwargs = { "hidden_states": moe_tensors.a, "w1": moe_tensors.w1_q, # type: ignore[union-attr] "w2": moe_tensors.w2_q, # type: ignore[union-attr] "topk_weights": topk_weights, "topk_ids": topk_ids, - "global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr] + "global_num_experts": num_experts, "activation": MoEActivation.SILU, "expert_map": None, "apply_router_weight_on_input": False, } - num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined] - with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: moe_config = make_dummy_moe_config( - num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] + max_num_tokens=moe_tensors.a.shape[0], + experts_per_token=topk_ids.shape[1], + num_experts=num_experts, + num_local_experts=num_local_experts, hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] in_dtype=moe_tensors.a.dtype, @@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8( per_out_channel, False, topk_weights, + None, ) workspace13.random_() diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 452bf64ed98..efb1e2f2969 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -14,6 +14,7 @@ import torch.distributed from torch.distributed import ProcessGroup from typing_extensions import ParamSpec +import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe( w1_scale = w1_scale.to(device=device) w2_scale = w2_scale.to(device=device) - pg = torch.distributed.new_group(list(range(pgi.world_size))) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + pg = torch.distributed.split_group( + split_ranks=[list(range(pgi.world_size))], + group_desc="deepep_deepgemm_test", + ) + else: + pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, pgi.rank) block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 5e0303c3df7..83cd2f09d1e 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -10,6 +10,7 @@ import pytest import torch.distributed from torch.distributed import ProcessGroup +import vllm.envs as envs from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config @@ -375,7 +376,13 @@ def _deep_ep_moe( w1_scale = w1_scale.to(device=device_idx) w2_scale = w2_scale.to(device=device_idx) - pg = torch.distributed.new_group(list(range(pgi.world_size))) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + pg = torch.distributed.split_group( + split_ranks=[list(range(pgi.world_size))], + group_desc="deepep_test", + ) + else: + pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 3503ce4cdeb..ebb99576756 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def make_dummy_moe_config( num_experts: int = 1, + num_local_experts: int | None = None, experts_per_token: int = 1, hidden_dim: int = 1, intermediate_size_per_partition: int = 1, in_dtype: torch.dtype = torch.bfloat16, + max_num_tokens: int = 512, ) -> FusedMoEConfig: """ This is a dummy config for the mk constructor interface @@ -66,14 +68,16 @@ def make_dummy_moe_config( experts_per_token=experts_per_token, hidden_dim=hidden_dim, intermediate_size_per_partition=intermediate_size_per_partition, - num_local_experts=num_experts, + num_local_experts=num_local_experts + if num_local_experts is not None + else num_experts, num_logical_experts=num_experts, moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), activation=MoEActivation.SILU, in_dtype=in_dtype, device="cuda", routing_method=RoutingMethodType.TopK, - max_num_tokens=512, + max_num_tokens=max_num_tokens, ) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 337bc177e6d..6572a7efd22 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton, awq_gemm_triton, ) +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed -device = "cuda" +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_xpu()), + reason="AWQ Triton kernels require CUDA/ROCm or XPU.", +) + +device = current_platform.device_type def reverse_awq_order(t: torch.Tensor): diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py index 481715da9cd..72c189d8331 100644 --- a/tests/model_executor/test_eagle_quantization.py +++ b/tests/model_executor/test_eagle_quantization.py @@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> assert output.shape == (2, output_size) -def test_kv_cache_scale_name_handling(): - # Mock a quant config that supports cache scales - mock_quant_config = Mock() - mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") - - # Condition check in load_weights - name = "layers.0.self_attn.k_proj.weight" - scale_name = mock_quant_config.get_cache_scale(name) - - # Check if get_cache_scale is called and returns expected value - mock_quant_config.get_cache_scale.assert_called_once_with(name) - assert scale_name == "layers.0.self_attn.kv_scale" - - -def test_kv_cache_scale_name_no_scale(): - # Mock a quant config that returns None for get_cache_scale - mock_quant_config = Mock() - mock_quant_config.get_cache_scale = Mock(return_value=None) - - name = "layers.0.mlp.gate_proj.weight" - scale_name = mock_quant_config.get_cache_scale(name) - - # Should return None for weights that don't have cache scales - assert scale_name is None - - def test_maybe_remap_kv_scale_name(): from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name @@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config(): assert call_kwargs["quant_config"] is mock_quant_config, ( "ParallelLMHead must receive the draft model's quant_config" ) - - -def test_load_weights_kv_scale_handling(): - kv_scale_param = Mock() - kv_scale_param.weight_loader = Mock() - - params_dict = { - "layers.0.self_attn.kv_scale": kv_scale_param, - } - - mock_quant_config = Mock() - mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") - - # Load_weights logic for KV cache scales - name = "layers.0.self_attn.k_proj.weight" - loaded_weight_tensor = torch.tensor([1.0, 2.0]) - - if mock_quant_config is not None: - scale_name = mock_quant_config.get_cache_scale(name) - if scale_name: - param = params_dict[scale_name] - assert param is kv_scale_param - weight_to_load = ( - loaded_weight_tensor - if loaded_weight_tensor.dim() == 0 - else loaded_weight_tensor[0] - ) - - assert scale_name == "layers.0.self_attn.kv_scale" - assert weight_to_load == loaded_weight_tensor[0] diff --git a/tests/models/language/pooling/test_colbert.py b/tests/models/language/pooling/test_colbert.py index 10c229fe063..b6ff9b9ffbd 100644 --- a/tests/models/language/pooling/test_colbert.py +++ b/tests/models/language/pooling/test_colbert.py @@ -14,7 +14,7 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score # ----------------------------------------------------------------------- # Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs) # ----------------------------------------------------------------------- -COLBERT_MODELS = { +COLBERT_MODELS: dict[str, dict] = { "bert": { "model": "answerdotai/answerai-colbert-small-v1", "colbert_dim": 96, diff --git a/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py new file mode 100644 index 00000000000..92017e95cb7 --- /dev/null +++ b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import importlib +import importlib.util +import os + +import pytest +import torch + +from tests.utils import TestFP8Layer +from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, +) +from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( + FP8ScaledMMLinearLayerConfig, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8DynamicTokenSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform + +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = [ + pytest.mark.skipif( + not ( + current_platform.is_rocm() + and current_platform.supports_fp8() + and aiter_available + ), + reason="Requires ROCm + FP8 support + aiter", + ), + pytest.mark.usefixtures("default_vllm_config"), +] + + +@pytest.fixture +def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1") + rocm_aiter_ops.refresh_env_variables() + yield + rocm_aiter_ops.refresh_env_variables() + + +def _make_config( + *, + weight_quant_key=kFp8StaticChannelSym, + out_dtype: torch.dtype = torch.bfloat16, + weight_shape: tuple[int, int] = (512, 4096), +) -> FP8ScaledMMLinearLayerConfig: + return FP8ScaledMMLinearLayerConfig( + weight_quant_key=weight_quant_key, + activation_quant_key=kFp8DynamicTokenSym, + weight_shape=weight_shape, + input_dtype=torch.bfloat16, + out_dtype=out_dtype, + ) + + +def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None: + if not os.path.exists(path): + return None + + with open(path, newline="") as f: + reader = csv.DictReader(f, skipinitialspace=True) + for row in reader: + try: + if ( + int(row.get("m", -1)) == m + and int(row.get("n", -1)) == n + and int(row.get("k", -1)) == k + ): + return dict(row) + except (TypeError, ValueError): + continue + return None + + +def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None: + if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc): + pytest.skip( + "hipb_mm bpreshuffle path has no valid hipBLASLt solution on " + "this ROCm stack." + ) + + +def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int): + import aiter + from aiter.ops.shuffle import shuffle_weight + + x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda") + w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda") + + aiter.hipb_create_extension() + x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype()) + w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype()) + + try: + aiter.hipb_mm( + x_q, + shuffle_weight(w_q, layout=(16, 16)).t(), + solution_index=-1, + out_dtype=torch.bfloat16, + scaleA=x_scale, + scaleB=w_scale.t().contiguous(), + scaleOut=None, + bpreshuffle=True, + ) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + +def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch): + # The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper + # requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises + # the rejection branch. + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0") + monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False) + rocm_aiter_ops.refresh_env_variables() + + is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported() + + assert not is_supported + assert reason == ( + "requires setting `VLLM_ROCM_USE_AITER=1`, " + "`VLLM_ROCM_USE_AITER_LINEAR=1`, " + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`." + ) + + +def test_hipb_mm_flag_enables_hip_online_tuning( + monkeypatch: pytest.MonkeyPatch, +): + import vllm.envs as envs_mod + import vllm.platforms.rocm as rocm_mod + + # The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set + # HIP_ONLINE_TUNING. + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1") + + try: + importlib.reload(envs_mod) + importlib.reload(rocm_mod) + assert envs_mod.VLLM_ROCM_USE_AITER + assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR + assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + assert os.environ.get("HIP_ONLINE_TUNING") == "1" + finally: + monkeypatch.undo() + os.environ.pop("HIP_ONLINE_TUNING", None) + importlib.reload(envs_mod) + importlib.reload(rocm_mod) + rocm_aiter_ops.refresh_env_variables() + + +def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel): + can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement( + _make_config() + ) + + assert can_implement + assert reason is None + + +@pytest.mark.parametrize( + ("config", "expected_reason"), + [ + ( + _make_config(weight_quant_key=kFp8StaticTensorSym), + "requires per token activation scales and per channel weight scales.", + ), + ( + _make_config(out_dtype=torch.float16), + "requires bfloat16 output dtype.", + ), + ( + _make_config(weight_shape=(8, 4090)), + "requires N >= 16 and both N and K divisible by 16, " + "received N=8 and K=4090.", + ), + ], +) +def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs( + enable_hipb_mm_kernel, + config: FP8ScaledMMLinearLayerConfig, + expected_reason: str, +): + can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement( + config + ) + + assert not can_implement + assert reason == expected_reason + + +def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights( + enable_hipb_mm_kernel, +): + weight_shape = (512, 4096) + kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel( + _make_config(weight_shape=weight_shape), + layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"), + ) + + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter( + torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"), + requires_grad=False, + ) + layer.input_scale = None + layer.input_scale_ub = None + + original_weight = layer.weight.detach().clone() + original_weight_scale = layer.weight_scale.detach().clone() + + kernel.process_weights_after_loading(layer) + + # process_weights_after_loading now pre-applies the transposes that used + # to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the + # shuffled tensor with a trailing `.t()` view, and the stored weight scale + # is its transposed-contiguous form. + expected_weight = rocm_aiter_ops.shuffle_weight( + original_weight.t().contiguous() + ).t() + torch.testing.assert_close(layer.weight, expected_weight) + + expected_weight_scale = original_weight_scale.t().contiguous() + torch.testing.assert_close(layer.weight_scale, expected_weight_scale) + + +def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel): + import aiter + + weight_shape = (512, 4096) + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=32) + + layer = TestFP8Layer( + weight_shape=weight_shape, + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticChannelSym, + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + device=torch.device("cuda"), + force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel, + ) + + # hipb_mm uses a transposed-result GEMM internally, so the flattened token + # count becomes the effective N dimension passed into hipBLASLt. Keep it + # aligned to avoid heuristic failures for tiny N. + x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda") + bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda") + + try: + out = layer(x, bias) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + x_2d = x.view(-1, x.shape[-1]) + x_q, x_scale = layer.kernel.quant_fp8( + x_2d, + layer.input_scale, + layer.input_scale_ub, + ) + try: + # process_weights_after_loading already applies the trailing `.t()` on + # the shuffled weight and the `.t().contiguous()` on the weight scale, + # so the raw aiter call uses them directly. + expected = aiter.hipb_mm( + x_q, + layer.weight, + solution_index=-1, + bias=bias, + out_dtype=torch.bfloat16, + scaleA=x_scale, + scaleB=layer.weight_scale, + scaleOut=None, + bpreshuffle=True, + ).view(*out.shape) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel) + assert out.shape == (2, 16, weight_shape[0]) + torch.testing.assert_close(out, expected) + + +def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel): + """Kernel output should match a dequantized fp32 reference within + fp8 per-token / per-channel quantization noise.""" + weight_shape = (512, 4096) # (N, K) + num_tokens = 32 + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens) + + fp8_dtype = current_platform.fp8_dtype() + fp8_max = torch.finfo(fp8_dtype).max + device = torch.device("cuda") + + # Build a bf16 weight and quantize per output channel (one scale per row). + w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device) + w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32) + w_scale = (w_amax / fp8_max).clamp(min=1e-12) + w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype) + w_dequant = w_fp8.to(torch.float32) * w_scale + + bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device) + + layer = torch.nn.Module() + # Pre-`process_weights_after_loading` convention: weight stored as the + # `[K, N]` view of the fp8 tensor. + layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False) + layer.input_scale = None + layer.input_scale_ub = None + + kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel( + _make_config(weight_shape=weight_shape), + layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"), + ) + kernel.process_weights_after_loading(layer) + + x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device) + + try: + out = kernel.apply_weights(layer, x, bias) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + # Reference: quantize x per-token the same way the kernel does, then run + # the matmul in fp32 against the dequantized weight. This isolates plumbing + # / reduction bugs from inherent fp8 quantization noise. + x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32) + x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12) + x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype) + x_dequant = x_q.to(torch.float32) * x_scale_ref + expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16) + + assert out.shape == (num_tokens, weight_shape[0]) + # K=4096 fp8 reduction leaves room for accumulation order drift and + # catastrophic cancellation on near-zero outputs; tolerances are loose + # enough to absorb that but tight enough to catch wrong layouts, missing + # bias, swapped scales, etc. + torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1) + + +def test_hipb_mm_kernel_online_tuning_writes_csv( + enable_hipb_mm_kernel, + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + weight_shape = (256, 4096) + cache_file = tmp_path / "hip_online_tuning_res.csv" + + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=16) + + monkeypatch.setenv("HIP_ONLINE_TUNING", "1") + monkeypatch.chdir(tmp_path) + + layer = TestFP8Layer( + weight_shape=weight_shape, + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticChannelSym, + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + device=torch.device("cuda"), + force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel, + ) + + # The effective heuristic N dimension is the flattened token count. + x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda") + try: + out = layer(x) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + torch.accelerator.synchronize() + + assert out.shape == (16, weight_shape[0]) + assert cache_file.exists() + + # hipb_mm records the internal GEMM dimensions used by hipBLASLt after its + # transposed-result transformation. + row = _find_csv_row( + str(cache_file), + m=weight_shape[0], + n=x.shape[0], + k=weight_shape[1], + ) + assert row is not None diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index 2023337e857..47abbd81289 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -797,11 +797,11 @@ class TestMistralTokenizer: True, ( [1, 3, 23325, 2294, 1686, 4, 23325], - [1, 3, 22177, 4304, 2662, 4, 22177, 2], + [1, 3, 22177, 4304, 2662, 4, 22177], ), ( "[INST]▁Hello▁world▁![/INST]▁Hello", - ("[INST]Hello world ![/INST]Hello"), + "[INST]Hello world ![/INST]Hello", ), ), ], diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index bde246c9b66..0e7f6af7e38 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,11 +49,13 @@ else KV_EXTRA_CONFIG='' fi -# Build the kv-transfer-config once +# Build the kv-transfer-config for P and D if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG_P='{"kv_connector":"NixlConnector","kv_role":"kv_producer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG_D='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG_P="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG_D="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" fi # Models to run @@ -159,7 +161,7 @@ run_tests_for_model() { --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$KV_CONFIG_P'" if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" for arg in "${extra_args[@]}"; do @@ -207,7 +209,7 @@ run_tests_for_model() { --enforce-eager \ --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$KV_CONFIG_D'" if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" for arg in "${extra_args[@]}"; do diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh index bc90680a533..10e119d48a9 100755 --- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh +++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh @@ -193,9 +193,11 @@ run_test_for_device() { local kv_device=$1 if [[ "$kv_device" == "cuda" ]]; then - local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + local kv_config_p='{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' + local kv_config_d='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' else - local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}" + local kv_config_p="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"${kv_device}\"}" + local kv_config_d="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"${kv_device}\"}" fi echo "" @@ -248,7 +250,7 @@ run_test_for_device() { --block-size ${BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config "$kv_config" \ + --kv-transfer-config "$kv_config_p" \ --speculative-config "$PREFILL_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & @@ -287,7 +289,7 @@ run_test_for_device() { --block-size ${BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $DECODER_TP_SIZE \ - --kv-transfer-config "$kv_config" \ + --kv-transfer-config "$kv_config_d" \ --speculative-config "$DECODE_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & diff --git a/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py new file mode 100644 index 00000000000..4a2ca6d2721 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_handshake_pp_aggregation.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace +from typing import Any + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorHandshakeMetadata, +) +from vllm.v1.engine import core as engine_core_module + +pytestmark = pytest.mark.cpu_test + + +class _Metadata(KVConnectorHandshakeMetadata): + pass + + +class _FakeExecutor: + handshake_metadata_src: ( + list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None + ) + last_instance: "_FakeExecutor | None" = None + + def __init__( + self, + vllm_config: Any, + ) -> None: + del vllm_config + self.handshake_metadata = self.handshake_metadata_src + self.handshake_calls = 0 + _FakeExecutor.last_instance = self + + def get_kv_connector_handshake_metadata( + self, + ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None: + self.handshake_calls += 1 + return self.handshake_metadata + + def init_kv_output_aggregator(self, connector: KVConnectorBase_V1) -> None: + pass + + +def _run_engine_core_handshake( + monkeypatch: pytest.MonkeyPatch, + connector: KVConnectorBase_V1, + *, + handshake_metadata: ( + list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None + ), +) -> _FakeExecutor: + class _FakeScheduler: + def __init__(self, **kwargs: Any) -> None: + self.connector = connector + + def get_kv_connector(self) -> KVConnectorBase_V1: + return connector + + _FakeExecutor.handshake_metadata_src = handshake_metadata + _FakeExecutor.last_instance = None + + monkeypatch.setattr("vllm.plugins.load_general_plugins", lambda: None) + monkeypatch.setattr( + engine_core_module.EngineCore, + "_initialize_kv_caches", + lambda self, vllm_config: SimpleNamespace(kv_cache_groups=[object()]), + ) + monkeypatch.setattr( + engine_core_module, + "StructuredOutputManager", + lambda vllm_config: object(), + ) + monkeypatch.setattr( + engine_core_module, + "resolve_kv_cache_block_sizes", + lambda kv_cache_config, vllm_config: (16, 16), + ) + monkeypatch.setattr( + engine_core_module, + "MULTIMODAL_REGISTRY", + SimpleNamespace(engine_receiver_cache_from_config=lambda vllm_config: None), + ) + monkeypatch.setattr(engine_core_module, "freeze_gc_heap", lambda: None) + monkeypatch.setattr( + engine_core_module, "maybe_attach_gc_debug_callback", lambda: None + ) + monkeypatch.setattr(engine_core_module, "enable_envs_cache", lambda: None) + monkeypatch.setattr(engine_core_module, "get_hash_fn_by_name", lambda name: None) + monkeypatch.setattr(engine_core_module, "init_none_hash", lambda hash_fn: None) + monkeypatch.setattr( + engine_core_module, "get_request_block_hasher", lambda *args: None + ) + + vllm_config = SimpleNamespace( + parallel_config=SimpleNamespace(data_parallel_rank_local=0), + scheduler_config=SimpleNamespace( + get_scheduler_cls=lambda: _FakeScheduler, + enable_chunked_prefill=False, + async_scheduling=False, + ), + speculative_config=None, + ec_transfer_config=None, + max_concurrent_batches=1, + model_config=SimpleNamespace(runner_type="generate"), + cache_config=SimpleNamespace( + enable_prefix_caching=False, + prefix_caching_hash_algo="builtin", + ), + ) + + engine_core_module.EngineCore(vllm_config, _FakeExecutor, log_stats=False) + assert _FakeExecutor.last_instance is not None + return _FakeExecutor.last_instance + + +class _LegacyConnector(KVConnectorBase_V1): + def __init__(self) -> None: + self.legacy_metadata: dict[int, KVConnectorHandshakeMetadata] | None = None + + def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: Any, + attn_metadata: Any, + **kwargs: Any, + ) -> None: + pass + + def wait_for_save(self) -> None: + pass + + def get_num_new_matched_tokens( + self, request: Any, num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, request: Any, blocks: Any, num_external_tokens: int + ) -> None: + pass + + def build_connector_meta(self, scheduler_output: Any) -> Any: + raise NotImplementedError + + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + self.legacy_metadata = metadata + + +class _PPAwareConnector(_LegacyConnector): + def __init__(self) -> None: + super().__init__() + self.pp_aware_metadata: ( + dict[tuple[int, int], KVConnectorHandshakeMetadata] | None + ) = None + + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + self.pp_aware_metadata = metadata + + +def test_engine_unwraps_handshake_metadata_for_legacy_connector( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Engine core always asks workers for `(pp_rank, tp_rank)`-keyed metadata, + then unwraps to `{tp_rank: metadata}` for a connector that has not opted + into PP-aware handshake (single-PP producer, all `pp_rank == 0`).""" + metadata_0 = _Metadata() + metadata_1 = _Metadata() + connector = _LegacyConnector() + + executor = _run_engine_core_handshake( + monkeypatch, + connector, + handshake_metadata=[ + {(0, 0): metadata_0}, + None, + {(0, 1): metadata_1}, + ], + ) + + assert executor.handshake_calls == 1 + assert connector.legacy_metadata == {0: metadata_0, 1: metadata_1} + + +def test_engine_rejects_pp_producer_for_legacy_connector( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A connector that has not opted into PP-aware handshake must not silently + drop metadata from `pp_rank > 0`; engine core init raises instead.""" + connector = _LegacyConnector() + + with pytest.raises(ValueError, match="does not support PP-disaggregated"): + _run_engine_core_handshake( + monkeypatch, + connector, + handshake_metadata=[{(0, 0): _Metadata()}, {(1, 0): _Metadata()}], + ) + + +def test_engine_passes_handshake_metadata_through_for_pp_aware_connector( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A PP-aware connector receives the full `(pp_rank, tp_rank)`-keyed dict + unchanged.""" + metadata_0 = _Metadata() + metadata_1 = _Metadata() + connector = _PPAwareConnector() + + executor = _run_engine_core_handshake( + monkeypatch, + connector, + handshake_metadata=[{(0, 0): metadata_0}, {(1, 0): metadata_1}], + ) + + assert executor.handshake_calls == 1 + assert connector.legacy_metadata is None + assert connector.pp_aware_metadata == { + (0, 0): metadata_0, + (1, 0): metadata_1, + } diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index d1f3a81ca96..f78037a1431 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -261,11 +261,12 @@ def test_multi_example_connector_consistency(): storage1_scheduler_events = _ignore_event_collection(events["storage1-SCHEDULER"]) storage2_scheduler_events = _ignore_event_collection(events["storage2-SCHEDULER"]) # First event is bind_gpu_block_pool from initialization, then - # set_xfer_handshake_metadata, then on_new_request when the request is enqueued, - # then get_num_new_matched_tokens and update_state_after_alloc from generate(). + # set_xfer_handshake_metadata_pp_aware, then on_new_request when the request is + # enqueued, then get_num_new_matched_tokens and update_state_after_alloc from + # generate(). assert storage1_scheduler_events[:6] == [ "bind_gpu_block_pool", - "set_xfer_handshake_metadata", + "set_xfer_handshake_metadata_pp_aware", "on_new_request", "get_num_new_matched_tokens 0", "update_state_after_alloc num_blocks=[0] 0", @@ -285,7 +286,7 @@ def test_multi_example_connector_consistency(): ] assert storage2_scheduler_events[:6] == [ "bind_gpu_block_pool", - "set_xfer_handshake_metadata", + "set_xfer_handshake_metadata_pp_aware", "on_new_request", "get_num_new_matched_tokens 0", "update_state_after_alloc num_blocks=[0] 0", @@ -1057,7 +1058,7 @@ def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch): "connectors": [ { "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_consumer", }, { "kv_connector": "MockConnector", diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1a7c35cacb8..6f6d8b1ca98 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1363,7 +1363,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): timeout = 6 kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", kv_connector_extra_config={"kv_lease_duration": timeout}, ) llm_kwargs = { @@ -2737,3 +2737,50 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) f"got {notif!r} (expected {expected_notif!r}, " f"buggy form would be {bad_notif!r})" ) + + +def test_kv_both_deprecation_warning(default_vllm_config, dist_init): + """kv_role='kv_both' should emit a deprecation log warning.""" + from unittest.mock import patch + + from vllm.logger import _print_warning_once + + _print_warning_once.cache_clear() + + vllm_config = create_vllm_config(kv_role="kv_both") + + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger" + ) as mock_logger: + mock_logger.warning_once = mock_logger.warning_once + NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=16), + ) + + mock_logger.warning_once.assert_called_once() + msg = mock_logger.warning_once.call_args[0][0] + assert "kv_role='kv_both'" in msg + assert "deprecated" in msg + + +def test_explicit_kv_role_no_deprecation_warning(default_vllm_config, dist_init): + """kv_role='kv_consumer' or 'kv_producer' should NOT emit a warning.""" + from unittest.mock import patch + + for role in ("kv_consumer", "kv_producer"): + vllm_config = create_vllm_config(kv_role=role) + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger" + ) as mock_logger: + NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=16), + ) + + ( + mock_logger.warning_once.assert_not_called(), + (f"kv_role={role!r} should not emit deprecation warning"), + ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 80088809469..6e399db7b14 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -453,7 +453,7 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size): """ kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", ) block_size = 16 llm_kwargs = { diff --git a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py index 3a3ef2a88a6..c3adc05e3ef 100644 --- a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py +++ b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py @@ -75,7 +75,7 @@ def test_gpu_memory_rixl_hma(model_name, sw_size): "gpu_memory_utilization": 0.5, "kv_transfer_config": KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", ), "max_model_len": 2048, "disable_hybrid_kv_cache_manager": False, diff --git a/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py new file mode 100644 index 00000000000..ac00bb48128 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_transfer_topology_sharded.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineTransferInfo, + TransferTopology, +) + +pytestmark = pytest.mark.cpu_test + + +class _FakeAttentionBackend: + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, int, int, int, int]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + +def _make_topology( + *, + tp_rank: int = 1, + tp_size: int = 4, + total_num_kv_heads: int = 8, +) -> TransferTopology: + return TransferTopology( + tp_rank=tp_rank, + tp_size=tp_size, + block_size=16, + engine_id="local-engine", + is_mla=False, + is_mamba=False, + total_num_kv_heads=total_num_kv_heads, + attn_backends=[_FakeAttentionBackend], + ) + + +def test_legacy_register_remote_engine_uses_pp_rank_zero() -> None: + topology = _make_topology() + info = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + ) + + registered = topology.register_remote_engine("remote-engine", info) + + assert registered == info + assert registered.remote_pp_rank == 0 + assert topology.get_engine_info("remote-engine") == info + assert topology._engines[("remote-engine", 0)] == info + assert topology.target_remote_ranks("remote-engine") == [0] + + +def test_register_remote_engine_stores_pp_ranks_separately() -> None: + topology = _make_topology(tp_rank=0, tp_size=2) + + info_0 = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=0, + start_layer=0, + end_layer=16, + ) + info_1 = EngineTransferInfo( + remote_tp_size=1, + remote_block_len=512, + remote_block_size=8, + remote_physical_blocks_per_logical=2, + remote_pp_rank=1, + start_layer=16, + end_layer=32, + ) + + registered_0 = topology.register_remote_engine("remote-engine", info_0) + registered_1 = topology.register_remote_engine("remote-engine", info_1) + + assert registered_0 == info_0 + assert registered_1 == info_1 + assert topology.get_engine_info("remote-engine") == info_0 + assert topology.get_engine_info("remote-engine", 0) == info_0 + assert topology.get_engine_info("remote-engine", 1) == info_1 + assert set(topology._engines) == { + ("remote-engine", 0), + ("remote-engine", 1), + } + + +def test_helpers_use_requested_pp_rank() -> None: + topology = _make_topology(tp_rank=1, tp_size=2, total_num_kv_heads=2) + topology.register_remote_engine( + "remote-engine", + EngineTransferInfo( + remote_tp_size=1, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=0, + start_layer=0, + end_layer=8, + ), + ) + topology.register_remote_engine( + "remote-engine", + EngineTransferInfo( + remote_tp_size=4, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + remote_pp_rank=1, + start_layer=8, + end_layer=16, + ), + ) + + assert not topology.is_kv_replicated("remote-engine", 0) + assert topology.is_kv_replicated("remote-engine", 1) + assert topology.replicates_kv_cache("remote-engine", 1) + assert topology.target_remote_ranks("remote-engine", 0) == [0] + assert topology.target_remote_ranks("remote-engine", 1) == [2, 3] + assert "remote_pp=1" in topology.describe("remote-engine", 1) + + +def test_engine_info_fields_have_backward_compatible_defaults() -> None: + topology = _make_topology() + info = EngineTransferInfo( + remote_tp_size=2, + remote_block_len=1024, + remote_block_size=16, + remote_physical_blocks_per_logical=1, + ) + + registered = topology.register_remote_engine("remote-engine", info) + + assert topology.get_engine_info("remote-engine") == registered + assert registered.remote_pp_rank == 0 + assert registered.start_layer == 0 + assert registered.end_layer == 0 diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 1b892849d90..9d801f772a6 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -103,7 +103,7 @@ def create_vllm_config( kv_load_failure_policy: Literal["recompute", "fail"] = "fail", kv_connector: str = "NixlConnector", kv_connector_module_path: str | None = None, - kv_role: str = "kv_both", + kv_role: str = "kv_consumer", disable_hybrid_kv_cache_manager: bool | None = None, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" diff --git a/tests/v1/kv_offload/tiering/test_obj_tier.py b/tests/v1/kv_offload/tiering/test_obj_tier.py new file mode 100644 index 00000000000..ed112738f5a --- /dev/null +++ b/tests/v1/kv_offload/tiering/test_obj_tier.py @@ -0,0 +1,365 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Mock-based unit tests for ObjectStoreSecondaryTierManager. + +These tests replace the NIXL backend with an in-memory mock so they run +without S3 credentials or a live object store. They verify the manager's +state machine: job submission, transfer completion polling, and lookup. +""" + +import uuid +from collections.abc import Callable +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key +from vllm.v1.kv_offload.tiering.base import JobMetadata, JobResult +from vllm.v1.kv_offload.tiering.obj.manager import ObjectStoreSecondaryTierManager + +# --------------------------------------------------------------------------- +# Shared stubs +# --------------------------------------------------------------------------- + + +def _make_vllm_config(): + return SimpleNamespace( + model_config=SimpleNamespace(model="test/model"), + cache_config=SimpleNamespace(block_size=16, cache_dtype="float16"), + parallel_config=SimpleNamespace( + tensor_parallel_size=1, + pipeline_parallel_size=1, + prefill_context_parallel_size=1, + decode_context_parallel_size=1, + rank=0, + ), + ) + + +_OFFLOADING_SPEC = SimpleNamespace( + vllm_config=_make_vllm_config(), + kv_cache_config=SimpleNamespace(kv_cache_groups=[]), +) + +_STORE_CONFIG = { + "bucket": "mock-bucket", + "endpoint_override": "mock:9000", + "access_key": "mock-access", + "secret_key": "mock-secret", +} + +_BLOCK_ELEMENTS = 256 +_DTYPE = torch.float32 +_RUN_PREFIX = f"test/{uuid.uuid4().hex[:8]}" +_CTX = ReqContext(req_id="test-req") + + +def key(n: int) -> OffloadKey: + return make_offload_key(n.to_bytes(8, "big"), 0) + + +def make_job( + job_id: int, + keys: list[OffloadKey], + block_ids: list[int] | None = None, +) -> JobMetadata: + if block_ids is None: + block_ids = list(range(len(keys))) + return JobMetadata( + job_id=job_id, + keys=keys, + block_ids=np.array(block_ids, dtype=np.int64), + is_promotion=False, + req_context=_CTX, + ) + + +# --------------------------------------------------------------------------- +# Mock NIXL agent +# --------------------------------------------------------------------------- + + +class MockNixlAgent: + """In-memory NIXL agent. Tracks stored object keys and simulates async + transfers: transfer() returns PROC, check_xfer_state() returns DONE and + commits the write to the in-memory key set. + + The four methods overridden by tests (register_memory, make_prepped_xfer, + check_xfer_state, query_memory) are stored as Callable instance attributes + so mypy allows reassignment in tests. + """ + + # Callable attributes — tests may reassign these on instances. + register_memory: Callable + make_prepped_xfer: Callable + check_xfer_state: Callable + query_memory: Callable + + def __init__(self): + self._stored_obj_keys: set[str] = set() + # handle_id -> (op, [obj_keys]) + self._pending: dict[int, tuple[str, list[str]]] = {} + self._handle_counter = 0 + self._last_obj_keys: list[str] = [] + # Bind default implementations as instance attributes. + self.register_memory = self._register_memory + self.make_prepped_xfer = self._make_prepped_xfer + self.check_xfer_state = self._check_xfer_state + self.query_memory = self._query_memory + + def create_backend(self, backend_type, params): + pass + + def _register_memory(self, descs, mem_type=None, backends=None): + mock = MagicMock() + mock.trim.return_value = MagicMock() + # Capture obj_keys from OBJ 4-tuples: (addr, len, dev_id, obj_key) + if mem_type == "OBJ" and descs: + self._last_obj_keys = [d[3] for d in descs if d[3]] + return mock + + def deregister_memory(self, desc): + pass + + def prep_xfer_dlist(self, agent_name, descs, mem_type=None, backends=None): + return MagicMock() + + def _make_prepped_xfer( + self, + op, + local_handle, + local_indices, + remote_handle, + remote_indices, + notif_msg=b"", + backends=None, + skip_desc_merge=False, + ): + handle = MagicMock() + handle._id = self._handle_counter + self._pending[self._handle_counter] = (op, list(self._last_obj_keys)) + self._handle_counter += 1 + return handle + + def transfer(self, handle): + return "PROC" + + def _check_xfer_state(self, handle): + entry = self._pending.pop(handle._id, None) + if entry: + op, obj_keys = entry + if op == "WRITE": + self._stored_obj_keys.update(obj_keys) + return "DONE" + + def release_xfer_handle(self, handle): + pass + + def release_dlist_handle(self, handle): + pass + + def _query_memory(self, queries, mem_type, agent_name): + return [object() if q[3] in self._stored_obj_keys else None for q in queries] + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + + +def _make_tier( + num_blocks: int = 4, +) -> tuple[ObjectStoreSecondaryTierManager, MockNixlAgent]: + """Create a tier backed by a fresh MockNixlAgent.""" + mock_agent = MockNixlAgent() + tensor = torch.zeros((num_blocks, _BLOCK_ELEMENTS), dtype=_DTYPE) + view = memoryview(tensor.numpy()) + with ( + patch("vllm.v1.kv_offload.tiering.obj.manager.nixl_agent_config"), + patch( + "vllm.v1.kv_offload.tiering.obj.manager.nixl_agent", + return_value=mock_agent, + ), + ): + tier = ObjectStoreSecondaryTierManager( + offloading_spec=_OFFLOADING_SPEC, + primary_kv_view=view, + tier_type="obj", + store_config=_STORE_CONFIG, + prefix=_RUN_PREFIX, + ) + return tier, mock_agent + + +def drain( + tier: ObjectStoreSecondaryTierManager, max_rounds: int = 20 +) -> list[JobResult]: + """Poll get_finished_jobs() until all in-flight jobs resolve.""" + results: list[JobResult] = [] + for _ in range(max_rounds): + results.extend(tier.get_finished_jobs()) + if not tier._transfers: + break + return results + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMockObjTierBasic: + def setup_method(self): + self.tier, self.agent = _make_tier(num_blocks=4) + + def test_lookup_empty_tier(self): + assert self.tier.lookup(key(1), _CTX) is False + + def test_store_and_lookup(self): + self.tier.submit_store(make_job(1, [key(1)], [0])) + results = drain(self.tier) + assert len(results) == 1 + assert results[0].success + assert self.tier.lookup(key(1), _CTX) is True + + def test_lookup_unrelated_key_returns_false(self): + self.tier.submit_store(make_job(1, [key(1)], [0])) + drain(self.tier) + assert self.tier.lookup(key(999), _CTX) is False + + def test_store_then_load_roundtrip(self): + self.tier.submit_store(make_job(1, [key(1), key(2)], [0, 1])) + results = drain(self.tier) + assert results[0].success + + self.tier.submit_load(make_job(2, [key(1), key(2)], [0, 1])) + results = drain(self.tier) + assert len(results) == 1 + assert results[0].success + + def test_multiple_jobs_tracked_independently(self): + self.tier.submit_store(make_job(1, [key(1)], [0])) + self.tier.submit_store(make_job(2, [key(2)], [1])) + results = drain(self.tier) + assert len(results) == 2 + assert all(r.success for r in results) + + def test_failed_transfer_reported(self): + self.agent.check_xfer_state = lambda h: "ERR" + self.tier.submit_store(make_job(1, [key(1)], [0])) + results = drain(self.tier) + assert len(results) == 1 + assert not results[0].success + + def test_pending_transfer_not_returned_until_done(self): + # First poll returns PROC; second poll returns DONE. + call_count = [0] + original = self.agent.check_xfer_state + + def delayed(h): + call_count[0] += 1 + return "PROC" if call_count[0] == 1 else original(h) + + self.agent.check_xfer_state = delayed + + self.tier.submit_store(make_job(1, [key(1)], [0])) + assert list(self.tier.get_finished_jobs()) == [] + results = list(self.tier.get_finished_jobs()) + assert len(results) == 1 + assert results[0].success + + +class TestMockObjTierMultiBlock: + def test_store_multiple_blocks(self): + tier, _ = _make_tier(num_blocks=8) + keys = [key(i) for i in range(8)] + tier.submit_store(make_job(1, keys, list(range(8)))) + results = drain(tier) + assert len(results) == 1 + assert results[0].success + assert all(tier.lookup(k, _CTX) for k in keys) + + def test_partial_block_lookup(self): + tier, _ = _make_tier(num_blocks=4) + tier.submit_store(make_job(1, [key(0), key(1)], [0, 1])) + drain(tier) + assert tier.lookup(key(0), _CTX) is True + assert tier.lookup(key(1), _CTX) is True + assert tier.lookup(key(2), _CTX) is False + + +class TestMockObjTierFailures: + def test_lookup_exception_returns_false(self): + tier, agent = _make_tier(num_blocks=4) + agent.query_memory = lambda *a, **k: (_ for _ in ()).throw( + RuntimeError("backend error") + ) + assert tier.lookup(key(1), _CTX) is False + + def test_submit_store_register_memory_failure_reported_in_get_finished(self): + tier, agent = _make_tier(num_blocks=4) + agent.register_memory = lambda *a, **k: None + tier.submit_store(make_job(1, [key(1)], [0])) + results = list(tier.get_finished_jobs()) + assert len(results) == 1 + assert results[0].job_id == 1 + assert not results[0].success + + def test_submit_load_register_memory_failure_reported_in_get_finished(self): + tier, agent = _make_tier(num_blocks=4) + agent.register_memory = lambda *a, **k: None + tier.submit_load(make_job(2, [key(1)], [0])) + results = list(tier.get_finished_jobs()) + assert len(results) == 1 + assert results[0].job_id == 2 + assert not results[0].success + + def test_submit_store_make_prepped_xfer_failure_reported_in_get_finished(self): + tier, agent = _make_tier(num_blocks=4) + agent.make_prepped_xfer = lambda *a, **k: None + tier.submit_store(make_job(3, [key(1)], [0])) + results = list(tier.get_finished_jobs()) + assert len(results) == 1 + assert results[0].job_id == 3 + assert not results[0].success + + def test_failure_and_success_both_returned_by_get_finished(self): + # One job fails at submission, another succeeds in flight. + tier, agent = _make_tier(num_blocks=4) + original_register = agent.register_memory + call_count = [0] + + def register_once_fail(*a, **k): + call_count[0] += 1 + return None if call_count[0] == 1 else original_register(*a, **k) + + agent.register_memory = register_once_fail + + tier.submit_store(make_job(1, [key(1)], [0])) # fails immediately + tier.submit_store(make_job(2, [key(2)], [1])) # succeeds + results = drain(tier) + assert len(results) == 2 + by_id = {r.job_id: r for r in results} + assert not by_id[1].success + assert by_id[2].success + + +class TestMockObjTierShutdown: + def test_shutdown_clears_in_flight_transfers(self): + tier, agent = _make_tier(num_blocks=4) + # Keep transfer in flight by never completing it + agent.check_xfer_state = lambda h: "PROC" + tier.submit_store(make_job(1, [key(1)], [0])) + assert len(tier._transfers) == 1 + tier.shutdown() + assert len(tier._transfers) == 0 + assert tier._dram_prepped_handle is None + assert tier._primary_reg is None + + def test_shutdown_idempotent(self): + tier, _ = _make_tier(num_blocks=4) + tier.shutdown() + tier.shutdown() # must not raise diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index ae0cbeab53b..e628f903792 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -544,13 +544,15 @@ def native_sample_recovered_tokens( target_probs: torch.Tensor, # [num_tokens, vocab_size] sampling_metadata: SamplingMetadata, device: torch.device, + use_fp64_gumbel: bool = False, ) -> torch.Tensor: batch_size = len(num_draft_tokens) vocab_size = target_probs.shape[-1] + q_dtype = torch.float64 if use_fp64_gumbel else torch.float32 q = torch.empty( (batch_size, vocab_size), - dtype=torch.float32, + dtype=q_dtype, device=device, ) q.exponential_() @@ -935,6 +937,67 @@ def test_sample_recovered_tokens( assert torch.equal(recovered_token_ids, ref_recovered_token_ids) +def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested(): + batch_size = 2 + vocab_size = 64 + max_spec_len = 2 + num_tokens = batch_size * max_spec_len + + draft_probs = torch.rand( + num_tokens, + vocab_size, + dtype=torch.float32, + device=DEVICE_TYPE, + ) + draft_probs = F.softmax(draft_probs, dim=-1) + target_probs = torch.rand( + num_tokens, + vocab_size, + dtype=torch.float32, + device=DEVICE_TYPE, + ) + target_probs = F.softmax(target_probs, dim=-1) + draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32) + + generators = { + i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size) + } + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE), + generators=generators, + ) + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.reshape(batch_size, max_spec_len).tolist(), + target_probs.log(), + ) + + expected = native_sample_recovered_tokens( + max_spec_len, + spec_decode_metadata.num_draft_tokens, + spec_decode_metadata.cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device=torch.device(DEVICE_TYPE), + use_fp64_gumbel=True, + ) + actual = sample_recovered_tokens( + max_spec_len, + spec_decode_metadata.num_draft_tokens, + spec_decode_metadata.cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device=torch.device(DEVICE_TYPE), + use_fp64_gumbel=True, + ) + + assert torch.equal(actual, expected) + + ########################### Tests for Synthetic Rejection Sampling ######### diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index a80fddc9235..554649b5e19 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -6,7 +6,12 @@ from torch import Generator from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p_pytorch, + random_sample, +) +from vllm.v1.sample.sampler import Sampler DEVICE_TYPE = current_platform.device_type @@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool: FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported() +def _seed_default_generator(seed: int) -> None: + set_random_seed(seed) + + @pytest.fixture(autouse=True) def reset_default_device(): """ @@ -49,6 +58,35 @@ def reset_default_device(): torch.set_default_device(original_device) +def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler(): + sampler = Sampler(use_fp64_gumbel=True) + + assert sampler.topk_topp_sampler.use_fp64_gumbel + + +def test_random_sample_uses_fp64_exponential_race_when_requested(): + torch.set_default_device(DEVICE_TYPE) + probs = torch.tensor( + [ + [0.70, 0.20, 0.10], + [0.05, 0.15, 0.80], + [0.25, 0.25, 0.50], + ], + dtype=torch.float32, + device=DEVICE_TYPE, + ) + + _seed_default_generator(12345) + q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device) + q.exponential_() + expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1) + + _seed_default_generator(12345) + actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True) + + assert torch.equal(actual, expected) + + def test_topk_impl_equivalence(): torch.set_default_device(DEVICE_TYPE) generator = Generator(device=DEVICE_TYPE).manual_seed(33) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index c13de6d4f71..32f9dcc86ab 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch): proposer.model = model_mock proposer._draft_attn_layer_names = {"layer.0"} - def fake_compute_probs(logits, sampling_metadata): + def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel): + assert use_fp64_gumbel == proposer.use_fp64_gumbel probs = torch.softmax(logits, dim=-1) return probs.argmax(dim=-1), probs diff --git a/tests/v1/spec_decode/test_llm_base_proposer_sampling.py b/tests/v1/spec_decode/test_llm_base_proposer_sampling.py new file mode 100644 index 00000000000..9c7ec760ebb --- /dev/null +++ b/tests/v1/spec_decode/test_llm_base_proposer_sampling.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.llm_base_proposer import ( + compute_probs_and_sample_next_token, +) + +DEVICE_TYPE = current_platform.device_type + + +def _seed_default_generator(seed: int) -> None: + set_random_seed(seed) + + +def _make_sampling_metadata(batch_size: int) -> SamplingMetadata: + return SamplingMetadata( + temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE), + all_greedy=False, + all_random=True, + top_p=None, + top_k=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.empty(0, device=DEVICE_TYPE), + presence_penalties=torch.empty(0, device=DEVICE_TYPE), + repetition_penalties=torch.empty(0, device=DEVICE_TYPE), + output_token_ids=[[] for _ in range(batch_size)], + spec_token_ids=[[] for _ in range(batch_size)], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + ) + + +def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race(): + batch_size = 4 + vocab_size = 32 + generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11) + logits = torch.randn( + batch_size, + vocab_size, + dtype=torch.float32, + device=DEVICE_TYPE, + generator=generator, + ) + metadata = _make_sampling_metadata(batch_size) + + _seed_default_generator(12345) + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device) + q.exponential_() + expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1) + + _seed_default_generator(12345) + actual_ids, actual_probs = compute_probs_and_sample_next_token( + logits.clone(), + metadata, + use_fp64_gumbel=True, + ) + + assert torch.equal(actual_ids, expected_ids) + assert torch.allclose(actual_probs, probs) diff --git a/tools/gumbel_precision/prove_exponential_race_precision.py b/tools/gumbel_precision/prove_exponential_race_precision.py new file mode 100644 index 00000000000..2af8f40fa74 --- /dev/null +++ b/tools/gumbel_precision/prove_exponential_race_precision.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CUDA proof for fp32 exponential-race tail truncation. + +This script is intentionally not a unit test. It is a reproducible, GPU-only +statistical proof for the hidden Gumbel-max idiom: + + q.exponential_() + sample = (probs / q).argmax() + +For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA, +fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very +small q values are impossible. The many-tail experiment below chooses a case +where a correct sampler should select a low-probability tail token dozens of +times, while fp32 q cannot select one. +""" + +from __future__ import annotations + +import argparse +import math +import time + +import torch + + +def _seed(seed: int) -> None: + torch.manual_seed(seed) + + +def measure_exponential_lower_tail( + *, + device: torch.device, + samples: int, + chunk_size: int, + seed: int, +) -> None: + threshold = 2.0**-24 + print(f"lower-tail threshold: {threshold:.18e}") + for dtype in (torch.float32, torch.float64): + _seed(seed) + count_below = 0 + min_q = float("inf") + max_q = 0.0 + start = time.perf_counter() + remaining = samples + while remaining > 0: + n = min(chunk_size, remaining) + q = torch.empty((n,), dtype=dtype, device=device) + q.exponential_() + count_below += int((q < threshold).sum().item()) + min_q = min(min_q, float(q.min().item())) + max_q = max(max_q, float(q.max().item())) + remaining -= n + torch.accelerator.synchronize() + elapsed = time.perf_counter() - start + print( + f"{dtype}: samples={samples} count(q < 2^-24)={count_below} " + f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s" + ) + + +def run_many_tail_race( + *, + device: torch.device, + trials: int, + num_tail_tokens: int, + gap: float, + chunk_trials: int, + seed: int, +) -> None: + p_tail = math.exp(-gap) + expected_tail_hits = ( + trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail) + ) + print( + "many-tail race: " + f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} " + f"expected_tail_hits={expected_tail_hits:.4f}" + ) + + for dtype in (torch.float32, torch.float64): + _seed(seed) + hits = 0 + p0 = torch.tensor(1.0, dtype=dtype, device=device) + pt = torch.tensor(p_tail, dtype=dtype, device=device) + start = time.perf_counter() + remaining = trials + while remaining > 0: + batch = min(chunk_trials, remaining) + q0 = torch.empty((batch,), dtype=dtype, device=device) + q0.exponential_() + qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device) + qt.exponential_() + head_score = p0 / q0 + tail_score = (pt / qt).amax(dim=-1) + hits += int((tail_score > head_score).sum().item()) + remaining -= batch + torch.accelerator.synchronize() + elapsed = time.perf_counter() - start + print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--lower-tail-samples", type=int, default=200_000_000) + parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000) + parser.add_argument("--race-trials", type=int, default=100_000) + parser.add_argument("--race-tail-tokens", type=int, default=262_144) + parser.add_argument("--race-gap", type=float, default=20.5) + parser.add_argument("--race-chunk-trials", type=int, default=64) + parser.add_argument("--seed", type=int, default=2026) + args = parser.parse_args() + + if not torch.accelerator.is_available(): + raise RuntimeError("CUDA is required for this proof.") + + device = torch.accelerator.current_accelerator() + if device.type != "cuda": + raise RuntimeError("CUDA is required for this proof.") + + print(f"torch={torch.__version__} cuda={torch.version.cuda}") + print(f"device={device}") + measure_exponential_lower_tail( + device=device, + samples=args.lower_tail_samples, + chunk_size=args.lower_tail_chunk_size, + seed=args.seed, + ) + run_many_tail_race( + device=device, + trials=args.race_trials, + num_tail_tokens=args.race_tail_tokens, + gap=args.race_gap, + chunk_trials=args.race_chunk_trials, + seed=args.seed, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 22855080824..a174208da4c 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -8,11 +8,9 @@ on files that have been changed. It groups files into different mypy calls based on their directory to avoid import following issues. Usage: - python tools/pre_commit/mypy.py + python tools/pre_commit/mypy.py Args: - ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to - "silent" for the main group of files. python_version: Python version to use (e.g., "3.10") or "local" to use the local Python version. changed_files: List of changed files to check. @@ -98,8 +96,8 @@ def mypy( def main(): - python_version = sys.argv[2] - file_groups = group_files(sys.argv[3:]) + python_version = sys.argv[1] + file_groups = group_files(sys.argv[2:]) if python_version == "local": python_version = f"{sys.version_info.major}.{sys.version_info.minor}" diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5a8b690433c..eb12bedd7bf 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -27,6 +27,16 @@ except ImportError: # on ROCm the fp8_dtype always calls is_fp8_fnuz # which is a host op, so we cache it once here. FP8_DTYPE = current_platform.fp8_dtype() +_HIPB_MM_INITIALIZED_DEVICES: set[int] = set() + + +def _ensure_hipb_mm_extension_initialized() -> None: + import aiter + + device = torch.accelerator.current_device_index() + if device not in _HIPB_MM_INITIALIZED_DEVICES: + aiter.hipb_create_extension() + _HIPB_MM_INITIALIZED_DEVICES.add(device) def is_aiter_found() -> bool: @@ -625,6 +635,43 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake( return torch.empty(m, n, dtype=output_dtype, device=A.device) +def _rocm_aiter_hipb_mm_fp8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + from aiter import hipb_mm + + _ensure_hipb_mm_extension_initialized() + return hipb_mm( + A, + B, + solution_index=-1, + bias=bias, + out_dtype=output_dtype, + scaleA=As, + scaleB=Bs, + scaleOut=None, + bpreshuffle=True, + ) + + +def _rocm_aiter_hipb_mm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[1] + return torch.empty(m, n, dtype=output_dtype, device=A.device) + + def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, @@ -1308,6 +1355,7 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM + _LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE @@ -1340,6 +1388,7 @@ class rocm_aiter_ops: cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM + cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -1512,6 +1561,13 @@ class rocm_aiter_ops: return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() + @classmethod + @if_aiter_supported + def is_linear_hipbmm_enabled(cls) -> bool: + from vllm.platforms.rocm import on_mi3xx + + return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED + @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: @@ -1668,6 +1724,12 @@ class rocm_aiter_ops: fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_hipb_mm_fp8", + op_func=_rocm_aiter_hipb_mm_fp8_impl, + fake_impl=_rocm_aiter_hipb_mm_fp8_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_triton_gemm_a8w8_blockscale", op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl, @@ -1858,6 +1920,17 @@ class rocm_aiter_ops: A, B, As, Bs, bias, output_dtype ) + @staticmethod + def hipb_mm_fp8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype) + @staticmethod def triton_gemm_a8w8_blockscale( A: torch.Tensor, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f12d128f083..bd8a19b6d2d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -868,9 +868,10 @@ def cutlass_scaled_mm_azp( bias: torch.Tensor | None = None, ) -> torch.Tensor: """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. + Args: + azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + azp: Only set in the per-token case. Per-token if set. """ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 @@ -3886,9 +3887,12 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: Note that sylvester hadamard transforms are also symmetric, which means that this function is also applies the (transpose <=> inverse) transform. - :param x: value to be transformed inplace - :param inplace: modify value in place - :return: value after transformation + Args: + x: value to be transformed inplace + inplace: modify value in place + + Returns: + value after transformation """ return torch.ops._C.hadacore_transform(x, inplace) diff --git a/vllm/compilation/passes/inductor_pass.py b/vllm/compilation/passes/inductor_pass.py index 8a0d5326dd9..29410f960cd 100644 --- a/vllm/compilation/passes/inductor_pass.py +++ b/vllm/compilation/passes/inductor_pass.py @@ -82,11 +82,12 @@ class InductorPass(CustomGraphPass): # type: ignore[misc] def hash_source(*srcs: str | Any) -> str: """ Utility method to hash the sources of functions or objects. - :param srcs: strings or objects to add to the hash. - Objects and functions have their source inspected. - Results are cached by resolved types to avoid repeated - inspect.getsource() calls. - :return: + + Args: + srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + Results are cached by resolved types to avoid repeated + inspect.getsource() calls. """ # Resolve instances to their class for a hashable cache key. cache_key = tuple( @@ -99,7 +100,9 @@ class InductorPass(CustomGraphPass): # type: ignore[misc] def hash_dict(dict_: dict[Any, Any]) -> str: """ Utility method to hash a dictionary, can alternatively be used for uuid. - :return: A sha256 hash of the json rep of the dictionary. + + Returns: + A sha256 hash of the json rep of the dictionary. """ encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index 2887c19ad4a..c0643a916b3 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -276,9 +276,11 @@ class FixFunctionalizationPass(VllmInductorPass): """ Replace mutated getitem users of the auto-functionalized node with the mutated arguments. - :param node: The auto-functionalized node - :param mutated_args: The mutated arguments, indexed by getitem index. - If the value of an arg is a string, `node.kwargs[arg]` is used. + + Args: + node: The auto-functionalized node + mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. """ for idx, user in self.getitem_users(node).items(): # Some functionalized nodes may return both a result at getitem[0] @@ -317,10 +319,11 @@ class FixFunctionalizationPass(VllmInductorPass): as node.kwargs cannot be used. See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 - :param graph: Graph to insert the defunctionalized node into - :param node: The auto-functionalized node to defunctionalize - :param args: If we cannot use kwargs, specify args directly. - If an arg is a string, `node.kwargs[arg]` is used. + Args: + graph: Graph to insert the defunctionalized node into + node: The auto-functionalized node to defunctionalize + args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. """ # noqa: E501 assert is_func(node, auto_functionalized), ( f"node must be auto-functionalized, is {node} instead" diff --git a/vllm/compilation/passes/utility/noop_elimination.py b/vllm/compilation/passes/utility/noop_elimination.py index 5f7d47ad6f8..80bf8ecc603 100644 --- a/vllm/compilation/passes/utility/noop_elimination.py +++ b/vllm/compilation/passes/utility/noop_elimination.py @@ -108,9 +108,13 @@ class NoOpEliminationPass(VllmInductorPass): def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool: """ This function checks if two dimensions are equivalent. - :param dim: The dimension arg to reshape/slice - :param i_dim: The corresponding dimension in the input tensor - :return: Are the dimensions equivalent? + + Args: + dim: The dimension arg to reshape/slice + i_dim: The corresponding dimension in the input tensor + + Returns: + Are the dimensions equivalent? There are two cases in which the dimensions are equivalent: 1. The dimensions are equal (both integers) diff --git a/vllm/config/model.py b/vllm/config/model.py index 67040a423b7..f648a69e10e 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -235,9 +235,10 @@ class ModelConfig: temperature and top_k/top_p. """ use_fp64_gumbel: bool = False - """Whether to use FP64 (instead of FP32) for the Gumbel noise used by the - sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost - of significantly lower kernel throughput on most GPUs.""" + """Whether to use FP64 (instead of FP32) random noise for Gumbel-max and + equivalent exponential-race sampling. FP64 preserves lower-tail sampling + events that fp32 uniform/exponential draws can truncate, at the cost of + significantly lower throughput on most GPUs.""" disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding window functionality of the model, capping to sliding window size. If the diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 6edd69a949e..a652c5a6c73 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -180,8 +180,9 @@ class CuMemAllocator: All data in the memory allocation with the specified tag will be offloaded to CPU memory, and others will be discarded. - :param offload_tags: The tags of the memory allocation that will be - offloaded. The rest of the memory allocation will be discarded. + Args: + offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. """ if offload_tags is None: # by default, allocated tensors are offloaded @@ -230,9 +231,10 @@ class CuMemAllocator: All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - :param tags: The tags of the memory allocation that will be loaded - back to GPU memory. If None, all memory allocation will be loaded - back to GPU memory. + Args: + tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. """ for ptr, data in self.pointer_to_data.items(): if tags is None or data.tag in tags: @@ -255,8 +257,9 @@ class CuMemAllocator: All memory allocation created inside the context will be allocated in the memory pool, and has the specified tag. - :param tag: The tag of the memory allocation. If None, the default tag - will be used. + Args: + tag: The tag of the memory allocation. If None, the default tag + will be used. """ if tag is None: tag = CuMemAllocator.default_tag diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index ee21185969f..c9bc8909519 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -132,7 +132,8 @@ class KVEventAggregator: """ Add events from a worker batch. - :param events: List of KVCacheEvent objects. + Args: + events: List of KVCacheEvent objects. """ if not isinstance(events, list): raise TypeError("events must be a list of KVCacheEvent.") @@ -142,7 +143,8 @@ class KVEventAggregator: """ Return events that appeared in all workers. - :return: List of events present in all workers. + Returns: + List of events present in all workers. """ return [ event @@ -154,7 +156,8 @@ class KVEventAggregator: """ Return all events for all workers. - :return: List of events for all workers. + Returns: + List of events for all workers. """ return list(self._event_counter.elements()) @@ -168,7 +171,8 @@ class KVEventAggregator: """ Increment the number of workers contributing events. - :param count: Number to increment the workers by. + Args: + count: Number to increment the workers by. """ if count <= 0: raise ValueError("count must be positive.") @@ -184,7 +188,8 @@ class KVEventAggregator: """ Return the number of workers. - :return: int number of workers. + Returns: + int number of workers. """ return self._num_workers diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index fafc1f45724..0ab694b7e73 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -368,7 +368,7 @@ def get_current_attn_backend( class EngineTransferInfo: """Common per-remote-engine transfer state, computed at handshake. - Stored per ``engine_id`` inside ``TransferTopology._engines``. + Stored per ``(engine_id, pp_rank)`` inside ``TransferTopology._engines``. """ remote_tp_size: int @@ -382,6 +382,15 @@ class EngineTransferInfo: remote_physical_blocks_per_logical: int """Physical blocks per logical block.""" + remote_pp_rank: int = 0 + """Remote producer PP rank for this engine.""" + + start_layer: int = 0 + """Global index of the first layer owned by this PP rank.""" + + end_layer: int = 0 + """Exclusive global index after the last layer owned by this PP rank.""" + # ---- Transfer topology ---- @@ -403,7 +412,7 @@ class TransferTopology: def __post_init__(self): self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size) - self._engines: dict[EngineId, EngineTransferInfo] = {} + self._engines: dict[tuple[EngineId, int], EngineTransferInfo] = {} # Figure out whether the first dimension of the cache is K/V # or num_blocks. @@ -461,13 +470,16 @@ class TransferTopology: f"Cannot register local engine {self.engine_id} as remote. " f"Local identity is set via __init__ params." ) - if remote_engine_id in self._engines: - return self._engines[remote_engine_id] - self._engines[remote_engine_id] = info + engine_key = (remote_engine_id, info.remote_pp_rank) + if engine_key in self._engines: + return self._engines[engine_key] + self._engines[engine_key] = info return info - def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo: - return self._engines[remote_engine_id] + def get_engine_info( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> EngineTransferInfo: + return self._engines[(remote_engine_id, remote_pp_rank)] # ============================================================ # Layout properties @@ -528,15 +540,22 @@ class TransferTopology: ) return self.block_size // remote_block_size - def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: + def is_kv_replicated( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> bool: """Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. """ - return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads + return ( + self._engines[(remote_engine_id, remote_pp_rank)].remote_tp_size + > self.total_num_kv_heads + ) - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + def replicates_kv_cache( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> bool: # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) + return self.is_mla or self.is_kv_replicated(remote_engine_id, remote_pp_rank) @property def local_replicates_kv_cache(self) -> bool: @@ -555,12 +574,14 @@ class TransferTopology: abs_ratio = -tp_ratio return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)] - def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]: + def target_remote_ranks( + self, remote_engine_id: EngineId, remote_pp_rank: int = 0 + ) -> list[int]: """Get the remote TP rank(s) that the current local TP rank will read from. When remote tp_size > local tp_size, reads from multiple remote ranks. """ - info = self._engines[remote_engine_id] + info = self._engines[(remote_engine_id, remote_pp_rank)] tp_ratio = self.tp_ratio(info.remote_tp_size) if tp_ratio > 0: return [self.tp_rank // tp_ratio] @@ -593,15 +614,16 @@ class TransferTopology: # Regular case: backends like FA register K/V in separate regions return cache if self.split_k_and_v else [cache] - def describe(self, remote_engine_id: EngineId) -> str: + def describe(self, remote_engine_id: EngineId, remote_pp_rank: int = 0) -> str: """One-line summary of transfer config for logging.""" - info = self._engines[remote_engine_id] + info = self._engines[(remote_engine_id, remote_pp_rank)] return ( f"TransferTopology(" f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, " f"num_kv_heads={self.total_num_kv_heads if not self.is_mla else 1}, " f"local_tp={self.tp_size}, " f"remote_tp={info.remote_tp_size}, " + f"remote_pp={remote_pp_rank}, " f"local_rank={self.tp_rank}, " f"remote_block_len={info.remote_block_len})" ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index fb5658da887..71d89f43a79 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -644,6 +644,23 @@ class KVConnectorBase_V1(ABC): """ return None + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + """ + Set handshake metadata keyed by (pp_rank, tp_rank). + - Default implementation assumes pp_rank is always 0 + - PP-aware connectors override this to consume all PP producer shards. + """ + if any(pp_rank != 0 for pp_rank, _ in metadata): + raise ValueError( + f"{type(self).__name__} received pp_rank > 0 handshake metadata " + "but does not support PP-disaggregated KV transfer." + ) + self.set_xfer_handshake_metadata( + {tp_rank: meta for (_, tp_rank), meta in metadata.items()} + ) + @classmethod def build_prom_metrics( cls, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 35cd7060691..d16fbee585a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -439,13 +439,12 @@ def _init_lmcache_engine( `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment variable is not set, this function will return None. - :param lmcache_config: The LMCache configuration. - :type lmcache_config: LMCacheEngineConfig - :param vllm_config: The vLLM configuration. - :type vllm_config: VllmConfig + Args: + lmcache_config: The LMCache configuration. + vllm_config: The vLLM configuration. - :return: The initialized LMCache engine - :rtype: LMCacheEngine + Returns: + The initialized LMCache engine """ if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME): return curr_engine diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 73418104bea..46354337e65 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -471,6 +471,12 @@ class MultiConnector(KVConnectorBase_V1, SupportsHMA): for c in self._connectors: c.set_xfer_handshake_metadata(metadata) + def set_xfer_handshake_metadata_pp_aware( + self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata] + ) -> None: + for c in self._connectors: + c.set_xfer_handshake_metadata_pp_aware(metadata) + def _aggregate_request_finished( self, request: "Request", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py index 187322b4ae4..dad81e84c45 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py @@ -94,6 +94,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA): super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None + + if vllm_config.kv_transfer_config.kv_role == "kv_both": + logger.warning_once( + "Using kv_role='kv_both' with NixlConnector is deprecated " + "and will be removed in a future release. Please set " + "kv_role='kv_producer' for prefill instances and " + "kv_role='kv_consumer' for decode instances. " + ) + self.kv_cache_config = kv_cache_config self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8775e519d99..8bd6e92157a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -227,6 +227,67 @@ def patched_fused_scaled_matmul_reduce_scatter_fake( return res +def _platform_device_type() -> str: + """Return the device-type string (e.g. ``"cuda"``, ``"xpu"``, ``"cpu"``) + for the current platform, in the form expected by + ``torch.distributed.init_process_group(backend=...)``. + """ + from vllm.platforms import current_platform + + if current_platform.is_cuda_alike(): + return "cuda" + elif current_platform.is_xpu(): + return "xpu" + elif current_platform.is_out_of_tree(): + return current_platform.device_name + else: + return "cpu" + + +def _device_backend_str(torch_distributed_backend: str | Backend) -> str: + """Normalize ``torch_distributed_backend`` to the ``":"`` + format required by ``split_group``'s ``backend`` argument. + + Accepts either a bare backend name (e.g. ``"nccl"``) or an already-prefixed + string (e.g. ``"cuda:nccl"``). + """ + backend_str = str(torch_distributed_backend) + if ":" in backend_str: + return backend_str + return f"{_platform_device_type()}:{backend_str}" + + +def _create_subgroups_split_group( + group_ranks: list[list[int]], + group_name: str, + torch_distributed_backend: str | Backend, +) -> tuple[ProcessGroup, ProcessGroup]: + """Create the device + CPU subgroups for ``GroupCoordinator`` via + ``torch.distributed.split_group``. + + ``split_group`` is collective on the parent group, so every parent rank + must enter with the same ``split_ranks`` definition. Each rank receives + the subgroup it belongs to. + """ + device_backend_str = _device_backend_str(torch_distributed_backend) + self_device_group = torch.distributed.split_group( + split_ranks=group_ranks, + group_desc=f"{group_name}:device", + backend=device_backend_str, + ) + # CPU subgroup: split_group requires the requested backend filter to + # include the parent's default device type (= the device the parent PG + # was bound to via ``device_id``), so a cpu-only filter is rejected. + # Include the device backend in the filter; only the gloo backend is + # actually used for CPU collectives on this group. + self_cpu_group = torch.distributed.split_group( + split_ranks=group_ranks, + group_desc=f"{group_name}:cpu", + backend=f"cpu:gloo,{device_backend_str}", + ) + return self_device_group, self_cpu_group + + def patched_fused_scaled_matmul_reduce_scatter( A: torch.Tensor, B: torch.Tensor, @@ -335,26 +396,39 @@ class GroupCoordinator: self_device_group = None self_cpu_group = None - from vllm.distributed.utils import get_cpu_distributed_timeout_or_none - - timeout = get_cpu_distributed_timeout_or_none() - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend + # VLLM_DISTRIBUTED_USE_SPLIT_GROUP gates the new ``split_group`` + # codepath. Default (False) preserves the legacy ``new_group`` path. + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + self_device_group, self_cpu_group = _create_subgroups_split_group( + group_ranks, group_name, torch_distributed_backend ) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - with suppress_stdout(): - cpu_group = torch.distributed.new_group( - ranks, backend="gloo", timeout=timeout + for ranks in group_ranks: + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + break + else: + from vllm.distributed.utils import get_cpu_distributed_timeout_or_none + + timeout = get_cpu_distributed_timeout_or_none() + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend ) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self_device_group = device_group - self_cpu_group = cpu_group + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + with suppress_stdout(): + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=timeout + ) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self_device_group = device_group + self_cpu_group = cpu_group assert self_cpu_group is not None assert self_device_group is not None @@ -1348,6 +1422,62 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def _init_process_group_for_split_group( + *, + backend: str, + distributed_init_method: str, + world_size: int, + rank: int, + local_rank: int, + timeout: timedelta | None, +) -> None: + """Initialize the default PG with both CPU (gloo) and device (e.g. nccl) + backends and an eager ``device_id`` binding so that subgroups can be + created via ``split_group`` (which requires the parent communicator to + be eagerly initialized). Falls back to ``gloo`` on CPU-only systems. + """ + if torch.accelerator.is_available() and backend != "gloo": + init_backend = "cpu:gloo,cuda:nccl" + device_id: torch.device | None = torch.device(f"cuda:{local_rank}") + else: + init_backend = "gloo" + device_id = None + torch.distributed.init_process_group( + backend=init_backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout, + device_id=device_id, + ) + + +def _validate_default_pg_for_split_group() -> None: + """When an external launcher (e.g. ``torchrun``) initialized the default + PG, ``GroupCoordinator`` cannot patch in additional backends or change + the eager-init behavior — ``split_group`` only selects subsets of an + existing parent. Validate that the parent has both ``device_id`` and a + CPU (gloo) backend, and emit a descriptive error pointing at the exact + init call to update otherwise. + """ + default_pg = torch.distributed.distributed_c10d._get_default_group() + assert default_pg.bound_device_id is not None, ( + "External launcher initialized the default process group " + "without device_id. vLLM requires the default PG to be device-" + "bound for split_group. Pass device_id=torch.device(f'cuda:" + "{local_rank}') to torch.distributed.init_process_group()." + ) + try: + default_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + "External launcher initialized the default process group " + "without a CPU (gloo) backend. vLLM requires both CPU and " + "device backends. Pass backend='cpu:gloo,cuda:nccl' to " + "torch.distributed.init_process_group()." + ) from e + + def _init_elastic_ep_world( config, local_rank: int, backend: str, rank: int, world_size: int ) -> None: @@ -1456,14 +1586,33 @@ def init_distributed_environment( "Fallback Gloo backend is not available." ) backend = "gloo" - # this backend is used for WORLD - torch.distributed.init_process_group( - backend=backend, - init_method=distributed_init_method, - world_size=world_size, - rank=rank, - timeout=timeout, - ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + # split_group needs local_rank early to compute device_id for + # the eager init. local_rank is not available in torch + # ProcessGroup, see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + local_rank = ( + int(envs.LOCAL_RANK) + if distributed_init_method == "env://" + else rank + ) + _init_process_group_for_split_group( + backend=backend, + distributed_init_method=distributed_init_method, + world_size=world_size, + rank=rank, + local_rank=local_rank, + timeout=timeout, + ) + else: + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout, + ) if enable_elastic_ep: tp_pp_cpu_group = torch.distributed.new_group( backend="gloo", timeout=timeout @@ -1476,6 +1625,9 @@ def init_distributed_environment( "Elastic EP is not yet supported with multi-node TP/PP" ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP and torch.accelerator.is_available(): + _validate_default_pg_for_split_group() + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1903,6 +2055,10 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + logger.debug( + "[shutdown] Distributed: cleanup start shutdown_ray=%s", + shutdown_ray, + ) # Reset environment variable cache envs.disable_envs_cache() @@ -1937,6 +2093,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): "torch._C._host_emptyCache() only available in Pytorch >=2.5" ) + logger.debug_once("[shutdown] Distributed: cleanup complete") + def in_the_same_node_as( pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0 diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index a560db87ea2..08a3ab58c78 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -95,6 +95,9 @@ async def serve_http( shutdown_event = asyncio.Event() def signal_handler() -> None: + if shutdown_event.is_set(): + return + logger.info_once("[shutdown] API server: shutdown triggered") shutdown_event.set() async def dummy_shutdown() -> None: @@ -108,12 +111,21 @@ async def serve_http( engine_client = app.state.engine_client timeout = engine_client.vllm_config.shutdown_timeout + mode = "abort" if timeout == 0 else "drain" + + logger.info( + "[shutdown] API server: stopping engine client mode=%s timeout=%ss", + mode, + timeout, + ) await loop.run_in_executor( None, partial(engine_client.shutdown, timeout=timeout) ) + logger.info_once("[shutdown] API server: engine client stopped") server.should_exit = True + logger.info_once("[shutdown] API server: signalling HTTP server shutdown") server_task.cancel() watchdog_task.cancel() if ssl_cert_refresher: @@ -134,7 +146,7 @@ async def serve_http( process, " ".join(process.cmdline()), ) - logger.info("Shutting down FastAPI HTTP server.") + logger.info_once("[shutdown] API server: shutting down FastAPI HTTP server") return server.shutdown() finally: shutdown_task.cancel() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7297243f918..892e5035ab6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -556,7 +556,8 @@ class LLM(BeamSearchOfflineMixin, PoolingOfflineMixin, OfflineInferenceMixin): and returns their outputs. Use after enqueue() to get results. Args: - output_type: The expected output type, defaults to RequestOutput. + output_type: The expected output type(s). If not provided, accepts + both RequestOutput and PoolingRequestOutput. use_tqdm: If True, shows a tqdm progress bar. Returns: diff --git a/vllm/envs.py b/vllm/envs.py index bb3bb34284b..a32f055a028 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -63,6 +63,7 @@ if TYPE_CHECKING: VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False + VLLM_DISTRIBUTED_USE_SPLIT_GROUP: bool = False VLLM_XLA_USE_SPMD: bool = False VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -114,6 +115,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True + VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: bool = False VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_AITER_MOE_DISPATCH_POLICY: int = 0 VLLM_ROCM_USE_AITER_RMSNORM: bool = True @@ -877,6 +879,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool( int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1")) ), + # When True, GroupCoordinator constructs its CPU/device subgroups via + # ``torch.distributed.split_group(backend=...)`` + # and ``init_distributed_environment`` initializes the default PG with + # mixed ``cpu:gloo,cuda:nccl`` backend + eager ``device_id`` binding. + "VLLM_DISTRIBUTED_USE_SPLIT_GROUP": lambda: bool( + int(os.getenv("VLLM_DISTRIBUTED_USE_SPLIT_GROUP", "0")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( @@ -1109,6 +1118,9 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") ), + "VLLM_ROCM_USE_AITER_LINEAR_HIPBMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() in ("true", "1") + ), # Whether to use aiter moe ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MOE": lambda: ( diff --git a/vllm/ir/op.py b/vllm/ir/op.py index 8e82b5d8c7e..742d3f33ff8 100644 --- a/vllm/ir/op.py +++ b/vllm/ir/op.py @@ -113,11 +113,14 @@ def register_op( """ Register a new vLLM IR op. - :param f: the native implementation of the op - :param name: the name of the op, defaults to the function name - :param activations: list of activation params, defaults to params starting with 'x' - :param allow_inplace: add a maybe_inplace overload that allows inplace impls - :return: the IrOp object if f is provided, otherwise a decorator + Args: + f: the native implementation of the op + name: the name of the op, defaults to the function name + activations: list of activation params, defaults to params starting with 'x' + allow_inplace: add a maybe_inplace overload that allows inplace impls + + Returns: + the IrOp object if f is provided, otherwise a decorator Example usage: ```python @@ -245,14 +248,17 @@ class IrOp: supported: bool = True, supports_args: Callable[..., bool] | None = None, inplace: bool = False, - ): + ) -> Callable[[Callable[..., Any]], "IrOpImpl"]: """ Register an implementation for this custom op. - :param provider: The name of the provider, must be unique. - :param supported: Static support check, use this to check platform support. - :param supports_args: Dynamic arg support check, used for types and shapes. - :param inplace: Does this op reuse activation input memory for outputs - :return: A decorator that registers the implementation. + Args: + provider: The name of the provider, must be unique. + supported: Static support check, use this to check platform support. + supports_args: Dynamic arg support check, used for types and shapes. + inplace: Does this op reuse activation input memory for outputs + + Returns: + A decorator that registers the implementation. The decorated function must have the same semantics and signature as the native implementation. diff --git a/vllm/ir/util.py b/vllm/ir/util.py index ac8a06155da..e9240f487ac 100644 --- a/vllm/ir/util.py +++ b/vllm/ir/util.py @@ -12,9 +12,9 @@ from typing import Any def hash_source(*srcs: str | Any) -> str: """ Utility method to hash the sources of functions or objects. - :param srcs: strings or objects to add to the hash. - Objects and functions have their source inspected. - :return: + Args: + srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. """ hasher = hashlib.sha256() for src in srcs: diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index cd2c9eb01be..8162acd5e8d 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -128,6 +128,7 @@ from vllm.model_executor.kernels.linear.scaled_mm import ( ) from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( AiterFp8BlockScaledMMKernel, + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel, AiterPerTokenFp8ScaledMMLinearKernel, AiterPreshuffledPerTokenFp8ScaledMMLinearKernel, @@ -285,6 +286,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = ChannelWiseTorchFP8ScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, AiterPreshuffledPerTokenFp8ScaledMMLinearKernel, AiterPerTokenFp8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel, @@ -1024,6 +1026,7 @@ __all__ = [ "FP8ScaledMMLinearLayerConfig", "Int8ScaledMMLinearLayerConfig", "ScaledMMLinearLayerConfig", + "AiterHipbMMPerTokenFp8ScaledMMLinearKernel", "AiterPreshuffledPerTokenFp8ScaledMMLinearKernel", "AiterPerTokenFp8ScaledMMLinearKernel", "NvFp4LinearKernel", diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 5ded5ca798a..1b39491ab34 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -212,6 +212,99 @@ class AiterPreshuffledPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): ) +class AiterHipbMMPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return False, "requires ROCm." + + if not rocm_aiter_ops.is_linear_hipbmm_enabled(): + return ( + False, + "requires setting `VLLM_ROCM_USE_AITER=1`, " + "`VLLM_ROCM_USE_AITER_LINEAR=1`, " + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`.", + ) + try: + import aiter # noqa: F401 + except Exception: + return False, "requires aiter library to be installed." + + if not hasattr(aiter, "hipb_mm"): + return False, "requires aiter hipb_mm support." + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + is_ptpc = ( + c.activation_quant_key.scale.group_shape.is_per_token() + and c.weight_quant_key.scale.group_shape.is_per_channel() + ) + if c.weight_shape is None: + return False, "weight_shape is required for Aiter kernels" + N, K = c.weight_shape + + if c.out_dtype is not torch.bfloat16: + return False, "requires bfloat16 output dtype." + + if not is_ptpc: + return ( + False, + "requires per token activation scales and per channel weight scales.", + ) + + if not (N >= 16 and N % 16 == 0 and K % 16 == 0): + return ( + False, + "requires N >= 16 and both N and K divisible by 16, " + f"received N={N} and K={K}.", + ) + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w_name, w_s_name, *_ = self.layer_param_names + w, w_s, *_ = self._get_layer_params(layer) + + # Pre-apply the transposes that used to live in + # _rocm_aiter_hipb_mm_fp8_impl so the kernel can consume B/Bs directly. + # The `.t()` on the shuffled weight is kept as a non-contiguous view — + # materializing it with `.contiguous()` would re-arrange the bytes and + # break the `bpreshuffle` layout. + shuffled_w = rocm_aiter_ops.shuffle_weight(w.t().contiguous()) + replace_parameter( + layer, + w_name, + torch.nn.Parameter(shuffled_w.t(), requires_grad=False), + ) + + if w_s.ndim > 1: + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(w_s.t().contiguous(), requires_grad=False), + ) + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + output_shape[-1] = B.shape[1] + return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view( + *output_shape + ) + + class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def is_supported( diff --git a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py index d8570049af2..fa91804f35c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py @@ -379,8 +379,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): topk_ids, activation, global_num_experts, - # the fp8 cutlass experts use their own expert map. - None, + expert_map, self.w1_scale, self.w2_scale, a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py index a65873aca49..a412a6936d3 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxint4_moe.py @@ -30,6 +30,8 @@ class TrtLlmMxint4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic): self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, ): super().__init__(moe_config, quant_config) self.topk = moe_config.experts_per_token diff --git a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py index 82969dd8e25..94208326461 100644 --- a/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/xpu_moe.py @@ -14,7 +14,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, + kFp8Static128BlockSym, kFp8StaticTensorSym, kInt4Static, kMxfp4Static, @@ -62,6 +64,7 @@ class XPUExperts(mk.FusedMoEExpertsModular): self.is_fp8 = False self.is_int4 = False self.is_mxfp4 = False + self.is_block_fp8 = False self.is_mxfp8 = False self.fused_moe_impl: XpuFusedMoe | None = None @@ -171,6 +174,7 @@ class XPUExperts(mk.FusedMoEExpertsModular): is_int4=self.is_int4, is_mxfp4=self.is_mxfp4, is_mxfp8=self.is_mxfp8, + is_block_fp8=self.is_block_fp8, ) assert self.fused_moe_impl is not None self.fused_moe_impl.apply( @@ -238,6 +242,33 @@ class XPUExpertsMxfp8(XPUExpertsFp8): return (weight_key, activation_key) in SUPPORTED_W_A +class XPUExpertsBlockFp8(XPUExperts): + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, + ): + super().__init__( + moe_config, + quant_config, + max_num_tokens, + num_dispatchers, + ) + self.is_block_fp8 = True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + class XPUExpertsWNA16(XPUExperts): """W4A16 INT4-symmetric MoE backed by `xpu_fused_moe(is_int4=True)`. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4ff43ce21b8..2f4401563b6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -599,12 +599,14 @@ class FusedMoE(PluggableLayer): ): """ Load grouped weight scales for group quantization or model weights - :param shard_dim: dimension to shard - :param expert_data: parameter for a particular expert - :param shard_id: either w1, w2, or w3 - :param loaded_weight: checkpoint weight to load into the param - :param tp_rank: tensor parallel rank - :param load_full_w2: whether or not the w2 loaded should be sharded. + + Args: + shard_dim: dimension to shard + expert_data: parameter for a particular expert + shard_id: either w1, w2, or w3 + loaded_weight: checkpoint weight to load into the param + tp_rank: tensor parallel rank + load_full_w2: whether or not the w2 loaded should be sharded. """ if shard_id == "w2": # In the case where we have actorder/g_idx, we do not partition the diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 0a2e3846dd9..cce3245b75c 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -184,11 +184,12 @@ def backend_to_kernel_cls( elif backend == Fp8MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.experts.xpu_moe import ( + XPUExpertsBlockFp8, XPUExpertsFp8, XPUExpertsMxfp8, ) - return [XPUExpertsFp8, XPUExpertsMxfp8] + return [XPUExpertsFp8, XPUExpertsMxfp8, XPUExpertsBlockFp8] elif backend == Fp8MoeBackend.CPU: from vllm.model_executor.layers.fused_moe.experts.cpu_moe import ( diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index ef7a2745a06..d3ea7fb211a 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -4,6 +4,7 @@ import torch from einops import rearrange +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -403,13 +404,16 @@ class _attention(torch.autograd.Function): v = v.contiguous() s = s.contiguous() - # Check CUDA compute capability - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - raise RuntimeError( - "Flash attention currently only supported", - "for compute capability >= 80", - ) + # Check CUDA compute capability (Ampere+ required for flash attention + # path). Other accelerators (ROCm, XPU) rely on their own Triton + # backend support and skip this check. + if current_platform.is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError( + "Flash attention currently only supported", + "for compute capability >= 80", + ) # Get input dimensions b, h, n, d = q.shape diff --git a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py index 59bab27c48c..23d7070cc80 100644 --- a/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn/kimi_gdn_linear_attn.py @@ -85,7 +85,7 @@ direct_register_custom_op( class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): def get_state_dtype( self, - ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + ) -> tuple[torch.dtype, torch.dtype]: if self.model_config is None or self.cache_config is None: raise ValueError("model_config and cache_config must be set") return MambaStateDtypeCalculator.kda_state_dtype( @@ -94,7 +94,7 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): def get_state_shape( self, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + ) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.kda_state_shape( self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size ) @@ -300,13 +300,13 @@ class KimiGatedDeltaNetAttention(GatedDeltaNetAttention): g1 = g1[:, :num_actual_tokens] beta = beta[:, :num_actual_tokens] - (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + (conv_state, recurrent_state) = constant_caches # conv_state must be (..., dim, width-1) for the conv kernels. # DS layout stores it that way directly; SD layout needs a transpose. if not is_conv_state_dim_first(): - conv_state_q = conv_state_q.transpose(-1, -2) - conv_state_k = conv_state_k.transpose(-1, -2) - conv_state_v = conv_state_v.transpose(-1, -2) + conv_state = conv_state.transpose(-1, -2) + + conv_state_q, conv_state_k, conv_state_v = conv_state.chunk(3, dim=-2) q_conv_weights = self.q_conv1d.weight.view( self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index c1fd81e40e3..0c86c787917 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -120,9 +120,9 @@ class MambaStateDtypeCalculator: cls, model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, - ): + ) -> tuple[torch.dtype, torch.dtype]: state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) - return (state_dtype, state_dtype, state_dtype, torch.float32) + return (state_dtype, torch.float32) class MambaStateShapeCalculator: @@ -243,7 +243,7 @@ class MambaStateShapeCalculator: head_k_dim: int | None = None, conv_kernel_size: int = 4, num_spec: int = 0, - ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]: + ) -> tuple[tuple[int, int], tuple[int, int, int]]: if num_k_heads is None: num_k_heads = num_heads if head_k_dim is None: @@ -252,19 +252,12 @@ class MambaStateShapeCalculator: proj_size = num_heads * head_dim proj_k_size = num_k_heads * head_k_dim + conv_dim = proj_size + 2 * proj_k_size conv_state_shape = cls._orient_conv_shape( - divide(proj_size, tp_world_size), conv_kernel_size - 1 - ) - conv_state_k_shape = cls._orient_conv_shape( - divide(proj_k_size, tp_world_size), conv_kernel_size - 1 + divide(conv_dim, tp_world_size), conv_kernel_size - 1 ) recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) - return ( - conv_state_shape, - conv_state_k_shape, - conv_state_k_shape, - recurrent_state_shape, - ) + return (conv_state_shape, recurrent_state_shape) @dataclass @@ -365,9 +358,4 @@ class MambaStateCopyFuncCalculator: @classmethod def kda_state_copy_func(cls): - return ( - get_conv_copy_spec, - get_conv_copy_spec, - get_conv_copy_spec, - get_temporal_copy_spec, - ) + return (get_conv_copy_spec, get_temporal_copy_spec) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 344ddd8abd2..5b911114d38 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -162,7 +162,13 @@ class QuantizationConfig(ABC): """ raise NotImplementedError - def get_cache_scale(self, name: str) -> str | None: + def get_cache_scale_mapper(self) -> "WeightsMapper | None": + """Mapping from checkpoint KV-cache scale names to vLLM scale names. + + Returning a mapper here causes `AutoWeightsLoader` to apply it to the + weight stream automatically; individual model `load_weights` methods + do not need to know about KV-cache scales. + """ return None def apply_vllm_mapper( # noqa: B027 @@ -172,8 +178,9 @@ class QuantizationConfig(ABC): Interface for models to update module names referenced in quantization configs in order to reflect the vllm model structure - :param hf_to_vllm_mapper: maps from hf model structure (the assumed - structure of the qconfig) to vllm model structure + Args: + hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure """ # TODO (@kylesayrs): add implementations for all subclasses pass diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b59e12e8e1b..f4bd57e10e6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -267,8 +267,11 @@ class CompressedTensorsConfig(QuantizationConfig): cls, config: dict[str, Any] ) -> tuple[dict[str, SparsityCompressionConfig], list[str]]: """ - :param config: The `quantization_config` dictionary from config.json - :return: A tuple with two elements + Args: + config: The `quantization_config` dictionary from config.json + + Returns: + A tuple with two elements 1. A dictionary mapping target layer names to their corresponding sparsity_config 2. A list of layer names to ignore for sparsity @@ -296,8 +299,11 @@ class CompressedTensorsConfig(QuantizationConfig): cls, config: dict[str, Any] ) -> QUANTIZATION_SCHEME_MAP_TYPE: """ - :param config: The `quantization_config` dictionary from config.json - :return: A dictionary mapping target layer names to their corresponding + Args: + config: The `quantization_config` dictionary from config.json + + Returns: + A dictionary mapping target layer names to their corresponding quantization_args for weights and input activations """ target_scheme_map: dict[str, Any] = dict() @@ -967,7 +973,9 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): """ Validator for the kv cache scheme. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM - :param kv_cache_scheme: the compressed-tensors kv cache scheme + + Args: + kv_cache_scheme: the compressed-tensors kv cache scheme """ if kv_cache_scheme is None: return diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 731cba1ba2a..78419a0dd98 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -38,11 +38,11 @@ class CompressedTensorsScheme(ABC): Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. - :param x: input to the layer - :param bias: bias parameter - + Args: + layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + x: input to the layer + bias: bias parameter """ raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index def4797b139..afb899cd6d7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -133,12 +133,11 @@ def find_matched_target( *All* component module names must match in order for a match to be successful. A successful match returns the first component target - :param layer_name: layer name - :param module: torch.nn.Module - :param targets: list of targets to match the layer against - :param fused_mapping: map from fused layer names to its components - :param fused_strategy: either "all" or "any". If using "all", fused - layers match if "all" of its components match + Args: + layer_name: layer name + module: torch.nn.Module + targets: list of targets to match the layer against + fused_mapping: map from fused layer names to its components """ if layer_name is None: @@ -161,9 +160,10 @@ def _find_first_match( exactly or as a regex after 're:'. If check_contains is set to True, additionally checks if the target string is contained within the value. - :param value: string to compare the list of targets against - :param targets: list of targets to match the layer against - :param check_contains: whether or not to do a substring match + Args: + value: string to compare the list of targets against + targets: list of targets to match the layer against + check_contains: whether or not to do a substring match """ for target in targets: @@ -205,9 +205,10 @@ def _match_fused_layer( Implements an "all" matching strategy where a fused layer matches iff "all" of its components match - :param layer_name: layer name - :param target_layers: list of targets to match the layer against - :param fused_mapping: map from fused layer names to its components + Args: + layer_name: layer name + target_layers: list of targets to match the layer against + fused_mapping: map from fused layer names to its components Examples: layer_name = "model.layers.0.self_attn.qkv_proj" diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 41e2e19785c..a6461900d13 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -207,25 +207,18 @@ class Fp8Config(QuantizationConfig): return Fp8KVCacheMethod(self) return None - def get_cache_scale(self, name: str) -> str | None: - """ - Check whether the param name matches the format for k/v cache scales - in compressed-tensors. If this is the case, return its equivalent - param name expected by vLLM + def get_cache_scale_mapper(self) -> "WeightsMapper": + """Map compressed-tensors KV-cache scale names to vLLM names.""" + from vllm.model_executor.models.utils import WeightsMapper - :param name: param name - :return: matching param name for KV cache scale in vLLM - """ - if name.endswith(".output_scale") and ".k_proj" in name: - return name.replace(".k_proj.output_scale", ".attn.k_scale") - if name.endswith(".output_scale") and ".v_proj" in name: - return name.replace(".v_proj.output_scale", ".attn.v_scale") - if name.endswith(".output_scale") and ".q_proj" in name: - return name.replace(".q_proj.output_scale", ".attn.q_scale") - if name.endswith("self_attn.prob_output_scale"): - return name.replace(".prob_output_scale", ".attn.prob_scale") - # If no matches, return None - return None + return WeightsMapper( + orig_to_new_suffix={ + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + ".q_proj.output_scale": ".attn.q_scale", + ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", + } + ) class CopyNumelCounter(TorchDispatchMode): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index dca49d7ed97..7458b70ea81 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -114,8 +114,9 @@ class GGUFConfig(QuantizationConfig): Interface for models to update module names referenced in quantization configs in order to reflect the vllm model structure - :param hf_to_vllm_mapper: maps from hf model structure (the assumed - structure of the qconfig) to vllm model structure + Args: + hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure """ if self.unquantized_modules is not None: self.unquantized_modules = hf_to_vllm_mapper.apply_list( diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index d7fa6cf2633..e8810919c20 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -46,16 +46,17 @@ class QuantFP8(CustomOp): compile_native: bool = True, ): """ - :param static: static or dynamic quantization - :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, - PER_CHANNEL, or arbitrary block size) - :param num_token_padding: Pad the token dimension of output to this - size - :param tma_aligned_scales: For group quantization, output scales in - TMA-aligned layout - :param column_major_scales: For group quantization, output scales in - column major format - :param compile_native: Manually compile forward_native if compile mode > None + Args: + static: static or dynamic quantization + group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, + PER_CHANNEL, or arbitrary block size) + num_token_padding: Pad the token dimension of output to this + size + tma_aligned_scales: For group quantization, output scales in + TMA-aligned layout + column_major_scales: For group quantization, output scales in + column major format + compile_native: Manually compile forward_native if compile mode > None """ super().__init__(compile_native=compile_native) self.static = static diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 726ac2232af..100632686b0 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -15,6 +15,30 @@ from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales logger = init_logger(__name__) +class KVCacheScaleParameter(torch.nn.Parameter): + """Scalar parameter for KV-cache scales. + + Initialized to -1.0 (an invalid sentinel) so call sites just write + `KVCacheScaleParameter()`. The `weight_loader` accepts shape `()` or + `(1,)` and rejects anything else — per-head scales go through a separate + path (compressed-tensors' `_tp_aware_loader`), not this one. Per-instance + overrides still work because instance attribute assignment shadows this + class-level loader. + """ + + def __new__(cls) -> "KVCacheScaleParameter": + return super().__new__(cls, torch.tensor(-1.0), requires_grad=False) + + @staticmethod + def weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + if loaded_weight.numel() != 1: + raise ValueError( + f"KV-cache scale expects a scalar weight, got shape " + f"{tuple(loaded_weight.shape)}" + ) + param.data.copy_(loaded_weight.reshape(())) + + class BaseKVCacheMethod(QuantizeMethodBase): """ Quant method that adds `_k_scale` and `_v_scale` attributes to the @@ -23,7 +47,8 @@ class BaseKVCacheMethod(QuantizeMethodBase): - quantize k/v_cache entries before saving them to the cache - dequantize k/v_cache entries before fetching them from the cache - :param quant_config: the appropriate QuantizationConfig + Args: + quant_config: the appropriate QuantizationConfig """ def __init__(self, quant_config: QuantizationConfig): @@ -37,11 +62,11 @@ class BaseKVCacheMethod(QuantizeMethodBase): # Initialize the Q and KV cache scales to -1.0, an invalid value. # If the q and k/v_scales appear in the checkpoint, it will be # overwritten when loading weights. - layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) - layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) - layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.q_scale = KVCacheScaleParameter() + layer.k_scale = KVCacheScaleParameter() + layer.v_scale = KVCacheScaleParameter() # Initialize P = softmax(QK^T) scales - layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.prob_scale = KVCacheScaleParameter() def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d1f7a169ee7..424fdf2fba0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -87,8 +87,9 @@ class QuarkConfig(QuantizationConfig): Interface for models to update module names referenced in quantization configs in order to reflect the vllm model structure - :param hf_to_vllm_mapper: maps from hf model structure (the assumed - structure of the qconfig) to vllm model structure + Args: + hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure """ quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {} @@ -646,26 +647,16 @@ class QuarkConfig(QuantizationConfig): return scheme - def get_cache_scale(self, name: str) -> str | None: - """ - Check whether the param name matches the format for k/v cache scales - in quark. If this is the case, return its equivalent param name - expected by vLLM - - :param name: param name - :return: matching param name for KV cache scale in vLLM - """ - if name.endswith(".output_scale") and ".k_proj" in name: - return name.replace(".k_proj.output_scale", ".attn.k_scale") - if name.endswith(".output_scale") and ".v_proj" in name: - return name.replace(".v_proj.output_scale", ".attn.v_scale") - if name.endswith(".output_scale") and ".q_proj" in name: - return name.replace(".q_proj.output_scale", ".attn.q_scale") - if name.endswith("self_attn.prob_output_scale"): - return name.replace(".prob_output_scale", ".attn.prob_scale") - - # If no matches, return None - return None + def get_cache_scale_mapper(self) -> "WeightsMapper": + """Map Quark KV-cache scale names to vLLM names.""" + return WeightsMapper( + orig_to_new_suffix={ + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + ".q_proj.output_scale": ".attn.q_scale", + ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", + } + ) class QuarkLinearMethod(LinearMethodBase): @@ -734,7 +725,9 @@ class QuarkKVCacheMethod(BaseKVCacheMethod): """ Validator for the kv cache configuration. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM - :param kv_cache_config: the quark kv cache scheme + + Args: + kv_cache_config: the quark kv cache scheme """ if kv_cache_config is None: return diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py index 412a07a85fe..6f8db9ea57d 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -38,11 +38,11 @@ class QuarkScheme(ABC): Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. - :param x: input to the layer - :param bias: bias parameter - + Args: + layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + x: input to the layer + bias: bias parameter """ raise NotImplementedError diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 7362abcc8fb..bfaf81ad007 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -251,7 +251,7 @@ class DeepseekV4ScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange( self.max_position_embeddings * self.scaling_factor, - device=current_platform.device_type, + device=inv_freq.device, dtype=torch.float32, ) freqs = torch.einsum("i,j -> ij", t, inv_freq) diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 40dd6dc9f39..6cf1c19cba4 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -123,7 +123,8 @@ def initialize_online_processing(layer: torch.nn.Module): Called by either `initialize_layerwise_reload` or an online quantization scheme, prevents double wrapping in the case of online quantization + reloading - :param layer: layer whose parameter weight loaders will be wrapped + Args: + layer: layer whose parameter weight loaders will be wrapped """ info = get_layerwise_info(layer) @@ -222,8 +223,9 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon This function should be applied after `initialize_layerwise_reload` is applied unwrap the layerwise weight loaders. - :param model: model to finalize processing for - :param model_config: config needed for applying processing to attention layers + Args: + model: model to finalize processing for + model_config: config needed for applying processing to attention layers """ if hasattr(model, "_original_do_torchao_reload"): model._do_torchao_reload = model._original_do_torchao_reload diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index 397a458cbdd..283a98de284 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -175,9 +175,12 @@ def get_numel_loaded( """ Determine how many elements would be loaded by a weight loader call. - :param weight loader: used to load weights - :param args: bound arguments to weight loader - :return: number of elements loaded by the weight loader, the return value of the + Args: + weight_loader: used to load weights + args: bound arguments to weight loader + + Returns: + number of elements loaded by the weight loader, the return value of the weight loader """ with CopyCounter() as counter: diff --git a/vllm/model_executor/model_loader/reload/sanitize.py b/vllm/model_executor/model_loader/reload/sanitize.py index 2a6dc7182d0..21c47a2257f 100644 --- a/vllm/model_executor/model_loader/reload/sanitize.py +++ b/vllm/model_executor/model_loader/reload/sanitize.py @@ -20,9 +20,12 @@ def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.T tensors will reference layers, and the WeakKeyDictionary will never evict entries, even when the model is deleted. - :param tensor: tensor to be sanitized - :param layer: layer whose references should be removed - :return: sanitized tensor + Args: + tensor: tensor to be sanitized + layer: layer whose references should be removed + + Returns: + sanitized tensor """ for key, value in tensor.__dict__.items(): if isinstance(value, MethodType) and value.__self__ is layer: @@ -38,10 +41,12 @@ def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Te Used by `restore_layer_on_meta` to add back layer references, allowing for proper weight loading. - :param tensor: tensor to be sanitized - :param layer: layer whose references should be removed - :return: sanitized tensor + Args: + tensor: tensor to be sanitized + layer: layer whose references should be removed + Returns: + sanitized tensor """ for key, value in tensor.__dict__.items(): if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel: diff --git a/vllm/model_executor/model_loader/reload/utils.py b/vllm/model_executor/model_loader/reload/utils.py index 7a3d6873e10..f0078d0f9d8 100644 --- a/vllm/model_executor/model_loader/reload/utils.py +++ b/vllm/model_executor/model_loader/reload/utils.py @@ -49,8 +49,11 @@ def has_device_tensors(bound_args: BoundArguments) -> bool: """ Return True if the loaded weights exist on an accelerator device - :param bound_args: args to load weights - :return: True if weights are on accelerator device + Args: + bound_args: args to load weights + + Returns: + True if weights are on accelerator device """ return any( isinstance(value, torch.Tensor) and value.device.type not in ("meta", "cpu") @@ -62,8 +65,11 @@ def get_info_size(info: LayerReloadingInfo) -> int: """ Calculate the number of bytes used by loaded weights for a given layer - :param info: layerwise info to get size of - :return: number of bytes used by loaded weights + Args: + info: layerwise info to get size of + + Returns: + number of bytes used by loaded weights """ return sum( value.nbytes diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 736b2134604..008abb6fdfe 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -687,7 +687,7 @@ def serialize_vllm_model( serializer = TensorSerializer( stream, encryption=encryption_params, - **tensorizer_config.serialization_kwargs, + **(tensorizer_config.serialization_kwargs or {}), ) serializer.write_module(model) serializer.close() diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index cbb191ebb62..76a83898989 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1541,6 +1541,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: if no remapping is needed. None: If the remapped name is not found in params_dict. """ + # Already in vLLM's expected form (e.g. weights pre-renamed by a + # `WeightsMapper` from the quant config). Skip the regex remap, which + # would otherwise double-apply the `.attn` prefix and drop the weight. + if name in params_dict: + return name if name.endswith(".kv_scale"): logger.warning_once( "DEPRECATED. Found kv_scale in the checkpoint. " diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 5905a198b28..0711fb03f84 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -430,18 +430,6 @@ class ApertusModel(nn.Module, EagleModelMixin): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index eb8c3e3f65e..d25c954fc19 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -293,18 +293,6 @@ class ArceeModel(nn.Module, EagleModelMixin): if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if "scale" in name or "zero_point" in name: remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is None: diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 55bc64cd94a..e453e6e6cf6 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -363,18 +363,6 @@ class AriaTextModel(LlamaModel, SupportsQuant): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/cohere2_moe.py b/vllm/model_executor/models/cohere2_moe.py index aa8adff188f..16e299ea27a 100644 --- a/vllm/model_executor/models/cohere2_moe.py +++ b/vllm/model_executor/models/cohere2_moe.py @@ -464,18 +464,6 @@ class Cohere2MoeModel(nn.Module): if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/cohere_eagle.py b/vllm/model_executor/models/cohere_eagle.py index 5c22d6e34dd..7b57c739ffe 100644 --- a/vllm/model_executor/models/cohere_eagle.py +++ b/vllm/model_executor/models/cohere_eagle.py @@ -150,18 +150,6 @@ class CohereEagleModel(nn.Module): if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e73dfb1f01e..317269ec3b6 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -352,19 +352,6 @@ class CohereModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 6c798bf2f36..32f4f15c7a3 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -394,19 +394,6 @@ class DbrxModel(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if name.endswith(("w1", "w2", "v1")): name = name + "_weight" for param_name, weight_name in expert_params_mapping: diff --git a/vllm/model_executor/models/deepseek_eagle3.py b/vllm/model_executor/models/deepseek_eagle3.py index 9b96cdec830..dc153ac9e0b 100644 --- a/vllm/model_executor/models/deepseek_eagle3.py +++ b/vllm/model_executor/models/deepseek_eagle3.py @@ -284,19 +284,6 @@ class DeepseekV2Eagle3Model(nn.Module): if "midlayer." in name: name = name.replace("midlayer.", "layers.0.") - # Handle kv cache quantization scales - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - # Remapping the name FP8 kv-scale if "scale" in name: name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index b633fd28508..dca05f72c69 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -391,18 +391,6 @@ class ExaoneModel(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 04708de93d3..e38dbb5ee29 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -389,18 +389,6 @@ class Exaone4Model(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/exaone_moe.py b/vllm/model_executor/models/exaone_moe.py index 80b7e0957e8..3373983f5c9 100644 --- a/vllm/model_executor/models/exaone_moe.py +++ b/vllm/model_executor/models/exaone_moe.py @@ -374,18 +374,6 @@ class ExaoneMoeModel(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 425ecc65195..733eb3ed3c1 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -328,16 +328,6 @@ class Gemma2Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index f61f7c6f780..7bae2b1a5e7 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -386,17 +386,6 @@ class Gemma3Model(nn.Module): ): loaded_weight -= 1 - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - # Check if this is a scale parameter that needs remapping first if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): # Try to remap the scale name first diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 770424ba0fd..ad8b21d86b4 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -1056,16 +1056,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ): name = f"self_decoder.{name}" - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 5d0e3efe2e1..d14a767df5f 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -1409,16 +1409,6 @@ class Gemma4Model(nn.Module, EagleModelMixin): params_dict.update(dict(self.named_buffers())) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is not None and remapped_name in params_dict: diff --git a/vllm/model_executor/models/gemma4_unified.py b/vllm/model_executor/models/gemma4_unified.py index 66fc914dc75..9cc0710c4d0 100644 --- a/vllm/model_executor/models/gemma4_unified.py +++ b/vllm/model_executor/models/gemma4_unified.py @@ -80,7 +80,7 @@ class Gemma4UnifiedVisionEmbedder(nn.Module): Pipeline: raw patches → LN₁ → Dense → LN₂ → +factorized_posemb → LN₃. """ - def __init__(self, config, quant_config=None): + def __init__(self, config, quant_config=None, prefix=""): super().__init__() patch_dim = config.model_patch_size**2 * 3 mm_embed_dim = config.mm_embed_dim @@ -91,6 +91,7 @@ class Gemma4UnifiedVisionEmbedder(nn.Module): mm_embed_dim, bias=True, quant_config=quant_config, + prefix=f"{prefix}.patch_dense", gather_output=True, ) self.patch_ln2 = nn.LayerNorm(mm_embed_dim) @@ -267,6 +268,7 @@ class Gemma4UnifiedForConditionalGeneration(Gemma4ForConditionalGeneration): Gemma4UnifiedVisionEmbedder( config.vision_config, quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_embedder"), ) if config.vision_config is not None else None diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 89447927d5c..4587a692766 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -258,18 +258,6 @@ class Glm4Model(LlamaModel): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale or zero point. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/glm_ocr_mtp.py b/vllm/model_executor/models/glm_ocr_mtp.py index 34e602bb669..3d283c101ca 100644 --- a/vllm/model_executor/models/glm_ocr_mtp.py +++ b/vllm/model_executor/models/glm_ocr_mtp.py @@ -166,6 +166,10 @@ class GlmOcrMTP(nn.Module, SupportsPP): return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + if self.quant_config is not None and ( + cache_scale_mapper := self.quant_config.get_cache_scale_mapper() + ): + weights = cache_scale_mapper.apply(weights) stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -189,19 +193,6 @@ class GlmOcrMTP(nn.Module, SupportsPP): name = self._rewrite_spec_layer_name(spec_layer, name) - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale or zero point. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index c29103c6d52..30da9b4dea2 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -254,19 +254,6 @@ class GPTJModel(nn.Module): if "attn.bias" in name or "attn.masked_bias" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index d12db96c5d4..920d43392cf 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -635,52 +635,6 @@ class GptOssModel(nn.Module, EagleModelMixin): moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id) - def kv_cache_scale_loader( - quant_config: QuantizationConfig, - name: str, - params_dict: dict[str, typing.Any], - weight: torch.Tensor, - default_weight_loader: Callable[..., None], - loaded_params: set[str], - ) -> tuple[bool, set[str]]: - """ - Load KV cache output scales. - Returns: - Tuple of (bool, set): - - bool: True if KV-cache scale was loaded into loaded_params - - set: Updated set of loaded_params if True else the original set - """ - # load explicit cached KV output scale from quant_config - if quant_config is not None and ( - scale_name := quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - if weight.numel() != 1: - raise ValueError( - f"KV cache scale '{scale_name}' is expected to be a " - f"scalar, but got a tensor of shape {weight.shape}." - ) - # Ensure weight is a scalar before passing to loader. - weight_loader(param, weight.flatten()[0]) - loaded_params.add(scale_name) - return True, loaded_params - - return False, loaded_params - - load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader( - self.quant_config, - name, - params_dict, - loaded_weight, - default_weight_loader, - loaded_params, - ) - if load_kv_cache_scale_completed: - continue - if ( all(key in name for key in ["input_scale", "mlp.experts"]) and expert_id is not None diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 4b486ede443..2adc29f8d25 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -334,18 +334,6 @@ class GraniteModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index e3585a6dd74..5909604bd54 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -365,19 +365,6 @@ class GraniteMoeModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1ab069e3ba3..4f861fc5016 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -495,18 +495,6 @@ class GraniteMoeHybridModel(nn.Module): if "A_log" in n: n = n.replace("A_log", "A") - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(n) - ): - # Loading kv cache quantization scales - loaded_weight = p - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - _load(scale_name, loaded_weight) - loaded_params.add(scale_name) - continue - if _load_quant_expert(n, p): continue diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index f06122a7fd1..3fc3d1a2d2c 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -548,18 +548,6 @@ class Grok1Model(nn.Module): for old_pattern, new_pattern in self.weight_name_remapping.items(): if old_pattern in name: name = name.replace(old_pattern, new_pattern) - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index b900c0ed83e..ec3cfbd017b 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -771,15 +771,6 @@ class HunYuanModel(nn.Module, EagleModelMixin): # processed with quantization, LoRA, fine-tuning, etc. if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - continue is_found = False for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/model_executor/models/hy_v3.py b/vllm/model_executor/models/hy_v3.py index bfff84b8049..6142f2c086a 100644 --- a/vllm/model_executor/models/hy_v3.py +++ b/vllm/model_executor/models/hy_v3.py @@ -531,17 +531,6 @@ class HYV3Model(nn.Module): for name, loaded_weight in weights: if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/hy_v3_mtp.py b/vllm/model_executor/models/hy_v3_mtp.py index 8594a38c3ab..37ace1dada3 100644 --- a/vllm/model_executor/models/hy_v3_mtp.py +++ b/vllm/model_executor/models/hy_v3_mtp.py @@ -264,6 +264,10 @@ class HYV3MTP(nn.Module): return torch.concat((q, k, v)) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + if self.quant_config is not None and ( + cache_scale_mapper := self.quant_config.get_cache_scale_mapper() + ): + weights = cache_scale_mapper.apply(weights) cla_factor = _get_cla_factor(self.config) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -336,14 +340,6 @@ class HYV3MTP(nn.Module): continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is None: continue diff --git a/vllm/model_executor/models/hyperclovax.py b/vllm/model_executor/models/hyperclovax.py index 3176c428413..2f54f78e758 100644 --- a/vllm/model_executor/models/hyperclovax.py +++ b/vllm/model_executor/models/hyperclovax.py @@ -395,18 +395,6 @@ class HyperCLOVAXModel(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale or zero point. remapped_name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/iquest_loopcoder.py b/vllm/model_executor/models/iquest_loopcoder.py index 24c004ff4c2..3755cba5d1a 100644 --- a/vllm/model_executor/models/iquest_loopcoder.py +++ b/vllm/model_executor/models/iquest_loopcoder.py @@ -476,18 +476,6 @@ class IQuestLoopCoderModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if "gate_projections" in name: continue diff --git a/vllm/model_executor/models/jais2.py b/vllm/model_executor/models/jais2.py index 4e03eb12ee4..dafa0f03ae9 100644 --- a/vllm/model_executor/models/jais2.py +++ b/vllm/model_executor/models/jais2.py @@ -386,16 +386,6 @@ class Jais2Model(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache scales for compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 86d7edc25f9..8a0b7ceea5f 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -790,21 +790,6 @@ class KeyeSiglipVisionModel(nn.Module): continue if "head.mlp" in name or "head.probe" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr( - param, - "weight_loader", - default_weight_loader, - ) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for ( param_name, weight_name, diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index a891950fa57..307b24ac112 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -600,7 +600,7 @@ class KimiLinearForCausalLM( def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.kda_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype ) @@ -608,7 +608,7 @@ class KimiLinearForCausalLM( @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + ) -> tuple[tuple[int, ...], tuple[int, ...]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size @@ -628,9 +628,7 @@ class KimiLinearForCausalLM( @classmethod def get_mamba_state_copy_func( cls, - ) -> tuple[ - MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc - ]: + ) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: return MambaStateCopyFuncCalculator.kda_state_copy_func() def compute_logits( diff --git a/vllm/model_executor/models/laguna.py b/vllm/model_executor/models/laguna.py index f79f6097c61..5572481565b 100644 --- a/vllm/model_executor/models/laguna.py +++ b/vllm/model_executor/models/laguna.py @@ -724,20 +724,6 @@ class LagunaModel(nn.Module, EagleModelMixin): loaded_params.add(name) continue - # Handle KV cache quantization scales - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - assert loaded_weight.numel() == 1, ( - f"KV scale numel {loaded_weight.numel()} != 1" - ) - loaded_weight = loaded_weight.squeeze() - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - # Handle stacked params (QKV, gate_up for # non-expert layers and shared_expert) for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3c797d05e93..cf59ccdf750 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -451,18 +451,6 @@ class LlamaModel(nn.Module, EagleModelMixin): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale or zero point. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index bfcb72a6a74..07ca2714ed2 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -588,21 +588,6 @@ class Llama4Model(LlamaModel): fused_experts_params = True expert_params_mapping = expert_params_mapping_fused - # If kv cache quantization scales exist and the weight name - # corresponds to one of the kv cache quantization scales, load - # them. - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - # Iterate over stacked_params_mapping to check if the current weight # is one of the stacked parameters. If so, load the weight with the # corresponding shard id. Note that MoE weights are handled @@ -625,9 +610,9 @@ class Llama4Model(LlamaModel): if is_pp_missing_parameter(name, self): continue - # Remap kv cache scale names for ModelOpt checkpoints. - # TODO: ModelOpt should implement get_cache_scale() such that - # kv cache scale name remapping can be done there. + # Remap kv cache scale names for any checkpoint format the + # quant config's `get_cache_scale_mapper` does not cover + # (idempotent for names already renamed by the mapper). if name.endswith("scale"): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 585c8f6dbd2..14842a75fea 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -127,19 +127,6 @@ class LlamaModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - # Handle kv cache quantization scales - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue # Remapping the name FP8 kv-scale or zero point. if "scale" in name or "zero_point" in name: name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 9fd6652aa24..bb1bbb85537 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -267,19 +267,6 @@ class LlamaModel(nn.Module): for name, loaded_weight in weights: if "midlayer." in name: name = name.replace("midlayer.", "layers.0.") - # Handle kv cache quantization scales - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue # Remapping the name FP8 kv-scale or zero point. if "scale" in name or "zero_point" in name: name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index a7699f0d598..4f67d468ace 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -104,18 +104,6 @@ class MiMoModel(Qwen2Model): continue if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/mimo_v2.py b/vllm/model_executor/models/mimo_v2.py index 3f466162649..7c6d5363c0a 100644 --- a/vllm/model_executor/models/mimo_v2.py +++ b/vllm/model_executor/models/mimo_v2.py @@ -555,22 +555,6 @@ class MiMoV2Model(nn.Module): if "mtp" in name: continue - if self.quant_config is not None: - cache_scale_name = self.quant_config.get_cache_scale(name) - if cache_scale_name is not None and cache_scale_name in params_dict: - param = params_dict[cache_scale_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - - kv_scale = loaded_weight - if kv_scale.dim() > 0 and kv_scale.numel() > 1: - kv_scale = kv_scale.view(-1)[0] - - weight_loader(param, kv_scale) - loaded_params.add(cache_scale_name) - continue - expert_matched = False for param_name, weight_name, expert_id, shard_id in expert_params_mapping: if weight_name not in name: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index cbfc254dda3..53c1c87cfce 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -388,19 +388,6 @@ class MixtralModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 15d43a9ddf9..7b2e6b93b27 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -375,18 +375,6 @@ class NemotronModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index f2f3811c064..b974a3eb085 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -334,18 +334,6 @@ class DeciModel(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 4491a6a3ea1..541f60c2c40 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -277,7 +277,8 @@ class OlmoModel(nn.Module): inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: """ - :param input_ids: A tensor of shape `(batch_size, seq_len)`. + Args: + input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: if inputs_embeds is not None: diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 212140fe15e..ad04b258bde 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,8 @@ class Olmo2Model(nn.Module): inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: """ - :param input_ids: A tensor of shape `(batch_size, seq_len)`. + Args: + input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: if inputs_embeds is not None: diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index 56505ec7be2..503d4b5c834 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -390,18 +390,6 @@ class OuroModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index cd88009c739..0bf10e3ce77 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -944,21 +944,6 @@ class SiglipVisionModel(nn.Module): continue if "packing_position_embedding" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr( - param, - "weight_loader", - default_weight_loader, - ) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for ( param_name, weight_name, diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 5770420ce56..a49e8ce2e82 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -537,19 +537,6 @@ class PhiMoEModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b83fedc70db..9c39c649708 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -439,18 +439,6 @@ class Qwen2Model(nn.Module, EagleModelMixin): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_dflash.py b/vllm/model_executor/models/qwen3_dflash.py index 25f139f26cb..820260f795c 100644 --- a/vllm/model_executor/models/qwen3_dflash.py +++ b/vllm/model_executor/models/qwen3_dflash.py @@ -469,17 +469,6 @@ class DFlashQwen3Model(nn.Module): for name, loaded_weight in weights: if "midlayer." in name: name = name.replace("midlayer.", "layers.0.") - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 4ec1be3367d..6980184cc8a 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -552,19 +552,6 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - assert loaded_weight.numel() == 1, ( - f"KV scale numel {loaded_weight.numel()} != 1" - ) - loaded_weight = loaded_weight.squeeze() - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue if "scale" in name or "zero_point" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/rnj1.py b/vllm/model_executor/models/rnj1.py index f83577b7a39..68c3722e2bc 100644 --- a/vllm/model_executor/models/rnj1.py +++ b/vllm/model_executor/models/rnj1.py @@ -350,16 +350,6 @@ class Rnj1Model(nn.Module): ): loaded_weight -= 1 - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue - if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is not None and remapped_name in params_dict: diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index d90174911fb..48147f7334e 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -376,18 +376,6 @@ class SeedOssModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index bff866d0d0c..454a0e97112 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -360,18 +360,6 @@ class SolarModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name) - ): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = ( - loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] - ) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 095d0e363d5..83d113415dc 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -52,6 +52,7 @@ class WeightsMapper: def __or__(self, other: "WeightsMapper") -> "WeightsMapper": """Combine two `WeightsMapper`s by merging their mappings.""" return WeightsMapper( + orig_to_new_regex={**self.orig_to_new_regex, **other.orig_to_new_regex}, orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr}, orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix}, orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix}, @@ -343,6 +344,20 @@ class AutoWeightsLoader: *, mapper: WeightsMapper | None = None, ) -> set[str]: + # Many models store quant_config in the base model instead of the causal model. + # We look at the causal model's direct children for this reason. + modules = (self.module, *self.module.children()) + iterator = (m.quant_config for m in modules if hasattr(m, "quant_config")) + quant_config = next(iterator, None) + cache_scale_mapper = ( + quant_config.get_cache_scale_mapper() if quant_config is not None else None + ) + if cache_scale_mapper is not None: + mapper = ( + mapper | cache_scale_mapper + if mapper is not None + else cache_scale_mapper + ) if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 4106672d501..7f96ceda09d 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -3,6 +3,7 @@ from collections.abc import Callable, Hashable from fractions import Fraction +from typing import Any from weakref import WeakValueDictionary import torch @@ -42,10 +43,9 @@ class BasevLLMParameter(Parameter): """ Initialize the BasevLLMParameter - :param data: torch tensor with the parameter data - :param weight_loader: weight loader callable - - :returns: a torch.nn.parameter + Args: + data: torch tensor with the parameter data + weight_loader: weight loader callable """ # During weight loading, we often do something like: @@ -445,15 +445,16 @@ class SharedWeightParameter(BasevLLMParameter): "currently support tensor parallelism" ) - def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): + def add_partition(self, index: int, data_key: Hashable, *args: Any, **kwargs: Any): """ Add a partition to the weight parameter. Partitions whose `data_key` is the same will share tensor data - :param index: index of partition to add - :param data_key: hashable key used to key shared tensors - :param *args: arguments for `torch.empty` - :param **kwargs: keyword arguments for `torch.empty` + Args: + index: index of partition to add + data_key: hashable key used to key shared tensors + *args: arguments for `torch.empty` + **kwargs: keyword arguments for `torch.empty` """ # load (shared) tensor using `data_key` if data_key not in self.tensors_registry: diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index fb724fbe2f1..c32be5768bf 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClam from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, ) @@ -45,11 +44,7 @@ from vllm.model_executor.models.utils import ( make_layers, maybe_prefix, ) -from vllm.models.deepseek_v4.attention import ( - DeepseekV4Indexer, - DeepseekV4MLA, -) -from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope +from vllm.models.deepseek_v4.amd.rocm import DeepseekV4ROCMAiterMLAAttention from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_tilelang @@ -225,158 +220,6 @@ class DeepseekV4MoE(nn.Module): return final_hidden_states.view(org_shape) -class DeepseekV4Attention(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, - ): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - layer_id = extract_layer_index(prefix) - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.n_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - assert self.n_heads % tp_size == 0 - - self.n_local_heads = self.n_heads // tp_size - self.q_lora_rank = config.q_lora_rank - self.o_lora_rank = config.o_lora_rank - self.head_dim = config.head_dim - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.n_groups = config.o_groups - self.n_local_groups = self.n_groups // tp_size - self.window_size = config.sliding_window - # NOTE(zyongye) Compress ratio can't be 0 - # we do this for because MTP layer is not included - # in the compress ratio list - if layer_id < config.num_hidden_layers: - self.compress_ratio = max(1, config.compress_ratios[layer_id]) - else: - self.compress_ratio = 1 - self.eps = config.rms_norm_eps - self.max_position_embeddings = config.max_position_embeddings - - # Padded to min 64 heads for FlashMLA, initialized to -inf - # (no sink effect). Weight loading fills the first n_local_heads slots. - padded_heads = max(self.n_local_heads, 64) - self.attn_sink = nn.Parameter( - torch.full((padded_heads,), -float("inf"), dtype=torch.float32), - requires_grad=False, - ) - - self.fused_wqa_wkv = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_wqa_wkv", - disable_tp=True, # fused ReplicatedLinear - ) - self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.wq_b = ColumnParallelLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wq_b", - ) - - self.kv_norm = RMSNorm(self.head_dim, self.eps) - self.wo_a = ColumnParallelLinear( - self.n_heads * self.head_dim // self.n_groups, - self.n_groups * self.o_lora_rank, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_a", - ) - self.wo_a.is_bmm = True - self.wo_a.bmm_batch_size = self.n_local_groups - self.wo_b = RowParallelLinear( - self.n_groups * self.o_lora_rank, - self.hidden_size, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_b", - ) - self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] - - self.rope_parameters = config.rope_scaling - - # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - self.rotary_emb = build_deepseek_v4_rope( - config, - head_dim=self.head_dim, - rope_head_dim=self.rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - compress_ratio=self.compress_ratio, - ) - - self.indexer = None - if self.compress_ratio == 4: - # Only C4A uses sparse attention and hence has indexer. - self.indexer = DeepseekV4Indexer( - vllm_config, - config=config, - hidden_size=self.hidden_size, - q_lora_rank=self.q_lora_rank, - quant_config=quant_config, - cache_config=vllm_config.cache_config, - topk_indices_buffer=topk_indices_buffer, - compress_ratio=self.compress_ratio, - prefix=f"{prefix}.indexer", - ) - - self.mla_attn = DeepseekV4MLA( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - vllm_config=vllm_config, - fused_wqa_wkv=self.fused_wqa_wkv, - q_norm=self.q_norm, - wq_b=self.wq_b, - kv_norm=self.kv_norm, - wo_a=self.wo_a, - wo_b=self.wo_b, - attn_sink=self.attn_sink, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - indexer_rotary_emb=self.rotary_emb, - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - window_size=self.window_size, - compress_ratio=self.compress_ratio, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None, - ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) - - class DeepseekV4DecoderLayer(nn.Module): def __init__( self, @@ -395,7 +238,7 @@ class DeepseekV4DecoderLayer(nn.Module): self.hidden_size = config.hidden_size self.rms_norm_eps = config.rms_norm_eps - self.attn = DeepseekV4Attention( + self.attn = DeepseekV4ROCMAiterMLAAttention( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, @@ -601,7 +444,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MLA.attn_gemm_parallel_execute + # DeepseekV4Attention.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. # Disable them on ROCm because of hang issues. @@ -897,7 +740,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", ".shared_experts.w2": ".shared_experts.down_proj", }, ) diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py index 7298f18365d..f7fb409af1b 100644 --- a/vllm/models/deepseek_v4/amd/rocm.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -2,15 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import cast import torch from vllm.forward_context import get_forward_context +from vllm.models.deepseek_v4.attention import DeepseekV4Attention from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache from vllm.models.deepseek_v4.nvidia.flashmla import ( DeepseekV4FlashMLASparseBackend, - DeepseekV4SparseMLAAttentionImpl, ) from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( @@ -26,16 +26,12 @@ from vllm.v1.attention.backends.mla.sparse_swa import ( ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( build_ragged_indices_from_dense, + rocm_inv_rope_einsum, rocm_sparse_attn_decode, rocm_sparse_attn_prefill, ) from vllm.v1.worker.workspace import current_workspace_manager -if TYPE_CHECKING: - from vllm.models.deepseek_v4.attention import ( - DeepseekV4MLAAttention, - ) - def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor: lengths = lengths.to(dtype=torch.int32).contiguous() @@ -582,13 +578,9 @@ class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend): def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]: return DeepseekV4ROCMAiterMLASparseMetadataBuilder - @staticmethod - def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]: - return DeepseekV4ROCMAiterMLASparseImpl - -class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): - """ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer.""" +class DeepseekV4ROCMAiterMLAAttention(DeepseekV4Attention): + """ROCm sparse MLA attention layer for DeepSeek V4.""" backend_cls = DeepseekV4ROCMAiterMLASparseBackend @@ -596,10 +588,21 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): def get_padded_num_q_heads(cls, num_heads: int) -> int: return num_heads - @classmethod - def forward_mqa( # type: ignore[override] - cls, - layer: "DeepseekV4MLAAttention", + def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + # ROCm BF16 reference wo_a path (inverse RoPE + einsum) + wo_b. + z = rocm_inv_rope_einsum( + self.rotary_emb, + o, + positions, + self.rope_head_dim, + self.n_local_groups, + self.o_lora_rank, + self.wo_a, + ) + return self.wo_b(z.flatten(1)) + + def forward_mqa( + self, q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -619,16 +622,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # Warmup dummy run: no real metadata. Reserve the same bf16 # gather workspace _forward_prefill would; the dequantize / topk # / sparse_fwd kernels are skipped this step. - swa_only = layer.compress_ratio <= 1 + swa_only = self.compress_ratio <= 1 N = ( 0 if swa_only - else (layer.max_model_len + layer.compress_ratio - 1) - // layer.compress_ratio + else (self.max_model_len + self.compress_ratio - 1) + // self.compress_ratio ) - M = N + layer.window_size + layer.max_num_batched_tokens + M = N + self.window_size + self.max_num_batched_tokens current_workspace_manager().get_simultaneous( - ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), ) output.zero_() return @@ -636,25 +639,24 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): assert isinstance(attn_metadata, dict) rocm_metadata = cast( DeepseekV4ROCMAiterMLASparseMetadata | None, - attn_metadata.get(layer.prefix), + attn_metadata.get(self.prefix), ) swa_metadata = cast( DeepseekV4ROCMAiterSparseSWAMetadata | None, - attn_metadata.get(layer.swa_cache_layer.prefix), + attn_metadata.get(self.swa_cache_layer.prefix), ) assert swa_metadata is not None - swa_only = layer.compress_ratio <= 1 - self_kv_cache = layer.kv_cache if not swa_only else None - swa_kv_cache = layer.swa_cache_layer.kv_cache + swa_only = self.compress_ratio <= 1 + self_kv_cache = self.kv_cache if not swa_only else None + swa_kv_cache = self.swa_cache_layer.kv_cache num_decodes = swa_metadata.num_decodes num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens if num_prefills > 0: - cls._forward_prefill( - layer=layer, + self._forward_prefill( q=q[num_decode_tokens:], positions=positions[num_decode_tokens:], compressed_k_cache=self_kv_cache, @@ -664,8 +666,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): swa_metadata=swa_metadata, ) if num_decodes > 0: - cls._forward_decode( - layer=layer, + self._forward_decode( q=q[:num_decode_tokens], kv_cache=self_kv_cache, swa_metadata=swa_metadata, @@ -674,10 +675,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): output=output[:num_decode_tokens], ) - @classmethod def _forward_decode( - cls, - layer: "DeepseekV4MLAAttention", + self, q: torch.Tensor, kv_cache: torch.Tensor | None, swa_metadata: DeepseekV4ROCMAiterSparseSWAMetadata, @@ -695,16 +694,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): if not swa_only: assert attn_metadata is not None assert swa_metadata.is_valid_token is not None - block_size = attn_metadata.block_size // layer.compress_ratio + block_size = attn_metadata.block_size // self.compress_ratio is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None ( topk_ragged_indices, topk_ragged_indptr, topk_lens, ) = compute_global_topk_ragged_indices_and_indptr( - layer.topk_indices_buffer[:num_decode_tokens], + self.topk_indices_buffer[:num_decode_tokens], swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], block_size, @@ -719,7 +718,7 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): rocm_sparse_attn_decode( q=q, kv_cache=kv_cache, - swa_k_cache=layer.swa_cache_layer.kv_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, swa_only=swa_only, topk_indices=topk_indices, topk_lens=topk_lens, @@ -729,18 +728,16 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): swa_ragged_indptr=swa_metadata.decode_swa_ragged_indptr, topk_ragged_indices=topk_ragged_indices, topk_ragged_indptr=topk_ragged_indptr, - attn_sink=layer.attn_sink, - scale=layer.scale, - head_dim=layer.head_dim, - nope_head_dim=layer.nope_head_dim, - rope_head_dim=layer.rope_head_dim, + attn_sink=self.attn_sink, + scale=self.scale, + head_dim=self.head_dim, + nope_head_dim=self.nope_head_dim, + rope_head_dim=self.rope_head_dim, output=output, ) - @classmethod def _forward_prefill( - cls, - layer: "DeepseekV4MLAAttention", + self, q: torch.Tensor, positions: torch.Tensor, compressed_k_cache: torch.Tensor | None, @@ -768,34 +765,34 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): prefill_token_base = query_start_loc_cpu[num_decodes] if not swa_only: - if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None - topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] topk_indices = topk_indices[:num_prefill_tokens] else: assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices assert topk_indices is not None top_k = topk_indices.shape[-1] - N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio + N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio else: - assert layer.topk_indices_buffer is not None - topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 N = 0 - M = N + layer.window_size + layer.max_num_batched_tokens - num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // ( - cls.PREFILL_CHUNK_SIZE + M = N + self.window_size + self.max_num_batched_tokens + num_chunks = (num_prefills + self.PREFILL_CHUNK_SIZE - 1) // ( + self.PREFILL_CHUNK_SIZE ) workspace_manager = current_workspace_manager() kv = workspace_manager.get_simultaneous( - ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), )[0] for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE - chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills) + chunk_start = chunk_idx * self.PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + self.PREFILL_CHUNK_SIZE, num_prefills) chunk_size = chunk_end - chunk_start if not swa_only: assert attn_metadata is not None @@ -804,10 +801,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): dequantize_and_gather_k_cache( kv[:chunk_size], compressed_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio, + seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, gather_lens=None, block_table=block_table[chunk_start:chunk_end], - block_size=attn_metadata.block_size // layer.compress_ratio, + block_size=attn_metadata.block_size // self.compress_ratio, offset=0, ) @@ -836,8 +833,8 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): ], seq_lens[chunk_start:chunk_end], gather_lens[chunk_start:chunk_end], - layer.window_size, - layer.compress_ratio, + self.window_size, + self.compress_ratio, top_k, M, N, @@ -847,10 +844,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): kv=kv.view(-1, 1, q.shape[-1]), indices=combined_indices, topk_length=combined_lens, - scale=layer.scale, - head_dim=layer.head_dim, - nope_head_dim=layer.nope_head_dim, - rope_head_dim=layer.rope_head_dim, - attn_sink=layer.attn_sink, + scale=self.scale, + head_dim=self.head_dim, + nope_head_dim=self.nope_head_dim, + rope_head_dim=self.rope_head_dim, + attn_sink=self.attn_sink, output=output[query_start:query_end], ) diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 5f13d1bd8d0..29302584880 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -4,8 +4,9 @@ DeepseekV4 MLA Attention Layer """ +from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -15,16 +16,16 @@ from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, ReplicatedLinear, + RowParallelLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.models.deepseek_v4.common.ops import ( fused_indexer_q_rope_quant, - fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, ) -from vllm.utils.deep_gemm import fp8_einsum -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -42,14 +43,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.input_quant_fp8 import ( - QuantFP8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, -) +from vllm.model_executor.models.utils import extract_layer_index +from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.models.deepseek_v4.compressor import DeepseekCompressor -from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -62,187 +58,209 @@ from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -if TYPE_CHECKING: - from vllm.models.deepseek_v4.nvidia.flashmla import ( - DeepseekV4SparseMLAAttentionImpl, - ) - logger = init_logger(__name__) -def _resolve_dsv4_backend(vllm_config: VllmConfig | None): - """Return the explicitly-requested DSv4 sparse backend enum, or None.""" - if vllm_config is None: - return None - attn_config = getattr(vllm_config, "attention_config", None) - return getattr(attn_config, "backend", None) if attn_config is not None else None - - -def _select_v4_sparse_impl( - vllm_config: VllmConfig | None = None, -) -> "type[DeepseekV4SparseMLAAttentionImpl]": - """Pick the V4 sparse MLA impl class. - - An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the - FlashInfer TRTLLM-gen path; otherwise the platform default (FlashMLA on - NVIDIA, ROCm Aiter on AMD) is used. - """ - from vllm.v1.attention.backends.registry import AttentionBackendEnum - - backend = _resolve_dsv4_backend(vllm_config) - if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4: - from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import ( - DeepseekV4FlashInferMLASparseImpl, - ) - - logger.info_once("Using FLASHINFER_MLA_SPARSE_DSV4 backend.") - return DeepseekV4FlashInferMLASparseImpl - if current_platform.is_rocm(): - from vllm.models.deepseek_v4.amd.rocm import ( - DeepseekV4ROCMAiterMLASparseImpl, - ) - - logger.info_once("Using ROCM_FLASHMLA_SPARSE_DSV4 backend.") - return DeepseekV4ROCMAiterMLASparseImpl - from vllm.models.deepseek_v4.nvidia.flashmla import ( - DeepseekV4FlashMLASparseImpl, - ) - - logger.info_once("Using FLASHMLA_SPARSE_DSV4 backend.") - return DeepseekV4FlashMLASparseImpl - - def _resolve_dsv4_kv_cache_dtype( - backend, + use_flashmla_fp8_layout: bool, kv_cache_dtype: str, cache_config: CacheConfig | None, ) -> tuple[str, torch.dtype]: - """Map ``(backend, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``. + """Map ``(layout, --kv-cache-dtype)`` to ``(cache_dtype_str, torch_dtype)``. - FlashInfer V4 reads a contiguous 512-wide KV row (bf16 or per-tensor FP8 - E4M3); FlashMLA V4 reads the legacy UE8M0 paged layout (uint8 / - ``fp8_ds_mla``). For FlashMLA the canonical ``fp8_ds_mla`` string is - written back onto ``cache_config`` so the page-size specs pick the 576B - layout. + Both layouts are paged; they differ in the per-token block format. The + FlashMLA fp8 layout (FlashMLA / ROCm Aiter) is the ``fp8_ds_mla`` format: + UE8M0 block-scaled fp8 packed as ``uint8`` (the canonical ``fp8_ds_mla`` + string is written back onto ``cache_config`` so the page-size specs pick + the 576B per-token slot). Otherwise (FlashInfer) each token's KV row is + stored in its plain element dtype — bf16 or per-tensor FP8 E4M3. """ - from vllm.v1.attention.backends.registry import AttentionBackendEnum + if use_flashmla_fp8_layout: + # fp8_ds_mla block format: UE8M0 block-scaled fp8 packed as uint8. + assert kv_cache_dtype.startswith("fp8"), ( + f"DeepseekV4 FlashMLA fp8 layout only supports fp8 kv-cache, " + f"got {kv_cache_dtype}" + ) + if kv_cache_dtype != "fp8_ds_mla": + if cache_config is not None: + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + return kv_cache_dtype, torch.uint8 - if backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4: - if kv_cache_dtype.startswith("fp8"): - return kv_cache_dtype, torch.float8_e4m3fn - # auto / bfloat16 -> contiguous BF16 cache. - return kv_cache_dtype, torch.bfloat16 - - # FlashMLA (and ROCm Aiter): legacy UE8M0 paged uint8 cache. - assert kv_cache_dtype.startswith("fp8"), ( - f"DeepseekV4 FlashMLA sparse backend only supports fp8 kv-cache, " - f"got {kv_cache_dtype}" - ) - if kv_cache_dtype != "fp8_ds_mla": - if cache_config is not None: - cache_config.cache_dtype = "fp8_ds_mla" - kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") - return kv_cache_dtype, torch.uint8 + # Plain bf16 / per-tensor fp8 KV row (FlashInfer). + if kv_cache_dtype.startswith("fp8"): + return kv_cache_dtype, torch.float8_e4m3fn + # auto / bfloat16 -> plain bf16 KV row. + return kv_cache_dtype, torch.bfloat16 -class DeepseekV4MLA(nn.Module): +class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC): + """DeepseekV4 MLA attention layer. + + The platform-specific sparse-MLA forward (``forward_mqa`` / + ``get_padded_num_q_heads`` / ``_o_proj`` / ``backend_cls``) is provided by a + subclass — ``DeepseekV4FlashMLAAttention`` / ``DeepseekV4FlashInferMLAAttention`` + (CUDA) or ``DeepseekV4ROCMAiterMLAAttention`` (ROCm) — selected by the + platform-specific deepseek_v4 model module. The base is never instantiated + directly. + """ + + # Provided by the platform subclass. + backend_cls: ClassVar[type[AttentionBackend]] + # KV-cache per-token block format (both layouts are paged). True (default) + # = FlashMLA / ROCm fp8_ds_mla (UE8M0 block-scaled fp8 packed as uint8); + # False = FlashInfer plain bf16 / per-tensor fp8 KV row. + use_flashmla_fp8_layout: ClassVar[bool] = True + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather + # workspace allocated in _forward_prefill and is also read by the dummy-run + # path to pre-reserve that workspace. + PREFILL_CHUNK_SIZE: ClassVar[int] = 4 + + @classmethod + @abstractmethod + def get_padded_num_q_heads(cls, num_heads: int) -> int: + """Q head count the q/output buffers are allocated at. + + The layer allocates the q/output buffers at + ``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must satisfy + ``result >= num_heads``. Backends with no padding constraint return + ``num_heads``. + """ + raise NotImplementedError + + @abstractmethod + def forward_mqa( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + """Platform-specific sparse MLA forward; writes attention into ``output``.""" + raise NotImplementedError + + @abstractmethod + def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Inverse-RoPE + wo_a + wo_b output projection (platform-specific).""" + raise NotImplementedError + def __init__( self, - hidden_size: int, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - o_lora_rank: int | None, vllm_config: VllmConfig, - fused_wqa_wkv: torch.nn.Module, - q_norm: torch.nn.Module, - wq_b: torch.nn.Module, - kv_norm: torch.nn.Module, - wo_a: torch.nn.Module, - wo_b: torch.nn.Module, - attn_sink: torch.nn.Module, - rotary_emb: torch.nn.Module, - indexer: torch.nn.Module | None, - indexer_rotary_emb: torch.nn.Module, - topk_indices_buffer: torch.Tensor | None, - aux_stream_list: list[torch.cuda.Stream] | None, - window_size: int, - compress_ratio: int | None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", + prefix: str, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ) -> None: super().__init__() - self.hidden_size = hidden_size - self.n_local_heads = num_heads - self.head_dim = head_dim - self.scale = scale - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.window_size = window_size - self.compress_ratio = compress_ratio if compress_ratio is not None else 1 - self.prefix = prefix - - # Extract config from vllm_config config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config tp_size = get_tensor_model_parallel_world_size() + layer_id = extract_layer_index(prefix) - # DeepseekV4-specific attributes (num_heads is already TP-adjusted) - self.eps = config.rms_norm_eps - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = head_dim - self.rope_head_dim - self.n_local_groups = config.o_groups // tp_size + self.prefix = prefix # Alias for compatibility with compressor + self.hidden_size = config.hidden_size + self.n_heads = config.num_attention_heads + assert self.n_heads % tp_size == 0 + self.n_local_heads = self.n_heads // tp_size + self.q_lora_rank = config.q_lora_rank self.o_lora_rank = config.o_lora_rank + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // tp_size + self.window_size = config.sliding_window + # NOTE(zyongye) Compress ratio can't be 0 + # we do this for because MTP layer is not included + # in the compress ratio list + if layer_id < config.num_hidden_layers: + self.compress_ratio = max(1, config.compress_ratios[layer_id]) + else: + self.compress_ratio = 1 + self.eps = config.rms_norm_eps + self.scale = self.head_dim**-0.5 - # Store projection modules - self.fused_wqa_wkv = fused_wqa_wkv - self.q_norm = q_norm - self.wq_b = wq_b - - self.kv_norm = kv_norm - self.wo_a = wo_a - - self._wo_a_act_quant = QuantFP8( - static=False, - group_shape=GroupShape(1, 128), - use_ue8m0=True, + # Padded Q head count is dictated by the platform subclass. + self.padded_heads = self.get_padded_num_q_heads(self.n_local_heads) + # Sink padded to the same head count, initialized to -inf (no sink + # effect). Weight loading fills the first n_local_heads slots. + self.attn_sink = nn.Parameter( + torch.full((self.padded_heads,), -float("inf"), dtype=torch.float32), + requires_grad=False, ) - # Bypass packed-for-deepgemm path — we need FP32 scales (not packed - # INT32) so fp8_einsum can handle layout transform internally. - self._wo_a_act_quant.use_deep_gemm_supported = False - self.wo_b = wo_b - # Pick fp8_einsum recipe based on GPU arch: - # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 - # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - cap = current_platform.get_device_capability() - assert cap is not None, "DeepseekV4 attention requires a CUDA device" - self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - self._tma_aligned_scales = cap.major >= 10 + self.fused_wqa_wkv = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_wqa_wkv", + disable_tp=True, # fused ReplicatedLinear + ) + self.q_norm = RMSNorm(self.q_lora_rank, self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wq_b", + ) - self.rotary_emb = rotary_emb - self.indexer_rotary_emb = indexer_rotary_emb + self.kv_norm = RMSNorm(self.head_dim, self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_a", + ) + self.wo_a.is_bmm = True + self.wo_a.bmm_batch_size = self.n_local_groups + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_b", + ) + + # Initialize rotary embedding before the indexer/compressor consume it. + self.rotary_emb = build_deepseek_v4_rope( + config, + head_dim=self.head_dim, + rope_head_dim=self.rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + compress_ratio=self.compress_ratio, + ) + self.indexer_rotary_emb = self.rotary_emb self.topk_indices_buffer = topk_indices_buffer - self.indexer = indexer - - # Per-head RMS normalization for Q (no learnable weights) - self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) - - # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic - head_bytes = ( - self.nope_head_dim # 448 fp8 NoPE - + self.rope_head_dim * 2 # 64 bf16 RoPE - + self.nope_head_dim // 64 # 7B scale factors - + 1 # 1B pad - ) + self.indexer = None + if self.compress_ratio == 4: + # Only C4A uses sparse attention and hence has indexer. + # aux_stream_list[2] is free here (outer GEMMs joined) for the inner + # overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on + # ROCm, where aux_stream_list is None. + indexer_aux_stream = ( + aux_stream_list[2] if aux_stream_list is not None else None + ) + self.indexer = DeepseekV4Indexer( + vllm_config, + config=config, + hidden_size=self.hidden_size, + q_lora_rank=self.q_lora_rank, + quant_config=quant_config, + cache_config=cache_config, + topk_indices_buffer=topk_indices_buffer, + compress_ratio=self.compress_ratio, + prefix=f"{prefix}.indexer", + aux_stream=indexer_aux_stream, + ) # Will be None on ROCm for now. self.aux_stream_list = aux_stream_list @@ -252,45 +270,39 @@ class DeepseekV4MLA(nn.Module): self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" - # Resolve the SWA cache tensor dtype from the selected backend: FlashMLA - # uses the legacy UE8M0 paged uint8 layout; FlashInfer uses a contiguous - # bf16 / per-tensor fp8 row. - backend = _resolve_dsv4_backend(vllm_config) - _, swa_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype( - backend, cache_config.cache_dtype, cache_config + # ---- Attention / KV-cache setup ---- + self.max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens ) + self.max_model_len = vllm_config.model_config.max_model_len + + # Resolve the kv-cache dtype from this backend's block format (a + # ClassVar set by the subclass): fp8_ds_mla (UE8M0 block-scaled fp8 as + # uint8) for FlashMLA / ROCm, vs a plain bf16 / per-tensor fp8 row for + # FlashInfer. The same resolution drives the SWA cache tensor dtype + # below. + self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype( + self.use_flashmla_fp8_layout, cache_config.cache_dtype, cache_config + ) + self.swa_cache_layer = DeepseekV4SWACache( head_dim=self.head_dim, window_size=self.window_size, - dtype=swa_cache_torch_dtype, + dtype=self.kv_cache_torch_dtype, prefix=f"{prefix}.swa_cache", cache_config=cache_config, ) - self.mla_attn = DeepseekV4MLAAttention( - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - compress_ratio=self.compress_ratio, - window_size=self.window_size, - head_bytes=head_bytes, - swa_cache_layer=self.swa_cache_layer, - attn_sink=attn_sink, # already padded with -inf - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - indexer=self.indexer, - topk_indices_buffer=self.topk_indices_buffer, - ) - # Mirror the inner layer's padded head count (single source of truth). - self.padded_heads = self.mla_attn.padded_heads + # Register with compilation context for metadata lookup. + compilation_config = vllm_config.compilation_config + if prefix and prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + if prefix: + compilation_config.static_forward_context[prefix] = self + self.kv_cache = torch.tensor([]) - # Create the compressor for layers with compress_ratio > 1; after - # creating the DeepseekV4MLAAttention layer to get its cache. + # Create the compressor for layers with compress_ratio > 1; after the + # attention setup above so its KV-cache prefix (self.prefix) is set. self.compressor = None if self.compress_ratio > 1: self.compressor = DeepseekCompressor( @@ -300,7 +312,7 @@ class DeepseekV4MLA(nn.Module): head_dim=self.head_dim, rotate=True, prefix=f"{prefix}.compressor", - k_cache_prefix=self.mla_attn.prefix, + k_cache_prefix=self.prefix, ) def forward( @@ -318,54 +330,38 @@ class DeepseekV4MLA(nn.Module): device=hidden_states.device, ) + # Metadata-independent input GEMMs + RMSNorm stay in the captured + # graph; the metadata-dependent rest (q up-proj + kv-insert, indexer, + # compressor, MLA attention) runs in the eager break. + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + qr, kv = fused_q_kv_rmsnorm( + qr, + kv, + self.q_norm.weight.data, + self.kv_norm.weight.data, + self.eps, + ) + # attention_impl is wrapped with @eager_break_during_capture: this is # where the breakable cudagraph capture breaks (the attention op runs # eagerly between captured graph segments). - self.attention_impl(hidden_states, positions, o_padded) + self.attention_impl( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + o_padded, + ) o = o_padded[:, : self.n_local_heads, :] - # Keep ROCm on the BF16 reference wo_a path util kernel ready. - if current_platform.is_rocm(): - z = rocm_inv_rope_einsum( - self.rotary_emb, - o, - positions, - self.rope_head_dim, - self.n_local_groups, - self.o_lora_rank, - self.wo_a, - ) - return self.wo_b(z.flatten(1)) - - # O projection: inverse RoPE + FP8 quant + einsum + wo_b - o_fp8, o_scale = fused_inv_rope_fp8_quant( - o, - positions, - self.rotary_emb.cos_sin_cache, - n_groups=self.n_local_groups, - heads_per_group=self.n_local_heads // self.n_local_groups, - nope_dim=self.nope_head_dim, - rope_dim=self.rope_head_dim, - tma_aligned_scales=self._tma_aligned_scales, - ) - - wo_a_fp8 = self.wo_a.weight - wo_a_scale = self.wo_a.weight_scale_inv - - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, - ) - fp8_einsum( - "bhr,hdr->bhd", - (o_fp8, o_scale), - (wo_a_fp8, wo_a_scale), - z, - recipe=self._einsum_recipe, - ) - - return self.wo_b(z.flatten(1)) + # Inverse-RoPE + wo_a + wo_b output projection (platform-specific). + return self._o_proj(o, positions) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: aux_streams = self.aux_stream_list @@ -431,27 +427,19 @@ class DeepseekV4MLA(nn.Module): def attention_impl( self, hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + kv_score: torch.Tensor, + indexer_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) - ) - - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - qr, kv = fused_q_kv_rmsnorm( - qr, - kv, - self.q_norm.weight.data, - self.kv_norm.weight.data, - self.eps, - ) - # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride - # on the default stream so q stays on its consumer stream (mla_attn + # on the default stream so q stays on its consumer stream (forward_mqa # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: @@ -514,7 +502,7 @@ class DeepseekV4MLA(nn.Module): # MLA attention writes into the pre-allocated `out` buffer # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + self.forward_mqa(q, kv, positions, out) def _fused_qnorm_rope_kv_insert( self, @@ -549,7 +537,7 @@ class DeepseekV4MLA(nn.Module): cos_sin_cache = self.rotary_emb.cos_sin_cache cache_dtype = swa_kv_cache.dtype - # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. + # kv is unchanged; attention reads kv solely via swa_kv_cache. if cache_dtype == torch.uint8: # Legacy FlashMLA UE8M0 paged path. Horizontally fused: # Q side: per-head RMSNorm (no weight) + GPT-J RoPE, zero-filling @@ -569,9 +557,10 @@ class DeepseekV4MLA(nn.Module): swa_metadata.block_size, ) - # FlashInfer full-cache path: contiguous [num_blocks, block_size, 512] - # cache (no Q padding). bf16 rewrites q in place; per-tensor fp8 writes a - # separately-allocated fp8 q and quantizes the KV row. + # FlashInfer full-cache path: the [num_blocks, block_size, 512] cache + # stores the KV row in its plain dtype (no Q padding). bf16 rewrites q + # in place; per-tensor fp8 writes a separately-allocated fp8 q and + # quantizes the KV row. block_size = swa_metadata.block_size swa_kv_cache_3d = swa_kv_cache.view(-1, block_size, self.head_dim) if cache_dtype == torch.bfloat16: @@ -597,99 +586,13 @@ class DeepseekV4MLA(nn.Module): swa_metadata.slot_mapping, positions, cos_sin_cache, - self.mla_attn._flashinfer_fp8_kv_scale, - self.mla_attn._flashinfer_fp8_q_scale_inv, + self._flashinfer_fp8_kv_scale, + self._flashinfer_fp8_q_scale_inv, self.eps, block_size, ) return q_fp8 - -class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): - def __init__( - self, - num_heads: int, - head_dim: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - q_lora_rank: int | None, - kv_lora_rank: int, - compress_ratio: int, - window_size: int, - head_bytes: int, - swa_cache_layer: DeepseekV4SWACache, - attn_sink: torch.Tensor, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - # Sparse MLA Args - indexer: object | None = None, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, - **extra_impl_args, - ) -> None: - super().__init__() - vllm_config = get_current_vllm_config() - self.impl_cls = _select_v4_sparse_impl(vllm_config) - self.backend_cls = self.impl_cls.backend_cls - self.num_heads = num_heads - self.num_kv_heads = 1 - self.head_dim = head_dim - self.scale = scale - self.window_size = window_size - self.head_bytes = head_bytes - self.compress_ratio = compress_ratio - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.nope_head_dim = qk_nope_head_dim - self.rope_head_dim = qk_rope_head_dim - self.indexer = indexer - self.topk_indices_buffer = topk_indices_buffer - - self.prefix = prefix # Alias for compatibility with compressor - - self.aux_stream = aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - # Padded Q head count is dictated by the selected impl. - self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads) - - # Store attention sink - assert attn_sink is not None - self.attn_sink: torch.Tensor = attn_sink - # Store SWA cache - assert swa_cache_layer is not None - self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer - - # Get vllm config for cache setup - self.max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens - ) - self.max_model_len = vllm_config.model_config.max_model_len - - # Resolve the kv-cache dtype from the selected backend. FlashMLA uses - # the legacy UE8M0 paged uint8 (fp8_ds_mla) layout; FlashInfer uses a - # contiguous bf16 / per-tensor fp8 row. - backend = _resolve_dsv4_backend(vllm_config) - kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" - self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype( - backend, kv_cache_dtype, cache_config - ) - - # Per-impl layer buffers (e.g. FlashInfer FP8 scale buffers). No-op for - # the FlashMLA / ROCm impls. - self.impl_cls.init_layer_buffers(self) - - # Register with compilation context for metadata lookup - compilation_config = vllm_config.compilation_config - if prefix and prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - if prefix: - compilation_config.static_forward_context[prefix] = self - - self.kv_cache = torch.tensor([]) - def get_attn_backend(self) -> type[AttentionBackend]: return self.backend_cls @@ -698,8 +601,9 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): self.compress_ratio <= 1 ): # SWA part. Allocated separately as DeepseekV4SWACache. return None - # FlashMLA uses the UE8M0 paged uint8 layout (576B aligned); FlashInfer - # uses a contiguous bf16 / per-tensor fp8 cache with no extra alignment. + # FlashMLA uses the fp8_ds_mla block format (UE8M0 block-scaled fp8 as + # uint8, 576B aligned); FlashInfer stores a plain bf16 / per-tensor fp8 + # row with no extra alignment. is_flashmla = self.kv_cache_dtype == "fp8_ds_mla" return MLAAttentionSpec( block_size=vllm_config.cache_config.block_size, @@ -712,15 +616,6 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): model_version="deepseek_v4", ) - def forward( - self, - q: torch.Tensor, - kv: torch.Tensor, - positions: torch.Tensor, - output: torch.Tensor, - ) -> None: - self.impl_cls.forward_mqa(self, q, kv, positions, output) - class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): def __init__( diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py index 71ea4fe506e..1e119614b0d 100644 --- a/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py @@ -3,28 +3,30 @@ """DeepSeek V4 FlashInfer TRTLLM-gen sparse MLA backend. Uses FlashInfer's public ``trtllm_batch_decode_sparse_mla_dsv4`` launcher with a -contiguous bf16 / per-tensor FP8 KV cache. Shares the V4 sparse-index pipeline -(SWA cache + compressor + indexer, 256-token blocks, head_size 512) with the -FlashMLA V4 backend; only the attention forward differs. +plain bf16 / per-tensor FP8 KV row (vs FlashMLA's packed ``fp8_ds_mla`` block +format). Shares the V4 sparse-index pipeline (SWA cache + compressor + indexer, +256-token blocks, head_size 512) with the FlashMLA V4 backend; only the +attention forward differs. """ -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, ClassVar, cast import torch from vllm.forward_context import get_forward_context +from vllm.models.deepseek_v4.attention import DeepseekV4Attention from vllm.models.deepseek_v4.common.ops import ( build_flashinfer_mixed_sparse_indices, ) -from vllm.models.deepseek_v4.nvidia.flashmla import ( - DeepseekV4FlashMLASparseBackend, - DeepseekV4SparseMLAAttentionImpl, +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLASparseBackend +from vllm.models.deepseek_v4.nvidia.ops.o_proj import ( + compute_fp8_einsum_recipe, + deep_gemm_fp8_o_proj, ) from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4 from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseMetadata if TYPE_CHECKING: - from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata # 128 MB TRTLLM-gen workspace, allocated once per device and zero-initialized @@ -51,23 +53,21 @@ class DeepseekV4FlashInferMLASparseBackend(DeepseekV4FlashMLASparseBackend): Inheriting from the FlashMLA V4 backend reuses its ``FlashMLASparseMetadata`` builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer builder lacks the ``c128a_*`` fields), 256-token blocks, head_size 512, and - the contiguous (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla`` - dtypes. + the (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla`` dtypes. """ @staticmethod def get_name() -> str: return "FLASHINFER_MLA_SPARSE_DSV4" - @staticmethod - def get_impl_cls() -> type["DeepseekV4FlashInferMLASparseImpl"]: - return DeepseekV4FlashInferMLASparseImpl - -class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): - """FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4.""" +class DeepseekV4FlashInferMLAAttention(DeepseekV4Attention): + """FlashInfer TRTLLM-gen sparse MLA attention layer for DeepSeek V4.""" backend_cls = DeepseekV4FlashInferMLASparseBackend + # FlashInfer stores a plain bf16 / per-tensor fp8 KV row, not the FlashMLA + # packed fp8_ds_mla block format (UE8M0 block-scaled fp8 as uint8). + use_flashmla_fp8_layout: ClassVar[bool] = False @classmethod def get_padded_num_q_heads(cls, num_heads: int) -> int: @@ -79,27 +79,44 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): ) return 64 if num_heads <= 64 else 128 - @classmethod - def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None: + def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + return deep_gemm_fp8_o_proj( + o, + positions, + self.rotary_emb.cos_sin_cache, + self.wo_a, + self.wo_b, + n_groups=self.n_local_groups, + heads_per_group=self.n_local_heads // self.n_local_groups, + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + o_lora_rank=self.o_lora_rank, + einsum_recipe=self._einsum_recipe, + tma_aligned_scales=self._tma_aligned_scales, + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe() # Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the - # per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``. - if layer.kv_cache_torch_dtype != torch.float8_e4m3fn: + # per-tensor FP8 cache path consumes these; bf16 reads ``self.scale``. + if self.kv_cache_torch_dtype != torch.float8_e4m3fn: return # TODO: load real per-tensor Q/KV scales from the checkpoint; unit # scales until the scale tensor names are wired. fp8_q_scale = 1.0 fp8_kv_scale = 1.0 - layer.register_buffer( + self.register_buffer( "_flashinfer_fp8_q_scale", torch.tensor([fp8_q_scale], dtype=torch.float32), persistent=False, ) - layer.register_buffer( + self.register_buffer( "_flashinfer_fp8_q_scale_inv", torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32), persistent=False, ) - layer.register_buffer( + self.register_buffer( "_flashinfer_fp8_kv_scale", torch.tensor([fp8_kv_scale], dtype=torch.float32), persistent=False, @@ -107,13 +124,11 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path # vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax # scale and the Q/KV per-tensor scales; bmm2 is the KV scale. - layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale - layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale + self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale + self._flashinfer_fp8_bmm2_scale = fp8_kv_scale - @classmethod - def forward_mqa( # type: ignore[override] - cls, - layer: "DeepseekV4MLAAttention", + def forward_mqa( + self, q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -147,21 +162,20 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): assert isinstance(attn_metadata, dict) flashmla_metadata = cast( - FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix) + FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) ) swa_metadata = cast( "DeepseekSparseSWAMetadata | None", - attn_metadata.get(layer.swa_cache_layer.prefix), + attn_metadata.get(self.swa_cache_layer.prefix), ) assert swa_metadata is not None - swa_only = layer.compress_ratio <= 1 + swa_only = self.compress_ratio <= 1 # SWA-only layers don't allocate their own compressed KV cache. - self_kv_cache = layer.kv_cache if not swa_only else None - swa_kv_cache = layer.swa_cache_layer.kv_cache + self_kv_cache = self.kv_cache if not swa_only else None + swa_kv_cache = self.swa_cache_layer.kv_cache - cls._forward( - layer=layer, + self._forward( q=q, kv_cache=self_kv_cache, swa_k_cache=swa_kv_cache, @@ -171,10 +185,8 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): output=output, ) - @classmethod def _build_sparse_index_metadata( - cls, - layer: "DeepseekV4MLAAttention", + self, kv_cache: torch.Tensor | None, swa_k_cache: torch.Tensor, swa_metadata: "DeepseekSparseSWAMetadata", @@ -200,17 +212,17 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): assert swa_metadata.block_table is not None decode_swa_indices = swa_metadata.decode_swa_indices.reshape( - num_decode_tokens, layer.window_size + num_decode_tokens, self.window_size ) decode_compressed_topk_lens = None decode_compressed_indices_are_local = False decode_is_valid_token = None if swa_only: - assert layer.topk_indices_buffer is not None + assert self.topk_indices_buffer is not None compressed_kv_cache = swa_k_cache decode_compressed_indices = None - prefill_topk_indices = layer.topk_indices_buffer[ + prefill_topk_indices = self.topk_indices_buffer[ num_decode_tokens:num_tokens, :0 ] compressed_block_table = None @@ -221,24 +233,24 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): assert attn_metadata is not None compressed_kv_cache = kv_cache compressed_block_table = attn_metadata.block_table[:num_reqs] - compressed_block_size = attn_metadata.block_size // layer.compress_ratio + compressed_block_size = attn_metadata.block_size // self.compress_ratio - if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None if num_prefill_tokens > 0: - prefill_topk_indices = layer.topk_indices_buffer[ + prefill_topk_indices = self.topk_indices_buffer[ num_decode_tokens:num_tokens ] top_k = prefill_topk_indices.shape[-1] else: - prefill_topk_indices = layer.topk_indices_buffer[:0, :0] + prefill_topk_indices = self.topk_indices_buffer[:0, :0] top_k = 0 decode_compressed_indices_are_local = True assert swa_metadata.is_valid_token is not None decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens] if num_decode_tokens > 0: - decode_compressed_indices = layer.topk_indices_buffer[ + decode_compressed_indices = self.topk_indices_buffer[ :num_decode_tokens ] else: @@ -284,18 +296,16 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): swa_metadata.block_size, compressed_block_table, compressed_block_size, - layer.window_size, - layer.compress_ratio, + self.window_size, + self.compress_ratio, top_k, decode_compressed_indices_are_local=decode_compressed_indices_are_local, decode_is_valid_token=decode_is_valid_token, ) return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens - @classmethod def _forward( - cls, - layer: "DeepseekV4MLAAttention", + self, q: torch.Tensor, kv_cache: torch.Tensor | None, swa_k_cache: torch.Tensor, @@ -304,7 +314,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): swa_only: bool, output: torch.Tensor, ) -> None: - assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) + assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) num_decodes = swa_metadata.num_decodes num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens @@ -319,8 +329,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): seq_lens, sparse_indices, sparse_topk_lens, - ) = cls._build_sparse_index_metadata( - layer=layer, + ) = self._build_sparse_index_metadata( kv_cache=kv_cache, swa_k_cache=swa_k_cache, swa_metadata=swa_metadata, @@ -332,12 +341,12 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # restrict to the real tokens (the launcher validates sparse indices). query = q[:num_tokens] output = output[:num_tokens] - bmm1_scale: float | torch.Tensor = layer.scale + bmm1_scale: float | torch.Tensor = self.scale bmm2_scale: float | torch.Tensor = 1.0 - if layer.kv_cache_torch_dtype == torch.float8_e4m3fn: + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: assert query.dtype == torch.float8_e4m3fn - bmm1_scale = layer._flashinfer_fp8_bmm1_scale - bmm2_scale = layer._flashinfer_fp8_bmm2_scale + bmm1_scale = self._flashinfer_fp8_bmm1_scale + bmm2_scale = self._flashinfer_fp8_bmm2_scale else: assert query.dtype == torch.bfloat16 query = query.contiguous() @@ -376,7 +385,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): out=output[:num_decode_tokens], bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, - sinks=layer.attn_sink, + sinks=self.attn_sink, cum_seq_lens_q=decode_cu, max_q_len=int(decode_lens_cpu.max().item()), ) @@ -401,7 +410,7 @@ class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): out=output[num_decode_tokens:num_tokens], bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, - sinks=layer.attn_sink, + sinks=self.attn_sink, cum_seq_lens_q=prefill_cu, max_q_len=int(prefill_lens_cpu.max().item()), ) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index e9b9c678306..5c3969deb80 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -1,22 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import abstractmethod -from typing import TYPE_CHECKING, ClassVar, cast +from typing import TYPE_CHECKING, cast import torch from vllm.forward_context import get_forward_context +from vllm.models.deepseek_v4.attention import DeepseekV4Attention from vllm.models.deepseek_v4.common.ops import ( combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, ) -from vllm.v1.attention.backend import ( - AttentionBackend, - MultipleOf, - SparseMLAAttentionImpl, +from vllm.models.deepseek_v4.nvidia.ops.o_proj import ( + compute_fp8_einsum_recipe, + deep_gemm_fp8_o_proj, ) +from vllm.v1.attention.backend import MultipleOf from vllm.v1.attention.backends.mla.flashmla_sparse import ( FlashMLASparseBackend, FlashMLASparseMetadata, @@ -28,63 +28,9 @@ from vllm.v1.attention.ops.flashmla import ( from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: - from vllm.models.deepseek_v4.attention import ( - DeepseekV4MLAAttention, - ) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata -class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): - """Abstract parent for DeepseekV4 sparse MLA impls. - - V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``) - rather than the v1 framework, so ``forward_mqa`` is overridden with a - classmethod that takes the layer as its first argument. This Liskov-broken - override is intentional: the grandparent's instance-method ``forward_mqa`` - is never called on V4 layers. - """ - - backend_cls: ClassVar[type[AttentionBackend]] - - # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather - # workspace allocated in _forward_prefill and is also read by the V4 layer's - # dummy-run path to pre-reserve that workspace. - PREFILL_CHUNK_SIZE: ClassVar[int] = 4 - - @classmethod - @abstractmethod - def forward_mqa( # type: ignore[override] - cls, - layer: "DeepseekV4MLAAttention", - q: torch.Tensor, - kv: torch.Tensor, - positions: torch.Tensor, - output: torch.Tensor, - ) -> None: - raise NotImplementedError - - @classmethod - @abstractmethod - def get_padded_num_q_heads(cls, num_heads: int) -> int: - """Q head count the backend wants q allocated at. - - The MLA wrapper allocates the q/output buffers at - ``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must - satisfy ``result >= num_heads``. Backends with no padding constraint - return ``num_heads``. - """ - raise NotImplementedError - - @classmethod - def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None: - """Register impl-specific buffers on the layer at construction. - - No-op by default; FlashInfer overrides this to register its per-tensor - FP8 scale buffers. - """ - return None - - class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: @@ -94,10 +40,6 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): def get_name() -> str: return "FLASHMLA_SPARSE_DSV4" - @staticmethod - def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]: - return DeepseekV4FlashMLASparseImpl - @classmethod def get_supported_head_sizes(cls) -> list[int]: # DeepSeek V4 layout: 448 NoPE + 64 RoPE = 512 (overrides the @@ -120,11 +62,31 @@ class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): return (num_blocks, block_size, head_size) -class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): - """FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer.""" +class DeepseekV4FlashMLAAttention(DeepseekV4Attention): + """FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA).""" backend_cls = DeepseekV4FlashMLASparseBackend + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._einsum_recipe, self._tma_aligned_scales = compute_fp8_einsum_recipe() + + def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + return deep_gemm_fp8_o_proj( + o, + positions, + self.rotary_emb.cos_sin_cache, + self.wo_a, + self.wo_b, + n_groups=self.n_local_groups, + heads_per_group=self.n_local_heads // self.n_local_groups, + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + o_lora_rank=self.o_lora_rank, + einsum_recipe=self._einsum_recipe, + tma_aligned_scales=self._tma_aligned_scales, + ) + @classmethod def get_padded_num_q_heads(cls, num_heads: int) -> int: # FP8 decode kernel only supports h_q = 64 or 128. @@ -135,10 +97,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): ) return 64 if num_heads <= 64 else 128 - @classmethod - def forward_mqa( # type: ignore[override] - cls, - layer: "DeepseekV4MLAAttention", + def forward_mqa( + self, q: torch.Tensor, kv: torch.Tensor, positions: torch.Tensor, @@ -159,35 +119,35 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # Warmup dummy run: no real metadata. Reserve the same bf16 # gather workspace _forward_prefill would; the dequantize / topk # / sparse_fwd kernels are skipped this step. - swa_only = layer.compress_ratio <= 1 + swa_only = self.compress_ratio <= 1 N = ( 0 if swa_only - else (layer.max_model_len + layer.compress_ratio - 1) - // layer.compress_ratio + else (self.max_model_len + self.compress_ratio - 1) + // self.compress_ratio ) - M = N + layer.window_size + layer.max_num_batched_tokens + M = N + self.window_size + self.max_num_batched_tokens current_workspace_manager().get_simultaneous( - ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), ) output.zero_() return assert isinstance(attn_metadata, dict) flashmla_metadata = cast( - FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix) + FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) ) swa_metadata = cast( "DeepseekSparseSWAMetadata | None", - attn_metadata.get(layer.swa_cache_layer.prefix), + attn_metadata.get(self.swa_cache_layer.prefix), ) assert swa_metadata is not None - swa_only = layer.compress_ratio <= 1 + swa_only = self.compress_ratio <= 1 # SWA-only layers (compress_ratio <= 1) don't have their own KV cache - # allocation, so layer.kv_cache may be empty after profiling cleanup. - self_kv_cache = layer.kv_cache if not swa_only else None - swa_kv_cache = layer.swa_cache_layer.kv_cache + # allocation, so self.kv_cache may be empty after profiling cleanup. + self_kv_cache = self.kv_cache if not swa_only else None + swa_kv_cache = self.swa_cache_layer.kv_cache # Split prefill and decode num_decodes = swa_metadata.num_decodes @@ -195,8 +155,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): num_decode_tokens = swa_metadata.num_decode_tokens if num_prefills > 0: - cls._forward_prefill( - layer=layer, + self._forward_prefill( q=q[num_decode_tokens:], positions=positions[num_decode_tokens:], compressed_k_cache=self_kv_cache, @@ -206,8 +165,7 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): swa_metadata=swa_metadata, ) if num_decodes > 0: - cls._forward_decode( - layer=layer, + self._forward_decode( q=q[:num_decode_tokens], kv_cache=self_kv_cache, swa_metadata=swa_metadata, @@ -216,10 +174,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): output=output[:num_decode_tokens], ) - @classmethod def _forward_decode( - cls, - layer: "DeepseekV4MLAAttention", + self, q: torch.Tensor, kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 swa_metadata: "DeepseekSparseSWAMetadata", @@ -235,13 +191,13 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): if not swa_only: assert attn_metadata is not None assert swa_metadata.is_valid_token is not None - block_size = attn_metadata.block_size // layer.compress_ratio + block_size = attn_metadata.block_size // self.compress_ratio is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - if layer.compress_ratio == 4: + if self.compress_ratio == 4: # C4A: local indices differ per layer (filled by Indexer). - assert layer.topk_indices_buffer is not None + assert self.topk_indices_buffer is not None global_indices, topk_lens = compute_global_topk_indices_and_lens( - layer.topk_indices_buffer[:num_decode_tokens], + self.topk_indices_buffer[:num_decode_tokens], swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], block_size, @@ -258,12 +214,12 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # We treat queries in the same seq as different queries # and later we only attend by generated indices. - # q arrives pre-padded to layer.padded_heads by the outer wrapper. + # q arrives pre-padded to self.padded_heads by the outer wrapper. q = q.unsqueeze(1) # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) # Use unsqueeze to preserve strides (handles padded blocks correctly) - swa_cache = layer.swa_cache_layer.kv_cache.unsqueeze(-2) + swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) @@ -274,20 +230,20 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # and num_splits via PyTorch's graph-aware allocator so CUDA graph # capture reuses the same addresses on replay); subsequent same-type # layers see have_initialized=True and skip the planner. - if layer.compress_ratio <= 1: + if self.compress_ratio <= 1: tile_metadata = swa_metadata.tile_sched_swaonly - elif layer.compress_ratio == 4: + elif self.compress_ratio == 4: tile_metadata = swa_metadata.tile_sched_c4a - elif layer.compress_ratio == 128: + elif self.compress_ratio == 128: tile_metadata = swa_metadata.tile_sched_c128a else: raise ValueError( - f"Unsupported compress_ratio={layer.compress_ratio}; " + f"Unsupported compress_ratio={self.compress_ratio}; " "expected 1, 4, or 128." ) assert tile_metadata is not None, ( "swa_metadata missing tile_sched entry for " - f"compress_ratio={layer.compress_ratio}; " + f"compress_ratio={self.compress_ratio}; " "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " "allocate one for this layer type." ) @@ -302,18 +258,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): is_fp8_kvcache=True, indices=swa_indices, topk_length=swa_lens, - softmax_scale=layer.scale, - attn_sink=layer.attn_sink, + softmax_scale=self.scale, + attn_sink=self.attn_sink, extra_k_cache=kv_cache if not swa_only else None, extra_indices_in_kvcache=topk_indices, extra_topk_length=topk_lens, out=output.unsqueeze(1), ) - @classmethod def _forward_prefill( - cls, - layer: "DeepseekV4MLAAttention", + self, q: torch.Tensor, positions: torch.Tensor, compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 @@ -343,9 +297,9 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): prefill_token_base = query_start_loc_cpu[num_decodes] if not swa_only: - if layer.compress_ratio == 4: - assert layer.topk_indices_buffer is not None - topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] topk_indices = topk_indices[:num_prefill_tokens] else: # C128A: pre-computed during metadata build. @@ -355,16 +309,16 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): # Compressed region must fit the full compressed pool (seq_len // # compress_ratio), not just top_k. top_k bounds how many indices # the indexer selects, not the pool size it indexes into. - N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio + N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. - assert layer.topk_indices_buffer is not None - topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 N = 0 - M = N + layer.window_size + layer.max_num_batched_tokens - chunk_size_const = cls.PREFILL_CHUNK_SIZE + M = N + self.window_size + self.max_num_batched_tokens + chunk_size_const = self.PREFILL_CHUNK_SIZE num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const workspace_manager = current_workspace_manager() @@ -382,10 +336,10 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): dequantize_and_gather_k_cache( kv[:chunk_size], compressed_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio, + seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, gather_lens=None, block_table=block_table[chunk_start:chunk_end], - block_size=attn_metadata.block_size // layer.compress_ratio, + block_size=attn_metadata.block_size // self.compress_ratio, offset=0, ) @@ -416,8 +370,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): ], seq_lens[chunk_start:chunk_end], gather_lens[chunk_start:chunk_end], - layer.window_size, - layer.compress_ratio, + self.window_size, + self.compress_ratio, top_k, M, N, @@ -426,8 +380,8 @@ class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): q=q[query_start:query_end], kv=kv.view(-1, 1, q.shape[-1]), indices=combined_indices.unsqueeze(1), - sm_scale=layer.scale, - attn_sink=layer.attn_sink, + sm_scale=self.scale, + attn_sink=self.attn_sink, topk_length=combined_lens, out=output[query_start:query_end], ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 547048ab58f..00866e6941b 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, ) @@ -55,13 +54,14 @@ from vllm.model_executor.models.utils import ( maybe_prefix, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.models.deepseek_v4.attention import ( - DeepseekV4Indexer, - DeepseekV4MLA, +from vllm.models.deepseek_v4.attention import DeepseekV4Attention +from vllm.models.deepseek_v4.nvidia.flashinfer_sparse import ( + DeepseekV4FlashInferMLAAttention, ) -from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.registry import AttentionBackendEnum class DeepseekV4MLP(nn.Module): @@ -713,163 +713,18 @@ class DeepseekV4MoE(nn.Module): self.experts.finalize_weights() -class DeepseekV4Attention(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: torch.Tensor | None = None, - aux_stream_list: list[torch.cuda.Stream] | None = None, +def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]: + """Pick the CUDA sparse-MLA attention class for the configured backend. + + An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the + FlashInfer TRTLLM-gen path; otherwise the FlashMLA path is used. + """ + if ( + vllm_config.attention_config.backend + == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4 ): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - layer_id = extract_layer_index(prefix) - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.n_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - assert self.n_heads % tp_size == 0 - - self.n_local_heads = self.n_heads // tp_size - self.q_lora_rank = config.q_lora_rank - self.o_lora_rank = config.o_lora_rank - self.head_dim = config.head_dim - self.rope_head_dim = config.qk_rope_head_dim - self.nope_head_dim = self.head_dim - self.rope_head_dim - self.n_groups = config.o_groups - self.n_local_groups = self.n_groups // tp_size - self.window_size = config.sliding_window - # NOTE(zyongye) Compress ratio can't be 0 - # we do this for because MTP layer is not included - # in the compress ratio list - if layer_id < config.num_hidden_layers: - self.compress_ratio = max(1, config.compress_ratios[layer_id]) - else: - self.compress_ratio = 1 - self.eps = config.rms_norm_eps - self.max_position_embeddings = config.max_position_embeddings - - # Padded to min 64 heads for FlashMLA, initialized to -inf - # (no sink effect). Weight loading fills the first n_local_heads slots. - padded_heads = max(self.n_local_heads, 64) - self.attn_sink = nn.Parameter( - torch.full((padded_heads,), -float("inf"), dtype=torch.float32), - requires_grad=False, - ) - - self.fused_wqa_wkv = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_wqa_wkv", - disable_tp=True, # fused ReplicatedLinear - ) - self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.wq_b = ColumnParallelLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wq_b", - ) - - self.kv_norm = RMSNorm(self.head_dim, self.eps) - self.wo_a = ColumnParallelLinear( - self.n_heads * self.head_dim // self.n_groups, - self.n_groups * self.o_lora_rank, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_a", - ) - self.wo_a.is_bmm = True - self.wo_a.bmm_batch_size = self.n_local_groups - self.wo_b = RowParallelLinear( - self.n_groups * self.o_lora_rank, - self.hidden_size, - bias=False, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.wo_b", - ) - self.softmax_scale = self.head_dim**-0.5 - self.scale_fmt = config.quantization_config["scale_fmt"] - - self.rope_parameters = config.rope_scaling - - # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - self.rotary_emb = build_deepseek_v4_rope( - config, - head_dim=self.head_dim, - rope_head_dim=self.rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - compress_ratio=self.compress_ratio, - ) - - self.indexer = None - if self.compress_ratio == 4: - # Only C4A uses sparse attention and hence has indexer. - # aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is - # free here (outer GEMMs joined) for the inner overlap of - # wq_b+fused_indexer_q_rope_quant vs compressor. - indexer_aux_stream = ( - aux_stream_list[2] if aux_stream_list is not None else None - ) - self.indexer = DeepseekV4Indexer( - vllm_config, - config=config, - hidden_size=self.hidden_size, - q_lora_rank=self.q_lora_rank, - quant_config=quant_config, - cache_config=vllm_config.cache_config, - topk_indices_buffer=topk_indices_buffer, - compress_ratio=self.compress_ratio, - prefix=f"{prefix}.indexer", - aux_stream=indexer_aux_stream, - ) - - self.mla_attn = DeepseekV4MLA( - hidden_size=self.hidden_size, - num_heads=self.n_local_heads, - head_dim=self.head_dim, - scale=self.softmax_scale, - qk_nope_head_dim=self.nope_head_dim, - qk_rope_head_dim=self.rope_head_dim, - v_head_dim=self.head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.head_dim, - o_lora_rank=self.o_lora_rank, - vllm_config=vllm_config, - fused_wqa_wkv=self.fused_wqa_wkv, - q_norm=self.q_norm, - wq_b=self.wq_b, - kv_norm=self.kv_norm, - wo_a=self.wo_a, - wo_b=self.wo_b, - attn_sink=self.attn_sink, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - indexer_rotary_emb=self.rotary_emb, - topk_indices_buffer=topk_indices_buffer, - aux_stream_list=aux_stream_list, - window_size=self.window_size, - compress_ratio=self.compress_ratio, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - prefix=prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - llama_4_scaling: torch.Tensor | None, - ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) + return DeepseekV4FlashInferMLAAttention + return DeepseekV4FlashMLAAttention class DeepseekV4DecoderLayer(nn.Module): @@ -886,7 +741,7 @@ class DeepseekV4DecoderLayer(nn.Module): self.hidden_size = config.hidden_size self.rms_norm_eps = config.rms_norm_eps - self.attn = DeepseekV4Attention( + self.attn = _select_dsv4_attn_cls(vllm_config)( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, @@ -1043,7 +898,7 @@ class DeepseekV4Model(nn.Module): self.rms_norm_eps = config.rms_norm_eps # Three aux streams: one per non-default input GEMM in - # DeepseekV4MLA.attn_gemm_parallel_execute + # DeepseekV4Attention.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. aux_stream_list = [torch.cuda.Stream() for _ in range(3)] @@ -1341,7 +1196,6 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", ".shared_experts.w2": ".shared_experts.down_proj", }, ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py new file mode 100644 index 00000000000..a0b4e2c678e --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.nn as nn + +from vllm.models.deepseek_v4.common.ops import fused_inv_rope_fp8_quant +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import fp8_einsum + + +def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]: + """fp8_einsum recipe + scale layout for the current GPU arch. + + SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128. + SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1. + + Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``. + """ + cap = current_platform.get_device_capability() + assert cap is not None, "DeepseekV4 attention requires a CUDA device" + einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) + tma_aligned_scales = cap.major >= 10 + return einsum_recipe, tma_aligned_scales + + +def deep_gemm_fp8_o_proj( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + wo_a: nn.Module, + wo_b: nn.Module, + *, + n_groups: int, + heads_per_group: int, + nope_dim: int, + rope_dim: int, + o_lora_rank: int, + einsum_recipe: tuple[int, int, int], + tma_aligned_scales: bool, +) -> torch.Tensor: + """O projection: inverse RoPE + FP8 quant + einsum + wo_b. + + Shared by the FlashMLA and FlashInfer CUDA backends. ``einsum_recipe`` / + ``tma_aligned_scales`` come from ``compute_fp8_einsum_recipe``. + """ + o_fp8, o_scale = fused_inv_rope_fp8_quant( + o, + positions, + cos_sin_cache, + n_groups=n_groups, + heads_per_group=heads_per_group, + nope_dim=nope_dim, + rope_dim=rope_dim, + tma_aligned_scales=tma_aligned_scales, + ) + z = torch.empty( + (o.shape[0], n_groups, o_lora_rank), + device=o.device, + dtype=torch.bfloat16, + ) + fp8_einsum( + "bhr,hdr->bhd", + (o_fp8, o_scale), + (wo_a.weight, wo_a.weight_scale_inv), + z, + recipe=einsum_recipe, + ) + return wo_b(z.flatten(1)) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index cf4319ac722..b1414665869 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -92,6 +92,11 @@ class CpuPlatform(Platform): return meminfo.total_memory + @classmethod + def mem_get_info(cls) -> tuple[int, int]: + meminfo = get_memory_node_info() + return meminfo.available_memory, meminfo.total_memory + @classmethod def set_device(cls, device: torch.device) -> None: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 89471e844d8..7f6d8794c28 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -134,6 +134,7 @@ def _sync_hip_cuda_env_vars(): # Sync at import time - catches misconfigurations from process start. _sync_hip_cuda_env_vars() + # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. @@ -312,6 +313,17 @@ def on_gfx950() -> bool: return _ON_GFX950 +# Enable HIP online tuning early, before hipBLASLt initializes. +# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. +if ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + and on_mi3xx() +): + os.environ["HIP_ONLINE_TUNING"] = "1" + + @cache def use_rocm_custom_paged_attention( qtype: torch.dtype, diff --git a/vllm/reasoning/cohere_command_reasoning_parser.py b/vllm/reasoning/cohere_command_reasoning_parser.py index b28a59089e7..949c9ff5d99 100644 --- a/vllm/reasoning/cohere_command_reasoning_parser.py +++ b/vllm/reasoning/cohere_command_reasoning_parser.py @@ -90,7 +90,7 @@ MODEL_TO_TAG_STYLE: dict[str, CohereTagStyle] = { tools=COMMAND_A_TOOLS_TAG, ), "Cohere2MoeForCausalLM": CohereTagStyle( - json_tags=(COMMAND_A_JSON_TAG,), + json_tags=(COMMAND_A_JSON_TAG, COMMAND_A_PLUS_JSON_TAG), tools=COMMAND_A_TOOLS_TAG, ), } diff --git a/vllm/transformers_utils/processors/minicpmv.py b/vllm/transformers_utils/processors/minicpmv.py index cc0dee8dacd..03649234eab 100644 --- a/vllm/transformers_utils/processors/minicpmv.py +++ b/vllm/transformers_utils/processors/minicpmv.py @@ -72,8 +72,8 @@ class MiniCPMVProcessor(ProcessorMixin): ) -> MiniCPMVBatchFeature: """Run the vendored MiniCPMV processor on a (text, images) pair. - Only single-sample input is currently supported; batched input is - coming soon. ``images`` is forwarded to the underlying image + Batched inputs are supported following the upstream MiniCPM-V + processor flow. ``images`` is forwarded to the underlying image processor and ``text`` is tokenized with image placeholders replaced by the appropriate slice tokens. Returns a ``MiniCPMVBatchFeature`` with at minimum ``input_ids`` and (when @@ -194,7 +194,7 @@ class MiniCPMVProcessor(ProcessorMixin): image_end_tokens.unsqueeze(-1), ] ) - return input_ids.unsqueeze(0), image_bounds + return input_ids, image_bounds def _convert_images_texts_to_inputs( self, @@ -220,23 +220,41 @@ class MiniCPMVProcessor(ProcessorMixin): image_sizes = images["image_sizes"] tgt_sizes = images["tgt_sizes"] - image_tags = regex.findall(pattern, texts) - assert len(image_tags) == len(image_sizes[0]) - text_chunks = texts.split(pattern) - final_texts = "" - for i in range(len(image_tags)): - placeholder = self.image_processor.get_slice_image_placeholder( - image_sizes[0][i] - ) - final_texts = final_texts + text_chunks[i] + placeholder - final_texts += text_chunks[-1] - input_ids, image_bounds = self._convert(final_texts, max_length) + if isinstance(texts, str): + texts = [texts] + + input_ids_list = [] + image_bounds_list = [] + + for index, text in enumerate(texts): + image_tags = regex.findall(pattern, text) + assert len(image_tags) == len(image_sizes[index]) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + placeholder = self.image_processor.get_slice_image_placeholder( + image_sizes[index][i] + ) + final_text = final_text + text_chunks[i] + placeholder + final_text += text_chunks[-1] + input_ids, image_bounds = self._convert(final_text, max_length) + input_ids_list.append(input_ids) + image_bounds_list.append(image_bounds) + + padded_input_ids, padding_lengths = self.pad( + input_ids_list, + padding_side="left", + ) + for i, length in enumerate(padding_lengths): + image_bounds_list[i] = image_bounds_list[i] + length + return MiniCPMVBatchFeature( data={ - "input_ids": input_ids, + "input_ids": padded_input_ids, + "attention_mask": padded_input_ids.ne(0), "pixel_values": images_val, "image_sizes": image_sizes, - "image_bound": [image_bounds], + "image_bound": image_bounds_list, "tgt_sizes": tgt_sizes, } ) @@ -249,42 +267,36 @@ class MiniCPMVProcessor(ProcessorMixin): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - def pad( - self, - orig_items, - key, - max_length=None, - padding_value=0, - padding_side="left", - ): - if not orig_items: - return torch.empty(0) + # Copied from openbmb/MiniCPM-V-4_5 processing_minicpmv.py. + def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"): + if not inputs: + return torch.empty(0), [] items = [] - if isinstance(orig_items[0][key], list): - assert isinstance(orig_items[0][key][0], torch.Tensor) - for it in orig_items: - for tr in it[key]: - items.append({key: tr}) + if isinstance(inputs[0], list): + assert isinstance(inputs[0][0], torch.Tensor) + for it in inputs: + for tr in it: + items.append(tr) else: - assert isinstance(orig_items[0][key], torch.Tensor) - items = orig_items + assert isinstance(inputs[0], torch.Tensor) + items = inputs batch_size = len(items) - shape = items[0][key].shape + shape = items[0].shape dim = len(shape) - assert dim <= 3 + assert dim <= 2 if max_length is None: max_length = 0 - max_length = max(max_length, max(item[key].shape[-1] for item in items)) - min_length = min(item[key].shape[-1] for item in items) - dtype = items[0][key].dtype + max_length = max(max_length, max(item.shape[-1] for item in items)) + min_length = min(item.shape[-1] for item in items) + dtype = items[0].dtype - if dim == 1: - return torch.cat([item[key] for item in items], dim=0) - elif dim == 2: + if dim == 0: + return torch.stack([item for item in items], dim=0), [0] + elif dim == 1: if max_length == min_length: - return torch.cat([item[key] for item in items], dim=0) + return torch.stack([item for item in items], dim=0), [0] * batch_size tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value else: tensor = ( @@ -292,23 +304,18 @@ class MiniCPMVProcessor(ProcessorMixin): + padding_value ) + padding_lengths = [] for i, item in enumerate(items): - tensor_to_pad = item[key] - if tensor_to_pad.shape[0] != 1: - raise ValueError( - f"Expected leading batch size of 1 for padding, " - f"but got shape {tensor_to_pad.shape}" - ) - squeezed = tensor_to_pad.squeeze(0) - if dim == 2: + if dim == 1: if padding_side == "left": - tensor[i, -squeezed.shape[0] :] = squeezed.clone() + tensor[i, -len(item) :] = item.clone() else: - tensor[i, : squeezed.shape[0]] = squeezed.clone() - elif dim == 3: + tensor[i, : len(item)] = item.clone() + elif dim == 2: if padding_side == "left": - tensor[i, -squeezed.shape[0] :, :] = squeezed.clone() + tensor[i, -len(item) :, :] = item.clone() else: - tensor[i, : squeezed.shape[0], :] = squeezed.clone() + tensor[i, : len(item), :] = item.clone() + padding_lengths.append(tensor.shape[-1] - len(item)) - return tensor + return tensor, padding_lengths diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 04def3e3769..cd215421a98 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -84,8 +84,11 @@ def maybe_model_redirect(model: str) -> str: """ Use model_redirect to redirect the model name to a local folder. - :param model: hf model name - :return: maybe redirect to a local folder + Args: + model: hf model name + + Returns: + maybe redirect to a local folder """ model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index af58bfd31a5..4b4a4435b31 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -808,14 +808,17 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]): ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, quant_key: "QuantKey"): + def fused_output_quant_supported(self, quant_key: "QuantKey") -> bool: """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - :param quant_key: QuantKey object that describes the quantization op - :return: is fusion supported for this type of quantization + Args: + quant_key: QuantKey object that describes the quantization op + + Returns: + is fusion supported for this type of quantization """ return False diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 9140a6fccd5..e3173949bf1 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -602,7 +602,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad fp8_use_mixed_batch = ( self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4 ) - # DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not + # DeepseekV4 has its own attention impl (DeepseekV4Attention) that does not # consume fp8_extra_metadata. Skipping the build here avoids a # forced D2H sync on seq_lens that would otherwise fire on every # prefill-bearing step, lifting GPU utilization on long-prefill diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 332350d8380..12fd3a17421 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -8,6 +8,7 @@ from importlib.util import find_spec import torch import torch.nn.functional as F +import vllm.envs as envs from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context from vllm.platforms import current_platform @@ -15,6 +16,7 @@ from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import LayerNameType from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.worker.workspace import current_workspace_manager if current_platform.is_rocm(): from vllm.platforms.rocm import _ON_GFX942, _ON_GFX950 @@ -408,8 +410,8 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = None # if rocm_aiter_ops.is_enabled(): - batch_size, next_n, heads, head_dim = q_fp8.shape - num_blocks, block_size, _, _ = kv_cache_fp8.shape + batch_size, next_n = q_fp8.shape[:2] + block_size = kv_cache_fp8.shape[1] if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() @@ -420,12 +422,10 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device="cuda", - dtype=torch.float32, + (out_logits,) = current_workspace_manager().get_simultaneous( + ((batch_size * next_n, max_model_len), torch.float32), ) + out_logits.fill_(float("-inf")) deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -444,12 +444,10 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), - float("-inf"), - device="cuda", - dtype=torch.float32, + (out_qk,) = current_workspace_manager().get_simultaneous( + ((heads, batch_size * next_n, max_model_len), torch.float32), ) + out_qk.fill_(float("-inf")) deepgemm_fp8_paged_mqa_logits_stage1( q_fp8, kv_cache_fp8, @@ -647,6 +645,43 @@ def rocm_aiter_sparse_attn_indexer( k_cache_prefix = _resolve_layer_name(k_cache_prefix) # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): + # Profiling early-exit: reserve memory to account for runtime + # allocations. Must be in the real impl, not the fake impl — + # torch.compile calls the fake impl under FakeTensor mode where + # workspace manager operations on the locked real workspace + # would corrupt PyTorch's dispatch state. + workspace_manager = current_workspace_manager() + + # Prefill k_fp8 and k_scale buffers, used by + # rocm_aiter_sparse_attn_indexer's prefill path + workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + # Decode logits buffer, used by rocm_fp8_paged_mqa_logits. + # batch_size * next_n <= hidden_states.shape[0] == max_num_batched_tokens + if _ON_GFX942 or _ON_GFX950: + workspace_manager.get_simultaneous( + ((hidden_states.shape[0], max_model_len), torch.float32), + ) + else: + workspace_manager.get_simultaneous( + ( + (q_fp8.shape[1], hidden_states.shape[0], max_model_len), + torch.float32, + ), + ) + # Transient logits tensor peak memory, produced by + # rocm_fp8_mqa_logits (prefill) and rocm_fp8_paged_mqa_logits + # (decode). Prefill logits are bounded by + # VLLM_SPARSE_INDEXER_MAX_LOGITS_MB via chunking in + # split_indexer_prefill_chunks; decode logits are smaller. + max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + _ = torch.empty( + max_logits_elems, dtype=torch.uint8, device=hidden_states.device + ) + return rocm_aiter_sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -671,7 +706,6 @@ def rocm_aiter_sparse_attn_indexer( has_decode = layer_attn_metadata.num_decodes > 0 has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens - device = hidden_states.device if k is None else k.device # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. @@ -703,17 +737,15 @@ def rocm_aiter_sparse_attn_indexer( if has_prefill: prefill_metadata = layer_attn_metadata.prefill assert prefill_metadata is not None + + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=device, - dtype=torch.uint8, - ) + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] if _ON_GFX942: ops.cp_gather_indexer_k_quant_cache( kv_cache, @@ -731,7 +763,6 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seq_lens, token_to_seq=chunk.token_to_seq, ) - logits = rocm_fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale.view(torch.float32)), diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 897d063e830..8384def4664 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2056,13 +2056,17 @@ class Scheduler(SchedulerInterface): return spec_decoding_stats def shutdown(self) -> None: + logger.debug_once("[shutdown] Scheduler: start") if self.kv_event_publisher: self.kv_event_publisher.shutdown() if self.connector is not None: self.connector.shutdown() + if self.ec_connector is not None: self.ec_connector.shutdown() + logger.debug_once("[shutdown] Scheduler: complete") + ######################################################################## # KV Connector Related Methods ######################################################################## diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b12aa9d0505..08c814ab34e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -177,13 +177,13 @@ class EngineCore: if xfer_handshake_metadata: # xfer_handshake_metadata is list of dicts from workers - # Each dict already has structure {tp_rank: metadata} + # Each dict already has structure {(pp_rank, tp_rank): metadata} # Merge all worker dicts into a single dict - content: dict[int, Any] = {} + content: dict[tuple[int, int], Any] = {} for worker_dict in xfer_handshake_metadata: if worker_dict is not None: content.update(worker_dict) - kv_connector.set_xfer_handshake_metadata(content) + kv_connector.set_xfer_handshake_metadata_pp_aware(content) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -608,6 +608,7 @@ class EngineCore: self.abort_requests(request_ids) def shutdown(self): + logger.debug_once("[shutdown] EngineCore: tearing down local resources") self.structured_output_manager.clear_backend() if self.model_executor: self.model_executor.shutdown() @@ -622,6 +623,7 @@ class EngineCore: # Tear down distributed state initialized in this EngineCore process # before it exits and release cached memory. cleanup_dist_env_and_memory() + logger.debug_once("[shutdown] EngineCore: local resource teardown complete") def profile(self, is_start: bool = True, profile_prefix: str | None = None): self.model_executor.profile(is_start, profile_prefix) @@ -1172,6 +1174,11 @@ class EngineCoreProc(EngineCore): signal_callback = SignalCallback(wakeup_engine) def signal_handler(signum, frame): + signal_name = signal.Signals(signum).name + logger.info( + "[shutdown] EngineCore: trigger received signal=%s", + signal_name, + ) engine_core.shutdown_state = EngineShutdownState.REQUESTED signal_callback.trigger() @@ -1181,7 +1188,7 @@ class EngineCoreProc(EngineCore): engine_core.run_busy_loop() except SystemExit: - logger.debug("EngineCore exiting.") + logger.info_once("[shutdown] EngineCore: exiting busy loop") raise except Exception as e: if engine_core is None: @@ -1285,13 +1292,21 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.REQUESTED: shutdown_timeout = self.vllm_config.shutdown_timeout + mode = "abort" if shutdown_timeout == 0 else "drain" - logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout) + logger.info( + "[shutdown] EngineCore: start mode=%s timeout=%ds", + mode, + shutdown_timeout, + ) if shutdown_timeout == 0: num_requests = self.scheduler.get_num_unfinished_requests() if num_requests > 0: - logger.info("Aborting %d requests", num_requests) + logger.info( + "[shutdown] EngineCore: aborting in-flight requests count=%d", + num_requests, + ) aborted_reqs = self.scheduler.finish_requests( None, RequestStatus.FINISHED_ABORTED ) @@ -1300,7 +1315,8 @@ class EngineCoreProc(EngineCore): num_requests = self.scheduler.get_num_unfinished_requests() if num_requests > 0: logger.info( - "Draining %d in-flight requests (timeout=%ds)", + "[shutdown] EngineCore: draining in-flight requests " + "count=%d timeout=%ds", num_requests, shutdown_timeout, ) @@ -1309,7 +1325,10 @@ class EngineCoreProc(EngineCore): # Exit when no work remaining if not self.has_work(): - logger.info("Shutdown complete") + logger.info( + "[shutdown] EngineCore: request processing complete; " + "starting resource teardown" + ) return False return True @@ -1353,7 +1372,10 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.RUNNING: return False - logger.info("Rejecting request %s (server shutting down)", request.request_id) + logger.debug( + "[shutdown] EngineCore: rejecting new request request_id=%s", + request.request_id, + ) self._send_abort_outputs_to_client([request.request_id], request.client_index) return True @@ -1363,7 +1385,10 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.RUNNING: return False - logger.warning("Rejecting utility call %s (server shutting down)", method_name) + logger.warning( + "[shutdown] EngineCore: rejecting utility call method=%s", + method_name, + ) output = UtilityOutput(call_id, failure_message="Server shutting down") self.output_queue.put_nowait( (client_idx, EngineCoreOutputs(utility_output=output)) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 14257b020ee..32f2d091eb3 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -391,6 +391,7 @@ class BackgroundResources: def __call__(self): """Clean up background resources.""" + logger.debug_once("[shutdown] MPClient: background resource cleanup start") self.engine_dead = True if self.engine_manager is not None: self.engine_manager.shutdown() @@ -445,6 +446,8 @@ class BackgroundResources: # Send shutdown signal. shutdown_sender.send(b"") + logger.debug_once("[shutdown] MPClient: background resource cleanup complete") + def validate_alive(self, frames: Sequence[zmq.Frame]): if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True @@ -645,9 +648,15 @@ class MPClient(EngineCoreClient): def shutdown(self, timeout: float | None = None) -> None: """Shutdown engine manager under timeout and clean up resources.""" if self._finalizer.detach() is not None: + timeout_str = "default" if timeout is None else f"{timeout}s" + logger.info("[shutdown] MPClient: start timeout=%s", timeout_str) if self.resources.engine_manager is not None: + logger.info_once("[shutdown] MPClient: stopping engine manager") self.resources.engine_manager.shutdown(timeout=timeout) + logger.info_once("[shutdown] MPClient: engine manager stopped") + logger.info_once("[shutdown] MPClient: cleaning up background resources") self.resources() + logger.info_once("[shutdown] MPClient: complete") def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" @@ -687,6 +696,9 @@ class MPClient(EngineCoreClient): if not _self or not _self._finalizer.alive or _self.resources.engine_dead: return _self.resources.engine_dead = True + logger.warning_once( + "[shutdown] MPClient: engine core exited unexpectedly; starting cleanup" + ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 7beef598e27..4063844d469 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -203,7 +203,7 @@ class Executor(ABC): def get_kv_connector_handshake_metadata( self, - ) -> list[dict[int, KVConnectorHandshakeMetadata]]: + ) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata]]: return self.collective_rpc("get_kv_connector_handshake_metadata") @overload diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c5766c923c8..66564bebdb6 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -422,27 +422,45 @@ class MultiprocExecutor(Executor): return False active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] + initial_count = len(active_procs()) + # Give processes time to clean themselves up properly first - logger.debug("Worker Termination: allow workers to gracefully shutdown") + logger.info( + "[shutdown] Executor: waiting for worker exit count=%d", + initial_count, + ) if wait_for_termination(active_procs(), 4): + logger.info_once("[shutdown] Executor: all workers exited gracefully") return # Send SIGTERM if still running - logger.debug("Worker Termination: workers still running sending SIGTERM") - for p in active_procs(): + remaining = active_procs() + logger.warning( + "[shutdown] Executor: workers still running after grace period; " + "sending SIGTERM count=%d", + len(remaining), + ) + for p in remaining: p.terminate() if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running - logger.debug( - "Worker Termination: resorting to SIGKILL to take down workers" + remaining = active_procs() + logger.warning( + "[shutdown] Executor: workers still running after SIGTERM; " + "sending SIGKILL count=%d", + len(remaining), ) - for p in active_procs(): + for p in remaining: p.kill() def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, "shutting_down", False): - logger.debug("Triggering shutdown of workers") + worker_count = len(getattr(self, "workers", None) or []) + logger.debug( + "[shutdown] Executor: start worker_count=%d", + worker_count, + ) self.shutting_down = True # Make sure all the worker processes are terminated first. @@ -468,6 +486,8 @@ class MultiprocExecutor(Executor): mq.shutdown() self.response_mqs = [] + logger.debug_once("[shutdown] Executor: complete") + def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return @@ -867,7 +887,9 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") elif shutdown_requested.is_set(): - logger.info("WorkerProc shutting down.") + logger.debug_once( + "[shutdown] WorkerProc: exiting after shutdown request" + ) else: logger.exception("WorkerProc failed.") @@ -879,7 +901,12 @@ class WorkerProc: except SystemExit as e: # SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that # the graceful shutdown process did not succeed - logger.warning("WorkerProc was terminated") + if shutdown_requested.is_set(): + logger.debug_once( + "[shutdown] WorkerProc: terminated by shutdown signal" + ) + else: + logger.warning("WorkerProc was terminated") # SystemExit must never be ignored raise e diff --git a/vllm/v1/kv_offload/tiering/factory.py b/vllm/v1/kv_offload/tiering/factory.py index cbde45dfcf8..be703a03b3d 100644 --- a/vllm/v1/kv_offload/tiering/factory.py +++ b/vllm/v1/kv_offload/tiering/factory.py @@ -63,3 +63,9 @@ SecondaryTierFactory.register_tier( "vllm.v1.kv_offload.tiering.fs.manager", "FileSystemTierManager", ) + +SecondaryTierFactory.register_tier( + "obj", + "vllm.v1.kv_offload.tiering.obj.manager", + "ObjectStoreSecondaryTierManager", +) diff --git a/vllm/v1/kv_offload/tiering/obj/__init__.py b/vllm/v1/kv_offload/tiering/obj/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm/v1/kv_offload/tiering/obj/config.py b/vllm/v1/kv_offload/tiering/obj/config.py new file mode 100644 index 00000000000..5507c6a198e --- /dev/null +++ b/vllm/v1/kv_offload/tiering/obj/config.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Connection configuration for the object store secondary tier.""" + +from dataclasses import dataclass + + +@dataclass +class ObjStoreConfig: + """Connection parameters for an object store backend.""" + + bucket: str + endpoint_override: str + access_key: str + secret_key: str + scheme: str = "http" + ca_bundle: str = "" + + def to_nixl_params(self) -> dict[str, str]: + """Build the NIXL backend params dict.""" + params: dict[str, str] = { + "bucket": self.bucket, + "endpoint_override": self.endpoint_override, + "scheme": self.scheme, + "access_key": self.access_key, + "secret_key": self.secret_key, + } + if self.ca_bundle: + params["ca_bundle"] = self.ca_bundle + return params diff --git a/vllm/v1/kv_offload/tiering/obj/manager.py b/vllm/v1/kv_offload/tiering/obj/manager.py new file mode 100644 index 00000000000..2d7280ae379 --- /dev/null +++ b/vllm/v1/kv_offload/tiering/obj/manager.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Object store secondary tier implementation.""" + +import ctypes +from collections.abc import Iterable +from typing import TYPE_CHECKING, NamedTuple + +from vllm.distributed.nixl_utils import NixlWrapper as nixl_agent +from vllm.distributed.nixl_utils import nixl_agent_config +from vllm.logger import init_logger +from vllm.v1.kv_offload.base import OffloadKey, ReqContext +from vllm.v1.kv_offload.file_mapper import FileMapper +from vllm.v1.kv_offload.tiering.base import ( + JobMetadata, + JobResult, + RequestOffloadingContext, + SecondaryTierManager, +) +from vllm.v1.kv_offload.tiering.obj.config import ObjStoreConfig + +if TYPE_CHECKING: + from nixl._api import nixl_prepped_dlist_handle, nixl_xfer_handle + + from vllm.v1.kv_offload.base import OffloadingSpec + +logger = init_logger(__name__) + +NIXL_WRITE = "WRITE" +NIXL_READ = "READ" +NIXL_PROC = "PROC" +NIXL_DONE = "DONE" + +# Device ID for CPU DRAM descriptors. DRAM is not a multi-device resource so +# the device ID is always 0. +NIXL_DEV_ID: int = 0 + +# Fields for NIXL OBJ descriptors: (addr, len, dev_id, obj_key). +# For existence probes addr and len are placeholders — no data is read. +# dev_id=0 is reserved for probes; transfers start from 1. +_PROBE_ADDR: int = 0 +_PROBE_LEN: int = 1 +_PROBE_DEV_ID: int = 0 + + +class TransferEntry(NamedTuple): + xfer_handle: "nixl_xfer_handle" + files_desc: object + obj_handle: "nixl_prepped_dlist_handle" + + +class ObjectStoreSecondaryTierManager(SecondaryTierManager): + """Secondary tier that offloads KV cache blocks to an S3-compatible store. + + Handles CPU DRAM <-> S3 transfers only. GPU <-> CPU is managed by the + primary tier. Object keys are formed as ``{prefix}/{hash_shard}/{hash}.bin``. + """ + + def __init__( + self, + offloading_spec: "OffloadingSpec", + primary_kv_view: memoryview, + tier_type: str, + store_config: dict, + prefix: str = "", + io_threads: int = 4, + ): + super().__init__(offloading_spec, primary_kv_view, tier_type) + agent_config = nixl_agent_config(backends=[]) + self._agent = nixl_agent("ObjAgent", agent_config) + obj_config = ObjStoreConfig(**store_config) + params = {**obj_config.to_nixl_params(), "num_threads": str(io_threads)} + self._agent.create_backend("OBJ", params) + self._transfers: dict[int, TransferEntry] = {} + self._failed_jobs: list[JobResult] = [] + self._primary_reg = None + self._block_size_bytes: int = 0 + root_dir = f"{prefix}/" if prefix else "" + self._file_mapper = FileMapper.from_offloading_spec(root_dir, offloading_spec) + self._next_obj_dev_id: int = 1 # dev_id=0 is reserved for _exists() probes + + self._probe_connectivity() + + base_addr = ctypes.addressof(ctypes.c_char.from_buffer(primary_kv_view)) + assert primary_kv_view.strides is not None + stride = primary_kv_view.strides[0] + self._primary_reg = self._agent.register_memory( + [(base_addr, primary_kv_view.nbytes, NIXL_DEV_ID, "")], "DRAM" + ) + self._block_size_bytes = stride + all_blocks = [ + (base_addr + i * stride, stride, NIXL_DEV_ID) + for i in range(len(primary_kv_view)) + ] + # NIXL_INIT_AGENT marks this as the local side; make_prepped_xfer requires + # local_xfer_side tagged with NIXL_INIT_AGENT and remote_xfer_side tagged + # with the peer agent name ("ObjAgent"). + self._dram_prepped_handle: nixl_prepped_dlist_handle = ( + self._agent.prep_xfer_dlist("NIXL_INIT_AGENT", all_blocks, "DRAM") + ) + + def _probe_connectivity(self) -> None: + """Verify object store connectivity at startup via a NIXL lookup probe. + + Performs a single exists() check against a synthetic key that will + never exist. A True/False result confirms the bucket is reachable; + an exception indicates misconfigured obj store params and raises RuntimeError. + """ + probe_key = "__nixl_probe__/connectivity_test" + try: + self._exists(probe_key) + logger.info("Object store tier connectivity probe succeeded") + except Exception as e: + raise RuntimeError( + f"Object store tier connectivity probe failed — check bucket, " + f"endpoint_override, access_key, secret_key, and scheme. " + f"Error: {e}" + ) from e + + def _exists(self, obj_key: str) -> bool: + results = self._agent.query_memory( + [(_PROBE_ADDR, _PROBE_LEN, _PROBE_DEV_ID, obj_key)], "OBJ", "OBJ" + ) + return results[0] is not None + + def _submit_transfer( + self, + job_id: int, + block_ids: Iterable[int], + obj_keys: Iterable[str], + op: str, + ) -> None: + """Submit an async transfer. op is 'WRITE' (store) or 'READ' (load).""" + block_ids_list = [int(bid) for bid in block_ids] + # The OBJ backend maps devId -> obj_key. All descriptors must have + # unique devIds or later registrations overwrite earlier ones. + nixl_files = [ + (0, self._block_size_bytes, dev_id, key) + for dev_id, key in enumerate(obj_keys, self._next_obj_dev_id) + ] + self._next_obj_dev_id += len(nixl_files) + + files_desc = self._agent.register_memory(nixl_files, "OBJ") + if files_desc is None: + logger.warning("register_memory (OBJ) failed for job %d", job_id) + self._failed_jobs.append(JobResult(job_id=job_id, success=False)) + return + + obj_handle = self._agent.prep_xfer_dlist("ObjAgent", files_desc.trim()) + if not obj_handle: + logger.warning("prep_xfer_dlist (OBJ) failed for job %d", job_id) + self._agent.deregister_memory(files_desc) + self._failed_jobs.append(JobResult(job_id=job_id, success=False)) + return + + xfer_handle = self._agent.make_prepped_xfer( + op, + self._dram_prepped_handle, + block_ids_list, + obj_handle, + list(range(len(nixl_files))), + ) + if not xfer_handle: + logger.warning("make_prepped_xfer failed for job %d", job_id) + self._agent.release_dlist_handle(obj_handle) + self._agent.deregister_memory(files_desc) + self._failed_jobs.append(JobResult(job_id=job_id, success=False)) + return + + state = self._agent.transfer(xfer_handle) + if state == "ERR": + logger.warning("agent.transfer failed for job %d", job_id) + self._agent.release_dlist_handle(obj_handle) + self._agent.deregister_memory(files_desc) + self._agent.release_xfer_handle(xfer_handle) + self._failed_jobs.append(JobResult(job_id=job_id, success=False)) + return + + self._transfers[job_id] = TransferEntry(xfer_handle, files_desc, obj_handle) + + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: + try: + return self._exists(self._file_mapper.get_file_name(key)) + except Exception as e: + logger.warning("lookup failed for key %s: %s", key, e) + return False + + def submit_store(self, job_metadata: JobMetadata) -> None: + obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys) + self._submit_transfer( + job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_WRITE + ) + + def submit_load(self, job_metadata: JobMetadata) -> None: + obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys) + self._submit_transfer( + job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_READ + ) + + def on_new_request(self, req_context: ReqContext) -> RequestOffloadingContext: + return RequestOffloadingContext() + + def get_finished_jobs(self) -> Iterable[JobResult]: + """Poll in-flight transfers; return completed (job_id, success) pairs.""" + results: list[JobResult] = self._failed_jobs + self._failed_jobs = [] + for job_id, entry in list(self._transfers.items()): + try: + state = self._agent.check_xfer_state(entry.xfer_handle) + except Exception as exc: + success = False + logger.warning("check_xfer_state raised for job %d: %s", job_id, exc) + else: + if state == NIXL_PROC: + continue + elif state == NIXL_DONE: + success = True + else: + success = False + logger.warning("transfer failed job=%d state=%s", job_id, state) + del self._transfers[job_id] + self._agent.release_xfer_handle(entry.xfer_handle) + self._agent.release_dlist_handle(entry.obj_handle) + self._agent.deregister_memory(entry.files_desc) + results.append(JobResult(job_id=job_id, success=success)) + return results + + def shutdown(self) -> None: + for job_id, entry in self._transfers.items(): + try: + self._agent.release_xfer_handle(entry.xfer_handle) + except Exception as exc: + logger.warning("release_xfer_handle failed for job %d: %s", job_id, exc) + try: + self._agent.release_dlist_handle(entry.obj_handle) + except Exception as exc: + logger.warning( + "release_dlist_handle failed for job %d: %s", job_id, exc + ) + try: + self._agent.deregister_memory(entry.files_desc) + except Exception as exc: + logger.warning("deregister_memory failed for job %d: %s", job_id, exc) + self._transfers.clear() + if self._dram_prepped_handle is not None: + try: + self._agent.release_dlist_handle(self._dram_prepped_handle) + except Exception as exc: + logger.warning("failed to release DRAM prepped handle: %s", exc) + self._dram_prepped_handle = None + if self._primary_reg is not None: + try: + self._agent.deregister_memory(self._primary_reg) + except Exception as exc: + logger.warning("failed to deregister primary buffer: %s", exc) + self._primary_reg = None diff --git a/vllm/v1/kv_offload/tiering/spec.py b/vllm/v1/kv_offload/tiering/spec.py index a4ea46e08eb..f223d81aa5e 100644 --- a/vllm/v1/kv_offload/tiering/spec.py +++ b/vllm/v1/kv_offload/tiering/spec.py @@ -131,8 +131,8 @@ class TieringOffloadingSpec(CPUOffloadingSpec): ) except Exception as e: logger.error( - "Failed to create secondary tier from config %s: %s", - tier_config, + "Failed to create secondary tier from config index %i: %s", + i, e, ) raise diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index baa0e77119b..6f324ce9850 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -75,9 +75,14 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: + def __init__( + self, + logprobs_mode: LogprobsMode = "raw_logprobs", + use_fp64_gumbel: bool = False, + ) -> None: super().__init__() self.logprobs_mode = logprobs_mode + self.use_fp64_gumbel = use_fp64_gumbel if current_platform.is_cuda(): # FlashInfer doesn't expose post-top-k/top-p logits/logprobs, # so it can't be used when the configured mode requires them. @@ -142,7 +147,10 @@ class TopKTopPSampler(nn.Module): elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators), logits_to_return + return ( + random_sample(probs, generators, self.use_fp64_gumbel), + logits_to_return, + ) def forward_cuda( self, @@ -163,6 +171,8 @@ class TopKTopPSampler(nn.Module): "PyTorch-native implementation." ) return self.forward_native(logits, generators, k, p) + if self.use_fp64_gumbel: + return self.forward_native(logits, generators, k, p) assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( "FlashInfer does not support returning logits/logprobs" ) @@ -190,16 +200,16 @@ class TopKTopPSampler(nn.Module): elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - if len(generators) != logits.shape[0]: + if len(generators) != logits.shape[0] and not self.use_fp64_gumbel: return compiled_random_sample(logits), logits_to_return probs = logits.softmax(dim=-1, dtype=torch.float32) - q = torch.empty_like(probs) + q = empty_exponential_noise_like(probs, self.use_fp64_gumbel) q.exponential_() for i, generator in generators.items(): q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return + return sample_with_exponential_noise(probs, q), logits_to_return def forward_hip( self, @@ -216,6 +226,8 @@ class TopKTopPSampler(nn.Module): "falling back to PyTorch-native." ) return self.forward_native(logits, generators, k, p) + if self.use_fp64_gumbel: + return self.forward_native(logits, generators, k, p) assert self.logprobs_mode not in ( "processed_logits", "processed_logprobs", @@ -404,16 +416,33 @@ def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: return logits.masked_fill_(logits < top_k_mask, -float("inf")) +def empty_exponential_noise_like( + probs: torch.Tensor, use_fp64_gumbel: bool +) -> torch.Tensor: + dtype = torch.float64 if use_fp64_gumbel else probs.dtype + return torch.empty(probs.shape, dtype=dtype, device=probs.device) + + +def sample_with_exponential_noise(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + if q.dtype == probs.dtype: + scores = probs.div_(q) + else: + scores = q.reciprocal_() + scores.mul_(probs) + return scores.argmax(dim=-1).view(-1) + + def random_sample( probs: torch.Tensor, generators: dict[int, torch.Generator], + use_fp64_gumbel: bool = False, ) -> torch.Tensor: """Randomly sample from the probabilities. We use this function instead of torch.multinomial because torch.multinomial causes CPU-GPU synchronization. """ - q = torch.empty_like(probs) + q = empty_exponential_noise_like(probs, use_fp64_gumbel) # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests @@ -425,7 +454,7 @@ def random_sample( # one by one. Optimize this. for i, generator in generators.items(): q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1) + return sample_with_exponential_noise(probs, q) def flashinfer_sample( diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 678654cb78a..153677e35fa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -65,6 +65,7 @@ class RejectionSampler(nn.Module): ): super().__init__() self.sampler = sampler + self.use_fp64_gumbel = getattr(sampler, "use_fp64_gumbel", False) logprobs_mode = self.sampler.logprobs_mode self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") @@ -176,6 +177,7 @@ class RejectionSampler(nn.Module): sampling_metadata, synthetic_mode=self.synthetic_mode, synthetic_conditional_rates=self.synthetic_conditional_rates, + use_fp64_gumbel=self.use_fp64_gumbel, ) logprobs_tensors = None @@ -406,6 +408,7 @@ def rejection_sample( sampling_metadata: SamplingMetadata, synthetic_mode: bool = False, synthetic_conditional_rates: torch.Tensor | None = None, + use_fp64_gumbel: bool = False, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 @@ -480,6 +483,7 @@ def rejection_sample( target_probs, sampling_metadata, device, + use_fp64_gumbel, ) # Rejection sampling for random sampling requests. @@ -669,13 +673,15 @@ def sample_recovered_tokens( target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, + use_fp64_gumbel: bool = False, ) -> torch.Tensor: # NOTE(woosuk): Create only one distribution for each request. batch_size = len(num_draft_tokens) vocab_size = target_probs.shape[-1] + q_dtype = torch.float64 if use_fp64_gumbel else torch.float32 q = torch.empty( (batch_size, vocab_size), - dtype=torch.float32, + dtype=q_dtype, device=device, ) q.exponential_() @@ -699,6 +705,7 @@ def sample_recovered_tokens( vocab_size, BLOCK_SIZE, NO_DRAFT_PROBS=draft_probs is None, + USE_FP64_GUMBEL=use_fp64_gumbel, ) return recovered_token_ids @@ -861,6 +868,7 @@ def sample_recovered_tokens_kernel( vocab_size, BLOCK_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, + USE_FP64_GUMBEL: tl.constexpr, ): req_idx = tl.program_id(0) start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) @@ -877,7 +885,10 @@ def sample_recovered_tokens_kernel( if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + token_idx) - max_val = float("-inf") + if USE_FP64_GUMBEL: + max_val = tl.full((), float("-inf"), tl.float64) + else: + max_val = tl.full((), float("-inf"), tl.float32) recovered_id = 0 for v in range(0, vocab_size, BLOCK_SIZE): vocab_offset = v + tl.arange(0, BLOCK_SIZE) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 9ac3821a326..eadc009c254 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -58,11 +58,16 @@ class Sampler(nn.Module): 9. Return the final `SamplerOutput`. """ - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + def __init__( + self, + logprobs_mode: LogprobsMode = "raw_logprobs", + use_fp64_gumbel: bool = False, + ): super().__init__() - self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode, use_fp64_gumbel) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode + self.use_fp64_gumbel = use_fp64_gumbel def forward( self, diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index aa1bf270c1c..38db5483937 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -32,6 +32,10 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import ( + empty_exponential_noise_like, + sample_with_exponential_noise, +) from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.utils import ( @@ -113,6 +117,7 @@ class SpecDecodeBaseProposer: self.use_local_argmax_reduction: bool = ( self.speculative_config.use_local_argmax_reduction ) + self.use_fp64_gumbel = vllm_config.model_config.use_fp64_gumbel self.max_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -409,7 +414,9 @@ class SpecDecodeBaseProposer: return logits.argmax(dim=-1), None if sampling_metadata.all_greedy: return logits.argmax(dim=-1), None - return compute_probs_and_sample_next_token(logits, sampling_metadata) + return compute_probs_and_sample_next_token( + logits, sampling_metadata, self.use_fp64_gumbel + ) def _sample_draft_tokens( self, @@ -1656,6 +1663,7 @@ class SpecDecodeBaseProposer: def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, + use_fp64_gumbel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. @@ -1682,11 +1690,11 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) + q = empty_exponential_noise_like(probs, use_fp64_gumbel) q.exponential_() # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs # will be used later for rejection sampling. - next_token_ids = probs.div(q).argmax(dim=-1).view(-1) + next_token_ids = sample_with_exponential_noise(probs.clone(), q) if not sampling_metadata.all_random: greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index f11c92a805d..ffdf43f54c3 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -444,6 +444,12 @@ def _shutdown_subprocesses( timeout = 0.0 timeout = max(timeout, 5.0) + logger.debug( + "[shutdown] Subprocess manager: start process_count=%d timeout=%ss", + len(procs), + timeout, + ) + for proc in procs: if proc.is_alive(): proc.terminate() @@ -456,9 +462,18 @@ def _shutdown_subprocesses( if proc.is_alive(): proc.join(remaining) - for proc in procs: - if proc.is_alive() and (pid := proc.pid) is not None: - kill_process_tree(pid) + remaining_pids = [ + proc.pid for proc in procs if proc.is_alive() and proc.pid is not None + ] + if remaining_pids: + logger.warning( + "[shutdown] Subprocess manager: force killing remaining processes count=%d", + len(remaining_pids), + ) + for pid in remaining_pids: + kill_process_tree(pid) + + logger.debug_once("[shutdown] Subprocess manager: complete") def run_api_server_worker_proc( @@ -565,6 +580,12 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: # have a user-configured shutdown timeout. timeout = 5.0 + logger.debug( + "[shutdown] Process manager: start process_count=%d timeout=%ss", + len(procs), + timeout, + ) + # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -579,9 +600,18 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: if proc.is_alive(): proc.join(remaining) - for proc in procs: - if proc.is_alive() and (pid := proc.pid) is not None: - kill_process_tree(pid) + remaining_pids = [ + proc.pid for proc in procs if proc.is_alive() and proc.pid is not None + ] + if remaining_pids: + logger.warning( + "[shutdown] Process manager: force killing remaining processes count=%d", + len(remaining_pids), + ) + for pid in remaining_pids: + kill_process_tree(pid) + + logger.debug_once("[shutdown] Process manager: complete") def copy_slice( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 261995f4b01..801a8574ac7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -504,7 +504,10 @@ class GPUModelRunner( self.use_async_scheduling = self.scheduler_config.async_scheduling # Sampler - self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + self.sampler = Sampler( + logprobs_mode=self.model_config.logprobs_mode, + use_fp64_gumbel=self.model_config.use_fp64_gumbel, + ) self.eplb_state: EplbState | None = None self._moe_model: MixtureOfExperts | None = None @@ -1875,9 +1878,8 @@ class GPUModelRunner( SpecDecodeMetadata | None, ]: """ - :return: tuple[ - logits_indices, spec_decode_metadata, - ] + Returns: + tuple[logits_indices, spec_decode_metadata] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -2202,7 +2204,8 @@ class GPUModelRunner( slot_mappings: dict[int, torch.Tensor] | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ - :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + Returns: + tuple[attn_metadata, spec_decode_common_attn_metadata] """ # Attention metadata is not needed for attention free models if len(self.kv_cache_config.kv_cache_groups) == 0: @@ -2500,9 +2503,11 @@ class GPUModelRunner( num_common_prefix_blocks: list[int], ) -> list[list[int]] | None: """ - :return: Optional[cascade_attn_prefix_lens] - cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``, - None if we should not use cascade attention + Returns: + Optional[cascade_attn_prefix_lens] + cascade_attn_prefix_lens is 2D: + ``[kv_cache_group_id][attn_group_idx]``, + None if we should not use cascade attention """ use_cascade_attn = False @@ -5321,11 +5326,12 @@ class GPUModelRunner( """ Reload weights from a weights iterator or from disk - :param weights_iterator: weights to load into model - :param weights_path: path to load weights from if weights_iterator is not - provided. Use path of original model if neither is provided. - :param is_checkpoint_format: set to False if weights have already been processed - into kernel format (repacking, renaming, etc.) + Args: + weights_iterator: weights to load into model + weights_path: path to load weights from if weights_iterator is not + provided. Use path of original model if neither is provided. + is_checkpoint_format: set to False if weights have already been + processed into kernel format (repacking, renaming, etc.) """ # TODO(@kylesayrs): generalize to all runners and loaders # argument validation @@ -5787,6 +5793,9 @@ class GPUModelRunner( num_scheduled_tokens, self.query_pos.np ) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1].fill( + cum_num_tokens[-1] + ) self.query_start_loc.copy_to_gpu() # Sync block table CPU->GPU so cleared rows from diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 1b30a981e21..bf6b287d8de 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -34,6 +34,9 @@ from vllm.distributed.kv_transfer import ( get_kv_transfer_group, has_kv_transfer_group, ) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata, +) from vllm.distributed.parallel_state import ( Handle, get_pp_group, @@ -513,8 +516,13 @@ class Worker(WorkerBase): return int(self.available_kv_cache_memory_bytes) - def get_kv_connector_handshake_metadata(self) -> dict | None: - """Get KV connector metadata from this worker if available.""" + def get_kv_connector_handshake_metadata( + self, + ) -> dict[tuple[int, int], KVConnectorHandshakeMetadata] | None: + """Get KV connector metadata from this worker if available. + + Returned dict is keyed by `(pp_rank, tp_rank)`. + """ if not has_kv_transfer_group(): return None @@ -525,8 +533,9 @@ class Worker(WorkerBase): if (metadata := connector.get_handshake_metadata()) is None: return None + pp_rank = get_pp_group().rank_in_group tp_rank = get_tp_group().rank_in_group - return {tp_rank: metadata} + return {(pp_rank, tp_rank): metadata} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec()