fix conflicts

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256
2026-06-03 17:24:48 +00:00
372 changed files with 15679 additions and 4257 deletions
+23
View File
@@ -0,0 +1,23 @@
name: vllm_rocm_ci
job_dirs:
- ".buildkite/hardware_tests"
run_all_patterns:
- "docker/Dockerfile.rocm"
- "docker/Dockerfile.rocm_base"
- "docker/ci-rocm.hcl"
- "docker/docker-bake-rocm.hcl"
- ".buildkite/hardware_tests/amd.yaml"
- ".buildkite/scripts/ci-bake-rocm.sh"
- ".buildkite/scripts/hardware_ci/run-amd-test.py"
- ".buildkite/scripts/hardware_ci/run-amd-test.sh"
- "CMakeLists.txt"
- "requirements/common.txt"
- "requirements/rocm.txt"
- "requirements/build/rocm.txt"
- "requirements/test/rocm.txt"
- "setup.py"
- "csrc/"
- "cmake/"
run_all_exclude_patterns:
- "csrc/cpu/"
- "cmake/cpu_extension.cmake"
+66 -35
View File
@@ -1,42 +1,73 @@
group: Hardware - AMD Build
group: Hardware - AMD Build
steps:
- label: "AMD: :docker: build image"
key: image-build-amd
# Ensure ci_base is up-to-date before building the test image.
# Compares a content hash of ci_base-affecting files against the remote
# image label. If hashes match the build is skipped (< 30 s); if they
# differ ci_base is rebuilt and pushed automatically.
- label: "AMD: :docker: ensure ci_base"
key: ensure-ci-base-amd
depends_on: []
device: amd_cpu
no_plugin: true
commands:
- >
docker build
--build-arg max_jobs=16
--build-arg REMOTE_VLLM=1
--build-arg ARG_PYTORCH_ROCM_ARCH='gfx90a;gfx942;gfx950'
--build-arg VLLM_BRANCH=$BUILDKITE_COMMIT
--tag "rocm/vllm-ci:${BUILDKITE_COMMIT}"
-f docker/Dockerfile.rocm
--target test
--no-cache
--progress plain .
- |
docker run --rm --network=none --entrypoint /bin/bash "rocm/vllm-ci:${BUILDKITE_COMMIT}" -ec '
if [ ! -d /vllm-workspace ]; then echo Missing directory: /vllm-workspace >&2; exit 1; fi
if [ ! -d /vllm-workspace/tests ]; then echo Missing directory: /vllm-workspace/tests >&2; exit 1; fi
if [ ! -d /vllm-workspace/src/vllm ]; then echo Missing directory: /vllm-workspace/src/vllm >&2; exit 1; fi
if [ ! -x /vllm-workspace/src/vllm/vllm-rs ]; then echo Missing executable: /vllm-workspace/src/vllm/vllm-rs >&2; exit 1; fi
command -v python3
command -v uv
command -v pytest
if ! command -v amd-smi >/dev/null 2>&1 && ! command -v rocminfo >/dev/null 2>&1; then
echo No ROCm CLI found in image >&2
exit 1
fi
python3 - <<PY
import torch, vllm
print(torch.__version__)
print(vllm.__version__)
PY
echo AMD image smoke OK
'
- docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}"
- bash .buildkite/scripts/ci-bake-rocm.sh ci-base-rocm-ci-with-deps
env:
DOCKER_BUILDKIT: "1"
VLLM_BAKE_FILE: "docker/docker-bake-rocm.hcl"
PYTORCH_ROCM_ARCH: "gfx90a;gfx942;gfx950"
REMOTE_VLLM: "1"
VLLM_BRANCH: "$BUILDKITE_COMMIT"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 1
- exit_status: -10 # Agent was lost
limit: 1
- label: "AMD: :docker: build test image and artifacts"
key: image-build-amd
depends_on:
- ensure-ci-base-amd
device: amd_cpu
no_plugin: true
commands:
- |
if [[ "${ROCM_CI_ARTIFACT_ONLY:-0}" == "1" ]]; then
echo "ROCM_CI_ARTIFACT_ONLY=1; building ROCm wheel artifact only"
IMAGE_TAG="" bash .buildkite/scripts/ci-bake-rocm.sh test-rocm-ci-with-artifacts
else
bash .buildkite/scripts/ci-bake-rocm.sh test-rocm-ci-with-wheel
fi
- |
docker run --rm --network=none --entrypoint /bin/bash "rocm/vllm-ci:${BUILDKITE_COMMIT}" -ec '
if [ ! -d /vllm-workspace ]; then echo Missing directory: /vllm-workspace >&2; exit 1; fi
if [ ! -d /vllm-workspace/tests ]; then echo Missing directory: /vllm-workspace/tests >&2; exit 1; fi
if [ ! -d /vllm-workspace/src/vllm ]; then echo Missing directory: /vllm-workspace/src/vllm >&2; exit 1; fi
if [ ! -x /vllm-workspace/src/vllm/vllm-rs ]; then echo Missing executable: /vllm-workspace/src/vllm/vllm-rs >&2; exit 1; fi
command -v python3
command -v uv
command -v pytest
if ! command -v amd-smi >/dev/null 2>&1 && ! command -v rocminfo >/dev/null 2>&1; then
echo No ROCm CLI found in image >&2
exit 1
fi
python3 - <<PY
import torch, vllm
print(torch.__version__)
print(vllm.__version__)
PY
echo AMD image smoke OK
'
env:
DOCKER_BUILDKIT: "1"
VLLM_BAKE_FILE: "docker/docker-bake-rocm.hcl"
PYTORCH_ROCM_ARCH: "gfx90a;gfx942;gfx950"
IMAGE_TAG: "rocm/vllm-ci:$BUILDKITE_COMMIT"
REMOTE_VLLM: "1"
VLLM_BRANCH: "$BUILDKITE_COMMIT"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 1
- exit_status: -10 # Agent was lost
limit: 1
+3 -1
View File
@@ -16,6 +16,7 @@ steps:
- tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
- tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
- tests/kernels/mamba/cpu/test_cpu_gdn_ops.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 30m "
@@ -24,7 +25,8 @@ steps:
pytest -x -v -s tests/kernels/moe/test_cpu_quant_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py"
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
- label: CPU-Compatibility Tests
depends_on: []
@@ -2,7 +2,6 @@
{
"test_name": "latency_llama8B_tp1",
"environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_KVCACHE_SPACE": 40
@@ -2,7 +2,6 @@
{
"test_name": "latency_llama8B_tp2",
"environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
@@ -13,7 +13,6 @@
200
],
"server_environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
@@ -5,7 +5,6 @@
],
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
"server_environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120
},
"server_parameters": {
@@ -9,7 +9,6 @@
128
],
"server_environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
@@ -5,7 +5,6 @@
],
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
"server_environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
@@ -5,7 +5,6 @@
],
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
"server_environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
@@ -2,7 +2,6 @@
{
"test_name": "throughput_llama8B_tp1",
"environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_KVCACHE_SPACE": 40
@@ -2,7 +2,6 @@
{
"test_name": "throughput_llama8B_tp2",
"environment_variables": {
"VLLM_RPC_TIMEOUT": 100000,
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
"VLLM_CPU_SGL_KERNEL": 1,
File diff suppressed because it is too large Load Diff
+121 -11
View File
@@ -52,6 +52,108 @@ cleanup_network() {
fi
}
prepare_artifact_image() {
if [[ "${VLLM_CI_USE_ARTIFACTS:-0}" != "1" ]]; then
return 1
fi
if ! command -v buildkite-agent >/dev/null 2>&1; then
echo "buildkite-agent not found; cannot download ROCm wheel artifact"
return 1
fi
local artifact_glob="${VLLM_CI_ARTIFACT_GLOB:-artifacts/vllm-rocm-install/vllm-rocm-install.tar.gz}"
local archive=""
local metadata_file=""
local base_image="${VLLM_CI_BASE_IMAGE:-rocm/vllm-dev:ci_base}"
local artifact_image=""
local artifact_key=""
local base_digest=""
local wheel_dir=""
local context_dir=""
local workspace_dir=""
artifact_work_dir=$(mktemp -d -t vllm-rocm-artifact.XXXXXX)
wheel_dir="${artifact_work_dir}/wheels"
context_dir="${artifact_work_dir}/context"
workspace_dir="${context_dir}/workspace"
mkdir -p "${wheel_dir}" "${context_dir}/wheels" "${workspace_dir}"
echo "--- Downloading ROCm wheel artifact"
if ! buildkite-agent artifact download "${artifact_glob}" "${artifact_work_dir}"; then
echo "Failed to download ${artifact_glob}"
return 1
fi
buildkite-agent artifact download \
"artifacts/vllm-rocm-install/ci-base-image.txt" \
"${artifact_work_dir}" >/dev/null 2>&1 || true
archive=$(find "${artifact_work_dir}" -name "vllm-rocm-install.tar.gz" -type f | head -1)
if [[ -z "${archive}" || ! -f "${archive}" ]]; then
echo "ROCm wheel artifact archive was not found"
return 1
fi
metadata_file=$(find "${artifact_work_dir}" -name "ci-base-image.txt" -type f | head -1)
if [[ -n "${metadata_file}" && -s "${metadata_file}" ]]; then
base_image=$(tr -d '[:space:]' < "${metadata_file}")
fi
echo "--- Preparing local ROCm test image"
echo "Base image: ${base_image}"
docker pull "${base_image}" || return 1
base_digest=$(
docker image inspect \
--format='{{if .RepoDigests}}{{index .RepoDigests 0}}{{else}}{{.Id}}{{end}}' \
"${base_image}" 2>/dev/null || printf '%s' "${base_image}"
)
artifact_key=$(
{
printf 'base-image:%s\n' "${base_digest}"
sha256sum "${archive}"
} | sha256sum | cut -c1-24
)
artifact_image="rocm/vllm-ci-artifact:${artifact_key}"
if docker image inspect "${artifact_image}" >/dev/null 2>&1; then
echo "Using existing local ROCm artifact image: ${artifact_image}"
image_name="${artifact_image}"
return 0
fi
tar -xzf "${archive}" -C "${wheel_dir}" || return 1
if ! ls "${wheel_dir}"/*.whl >/dev/null 2>&1; then
echo "ROCm wheel artifact did not contain a wheel"
return 1
fi
if [[ ! -d "${wheel_dir}/tests" ]]; then
echo "ROCm wheel artifact did not contain the test workspace"
return 1
fi
cp "${wheel_dir}"/*.whl "${context_dir}/wheels/" || return 1
tar -C "${wheel_dir}" --exclude='*.whl' -cf - . \
| tar -C "${workspace_dir}" -xf - || return 1
cat > "${context_dir}/Dockerfile" <<'EOF'
ARG BASE_IMAGE
FROM ${BASE_IMAGE}
COPY wheels/ /tmp/vllm-wheels/
COPY workspace/ /vllm-workspace/
RUN python3 -m pip install --no-deps --force-reinstall /tmp/vllm-wheels/*.whl \
&& rm -rf /tmp/vllm-wheels
WORKDIR /vllm-workspace
EOF
echo "--- Building local ROCm test image"
docker build \
--pull=false \
--build-arg "BASE_IMAGE=${base_image}" \
-t "${artifact_image}" \
"${context_dir}" || return 1
image_name="${artifact_image}"
return 0
}
is_multi_node() {
local cmds="$1"
# Primary signal: NUM_NODES environment variable set by the pipeline
@@ -243,22 +345,30 @@ report_docker_usage
# --- Pull test image ---
echo "--- Pulling container"
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
image_name="${VLLM_CI_FALLBACK_IMAGE:-rocm/vllm-ci:${BUILDKITE_COMMIT:-local}}"
artifact_work_dir=""
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
docker pull "${image_name}"
remove_docker_container() {
# docker run uses --rm, so the container is normally already gone when the
# EXIT trap runs. Cleanup is best-effort and must not affect the test result.
docker rm -f "${container_name}" >/dev/null 2>&1 || true
if docker container inspect "${container_name}" >/dev/null 2>&1; then
docker rm -f "${container_name}" || true
fi
if [[ "${VLLM_CI_REMOVE_TEST_IMAGE:-0}" == "1" ]]; then
docker image rm -f "${image_name}" || true
else
# Keep images by default so later jobs on the same AMD node can reuse layers.
echo "Keeping ROCm test image locally: ${image_name}"
fi
if [[ -n "${artifact_work_dir}" ]]; then
rm -rf "${artifact_work_dir}"
fi
}
trap remove_docker_container EXIT
on_exit() {
local exit_code=$?
remove_docker_container
exit "$exit_code"
}
trap on_exit EXIT
if ! prepare_artifact_image; then
echo "Using full ROCm CI image: ${image_name}"
docker pull "${image_name}" || exit 1
fi
# --- Prepare commands ---
echo "--- Running container"
@@ -37,7 +37,8 @@ function cpu_tests() {
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/core/test_cpu_activation.py
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic"
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
# skip tests requiring model downloads if HF_TOKEN is not set
# due to rate-limits
+39
View File
@@ -0,0 +1,39 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
set -euo pipefail
REQUIREMENTS_FILE="${KV_CONNECTORS_REQUIREMENTS:-/vllm-workspace/requirements/kv_connectors.txt}"
uv pip install --system -r "${REQUIREMENTS_FILE}"
NIXL_METADATA=$(python3 - <<'PY'
import importlib.metadata as metadata
import torch
cuda_version = torch.version.cuda
if cuda_version is None:
raise SystemExit("torch.version.cuda is not set")
print(cuda_version.split(".", 1)[0], metadata.version("nixl"))
PY
)
read -r CUDA_MAJOR NIXL_VERSION <<<"${NIXL_METADATA}"
# nixl>=1.1.0 can install multiple CUDA wheel variants. Keep only the variant
# matching this CI image so nixl_ep_cpp links against the available libcudart.
uv pip uninstall --system nixl-cu12 nixl-cu13 2>/dev/null || true
uv pip install --system --no-deps "nixl-cu${CUDA_MAJOR}==${NIXL_VERSION}"
python3 - <<'PY'
import importlib.metadata as metadata
for package_name in ("nixl", "nixl-cu12", "nixl-cu13"):
try:
version = metadata.version(package_name)
except metadata.PackageNotFoundError:
version = "not installed"
print(f"{package_name}: {version}")
PY
+11 -13
View File
@@ -1238,14 +1238,11 @@ steps:
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/
- tests/entrypoints/rpc
- tests/entrypoints/serve/instrumentator
- tests/tool_use
- tests/entrypoints/serve
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/serve/instrumentator
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- pytest -v -s tool_use
- pytest -v -s entrypoints/serve --ignore=entrypoints/serve/dev/rpc
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/serve/dev/rpc
- label: Entrypoints Integration (API Server openai - Part 1) # TBD
timeout_in_minutes: 180
@@ -1276,11 +1273,13 @@ steps:
- tests/entrypoints/openai
- tests/entrypoints/test_chat_utils
- tests/entrypoints/generate
- tests/tool_use
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai/completion --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/generate
- pytest -v -s tool_use
- label: Entrypoints Integration (API Server openai - Part 3) # TBD
timeout_in_minutes: 180
@@ -1370,7 +1369,7 @@ steps:
- vllm/platforms/rocm.py
commands:
- pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/serve/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: OpenAI API correctness # TBD
timeout_in_minutes: 180
@@ -2747,14 +2746,11 @@ steps:
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/
- tests/entrypoints/rpc
- tests/entrypoints/serve/instrumentator
- tests/tool_use
- tests/entrypoints/serve
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/serve/instrumentator
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- pytest -v -s tool_use
- pytest -v -s entrypoints/serve --ignore=entrypoints/serve/dev/rpc
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/serve/dev/rpc
- label: Entrypoints Integration (API Server openai - Part 1) # TBD
timeout_in_minutes: 180
@@ -2785,11 +2781,13 @@ steps:
- tests/entrypoints/openai
- tests/entrypoints/test_chat_utils
- tests/entrypoints/generate
- tests/tool_use
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/openai/completion --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/generate
- pytest -v -s tool_use
- label: Entrypoints Integration (API Server openai - Part 3) # TBD
timeout_in_minutes: 180
+9 -9
View File
@@ -11,7 +11,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: Distributed FlashInfer NixlConnector PD accuracy (4 GPUs)
key: distributed-flashinfer-nixlconnector-pd-accuracy-4-gpus
@@ -22,7 +22,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- FLASHINFER=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: DP EP Distributed NixlConnector PD accuracy tests (4 GPUs)
@@ -34,7 +34,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- DP_EP=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: CrossLayer KV layout Distributed NixlConnector PD accuracy tests (4 GPUs)
@@ -46,7 +46,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- CROSS_LAYERS_BLOCKS=True bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: Hybrid SSM NixlConnector PD accuracy tests (4 GPUs)
@@ -58,7 +58,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
- label: MultiConnector (Nixl+Offloading) PD accuracy (2 GPUs)
@@ -73,7 +73,7 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/run_multi_connector_accuracy_test.sh
- label: NixlConnector PD + Spec Decode acceptance (2 GPUs)
@@ -87,7 +87,7 @@ steps:
- vllm/v1/worker/kv_connector_model_runner_mixin.py
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/config_sweep_spec_decode_test.sh
- label: MultiConnector (Nixl+Offloading) PD edge cases (2 GPUs)
@@ -102,5 +102,5 @@ steps:
- vllm/distributed/kv_transfer/kv_connector/v1/offloading/
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- bash v1/kv_connector/nixl_integration/run_multi_connector_edge_case_test.sh
+6 -8
View File
@@ -11,7 +11,7 @@ steps:
- tests/entrypoints/
commands:
- pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/rpc --ignore=entrypoints/sleep --ignore=entrypoints/serve/instrumentator --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: Entrypoints Integration (LLM)
key: entrypoints-integration-llm
@@ -61,10 +61,12 @@ steps:
- tests/entrypoints/openai
- tests/entrypoints/test_chat_utils
- tests/entrypoints/generate
- tests/tool_use
commands:
- pytest -v -s entrypoints/openai/completion --ignore=entrypoints/openai/completion/test_tensorizer_entrypoint.py
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/generate
- pytest -v -s tool_use
mirror:
amd:
device: mi325_1
@@ -100,14 +102,11 @@ steps:
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/
- tests/entrypoints/rpc
- tests/entrypoints/serve/instrumentator
- tests/tool_use
- tests/entrypoints/serve
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/serve/instrumentator
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/rpc
- pytest -v -s tool_use
- pytest -v -s entrypoints/serve --ignore=entrypoints/serve/dev/rpc
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/serve/dev/rpc
mirror:
amd:
device: mi325_1
@@ -155,6 +154,5 @@ steps:
source_file_dependencies:
- csrc/
- vllm/entrypoints/openai/
- vllm/model_executor/models/whisper.py
commands: # LMEval
- pytest -s entrypoints/openai/correctness/
+2 -1
View File
@@ -86,7 +86,7 @@ steps:
- tests/v1/metrics
- tests/entrypoints/openai/correctness/test_lmeval.py
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash /vllm-workspace/.buildkite/scripts/install-kv-connectors.sh
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core
@@ -281,6 +281,7 @@ steps:
- vllm/model_executor/layers/quantization/quark/
- vllm/multimodal/
- vllm/outputs.py
- vllm/parser/
- vllm/platforms/
- vllm/pooling_params.py
- vllm/ray/
@@ -94,11 +94,13 @@ steps:
- vllm/v1/worker/gpu_worker.py
- tests/distributed/test_pipeline_parallel.py
- tests/distributed/test_pp_cudagraph.py
- tests/v1/distributed/test_pp_dp_v2.py
commands:
- set -x
- export VLLM_USE_V2_MODEL_RUNNER=1
- pytest -v -s distributed/test_pipeline_parallel.py -k "not ray and not Jamba"
- pytest -v -s distributed/test_pp_cudagraph.py -k "not ray"
- pytest -v -s v1/distributed/test_pp_dp_v2.py
- label: Model Runner V2 Spec Decode
device: h200_35gb
+4 -4
View File
@@ -45,19 +45,19 @@ steps:
- vllm/entrypoints/serve/
- vllm/v1/engine/
- tests/utils.py
# - tests/entrypoints/rpc/test_collective_rpc.py
# - tests/entrypoints/serve/dev/rpc/test_collective_rpc.py
- tests/entrypoints/serve/disagg/test_serving_tokens.py
- tests/entrypoints/serve/instrumentator/test_basic.py
- tests/entrypoints/serve/instrumentator/test_metrics.py
# - tests/entrypoints/serve/instrumentator/test_sleep.py
# - tests/entrypoints/serve/dev/test_sleep.py
commands:
- export VLLM_USE_RUST_FRONTEND=1
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# - pytest -v -s entrypoints/rpc/test_collective_rpc.py
# - pytest -v -s entrypoints/serve/dev/rpc/test_collective_rpc.py
- pytest -v -s entrypoints/serve/instrumentator/test_basic.py -k "not show_version and not server_load"
- pytest -v -s entrypoints/serve/disagg/test_serving_tokens.py -k "not stream and not lora and not test_generate_logprobs and not stop_string_workflow"
- pytest -v -s entrypoints/serve/instrumentator/test_metrics.py -k "text and not show and not run_batch and not test_metrics_counts and not test_metrics_exist"
# - pytest -v -s entrypoints/serve/instrumentator/test_sleep.py
# - pytest -v -s entrypoints/serve/dev/test_sleep.py
- label: Rust Frontend Core Correctness
timeout_in_minutes: 30
+7
View File
@@ -33,3 +33,10 @@ share/python-wheels/
*.egg
MANIFEST
rust/target/
# Not needed in Docker builds
docs/
.github/
.pre-commit-config.yaml
.clang-format
.gitattributes
format.sh
+1 -1
View File
@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Add label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
github.rest.issues.addLabels({
+3 -3
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Label issues based on keywords
id: label-step
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
// Configuration: Add new labels and keywords here
@@ -315,7 +315,7 @@ jobs:
- name: CC users for labeled issues
if: steps.label-step.outputs.labels_added != '[]'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
// Configuration: Map labels to GitHub users to CC
@@ -392,7 +392,7 @@ jobs:
- name: Request missing ROCm info from issue author
if: contains(steps.label-step.outputs.labels_added, 'rocm') && contains(toJSON(github.event.issue.labels.*.name), 'bug')
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const body = (context.payload.issue.body || '').toLowerCase();
+2 -2
View File
@@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Update PR description
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const { owner, repo } = context.repo;
@@ -55,7 +55,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Post welcome comment for first-time contributors
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const { owner, repo } = context.repo;
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check PR label and author merge count
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const { data: pr } = await github.rest.pulls.get({
+1 -1
View File
@@ -21,7 +21,7 @@ repos:
rev: v21.1.2
hooks:
- id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
exclude: 'csrc/(moe/topk_softmax_kernels.cu|libtorch_stable/quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
types_or: [c++, cuda]
args: [--style=file, --verbose]
- repo: https://github.com/DavidAnson/markdownlint-cli2
+15 -12
View File
@@ -112,6 +112,8 @@ endif()
#
# spinloop extension (pure CXX; must stay above the non-CUDA device branch so
# CPU builds define the target before the early return)
# This extension requires SABI 3.11 since it relies on Py_buffer support. Loading
# failure is handled gracefully on vLLM side for lower Python versions.
#
set(VLLM_SPINLOOP_EXT_SRC "csrc/spinloop.cpp")
set(SPINLOOP_COMPILE_FLAGS "")
@@ -309,14 +311,9 @@ set(VLLM_EXT_SRC
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
@@ -503,12 +500,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
set(SRCS
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
"csrc/libtorch_stable/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/libtorch_stable/moe/mxfp8_moe/mxfp8_experts_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
else()
@@ -598,7 +595,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
# Add QuickReduce kernels (ROCm-only; not part of stable ABI migration).
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
@@ -633,6 +630,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu"
"csrc/libtorch_stable/pos_encoding_kernels.cu"
@@ -647,7 +646,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/attention/paged_attention_v1.cu"
"csrc/libtorch_stable/attention/paged_attention_v2.cu"
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu")
"csrc/libtorch_stable/cache_kernels.cu"
"csrc/libtorch_stable/cache_kernels_fused.cu"
"csrc/libtorch_stable/custom_all_reduce.cu"
"csrc/libtorch_stable/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -659,7 +661,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu"
"csrc/libtorch_stable/minimax_reduce_rms_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
+24 -1
View File
@@ -369,6 +369,18 @@ else()
add_compile_definitions(-DVLLM_NUMA_DISABLED)
endif()
# check if the pytorch wheel ships libopenblas.so.
set(VLLM_OPENBLAS_LIB "")
if (NOT ENABLE_X86_ISA)
file(GLOB _VLLM_TORCH_OPENBLAS_LIBS
"${TORCH_INSTALL_PREFIX}/lib/libopenblas*.so*")
# Note: we don't link openblas directly to _C extension, as it's available through libtorch.so
if (_VLLM_TORCH_OPENBLAS_LIBS)
list(GET _VLLM_TORCH_OPENBLAS_LIBS 0 VLLM_OPENBLAS_LIB)
message(STATUS "CPU OpenBLAS library: ${VLLM_OPENBLAS_LIB}")
endif()
endif()
#
# Generate CPU attention dispatch header
#
@@ -387,6 +399,7 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/layernorm.cpp"
@@ -410,6 +423,12 @@ if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
${VLLM_EXT_SRC})
endif()
if (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC})
endif()
if(USE_ONEDNN)
set(VLLM_EXT_SRC
"csrc/cpu/dnnl_kernels.cpp"
@@ -418,7 +437,6 @@ endif()
if (ENABLE_X86_ISA)
set(VLLM_EXT_SRC_SGL
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/sgl-kernels/conv.cpp"
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
@@ -430,6 +448,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
set(VLLM_EXT_SRC_AVX512
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
@@ -446,6 +465,7 @@ if (ENABLE_X86_ISA)
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
set(VLLM_EXT_SRC_AVX2
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/cpu_attn.cpp"
@@ -519,6 +539,9 @@ else()
USE_SABI 3
WITH_SOABI
)
if (VLLM_OPENBLAS_LIB)
target_compile_definitions(_C PRIVATE VLLM_HAS_OPENBLAS)
endif()
endif()
message(STATUS "Enabling C extension.")
@@ -31,7 +31,7 @@ endif()
if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
@@ -39,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG bce29425653ec0fbc579d329883030e832d15ada
GIT_TAG dd62dac706b1cf7895bd99b18c6cb7e7e117ee25
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
+31 -10
View File
@@ -14,7 +14,18 @@ import argparse
import os
import shutil
from torch.utils.hipify.hipify_python import hipify
from torch.utils.hipify.hipify_python import get_hip_file_path, hipify
def _expected_hip_build_path(source_abs: str, output_directory: str) -> str:
"""Match torch.utils.hipify.hipify_python.preprocessor fout_path naming."""
rel = os.path.relpath(source_abs, output_directory)
return os.path.abspath(
os.path.join(
output_directory, get_hip_file_path(rel, is_pytorch_extension=True)
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -53,7 +64,11 @@ if __name__ == "__main__":
hipify_result = hipify(
project_directory=args.project_dir,
output_directory=args.output_dir,
header_include_dirs=[],
# Hipify resolves quoted includes next to the including file first; vLLM
# uses paths relative to csrc/ (e.g. "libtorch_stable/torch_utils.h"
# from quantization/w8a8/fp8/*.cu). Without an include root here, those
# headers are never found and are not hipified or rewritten in dependents.
header_include_dirs=["."],
includes=includes,
extra_files=extra_files,
show_detailed=True,
@@ -64,14 +79,20 @@ if __name__ == "__main__":
hipified_sources = []
for source in args.sources:
s_abs = os.path.abspath(source)
hipified_s_abs = (
hipify_result[s_abs].hipified_path
if (
s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
if s_abs in hipify_result and hipify_result[s_abs].hipified_path is not None:
path = hipify_result[s_abs].hipified_path
# PyTorch skips writing when is_pytorch_extension and text unchanged;
# hipified_path then stays *.cu. CMake expects *.hip under output_dir.
if s_abs.endswith(".cu") and path.endswith(".cu"):
dest = _expected_hip_build_path(s_abs, args.output_dir)
if os.path.normpath(path) != os.path.normpath(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
shutil.copy2(path, dest)
hipified_s_abs = dest
else:
hipified_s_abs = path
else:
hipified_s_abs = s_abs
hipified_sources.append(hipified_s_abs)
assert len(hipified_sources) == len(args.sources)
+8
View File
@@ -81,6 +81,14 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
set_property(GLOBAL APPEND PROPERTY VLLM_HIPIFY_ALL_SRCS ${SRCS})
set_property(GLOBAL APPEND PROPERTY VLLM_HIPIFY_ALL_BYPRODUCTS ${HIP_SRCS})
# Chain hipify targets so they run sequentially. Parallel hipify
# invocations race on shutil.copytree, overwriting .hip files
# produced by another target back to .cu originals.
if (DEFINED _VLLM_LAST_HIPIFY_TARGET)
add_dependencies(hipify${NAME} ${_VLLM_LAST_HIPIFY_TARGET})
endif()
set(_VLLM_LAST_HIPIFY_TARGET "hipify${NAME}" PARENT_SCOPE)
# Swap out original extension sources with hipified sources.
list(APPEND HIP_SRCS ${CXX_SRCS})
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
+49 -1
View File
@@ -30,7 +30,12 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
enum class FusedMOEAct {
SiluAndMul,
SwigluOAIAndMul,
GeluAndMul,
GeluTanhAndMul,
};
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
@@ -39,6 +44,8 @@ FusedMOEAct get_act_type(const std::string& act) {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else if (act == "gelu_tanh") {
return FusedMOEAct::GeluTanhAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -143,6 +150,44 @@ void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_tanh_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride,
const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(0.7978845608028654);
vec_op::FP32Vec16 w2_vec(0.5);
vec_op::FP32Vec16 w3_vec(0.044715);
alignas(64) float temp[16];
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto gate_pow3_vec = gate_vec * gate_vec * gate_vec;
auto inner_vec = w1_vec * (gate_vec + w3_vec * gate_pow3_vec);
inner_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::tanh(temp[i]);
}
vec_op::FP32Vec16 tanh_vec(temp);
auto gelu_tanh = gate_vec * w2_vec * (one_vec + tanh_vec);
auto gated_output_fp32 = up_vec * gelu_tanh;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -160,6 +205,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluTanhAndMul:
gelu_tanh_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}
+106 -1
View File
@@ -89,6 +89,35 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
}
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit FP16Vec16(const void* ptr) {
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit FP16Vec16(bool, const void* ptr) : FP16Vec16(ptr) {}
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const {
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
void save(void* ptr, int elem_num) const {
int num = std::max(0, std::min(elem_num, VEC_ELEM_NUM));
if (num <= 8) {
vec_xst_len(reg.val[0], (signed short*)ptr, num * 2);
} else {
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst_len(reg.val[1], (signed short*)ptr + 8, (num - 8) * 2);
}
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
@@ -100,6 +129,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(bool, const void* ptr) : BF16Vec16(ptr) {}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
@@ -379,6 +410,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(bool, const float* ptr) : FP32Vec16(ptr) {}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
@@ -402,6 +435,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const FP16Vec16& v);
explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
@@ -735,6 +769,40 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
#endif
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
alignas(16) float temp_fp32[16];
alignas(16) c10::Half temp_fp16[16];
vec_xst(v.reg.val[0], 0, temp_fp32);
vec_xst(v.reg.val[1], 16, temp_fp32);
vec_xst(v.reg.val[2], 32, temp_fp32);
vec_xst(v.reg.val[3], 48, temp_fp32);
for (int i = 0; i < 16; i++) {
temp_fp16[i] = c10::Half(temp_fp32[i]);
}
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)temp_fp16);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)temp_fp16);
}
inline FP32Vec16::FP32Vec16(const FP16Vec16& v) {
alignas(16) c10::Half temp_fp16[16];
alignas(16) float temp_fp32[16];
vec_xst(v.reg.val[0], 0, (signed short*)temp_fp16);
vec_xst(v.reg.val[1], 16, (signed short*)temp_fp16);
for (int i = 0; i < 16; i++) {
temp_fp32[i] = float(temp_fp16[i]);
}
reg.val[0] = vec_xl(0, temp_fp32);
reg.val[1] = vec_xl(16, temp_fp32);
reg.val[2] = vec_xl(32, temp_fp32);
reg.val[3] = vec_xl(48, temp_fp32);
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
#ifdef _ARCH_PWR10
__vector signed short ret[4];
@@ -794,6 +862,43 @@ inline void prefetch(const void* addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
}
}; // namespace vec_op
struct INT8Vec64 {
__vector signed char data[4];
INT8Vec64() = default;
explicit INT8Vec64(const int8_t* ptr) {
data[0] = vec_xl(0, ptr);
data[1] = vec_xl(16, ptr);
data[2] = vec_xl(32, ptr);
data[3] = vec_xl(48, ptr);
}
explicit INT8Vec64(bool, const int8_t* ptr) : INT8Vec64(ptr) {}
void save(int8_t* ptr) const {
vec_xst(data[0], 0, ptr);
vec_xst(data[1], 16, ptr);
vec_xst(data[2], 32, ptr);
vec_xst(data[3], 48, ptr);
}
void save(int8_t* ptr, int elem_num) const {
if (elem_num <= 0) return;
int full_vecs = elem_num / 16;
for (int i = 0; i < full_vecs && i < 4; i++) {
vec_xst(data[i], i * 16, ptr);
}
int remaining = elem_num % 16;
if (remaining > 0 && full_vecs < 4) {
vec_xst_len(data[full_vecs], ptr + full_vecs * 16, remaining);
}
}
void nt_save(int8_t* ptr) const { save(ptr); }
};
} // namespace vec_op
#endif
+82
View File
@@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/native/CPUBlas.h>
// Unlike brgemm, PyTorch does not publicly expose at::native::cpublas::gemm
// If OpenBLS is available in the PyTorch wheel, we rely on it for fast
// bf16:bf16->fp32 GEMMs Otherwise, we fall back to PyTorch reference BLAS path.
#if defined(VLLM_HAS_OPENBLAS)
extern "C" void sbgemm_(char* transa, char* transb, int* m, int* n, int* k,
float* alpha, const at::BFloat16* a, int* lda,
const at::BFloat16* b, int* ldb, float* beta, float* c,
int* ldc);
extern "C" void sgemm_(char* transa, char* transb, int* m, int* n, int* k,
float* alpha, const float* a, int* lda, const float* b,
int* ldb, float* beta, float* c, int* ldc);
inline char blas_transpose(at::native::TransposeType trans) {
switch (trans) {
case at::native::TransposeType::NoTranspose:
return 'n';
case at::native::TransposeType::Transpose:
return 't';
case at::native::TransposeType::ConjTranspose:
return 'c';
}
return 'n';
}
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const at::BFloat16* a,
int64_t lda, const at::BFloat16* b, int64_t ldb,
float beta, float* c, int64_t ldc) {
char transa_ = blas_transpose(transa);
char transb_ = blas_transpose(transb);
int m_ = static_cast<int>(m);
int n_ = static_cast<int>(n);
int k_ = static_cast<int>(k);
int lda_ = static_cast<int>(lda);
int ldb_ = static_cast<int>(ldb);
int ldc_ = static_cast<int>(ldc);
sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
c, &ldc_);
}
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const float* a, int64_t lda,
const float* b, int64_t ldb, float beta, float* c,
int64_t ldc) {
char transa_ = blas_transpose(transa);
char transb_ = blas_transpose(transb);
int m_ = static_cast<int>(m);
int n_ = static_cast<int>(n);
int k_ = static_cast<int>(k);
int lda_ = static_cast<int>(lda);
int ldb_ = static_cast<int>(ldb);
int ldc_ = static_cast<int>(ldc);
sgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
c, &ldc_);
}
inline void blas_gemm(at::native::TransposeType, at::native::TransposeType,
int64_t, int64_t, int64_t, float, const at::Half*,
int64_t, const at::Half*, int64_t, float, float*,
int64_t) {
TORCH_CHECK(false, "CPU OpenBLAS hgemm is not available.");
}
#else
template <typename scalar_t>
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const scalar_t* a, int64_t lda,
const scalar_t* b, int64_t ldb, float beta, float* c,
int64_t ldc) {
auto gemm = at::native::cpublas::gemm_no_downcast_stub.DEFAULT;
gemm(c10::CppTypeToScalarType<scalar_t>::value, transa, transb, m, n, k,
at::Scalar(alpha), a, lda, b, ldb, at::Scalar(beta), c, ldc);
}
#endif
+278 -141
View File
@@ -301,25 +301,42 @@ void chunk_gated_delta_rule_kernel_impl(
// attn = k_beta @ key.transpose(-1, -2)
// attn: [B, HV, num_chunk, chunk_size, chunk_size]
// transpose and pack for key
pack_vnni<scalar_t>(
/* dst */ k_transpose,
/* src */ curr_k_pad,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// k_beta @ key.transpose(-1, -2)
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ curr_k_beta,
/* B */ k_transpose,
/* C */ curr_attn);
if constexpr (brgemm_supported()) {
pack_vnni<scalar_t>(
/* dst */ k_transpose,
/* src */ curr_k_pad,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// k_beta @ key.transpose(-1, -2)
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ curr_k_beta,
/* B */ k_transpose,
/* C */ curr_attn);
} else {
blas_gemm(
at::native::TransposeType::Transpose,
at::native::TransposeType::NoTranspose,
chunk_size,
chunk_size,
qk_head_size,
1.0f,
curr_k_pad,
qk_head_size,
curr_k_beta,
qk_head_size,
0.0f,
curr_attn,
chunk_size);
}
// attn = attn * decay_mask
for (int64_t m = 0; m < chunk_size; m++) {
at::vec::map2<float>(
@@ -413,25 +430,42 @@ void chunk_gated_delta_rule_kernel_impl(
// k_beta_g = k_beta * g: [B, HV, num_chunk, chunk_size, EK]
// k_cumdecay: [B, HV, num_chunk, chunk_size, EK]
// pack for value
pack_vnni2<scalar_t>(
/* dst */ v_pack,
/* src */ curr_v_beta,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// value = attn @ v_beta
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ v_pack,
/* C */ curr_value);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ v_pack,
/* src */ curr_v_beta,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// value = attn @ v_beta
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ v_pack,
/* C */ curr_value);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
chunk_size,
1.0f,
curr_v_beta,
v_head_size,
curr_attn_reduced,
chunk_size,
0.0f,
curr_value,
v_head_size);
}
// k_beta_g = k_beta * g.exp().unsqueeze(-1)
for (int64_t j = 0; j < chunk_size; j++) {
int64_t i = 0;
@@ -445,25 +479,42 @@ void chunk_gated_delta_rule_kernel_impl(
}
}
// pack for k_beta_g
pack_vnni2<scalar_t>(
/* dst */ k_beta_g_pack,
/* src */ k_beta_g,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ qk_head_size);
// k_cumdecay = attn @ k_beta_g
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ qk_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ qk_head_size,
/* ldc */ qk_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ k_beta_g_pack,
/* C */ k_cumdecay);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ k_beta_g_pack,
/* src */ k_beta_g,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ qk_head_size);
// k_cumdecay = attn @ k_beta_g
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ qk_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ qk_head_size,
/* ldc */ qk_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ k_beta_g_pack,
/* C */ k_cumdecay);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
qk_head_size,
chunk_size,
chunk_size,
1.0f,
k_beta_g,
qk_head_size,
curr_attn_reduced,
chunk_size,
0.0f,
k_cumdecay,
qk_head_size);
}
for (int i = 0; i < chunk_size; i++) {
at::vec::map<scalar_t>(
[](fVec x) { return x; },
@@ -551,25 +602,42 @@ void chunk_gated_delta_rule_kernel_impl(
// attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
// k_transpose_i = k_i.transpose(-1, -2)
pack_vnni<scalar_t>(
/* dst */ k_transpose_i,
/* src */ k_i,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// attn_i = q_i @ k_transpose_i
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ q_i,
/* B */ k_transpose_i,
/* C */ attn_i);
if constexpr (brgemm_supported()) {
pack_vnni<scalar_t>(
/* dst */ k_transpose_i,
/* src */ k_i,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// attn_i = q_i @ k_transpose_i
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ q_i,
/* B */ k_transpose_i,
/* C */ attn_i);
} else {
blas_gemm(
at::native::TransposeType::Transpose,
at::native::TransposeType::NoTranspose,
chunk_size,
chunk_size,
qk_head_size,
1.0f,
k_i,
qk_head_size,
q_i,
qk_head_size,
0.0f,
attn_i,
chunk_size);
}
// attn_i = attn_i * decay_mask_i
for (int64_t m = 0; m < chunk_size; m++) {
auto attn_i_m = attn_i + m * chunk_size;
@@ -609,28 +677,45 @@ void chunk_gated_delta_rule_kernel_impl(
}
// pack for curr_last_recurrent_state
pack_vnni2<scalar_t>(
/* dst */ curr_last_recurrent_state_pack_reduced,
/* src */ curr_last_recurrent_state_reduced,
/* N */ qk_head_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ curr_last_recurrent_state_pack_reduced,
/* src */ curr_last_recurrent_state_reduced,
/* N */ qk_head_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
// k_cumdecay_i: [chunk_size, EK]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ k_cumdecay_i_reduced,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ v_prime);
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
// k_cumdecay_i: [chunk_size, EK]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ k_cumdecay_i_reduced,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ v_prime);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
qk_head_size,
1.0f,
curr_last_recurrent_state_reduced,
v_head_size,
k_cumdecay_i_reduced,
qk_head_size,
0.0f,
v_prime,
v_head_size);
}
// v_new = v_prime = v_i - v_prime
// v_i: [chunk_size, EV]
@@ -663,41 +748,75 @@ void chunk_gated_delta_rule_kernel_impl(
}
// attn_inter = qg @ curr_last_recurrent_state: [chunk_size, EV]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ qg,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ attn_inter);
if constexpr (brgemm_supported()) {
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ qg,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ attn_inter);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
qk_head_size,
1.0f,
curr_last_recurrent_state_reduced,
v_head_size,
qg,
qk_head_size,
0.0f,
attn_inter,
v_head_size);
}
// core_attn_out[:, :, i] = attn_inter + attn_i @ v_new
// pack for v_prime
pack_vnni2<scalar_t>(
/* dst */ v_prime_pack_reduced,
/* src */ v_prime_reduced,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
// attn_i: [chunk_size, chunk_size]
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ true,
/* A */ attn_i_reduced,
/* B */ v_prime_pack_reduced,
/* C */ attn_inter);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ v_prime_pack_reduced,
/* src */ v_prime_reduced,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
// attn_i: [chunk_size, chunk_size]
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ true,
/* A */ attn_i_reduced,
/* B */ v_prime_pack_reduced,
/* C */ attn_inter);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
chunk_size,
1.0f,
v_prime_reduced,
v_head_size,
attn_i_reduced,
chunk_size,
1.0f,
attn_inter,
v_head_size);
}
// core_attn_out[:, :, i] = attn_inter
for (int64_t m = 0; m < chunk_size; m++) {
@@ -762,17 +881,34 @@ void chunk_gated_delta_rule_kernel_impl(
/* ld_dst */ chunk_size);
// kgv = kg.transpose(-1, -2) @ v_new
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ qk_head_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ kg_transpose,
/* B */ v_prime_pack_reduced,
/* C */ kgv);
if constexpr (brgemm_supported()) {
at::native::cpublas::brgemm(
/* M */ qk_head_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ kg_transpose,
/* B */ v_prime_pack_reduced,
/* C */ kgv);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
qk_head_size,
chunk_size,
1.0f,
v_prime_reduced,
v_head_size,
kg_transpose,
chunk_size,
0.0f,
kgv,
v_head_size);
}
// last_recurrent_state = 1) + 2)
for (int64_t m = 0; m < qk_head_size; m++) {
at::vec::map2<float>(
@@ -921,7 +1057,8 @@ void fused_sigmoid_gating_delta_rule_update_kernel_impl(
float k_scale = use_qk_l2norm_in_kernel ? qk_scale_buf[k_scale_offset] : 1.0f;
int64_t v_offset = si * v_strideS + bi * v_strideB + ni * v_strideH;
int64_t o_offset = ((bi * seq_len + si) * v_num_heads + ni) * v_head_dim;
float beta_val = 1 / (1 + std::exp(-b_ptr[ni]));
// See: https://github.com/sgl-project/sglang/pull/26634
float beta_val = 1 / (1 + std::exp(-b_ptr[bi * v_num_heads + ni]));
fVec beta_vec = fVec(beta_val);
int64_t dvi = 0;
for (; dvi <= v_head_dim - VecSize; dvi += VecSize) {
+18 -7
View File
@@ -4,9 +4,12 @@
// clang-format off
#pragma once
#include <ATen/native/CPUBlas.h>
#include "common.h"
#include "blas_gemm.h"
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endif
// amx-bf16
#define TILE_M 16
@@ -21,31 +24,39 @@ constexpr int block_size_n() {
return 2 * TILE_N;
}
constexpr bool brgemm_supported() {
#if defined(CPU_CAPABILITY_AVX512)
return true;
#else
return false;
#endif
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
return brgemm_supported();
}
// this requires PyTorch 2.7 or above
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<uint8_t>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
// work around compiler internal error
+2
View File
@@ -11,7 +11,9 @@
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#if defined(CPU_CAPABILITY_AVX512)
#include <immintrin.h>
#endif
namespace {
using namespace at::vec;
+10 -8
View File
@@ -5,7 +5,7 @@
#include <sys/stat.h>
#include <unistd.h>
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
#include <atomic>
#endif
@@ -38,7 +38,7 @@ struct KernelVecType<c10::Half> {
};
struct ThreadSHMContext {
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
// memory model is weaker on AArch64, so we use atomic variables for
// consumer (load-acquire) and producer (store-release) to make sure
// that a stamp cannot be ready before the corresponding data is ready.
@@ -75,7 +75,7 @@ struct ThreadSHMContext {
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
TORCH_CHECK((size_t)this % 64 == 0);
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
_curr_thread_stamp[0].store(1, std::memory_order_relaxed);
_curr_thread_stamp[1].store(1, std::memory_order_relaxed);
_ready_thread_stamp[0].store(0, std::memory_order_relaxed);
@@ -124,7 +124,7 @@ struct ThreadSHMContext {
}
char get_curr_stamp(int idx) const {
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
return _curr_thread_stamp[idx].load(std::memory_order_acquire);
#else
return _curr_thread_stamp[idx];
@@ -132,7 +132,7 @@ struct ThreadSHMContext {
}
char get_ready_stamp(int idx) const {
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
return _ready_thread_stamp[idx].load(std::memory_order_acquire);
#else
return _ready_thread_stamp[idx];
@@ -140,7 +140,7 @@ struct ThreadSHMContext {
}
void next_stamp() {
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
_curr_thread_stamp[local_stamp_buffer_idx].fetch_add(
1, std::memory_order_release);
#else
@@ -150,7 +150,7 @@ struct ThreadSHMContext {
}
void commit_ready_stamp() {
#ifdef __aarch64__
#if defined(__aarch64__) || defined(__powerpc64__)
_ready_thread_stamp[local_stamp_buffer_idx].store(
_curr_thread_stamp[local_stamp_buffer_idx].load(
std::memory_order_relaxed),
@@ -186,8 +186,10 @@ struct ThreadSHMContext {
break;
}
++_spinning_count;
#ifdef __aarch64__
#if defined(__aarch64__)
__asm__ __volatile__("yield");
#elif defined(__powerpc64__)
__asm__ __volatile__("or 1,1,1");
#else
_mm_pause();
#endif // __aarch64__
+21 -20
View File
@@ -378,7 +378,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// SHM CCL
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
defined(__powerpc64__)
ops.def(
"init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
"int",
@@ -447,6 +448,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool is_vnni) -> Tensor");
ops.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
// Adapted from sglang: casual_conv1d kernels
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
&causal_conv1d_weight_pack);
ops.def(
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
"conv_states, Tensor? query_start_loc,"
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
"int pad_slot_id, bool is_vnni) -> "
"Tensor");
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
ops.def(
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
"weight, Tensor? bias, bool silu_activation,"
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
"bool is_vnni) -> Tensor");
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
#endif
// Adapted from sglang: GDN kernels
ops.def(
"chunk_gated_delta_rule_cpu(Tensor query, Tensor key, Tensor value, "
@@ -470,25 +490,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> (Tensor, Tensor)");
ops.impl("fused_gdn_gating_cpu", torch::kCPU, &fused_gdn_gating_cpu);
// Adapted from sglang: casual_conv1d kernels
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
&causal_conv1d_weight_pack);
ops.def(
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
"conv_states, Tensor? query_start_loc,"
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
"int pad_slot_id, bool is_vnni) -> "
"Tensor");
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
ops.def(
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
"weight, Tensor? bias, bool silu_activation,"
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
"bool is_vnni) -> Tensor");
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
#endif
// CPU attention kernels
ops.def(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
@@ -1,7 +1,11 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include "custom_all_reduce.cuh"
@@ -11,7 +15,7 @@ using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
@@ -25,9 +29,9 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
fully_connected);
return (fptr_t) new vllm::CustomAllreduce(
ipc_ptrs, rank_data.mutable_data_ptr(), rank_data.numel(), rank,
world_size, fully_connected);
}
/**
@@ -46,10 +50,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
bool _is_weak_contiguous(torch::stable::Tensor& t) {
if (t.is_contiguous()) {
return true;
}
int64_t storage_nbytes = 0;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_storage_size(t.get(), &storage_nbytes));
return storage_nbytes - t.storage_offset() * t.element_size() ==
static_cast<int64_t>(t.numel() * t.element_size());
}
/**
@@ -59,42 +67,45 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t _reg_buffer,
int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
const torch::stable::accelerator::DeviceGuard device_guard(
inp.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(inp.get_device_index());
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
STD_TORCH_CHECK((inp.scalar_type()) == (out.scalar_type()));
STD_TORCH_CHECK((inp.numel()) == (out.numel()));
STD_TORCH_CHECK(_is_weak_contiguous(out));
STD_TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
STD_TORCH_CHECK((input_size) <= (reg_buffer_sz_bytes));
STD_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.const_data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
reg_buffer = inp.mutable_data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
case torch::headeronly::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<float*>(out.mutable_data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
case torch::headeronly::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
reinterpret_cast<half*>(out.mutable_data_ptr()),
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
case torch::headeronly::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
reinterpret_cast<nv_bfloat16*>(out.mutable_data_ptr()), out.numel());
break;
}
#endif
@@ -112,7 +123,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
STD_TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
@@ -143,47 +154,49 @@ void register_graph_buffers(fptr_t _fa,
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
std::tuple<fptr_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
int device_index;
STD_CUDA_CHECK(cudaGetDevice(&device_index));
const torch::stable::accelerator::DeviceGuard device_guard(device_index);
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
const cudaStream_t stream = get_current_cuda_stream(device_index);
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK(
STD_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
#else
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
STD_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
#endif
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
STD_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
STD_CUDA_CHECK(cudaStreamSynchronize(stream));
STD_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
AT_CUDA_CHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
auto handle = torch::stable::empty(
{static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
torch::headeronly::ScalarType::Byte, std::nullopt,
torch::stable::Device(torch::stable::DeviceType::CPU));
STD_CUDA_CHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)handle.mutable_data_ptr(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
fptr_t open_mem_handle(torch::stable::Tensor& mem_handle) {
void* ipc_ptr;
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
STD_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)mem_handle.const_data_ptr()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
STD_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
@@ -2,7 +2,7 @@
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
#include "broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
@@ -28,7 +28,20 @@
* [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include <cuda_fp8.h>
#else
@@ -37,14 +50,6 @@
#include <cuda_runtime.h>
#include <type_traits>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef FINAL_MASK
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
@@ -70,7 +75,7 @@ namespace deepseek_v4_fused_ops {
namespace {
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
auto* props = get_device_prop();
return props->major * 10 + props->minor;
}
} // namespace
@@ -564,7 +569,7 @@ static void launchFusedDeepseekV4Templated(
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
// unavailable there. Refuse the launch loudly instead of silently
// skipping the work.
TORCH_CHECK(
STD_TORCH_CHECK(
sm_version >= 80,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ "
"(Ampere or newer); got sm_",
@@ -635,7 +640,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
DISPATCH(64)
DISPATCH(128)
default:
TORCH_CHECK(false,
STD_TORCH_CHECK(false,
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert: "
"unsupported num_heads_q_padded=",
num_heads_q_padded,
@@ -650,71 +655,80 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
// ────────────────────────────────────────────────────────────────────────────
// Torch op wrapper
// ────────────────────────────────────────────────────────────────────────────
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::Tensor const& slot_mapping, // [N] int64
torch::Tensor const& position_ids, // [N] int64
torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
int64_t q_head_padded, // padded Q head count for output
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, // [N, num_heads_q, 512] bf16
torch::stable::Tensor const& kv, // [N, 512] bf16 (read-only)
torch::stable::Tensor& k_cache, // [num_blocks, block_bytes] uint8
torch::stable::Tensor const& slot_mapping, // [N] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16
int64_t q_head_padded, // padded Q head count for output
double eps, int64_t cache_block_size) {
TORCH_CHECK(q_in.is_cuda() && q_in.is_contiguous(),
"q_in must be contiguous CUDA");
TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA");
TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA");
TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64,
"slot_mapping must be int64 CUDA");
TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64,
"position_ids must be int64 CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
"q_in shape [N, num_heads_q, 512]");
TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
TORCH_CHECK(q_in.dtype() == kv.dtype(), "q_in and kv dtype must match");
TORCH_CHECK(q_head_padded >= q_in.size(1),
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8");
TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32,
"cos_sin_cache must be float32");
STD_TORCH_CHECK(q_in.device().is_cuda() && q_in.is_contiguous(),
"q_in must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() ==
torch::headeronly::ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() ==
torch::headeronly::ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda(), "cos_sin_cache must be CUDA");
STD_TORCH_CHECK(q_in.dim() == 3 && q_in.size(2) == 512,
"q_in shape [N, num_heads_q, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q_in.scalar_type() == kv.scalar_type(),
"q_in and kv dtype must match");
STD_TORCH_CHECK(q_head_padded >= q_in.size(1),
"q_head_padded must be >= q_in.size(1) (num_heads_q)");
STD_TORCH_CHECK(k_cache.scalar_type() == torch::headeronly::ScalarType::Byte,
"k_cache must be uint8");
STD_TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64]");
STD_TORCH_CHECK(cos_sin_cache.scalar_type() ==
torch::headeronly::ScalarType::Float,
"cos_sin_cache must be float32");
// With DP padding, slot_mapping can be shorter than q/kv/positions.
// Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them);
// KV quant+insert runs only on the first slot_mapping.size(0) rows.
int const num_tokens_full = static_cast<int>(q_in.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q_in.size(1));
int const num_heads_q_padded = static_cast<int>(q_head_padded);
int const cache_block_size_i = static_cast<int>(cache_block_size);
int const kv_block_stride = static_cast<int>(k_cache.stride(0));
at::cuda::OptionalCUDAGuard device_guard(device_of(q_in));
auto stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
q_in.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q_in.get_device_index());
// Allocate the padded q output. The kernel writes every element (live
// region gets RMSNorm+RoPE; pad region gets zeros), so `empty` is safe.
torch::Tensor q_out = torch::empty(
{q_in.size(0), q_head_padded, q_in.size(2)}, q_in.options());
auto q_out = torch::stable::new_empty(
q_in, {q_in.size(0), q_head_padded, q_in.size(2)}, q_in.scalar_type());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
q_in.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] {
using qkv_scalar_t = scalar_t;
vllm::deepseek_v4_fused_ops::
launchFusedDeepseekV4QNormRopeKVRopeQuantInsert<qkv_scalar_t>(
reinterpret_cast<qkv_scalar_t const*>(q_in.data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.data_ptr()),
reinterpret_cast<int64_t const*>(slot_mapping.data_ptr()),
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
cos_sin_cache.data_ptr<float>(), static_cast<float>(eps),
reinterpret_cast<qkv_scalar_t const*>(q_in.const_data_ptr()),
reinterpret_cast<qkv_scalar_t*>(q_out.mutable_data_ptr()),
reinterpret_cast<qkv_scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
num_heads_q_padded, cache_block_size_i, kv_block_stride,
stream);
@@ -20,7 +20,7 @@
#include "torch_utils.h"
#include "../async_util.cuh"
#include "async_util.cuh"
#include "../cuda_compat.h"
#include "../type_convert.cuh"
#include "dispatch_utils.h"
@@ -15,16 +15,19 @@
* limitations under the License.
*/
#include "torch_utils.h"
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include "minimax_reduce_rms_kernel.h"
#include <algorithm>
@@ -611,7 +614,7 @@ int get_sm_count() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
CUDA_CHECK(cudaGetDevice(&device_id));
STD_CUDA_CHECK(cudaGetDevice(&device_id));
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
@@ -621,13 +624,13 @@ int get_sm_count() {
inline int getSMVersion(bool queryRealSmArch = false) {
int device{-1};
CUDA_CHECK(cudaGetDevice(&device));
STD_CUDA_CHECK(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device));
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
STD_CUDA_CHECK(cudaDeviceGetAttribute(
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121 && !queryRealSmArch) {
return 120;
@@ -639,7 +642,7 @@ template <typename KernelFunc>
int get_max_active_blocks(KernelFunc kernel, int block_size,
int dynamic_smem = 0) {
int max_active = 0;
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active, kernel, block_size, dynamic_smem));
return std::max(max_active, 1);
}
@@ -678,27 +681,27 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(
STD_CUDA_CHECK(cudaLaunchKernelEx(
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
}
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
void minimax_reduce_rms_kernel_launcher_float4(
MiniMaxReduceRMSParams const& params) {
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q % params.hidden_dim == 0);
STD_TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
if (params.stride_q > 0) {
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
}
TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
STD_TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
STD_TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
STD_TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
STD_TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
if (params.stride_k > 0) {
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
STD_TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
}
int token_num = params.size_q / params.hidden_dim;
@@ -746,7 +749,7 @@ void minimax_reduce_rms_kernel_launcher_float4(
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
STD_CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
}
template <int NRanks>
@@ -759,21 +762,21 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
(params.hidden_dim * params.nranks == 6144) &&
(params.hidden_dim_k * params.nranks == 1024);
if (params.dtype == at::ScalarType::Half) {
if (params.dtype == torch::headeronly::ScalarType::Half) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
params);
} else {
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::BFloat16) {
} else if (params.dtype == torch::headeronly::ScalarType::BFloat16) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
1024>(params);
} else {
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::Float) {
} else if (params.dtype == torch::headeronly::ScalarType::Float) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
params);
@@ -781,7 +784,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
}
} else {
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
STD_TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
}
}
@@ -795,16 +798,18 @@ void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
} else if (params.nranks == 16) {
dispatch_dtype<16>(params);
} else {
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
STD_TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
}
}
} // namespace tensorrt_llm
} // namespace vllm
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps) {
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps) {
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
allreduce_params.nranks = static_cast<int>(nranks);
@@ -815,12 +820,12 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
allreduce_params.stride_q = allreduce_params.hidden_dim;
allreduce_params.workspace =
reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_params.allreduce_in = input.data_ptr();
allreduce_params.rms_gamma = norm_weight.data_ptr();
allreduce_params.allreduce_in = const_cast<void*>(input.const_data_ptr());
allreduce_params.rms_gamma = const_cast<void*>(norm_weight.const_data_ptr());
allreduce_params.rms_eps = static_cast<float>(eps);
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
allreduce_params.stream = get_current_cuda_stream(input.get_device_index());
torch::Tensor rms_norm_out = torch::empty_like(input);
torch::stable::Tensor rms_norm_out = torch::stable::empty_like(input);
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
@@ -828,26 +833,33 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
return rms_norm_out;
}
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
STD_TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
STD_TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
int64_t qkv_dim = qkv.size(-1);
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
STD_TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
STD_TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
const torch::stable::accelerator::DeviceGuard device_guard(
qkv.get_device_index());
int64_t num_tokens = qkv.size(0);
int elem_bytes = qkv.element_size();
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
torch::stable::Tensor q_out =
torch::stable::new_empty(qkv, {num_tokens, q_size}, qkv.scalar_type());
torch::stable::Tensor k_out =
torch::stable::new_empty(qkv, {num_tokens, kv_size}, qkv.scalar_type());
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
params.nranks = static_cast<int>(nranks);
@@ -863,13 +875,14 @@ std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
uint8_t* base =
const_cast<uint8_t*>(static_cast<const uint8_t*>(qkv.const_data_ptr()));
params.allreduce_in = base;
params.allreduce_in_k = base + q_size * elem_bytes;
params.rms_gamma = norm_weight_q.data_ptr();
params.rms_gamma_k = norm_weight_k.data_ptr();
params.rms_gamma = const_cast<void*>(norm_weight_q.const_data_ptr());
params.rms_gamma_k = const_cast<void*>(norm_weight_k.const_data_ptr());
params.rms_eps = static_cast<float>(eps);
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
params.stream = get_current_cuda_stream(qkv.get_device_index());
params.rms_norm_out = q_out.mutable_data_ptr();
params.rms_norm_out_k = k_out.mutable_data_ptr();
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa,
const torch::stable::Tensor& sfb,
torch::stable::Tensor& d,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
STD_TORCH_CHECK(a.dim() == 2,
"a must be a 2D tensor of shape (num_tokens, k)");
STD_TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
STD_TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
STD_TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
STD_TORCH_CHECK(a.stride(1) == 1, "a must be row major");
STD_TORCH_CHECK(b.stride(1) == 1, "b must be column major");
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
auto stream = get_current_cuda_stream(a.get_device_index());
if (d.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", TORCH_BOX(&cutlass_mxfp8_grouped_mm));
}
@@ -4,9 +4,9 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cassert>
#include <iostream>
@@ -15,18 +15,22 @@
#include "cute/tensor.hpp"
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm_pre_compute(
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
torch::stable::Tensor& a_ptrs, torch::stable::Tensor& b_ptrs,
torch::stable::Tensor& sfa_ptrs, torch::stable::Tensor& sfb_ptrs,
torch::stable::Tensor& d_ptrs, torch::stable::Tensor& stride_a,
torch::stable::Tensor& stride_b, torch::stable::Tensor& stride_d,
torch::stable::Tensor& layout_sfa, torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
const torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
using ElementA = typename OffsetFunctor::ElementA;
using ElementB = typename OffsetFunctor::ElementB;
@@ -42,10 +46,10 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
using StrideB = typename StrideFunctor::StrideB;
using StrideD = typename StrideFunctor::StrideD;
int num_experts = (int)expert_offsets.size(0);
TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
int num_experts = static_cast<int>(expert_offsets.size(0));
STD_TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
OffsetFunctor offset_functor(
reinterpret_cast<int*>(expert_offsets.data_ptr()),
@@ -72,13 +76,18 @@ void cutlass_mxfp8_grouped_mm_pre_compute(
}
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm(
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, cudaStream_t stream) {
void cutlass_mxfp8_grouped_mm(const torch::stable::Tensor& a_ptrs,
const torch::stable::Tensor& b_ptrs,
const torch::stable::Tensor& sfa_ptrs,
const torch::stable::Tensor& sfb_ptrs,
const torch::stable::Tensor& d_ptrs,
const torch::stable::Tensor& stride_a,
const torch::stable::Tensor& stride_b,
const torch::stable::Tensor& stride_d,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& problem_sizes,
cudaStream_t stream) {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
@@ -93,13 +102,12 @@ void cutlass_mxfp8_grouped_mm(
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
hw_info.device_id = d_ptrs.get_device_index();
hw_info.sm_count = get_device_prop()->multiProcessorCount;
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
int num_experts = (int)problem_sizes.size(0);
int num_experts = static_cast<int>(problem_sizes.size(0));
UnderlyingProblemShape* underlying_problem_shape =
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
@@ -127,44 +135,55 @@ void cutlass_mxfp8_grouped_mm(
Gemm gemm;
auto can_implement_status = gemm.can_implement(arguments);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
STD_TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
torch::TensorOptions options_uint8 =
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
size_t workspace_size = gemm.get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
torch::stable::Tensor workspace = torch::stable::empty(
{static_cast<int64_t>(workspace_size)},
torch::headeronly::ScalarType::Byte, std::nullopt, d_ptrs.device());
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM");
status = gemm.run(stream, nullptr, true); // Enable PDL
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = (int)problem_sizes.size(0);
torch::TensorOptions options_int64 =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::TensorOptions options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
const torch::stable::Tensor& a, const torch::stable::Tensor& b,
const torch::stable::Tensor& sfa, const torch::stable::Tensor& sfb,
torch::stable::Tensor& d, const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = static_cast<int>(problem_sizes.size(0));
auto device = a.device();
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfa_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor sfb_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor d_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
torch::stable::Tensor stride_a = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_b = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor stride_d = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor layout_sfa =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
torch::stable::Tensor layout_sfb =
torch::stable::empty({num_experts, 5}, torch::headeronly::ScalarType::Int,
std::nullopt, device);
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
@@ -176,4 +195,4 @@ void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
layout_sfa, layout_sfb, problem_sizes, stream);
}
} // namespace expert_specialization
} // namespace expert_specialization
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
STD_TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
STD_TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
STD_TORCH_CHECK(input.stride(1) == 1, "input must be row major");
STD_TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
expert_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"expert_offsets must be int32");
STD_TORCH_CHECK(
blockscale_offsets.scalar_type() == torch::headeronly::ScalarType::Int,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
STD_TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
STD_TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
if (input.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.scalar_type() == torch::headeronly::ScalarType::Half) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
STD_TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
STD_TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
// Registered here (not torch_bindings.cpp) because ENABLE_ES_MXFP8_GROUPED_MM
// is applied only under COMPILE_LANGUAGE:CUDA.
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("mxfp8_experts_quant", TORCH_BOX(&mxfp8_experts_quant));
}
@@ -4,16 +4,19 @@
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/macros.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>
#include <cuda/ptx>
#include "cute/tensor.hpp"
#include "libtorch_stable/torch_utils.h"
namespace expert_specialization {
@@ -356,12 +359,12 @@ __global__ void mxfp8_experts_quant_kernel(
}
template <typename T_IN>
void launch_mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
void launch_mxfp8_experts_quant(const torch::stable::Tensor& input,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& blockscale_offsets,
torch::stable::Tensor& quant_output,
torch::stable::Tensor& scale_factor) {
ThrLayout thr_layout{};
ValLayout val_layout{};
SfR2SThrLayout r2s_thr_layout{};
@@ -386,19 +389,18 @@ void launch_mxfp8_experts_quant(const torch::Tensor& input,
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
int max_active_blocks_per_sm = -1;
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
STD_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_per_sm,
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g),
decltype(tiled_copy_r2s)>,
THREAD_BLOCK_SIZE, 0));
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
max_active_blocks_per_sm,
dim3 grid(get_device_prop()->multiProcessorCount * max_active_blocks_per_sm,
1, 1);
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
int num_experts = (int)problem_sizes.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
int num_experts = static_cast<int>(problem_sizes.size(0));
auto stream = get_current_cuda_stream(input.get_device_index());
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
<<<grid, block, 0, stream>>>(
+45 -4
View File
@@ -3,10 +3,6 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm);
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
@@ -28,6 +24,10 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
int64_t group_size, double eps, double int8_min,
double int8_max);
#ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
@@ -231,6 +231,27 @@ void fused_qk_norm_rope(torch::stable::Tensor& qkv, int64_t num_heads_q,
torch::stable::Tensor& position_ids,
int64_t forced_token_heads_per_warp);
torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& q_in, torch::stable::Tensor const& kv,
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
double eps, int64_t cache_block_size);
#ifndef USE_ROCM
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
torch::stable::Tensor const& norm_weight, torch::stable::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps);
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
minimax_allreduce_rms_qk(torch::stable::Tensor qkv,
torch::stable::Tensor const& norm_weight_q,
torch::stable::Tensor const& norm_weight_k,
torch::stable::Tensor workspace, int64_t const q_size,
int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
// Sampler kernels (shared CUDA/ROCm)
void apply_repetition_penalties_(
torch::stable::Tensor& logits, const torch::stable::Tensor& prompt_mask,
@@ -273,6 +294,26 @@ void selective_scan_fwd(
const std::optional<torch::stable::Tensor>& cu_chunk_seqlen,
const std::optional<torch::stable::Tensor>& last_chunk_indices);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::stable::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::stable::Tensor& inp,
torch::stable::Tensor& out, fptr_t reg_buffer,
int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::stable::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::stable::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
// Activation kernels (shared CUDA/ROCm)
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void silu_and_mul_clamp(torch::stable::Tensor& out,
@@ -25,7 +25,7 @@
#include <cuda_fp8.h>
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
#include "libtorch_stable/launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
@@ -34,7 +34,7 @@
static_assert(CVT_FP4_ELTS_PER_THREAD == 16,
"MXFP4 experts quant requires PACK16 mode (CUDA >= 12.9)");
#include "launch_bounds_utils.h"
#include "libtorch_stable/launch_bounds_utils.h"
namespace vllm {
@@ -26,7 +26,7 @@
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
#include "launch_bounds_utils.h"
#include "libtorch_stable/launch_bounds_utils.h"
namespace vllm {
@@ -26,7 +26,7 @@
#include "../../cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
#include "libtorch_stable/launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
@@ -7,14 +7,11 @@
#include <torch/csrc/stable/ops.h>
// NOTE: These headers are intentionally kept in csrc/quantization/gguf/ (not
// moved to libtorch_stable) to avoid unnecessary reformatting that would break
// git rename detection and pollute blame history.
#include "../../../quantization/gguf/ggml-common.h"
#include "../../../quantization/gguf/vecdotq.cuh"
#include "../../../quantization/gguf/dequantize.cuh"
#include "../../../quantization/gguf/mmvq.cuh"
#include "../../../quantization/gguf/mmq.cuh"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
#include "moe_vec.cuh"
@@ -7,7 +7,11 @@
#include <cmath>
#include <cuda_fp8.h>
#ifdef USE_ROCM
#include <hip/hip_fp8.h>
#else
#include <cuda_fp8.h>
#endif
#include "libtorch_stable/quantization/vectorization.cuh"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
@@ -15,12 +19,23 @@
#include "libtorch_stable/torch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val) {
#ifdef USE_ROCM
// 16-thread logical groups may pack up to four per 64-lane wavefront; use a
// 64-bit mask and explicit width so shuffles stay within each group.
const int lane_in_wave = threadIdx.x % warpSize;
const unsigned long long mask = 0xFFFFull << ((lane_in_wave / 16) * 16);
val = fmaxf(val, __shfl_xor_sync(mask, val, 8, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2, 16));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1, 16));
#else
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
#endif
return val;
}
@@ -103,10 +118,18 @@ __device__ __forceinline__ float LoadRegisterGroupAndComputeAbsmax(
}
__device__ __forceinline__ float GroupReduceMax8(float val) {
#ifdef USE_ROCM
const int lane_in_wave = threadIdx.x % warpSize;
const unsigned long long mask = 0xFFull << (lane_in_wave & ~7);
val = fmaxf(val, __shfl_xor_sync(mask, val, 4, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1, 8));
#else
unsigned mask = 0xffu << (threadIdx.x & 24u);
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
#endif
return val;
}
@@ -684,15 +707,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "per_token_group_quant_8bit_packed_register", ([&] {
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
LAUNCH_REG_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == torch::headeronly::ScalarType::Char) {
if (dst_type == torch::headeronly::ScalarType::Char) {
LAUNCH_REG_KERNEL(scalar_t, int8_t);
} else {
STD_TORCH_CHECK(
false,
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
"outputs.");
VLLM_STABLE_DISPATCH_FP8_TYPES(
dst_type, "per_token_group_quant_8bit_packed_fp8",
([&] { LAUNCH_REG_KERNEL(scalar_t, fp8_t); }));
}
}));
+1 -1
View File
@@ -7,7 +7,7 @@
#include "torch_utils.h"
#ifndef USE_ROCM
#include "../persistent_topk.cuh"
#include "persistent_topk.cuh"
#endif
namespace {
+73 -10
View File
@@ -7,11 +7,6 @@
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
@@ -32,6 +27,11 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
@@ -337,6 +337,24 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input, Tensor norm_weight, Tensor workspace, "
"int rank, int nranks, float eps) -> Tensor");
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv, Tensor norm_weight_q, Tensor norm_weight_k, "
"Tensor workspace, int q_size, int kv_size, int rank, int nranks, "
"float eps) -> (Tensor, Tensor)");
#endif
// Apply repetition penalties to logits in-place.
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
@@ -508,11 +526,6 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
#ifndef USE_ROCM
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// Per-token group quantization
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
ops.impl("per_token_group_fp8_quant_packed",
@@ -520,6 +533,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
#ifndef USE_ROCM
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
@@ -571,6 +589,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
#ifndef USE_ROCM
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
#endif
// Sampler kernels (shared CUDA/ROCm)
ops.impl("apply_repetition_penalties_",
@@ -725,6 +749,45 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
}
STABLE_TORCH_LIBRARY_FRAGMENT(_C_custom_ar, custom_ar) {
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.def("dispose(int fa) -> ()");
custom_ar.def("meta_size() -> int");
custom_ar.def("register_buffer(int fa, int[] ipc_tensors) -> ()");
custom_ar.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
custom_ar.def(
"register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
custom_ar.def("allocate_shared_buffer_and_handle(int size) -> (int, Tensor)");
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int");
custom_ar.def("free_shared_buffer(int ptr) -> ()");
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CUDA, custom_ar) {
custom_ar.impl("init_custom_ar", TORCH_BOX(&init_custom_ar));
custom_ar.impl("all_reduce", TORCH_BOX(&all_reduce));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CPU, custom_ar) {
custom_ar.impl("open_mem_handle", TORCH_BOX(&open_mem_handle));
}
STABLE_TORCH_LIBRARY_IMPL(_C_custom_ar, CompositeExplicitAutograd, custom_ar) {
custom_ar.impl("dispose", TORCH_BOX(&dispose));
custom_ar.impl("meta_size", TORCH_BOX(&meta_size));
custom_ar.impl("register_buffer", TORCH_BOX(&register_buffer));
custom_ar.impl("get_graph_buffer_ipc_meta",
TORCH_BOX(&get_graph_buffer_ipc_meta));
custom_ar.impl("register_graph_buffers", TORCH_BOX(&register_graph_buffers));
custom_ar.impl("allocate_shared_buffer_and_handle",
TORCH_BOX(&allocate_shared_buffer_and_handle));
custom_ar.impl("free_shared_buffer", TORCH_BOX(&free_shared_buffer));
}
STABLE_TORCH_LIBRARY_IMPL(_C_cache_ops, CPU, ops) {
ops.impl("swap_blocks_batch", TORCH_BOX(&swap_blocks_batch));
}
+1 -5
View File
@@ -6,11 +6,7 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#ifndef USE_ROCM
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <deque>
+2 -2
View File
@@ -19,7 +19,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
#include <torch/headeronly/core/ScalarType.h>
namespace vllm {
namespace tensorrt_llm {
@@ -51,7 +51,7 @@ static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
struct MiniMaxReduceRMSParams {
int nranks{};
int rank{};
at::ScalarType dtype{at::ScalarType::Undefined};
torch::headeronly::ScalarType dtype{torch::headeronly::ScalarType::Undefined};
int size_q{};
int hidden_dim{};
int size_k{};
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/all.h>
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
auto stream = at::cuda::getCurrentCUDAStream();
if (d.dtype() == torch::kBFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.dtype() == torch::kFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
}
-60
View File
@@ -1,60 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/all.h>
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
auto stream = at::cuda::getCurrentCUDAStream();
if (input.dtype() == torch::kBFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.dtype() == torch::kFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
}
-36
View File
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
int64_t cache_block_size);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
@@ -107,24 +101,6 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
int64_t activation_kind);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
@@ -135,15 +111,3 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#ifndef USE_ROCM
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps);
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
+20 -78
View File
@@ -55,14 +55,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
// kernel launch. Registered in _C_stable_libtorch.
// Quantization ops
#ifndef USE_ROCM
@@ -163,34 +156,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl registration is in source file
#endif
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file
#endif
}
#ifdef USE_ROCM
TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Quick Reduce all-reduce kernels (ROCm-only; stays on legacy _C).
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
custom_ar.def("qr_max_size", &qr_max_size);
}
#endif
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
@@ -205,48 +191,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&get_max_shared_memory_per_block_device_attribute);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.def("allocate_shared_buffer_and_handle",
&allocate_shared_buffer_and_handle);
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
#ifdef USE_ROCM
// Quick Reduce all-reduce kernels
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
custom_ar.def("qr_max_size", &qr_max_size);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
+2 -2
View File
@@ -50,7 +50,7 @@ struct _typeConvert<float> {
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
struct _typeConvert<torch::headeronly::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
@@ -73,7 +73,7 @@ struct _typeConvert<c10::Half> {
// CUDA_ARCH < 800 does not have BF16 support
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
struct _typeConvert<torch::headeronly::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
+2 -2
View File
@@ -757,10 +757,10 @@ RUN --mount=type=cache,target=/opt/uv/cache \
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
ARG FLASHINFER_VERSION=0.6.11.post2
ARG FLASHINFER_VERSION=0.6.12
RUN --mount=type=cache,target=/opt/uv/cache \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
--index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# ============================================================
# OPENAI API SERVER DEPENDENCIES
+2 -2
View File
@@ -256,13 +256,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.6.11.post2
# release version: v0.6.12
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --depth 1 --branch v0.6.11.post2 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& git clone --depth 1 --branch v0.6.12 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
+200 -108
View File
@@ -2,6 +2,7 @@
ARG REMOTE_VLLM="0"
ARG COMMON_WORKDIR=/app
ARG BASE_IMAGE=rocm/vllm-dev:base
ARG CI_BASE_IMAGE=rocm/vllm-dev:ci_base
# NIC backend for MoRI RDMA support.
# By default (all), drivers and userspace libraries for all supported NIC types
# (ainic and bnxt) are installed; MoRI selects the appropriate one at runtime.
@@ -16,7 +17,8 @@ ARG NIC_BACKEND=all
ARG AINIC_VERSION=1.117.3-hydra
ARG UBUNTU_CODENAME=jammy
# Sccache configuration (only used in release pipeline)
# Sccache configuration. Release builds use this today; CI can opt in when a
# shared S3-compatible cache backend is available.
ARG USE_SCCACHE
ARG SCCACHE_DOWNLOAD_URL
ARG SCCACHE_ENDPOINT
@@ -29,12 +31,16 @@ FROM ${BASE_IMAGE} AS base
ARG ARG_PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
# Install some basic utilities
# Install build dependencies and utilities
RUN apt-get update -q -y && apt-get install -q -y \
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \
apt-transport-https ca-certificates wget curl \
libnuma-dev
RUN python3 -m pip install --upgrade pip
libnuma-dev ccache mold
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip
# Note: mold is installed but not set as the system default linker because
# some packages use JIT compilation at runtime with flags mold does not support.
# Build stages opt in via LDFLAGS="-fuse-ld=mold".
# Remove sccache only if not using sccache (it exists in base image from Dockerfile.rocm_base)
ARG USE_SCCACHE
RUN if [ "$USE_SCCACHE" != "1" ]; then \
@@ -55,6 +61,12 @@ ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
# ccache directory - persisted across layer rebuilds via cache mounts.
ENV CCACHE_DIR=/root/.cache/ccache
ENV CCACHE_COMPILERCHECK=content
# Empty by default so build steps fall back to $(nproc); CI can override.
ARG max_jobs
ENV MAX_JOBS=${max_jobs}
# Install sccache if USE_SCCACHE is enabled (for release builds)
ARG USE_SCCACHE
@@ -86,6 +98,7 @@ RUN if [ "$USE_SCCACHE" = "1" ]; then \
ARG USE_SCCACHE
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
ENV SCCACHE_ENDPOINT=${USE_SCCACHE:+${SCCACHE_ENDPOINT}}
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
@@ -114,8 +127,7 @@ FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
# -----------------------
# Rust build stage
# Builds the `vllm-rs` frontend in a dedicated stage so the wheel build stages
# don't need the rust toolchain or protoc. Runs in parallel with the main wheel
# build for faster end-to-end builds.
# don't need the rust toolchain or protoc.
FROM fetch_vllm AS rust-build
ARG COMMON_WORKDIR
@@ -144,24 +156,74 @@ ENV RUSTUP_MAX_RETRIES=10
# layer for later COPY --from=rust-build.
RUN --mount=type=cache,id=vllm-rocm-cargo-registry,target=/root/.cargo/registry,sharing=locked \
--mount=type=cache,id=vllm-rocm-cargo-git,target=/root/.cargo/git,sharing=locked \
--mount=type=cache,id=vllm-rocm-cargo-target,target=${COMMON_WORKDIR}/vllm/rust/target,sharing=locked \
cd ${COMMON_WORKDIR}/vllm \
&& VLLM_RS_TARGET_PATH=/tmp/vllm-rs bash build_rust.sh \
&& test -x /tmp/vllm-rs
# -----------------------
# vLLM build stages
# vLLM native build stages
#
# csrc-build intentionally copies only files that affect ROCm native extension
# compilation. That keeps unrelated CI/test/docs edits from invalidating the
# expensive HIP/C++ build layer.
FROM base AS csrc-build
ARG COMMON_WORKDIR
WORKDIR ${COMMON_WORKDIR}/vllm
COPY requirements/rocm.txt requirements/rocm.txt
COPY requirements/common.txt requirements/common.txt
RUN --mount=type=cache,id=vllm-rocm-uv,target=/root/.cache/uv \
uv pip install --system -r requirements/rocm.txt
# pyproject.toml is bind-mounted in the RUN step so metadata-only changes do
# not invalidate the expensive native build layer.
COPY setup.py CMakeLists.txt ./
COPY cmake cmake/
COPY csrc csrc/
COPY vllm/envs.py vllm/envs.py
COPY vllm/__init__.py vllm/__init__.py
ENV VLLM_TARGET_DEVICE=rocm
ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+rocm.csrc.build"
RUN --mount=type=bind,source=pyproject.toml,target=${COMMON_WORKDIR}/vllm/pyproject.toml \
--mount=type=cache,id=vllm-rocm-ccache,target=/root/.cache/ccache \
export CCACHE_BASEDIR="$PWD" \
&& echo "=== ccache stats before ROCm native build ===" \
&& (ccache --show-stats || true) \
&& (ccache --zero-stats || true) \
&& EFFECTIVE_MAX_JOBS="${MAX_JOBS:-$(nproc)}" \
&& echo "Building ROCm native extension wheel with MAX_JOBS=${EFFECTIVE_MAX_JOBS}" \
&& LDFLAGS="-fuse-ld=mold" MAX_JOBS="${EFFECTIVE_MAX_JOBS}" python3 setup.py bdist_wheel --dist-dir=dist \
&& test -d dist \
&& ls dist/*.whl >/dev/null \
&& echo "=== ccache stats after ROCm native build ===" \
&& (ccache --show-stats || true)
# Build the full vLLM ROCm wheel by reusing the native extension wheel from
# csrc-build. This stage still rebuilds for Python/package changes, but skips
# the expensive HIP/C++ compile when native inputs are unchanged.
FROM fetch_vllm AS build_vllm
ARG COMMON_WORKDIR
ENV VLLM_TARGET_DEVICE=rocm
COPY --from=csrc-build ${COMMON_WORKDIR}/vllm/dist /precompiled-wheels
# Drop the pre-built rust frontend binary into the source tree. setup.py
# detects it and ships it as-is, skipping the local cargo build.
COPY --from=rust-build /tmp/vllm-rs ${COMMON_WORKDIR}/vllm/vllm/vllm-rs
# Build vLLM (setup.py auto-detects sccache in PATH)
RUN cd vllm \
&& python3 -m pip install -r requirements/rocm.txt \
&& python3 setup.py clean --all \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN --mount=type=cache,id=vllm-rocm-uv,target=/root/.cache/uv \
cd vllm \
&& uv pip install --system -r requirements/rocm.txt \
&& export VLLM_USE_PRECOMPILED=1 \
&& export VLLM_PRECOMPILED_WHEEL_LOCATION="$(ls /precompiled-wheels/*.whl)" \
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
&& echo "Packaging vLLM ROCm wheel using precompiled extensions from ${VLLM_PRECOMPILED_WHEEL_LOCATION}" \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& test -d dist \
&& ls dist/*.whl >/dev/null
FROM scratch AS export_vllm
ARG COMMON_WORKDIR
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
@@ -171,6 +233,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/pyproject.toml /pyproject.toml
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# RIXL/UCX build stages
@@ -201,14 +264,17 @@ RUN apt-get -y update && apt-get -y install autoconf libtool pkg-config \
ibverbs-providers \
&& rm -rf /var/lib/apt/lists/*
RUN uv pip install --system meson auditwheel patchelf tomlkit
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system meson auditwheel patchelf tomlkit
RUN cd /usr/local/src && \
RUN --mount=type=cache,target=/root/.cache/ccache \
cd /usr/local/src && \
git clone ${UCX_REPO} && \
cd ucx && \
git checkout ${UCX_BRANCH} && \
./autogen.sh && \
mkdir build && cd build && \
CC="ccache gcc" CXX="ccache g++" \
../configure \
--prefix=/usr/local/ucx \
--enable-shared \
@@ -220,20 +286,22 @@ RUN cd /usr/local/src && \
--with-verbs \
--with-dm \
--enable-mt && \
make -j && \
make -j$(nproc) && \
make install
ENV PATH=/usr/local/ucx/bin:$PATH
ENV LD_LIBRARY_PATH=${UCX_HOME}/lib:${LD_LIBRARY_PATH}
RUN git clone ${RIXL_REPO} /opt/rixl && \
RUN --mount=type=cache,target=/root/.cache/ccache \
git clone ${RIXL_REPO} /opt/rixl && \
cd /opt/rixl && \
git checkout ${RIXL_BRANCH} && \
CC="ccache gcc" CXX="ccache g++" \
meson setup build --prefix=${RIXL_HOME} \
-Ducx_path=${UCX_HOME} \
-Drocm_path=${ROCM_PATH} && \
cd build && \
ninja && \
ninja -j$(nproc) && \
ninja install
# Generate RIXL wheel
@@ -250,30 +318,44 @@ RUN cd /opt/rixl && \
--ucx-plugins-dir ${UCX_HOME}/lib/ucx \
--nixl-plugins-dir ${RIXL_HOME}/lib/x86_64-linux-gnu/plugins
# DeepEP build stage
FROM base AS build_deep
# ROCShmem build stage - split from DeepEP so changing DEEPEP_BRANCH does not
# invalidate the slow ROCShmem build.
FROM base AS build_rocshmem
ARG ROCSHMEM_BRANCH="f0acb0c6"
ARG ROCSHMEM_REPO="https://github.com/ROCm/rocm-systems.git"
ARG DEEPEP_BRANCH="a9ea9774"
ARG DEEPEP_REPO="https://github.com/ROCm/DeepEP.git"
ARG DEEPEP_NIC="cx7"
# DeepEP only supports gfx942 and gfx950; build ROCShmem for the same set so
# it can be linked against DeepEP without arch mismatches.
ARG DEEPEP_ROCM_ARCH="gfx942;gfx950"
ENV ROCM_PATH=/opt/rocm
ENV ROCSHMEM_DIR=/opt/rocshmem
RUN git clone ${ROCSHMEM_REPO} \
RUN --mount=type=cache,target=/root/.cache/ccache \
git clone --no-checkout --filter=blob:none ${ROCSHMEM_REPO} \
&& cd rocm-systems \
&& git sparse-checkout set --cone projects/rocshmem \
&& git checkout ${ROCSHMEM_BRANCH} \
&& mkdir -p projects/rocshmem/build \
&& cd projects/rocshmem/build \
&& INSTALL_PREFIX=${ROCSHMEM_DIR} \
../scripts/build_configs/all_backends -DUSE_EXTERNAL_MPI=OFF
&& CC="ccache gcc" CXX="ccache g++" INSTALL_PREFIX=${ROCSHMEM_DIR} \
bash ../scripts/build_configs/all_backends \
-DROCM_PATH=${ROCM_PATH} \
-DGPU_TARGETS="${DEEPEP_ROCM_ARCH}" \
-DUSE_EXTERNAL_MPI=OFF
# Build DeepEP wheel.
# DeepEP looks for rocshmem at ROCSHMEM_DIR.
RUN git clone ${DEEPEP_REPO} \
# DeepEP build stage - depends on ROCShmem, builds the HIP kernel wheel.
FROM build_rocshmem AS build_deepep
ARG DEEPEP_BRANCH="a9ea9774"
ARG DEEPEP_REPO="https://github.com/ROCm/DeepEP.git"
ARG DEEPEP_NIC="cx7"
# Build DeepEP wheel. DeepEP looks for rocshmem at ROCSHMEM_DIR.
# DeepEP only supports gfx942 and gfx950, so avoid gfx90a in the default list.
RUN --mount=type=cache,target=/root/.cache/ccache \
export PYTORCH_ROCM_ARCH="gfx942;gfx950" \
&& git clone ${DEEPEP_REPO} \
&& cd DeepEP \
&& git checkout ${DEEPEP_BRANCH} \
&& python3 setup.py --variant rocm --rocm-explicit-ctx --nic ${DEEPEP_NIC} bdist_wheel --dist-dir=/app/deep_install
&& LDFLAGS="-fuse-ld=mold" MAX_JOBS="${MAX_JOBS:-$(nproc)}" python3 setup.py --variant rocm --rocm-explicit-ctx --nic ${DEEPEP_NIC} bdist_wheel --dist-dir=/app/deep_install
# MoRI runtime dependencies live in Dockerfile.rocm so NIC backend changes do
# not force users to rebuild the long-lived Dockerfile.rocm_base image.
@@ -372,8 +454,9 @@ RUN if [ "$GIT_REPO_CHECK" != "0" ]; then \
# Extract version from git BEFORE any modifications (pin_rocm_dependencies.py modifies requirements/rocm.txt)
# This ensures setuptools_scm sees clean repo state for version detection
RUN --mount=type=bind,source=.git,target=vllm/.git \
--mount=type=cache,target=/root/.cache/uv \
cd vllm \
&& pip install setuptools_scm regex \
&& uv pip install --system setuptools_scm regex \
&& VLLM_VERSION=$(python3 -c "import setuptools_scm; print(setuptools_scm.get_version())") \
&& echo "Detected vLLM version: ${VLLM_VERSION}" \
&& echo "${VLLM_VERSION}" > /tmp/vllm_version.txt
@@ -409,18 +492,20 @@ RUN echo "Pinning vLLM dependencies to custom wheel versions..." \
&& python3 /tmp/pin_rocm_dependencies.py /install ${COMMON_WORKDIR}/vllm/requirements/rocm.txt
# Install dependencies using custom wheels from /install
RUN cd vllm \
RUN --mount=type=cache,target=/root/.cache/uv \
cd vllm \
&& echo "Building vLLM with custom wheels from /install" \
&& python3 -m pip install --find-links /install -r requirements/rocm.txt \
&& python3 setup.py clean --all
&& uv pip install --system --find-links /install -r requirements/rocm.txt
# Build wheel using pre-extracted version to avoid dirty state from modified requirements/rocm.txt
# (setup.py auto-detects sccache in PATH)
# (setup.py auto-detects ccache/sccache in PATH)
RUN --mount=type=bind,source=.git,target=vllm/.git \
--mount=type=cache,id=vllm-rocm-ccache,target=/root/.cache/ccache \
cd vllm \
&& export CCACHE_BASEDIR="$PWD" \
&& export SETUPTOOLS_SCM_PRETEND_VERSION=$(cat /tmp/vllm_version.txt) \
&& echo "Building wheel with version: ${SETUPTOOLS_SCM_PRETEND_VERSION}" \
&& python3 setup.py bdist_wheel --dist-dir=dist
&& MAX_JOBS="${MAX_JOBS:-$(nproc)}" python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_vllm_wheel_release
ARG COMMON_WORKDIR
@@ -431,112 +516,118 @@ COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/pyproject.toml /pyproject.toml
COPY --from=build_vllm_wheel_release ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
# -----------------------
# Test vLLM image
FROM mori_base AS test
# CI base image (Tier 1) - stable, rarely changing CI dependencies.
# Per-PR test builds pull this as CI_BASE_IMAGE so the test stage only layers
# in the vLLM artifacts for the current commit.
FROM mori_base AS ci_base
ARG COMMON_WORKDIR
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
# Install vLLM using uv (inherited from base stage)
# Note: No -U flag to avoid upgrading PyTorch ROCm to CUDA version
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
--mount=type=cache,target=/root/.cache/uv \
cd /install \
&& uv pip install --system -r requirements/rocm.txt \
&& uv pip install --system -r requirements/test/rocm.txt \
&& pip uninstall -y vllm \
&& uv pip install --system *.whl
# Persist the built wheel in the image so python_only_compile_rocm.sh can
# reinstall it after removing compilers. The bind-mounted /install contents
# above are not available once that RUN step completes.
COPY --from=export_vllm /*.whl /opt/vllm-wheels/
# Update rdma-core to support latest rocshmem
# Update rdma-core to support latest rocshmem.
ARG DEEPEP_NIC
RUN if [ "${DEEPEP_NIC}" = "cx7" ] || [ "${DEEPEP_NIC}" = "io" ]; then \
git clone --branch v62.0 --depth 1 https://github.com/linux-rdma/rdma-core.git /tmp/rdma-core && \
cd /tmp/rdma-core && \
mkdir -p build && cd build && \
cmake -GNinja -DCMAKE_INSTALL_PREFIX=/usr -DNO_MAN_PAGES=1 .. && \
ninja && ninja install && ldconfig && rm -rf /tmp/rdma-core; \
ninja && ninja install && ldconfig && rm -rf /tmp/rdma-core; \
fi
# Install RIXL wheel
# Install RIXL + DeepEP wheels.
RUN --mount=type=bind,from=build_rixl,src=/app/install,target=/rixl_install \
uv pip install --system /rixl_install/*.whl
--mount=type=bind,from=build_deepep,src=/app/deep_install,target=/deep_install \
uv pip install --system /rixl_install/*.whl /deep_install/*.whl
# Install DeepEP wheel
RUN --mount=type=bind,from=build_deep,src=/app/deep_install,target=/deep_install \
uv pip install --system /deep_install/*.whl
COPY --from=build_deep /opt/rocshmem /opt/rocshmem
# Copy ROCShmem runtime libraries.
COPY --from=build_rocshmem /opt/rocshmem /opt/rocshmem
# RIXL/MoRIIO runtime dependencies (RDMA userspace libraries)
RUN apt-get update -q -y && apt-get install -q -y \
# RDMA userspace libraries plus FFmpeg dev libs needed by torchcodec.
RUN apt-get update -q -y && apt-get install -q -y --no-install-recommends \
librdmacm1 \
libibverbs1 \
ibverbs-providers \
ibverbs-utils \
pkg-config ffmpeg libavcodec-dev libavformat-dev libavutil-dev \
libswscale-dev libavdevice-dev libavfilter-dev libswresample-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /vllm-workspace
ARG COMMON_WORKDIR
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
# install development dependencies (for testing)
RUN cd /vllm-workspace \
&& python3 -m pip install -e tests/vllm_test_utils \
&& python3 -m pip install pytest-shard
# enable fast downloads from hf (for testing)
ENV HF_XET_HIGH_PERFORMANCE=1
# increase timeout for hf downloads (for testing)
ENV HF_HUB_DOWNLOAD_TIMEOUT 60
# install audio decode package `torchcodec` from source (required due to
# ROCm and torch version mismatch) for tests with datasets package
# Install torchcodec from source for ROCm/torch ABI compatibility.
COPY tools/install_torchcodec_rocm.sh /tmp/install_torchcodec.sh
RUN bash /tmp/install_torchcodec.sh \
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/pip \
--mount=type=cache,target=/root/.cache/torchcodec-wheels \
bash /tmp/install_torchcodec.sh \
&& rm /tmp/install_torchcodec.sh \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Copy in the v1 package (for python-only install test group)
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Pre-install shared ROCm runtime dependencies.
COPY requirements/common.txt requirements/rocm.txt /tmp/ci-base-requirements/
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r /tmp/ci-base-requirements/rocm.txt \
&& rm -rf /tmp/ci-base-requirements
# Set MIOPEN ENVS to resolve performance regressions in MIOpen 3D convolution kernel
# Enable fast and less brittle model downloads in tests.
ENV HF_XET_HIGH_PERFORMANCE=1
ENV HF_HUB_DOWNLOAD_TIMEOUT=60
# Pre-install vLLM test dependencies.
COPY requirements/test/rocm.txt /tmp/rocm-test-reqs.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r /tmp/rocm-test-reqs.txt
# Rebuild fastsafetensors from source so its C++ extension is compiled with
# USE_ROCM and can detect libamdhip64.so at runtime.
RUN --mount=type=cache,target=/root/.cache/pip \
FASTSAFETENSORS_REQ="$(grep -E '^fastsafetensors(==| @ )' /tmp/rocm-test-reqs.txt | head -1)" \
&& test -n "${FASTSAFETENSORS_REQ}" \
&& python3 -m pip install --force-reinstall --no-deps \
--no-binary fastsafetensors "${FASTSAFETENSORS_REQ}" \
&& rm /tmp/rocm-test-reqs.txt
# Set MIOPEN ENVS to resolve performance regressions in MIOpen 3D convolution kernel.
# See: https://github.com/pytorch/pytorch/issues/169857
ENV MIOPEN_DEBUG_CONV_DIRECT=0
ENV MIOPEN_DEBUG_CONV_GEMM=0
# Use legacy IPC mode for HSA to avoid GPU memory pinning issues with UCX rocm_ipc
# Use legacy IPC mode for HSA to avoid GPU memory pinning issues with UCX rocm_ipc.
# See: https://github.com/ROCm/rocm-libraries/issues/6266
ENV HSA_ENABLE_IPC_MODE_LEGACY=1
# Source code is used in the `python_only_compile.sh` test
# We hide it inside `src/` so that this source code
# will not be imported by other tests
RUN mkdir src && mv vllm src/vllm
# ROCm profiler limits workaround.
RUN echo "ROCTRACER_MAX_EVENTS=10000000" > ${COMMON_WORKDIR}/libkineto.conf
ENV KINETO_CONFIG="${COMMON_WORKDIR}/libkineto.conf"
# This is a workaround to ensure pytest exits with the correct status code in CI tests.
RUN printf '%s\n' \
'import os' \
'' \
'_exit_code = 1' \
'' \
'def pytest_sessionfinish(session, exitstatus):' \
' global _exit_code' \
' _exit_code = int(exitstatus)' \
'' \
'def pytest_unconfigure(config):' \
' import sys' \
' sys.stdout.flush()' \
' sys.stderr.flush()' \
' os._exit(_exit_code)' \
> /vllm-workspace/conftest.py
# Install vllm_test_utils in ci_base for ci_base + wheel parity.
COPY tests/vllm_test_utils /tmp/vllm_test_utils
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system /tmp/vllm_test_utils \
&& rm -rf /tmp/vllm_test_utils
# -----------------------
# Test vLLM image (Tier 2) - vLLM-only layer on top of ci_base.
FROM ${CI_BASE_IMAGE} AS test
ARG COMMON_WORKDIR
# Install the vLLM wheel (--no-deps: all deps already in ci_base).
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
--mount=type=cache,target=/root/.cache/uv \
cd /install \
&& uv pip install --system --no-deps *.whl
# Store the vLLM wheel in the image for python-only install tests.
COPY --from=export_vllm /*.whl /opt/vllm-wheels/
WORKDIR /vllm-workspace
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
# Copy in the v1 package (for python-only install test group).
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# Hide source under src/ so it won't shadow the installed package in tests.
RUN mkdir src && mv vllm src/vllm
# -----------------------
# Final vLLM image
@@ -553,6 +644,7 @@ RUN rm -f /usr/bin/sccache || true \
# This prevents S3 bucket config from leaking into production images
ENV SCCACHE_BUCKET=
ENV SCCACHE_REGION=
ENV SCCACHE_ENDPOINT=
ENV SCCACHE_S3_NO_CREDENTIALS=
ENV SCCACHE_IDLE_TIMEOUT=
+1 -1
View File
@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="v0.1.13"
ARG AITER_BRANCH="v0.1.13.post1"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="v1.1.0"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
+376
View File
@@ -0,0 +1,376 @@
# ci-rocm.hcl - CI-specific configuration for vLLM ROCm Docker builds
#
# This file lives in the vLLM repo at docker/ci-rocm.hcl so ROCm Docker
# build mechanics can evolve with Dockerfile.rocm and docker-bake-rocm.hcl.
# Used with: docker buildx bake -f docker/docker-bake-rocm.hcl -f docker/ci-rocm.hcl test-rocm-ci
#
# Registry cache: Docker Hub (rocm/vllm-ci-cache) is used exclusively.
# AMD build agents already have Docker Hub credentials (they push the test
# image to rocm/vllm-ci), so no additional credential setup is required.
# ROCm CI uses Docker Hub for BuildKit layer cache by default. A separate
# compiler cache can be enabled with USE_SCCACHE=1 when AMD provides a shared
# S3-compatible cache endpoint.
# CI metadata
variable "BUILDKITE_COMMIT" {
default = ""
}
variable "BUILDKITE_BUILD_NUMBER" {
default = ""
}
variable "BUILDKITE_BUILD_ID" {
default = ""
}
variable "PARENT_COMMIT" {
default = ""
}
# Merge-base of HEAD with main - provides a more stable cache fallback than
# parent commit for long-lived PRs. Mirrors the VLLM_MERGE_BASE_COMMIT
# pattern used in the shared ci.hcl file. Auto-computed by ci-bake-rocm.sh
# when unset.
variable "VLLM_MERGE_BASE_COMMIT" {
default = ""
}
# Bridge to vLLM's COMMIT variable for OCI labels
variable "COMMIT" {
default = BUILDKITE_COMMIT
}
# Image tags (set by CI)
variable "IMAGE_TAG" {
default = ""
}
variable "IMAGE_TAG_LATEST" {
default = ""
}
# ROCm-specific GPU architecture targets
variable "PYTORCH_ROCM_ARCH" {
default = "gfx90a;gfx942;gfx950"
}
# Pre-built CI base image (Tier 1). Per-PR builds pull this instead of
# rebuilding RIXL/DeepEP/torchcodec from scratch. The ci_base stage in
# Dockerfile.rocm inherits from base, so CI_BASE_IMAGE only affects the test
# stage and is irrelevant when building --target ci_base itself.
variable "CI_BASE_IMAGE" {
default = "rocm/vllm-dev:ci_base"
}
# Leave CI_MAX_JOBS empty so the Dockerfile falls back to $(nproc) and uses
# the full builder parallelism. Operators can still override this per build.
variable "CI_MAX_JOBS" {
default = ""
}
# Upstream dependency commit pins -- extracted from Dockerfile.rocm by
# ci-bake-rocm.sh at build time. Empty defaults are safe: the cache
# functions produce no entries when the variable is empty.
variable "RIXL_BRANCH" {
default = ""
}
variable "UCX_BRANCH" {
default = ""
}
variable "ROCSHMEM_BRANCH" {
default = ""
}
variable "DEEPEP_BRANCH" {
default = ""
}
variable "RIXL_CACHE_KEY" {
default = ""
}
variable "ROCSHMEM_CACHE_KEY" {
default = ""
}
variable "DEEPEP_CACHE_KEY" {
default = ""
}
# Docker Hub registry cache for AMD builds.
#
# A separate repo (rocm/vllm-ci-cache) is used for BuildKit layer cache.
# Final-image cache exports use mode=min to reduce the volume of data pushed.
# Source-scoped csrc cache exports default to mode=max so fresh workers can
# recover more of the native build graph when ROCm extension inputs change.
# NOTE: mode=min still includes all layers referenced by the final image
# manifest, including inherited base layers (~7.25GB ROCm runtime).
# Docker Hub auto-creates the repo on first push.
#
# Final-image cache stays commit-scoped. Branch-to-branch reuse for the test
# image comes from importing the parent and merge-base commit cache refs.
#
# The source-scoped native cache is exported both per-commit and per-branch so
# ROCm extension rebuilds are shareable within the same commit reruns and across
# consecutive commits on the same branch without depending on a single global
# latest tag.
variable "DOCKERHUB_CACHE_REPO" {
default = "rocm/vllm-ci-cache"
}
variable "DOCKERHUB_CACHE_TO" {
default = ""
}
variable "ROCM_CACHE_BRANCH_TAG" {
default = ""
}
variable "ROCM_CACHE_UPSTREAM_BRANCH_TAG" {
default = ""
}
variable "ROCM_CSRC_CACHE_TO_MODE" {
default = "max"
}
variable "ROCM_FINAL_CACHE_TO_MODE" {
default = "min"
}
# Functions
function "get_cache_from_rocm" {
params = []
result = compact([
# Exact commit hit - fastest cache on re-runs of the same commit
BUILDKITE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-${BUILDKITE_COMMIT}" : "",
# Parent commit - useful cache for incremental changes
PARENT_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-${PARENT_COMMIT}" : "",
# Merge-base with main - stable fallback for long-lived or rebased PRs;
# maps to a real main-branch commit whose cache layers are likely warm
VLLM_MERGE_BASE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-${VLLM_MERGE_BASE_COMMIT}" : "",
# Import the source-scoped native build cache as well so builds whose
# Python/package layers changed can still reuse compiled ROCm objects.
BUILDKITE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${BUILDKITE_COMMIT}" : "",
PARENT_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${PARENT_COMMIT}" : "",
VLLM_MERGE_BASE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${VLLM_MERGE_BASE_COMMIT}" : "",
ROCM_CACHE_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-branch-${ROCM_CACHE_BRANCH_TAG}" : "",
ROCM_CACHE_UPSTREAM_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-branch-${ROCM_CACHE_UPSTREAM_BRANCH_TAG}" : "",
# Branch-scoped full image cache - fallback when parent-commit cache is evicted
ROCM_CACHE_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-branch-${ROCM_CACHE_BRANCH_TAG}" : "",
ROCM_CACHE_UPSTREAM_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-branch-${ROCM_CACHE_UPSTREAM_BRANCH_TAG}" : "",
])
}
function "get_cache_to_rocm" {
params = []
result = compact([
# Commit-scoped cache for exact re-runs.
BUILDKITE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-${BUILDKITE_COMMIT},mode=${ROCM_FINAL_CACHE_TO_MODE}" : "",
# Branch-scoped cache so later commits on the same branch can reuse the full
# image layers when the parent-commit cache is evicted. Unlike the old
# rocm-latest tag (which caused duplicate exporter 400s), this is per-branch.
ROCM_CACHE_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocm-branch-${ROCM_CACHE_BRANCH_TAG},mode=${ROCM_FINAL_CACHE_TO_MODE}" : "",
])
}
function "get_cache_from_rocm_csrc" {
params = []
result = compact([
BUILDKITE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${BUILDKITE_COMMIT}" : "",
PARENT_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${PARENT_COMMIT}" : "",
VLLM_MERGE_BASE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${VLLM_MERGE_BASE_COMMIT}" : "",
ROCM_CACHE_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-branch-${ROCM_CACHE_BRANCH_TAG}" : "",
ROCM_CACHE_UPSTREAM_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-branch-${ROCM_CACHE_UPSTREAM_BRANCH_TAG}" : "",
])
}
function "get_cache_to_rocm_csrc" {
params = []
result = compact([
# Export the exact-commit native cache for same-commit reruns.
BUILDKITE_COMMIT != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-${BUILDKITE_COMMIT},mode=${ROCM_CSRC_CACHE_TO_MODE}" : "",
# Export the branch-scoped native cache so later commits on the same branch
# can reuse compiled ROCm objects even when the exact parent cache is absent.
ROCM_CACHE_BRANCH_TAG != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:csrc-rocm-branch-${ROCM_CACHE_BRANCH_TAG},mode=${ROCM_CSRC_CACHE_TO_MODE}" : "",
])
}
# Cache functions for upstream dependency stages (RIXL/UCX, ROCShmem, DeepEP).
# These stages are pinned to specific upstream commit hashes, so cache keys use
# those hashes rather than the Buildkite commit. This means the cache persists
# across all vLLM commits as long as the upstream dependency pins don't change.
function "get_cache_from_rocm_deps" {
params = []
result = compact([
RIXL_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rixl-rocm-${RIXL_CACHE_KEY}" : (RIXL_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rixl-rocm-${RIXL_BRANCH}-ucx-${UCX_BRANCH}" : ""),
ROCSHMEM_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocshmem-rocm-${ROCSHMEM_CACHE_KEY}" : (ROCSHMEM_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocshmem-rocm-${ROCSHMEM_BRANCH}" : ""),
DEEPEP_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:deepep-rocm-${DEEPEP_CACHE_KEY}" : (DEEPEP_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:deepep-rocm-${DEEPEP_BRANCH}-rocshmem-${ROCSHMEM_BRANCH}" : ""),
])
}
function "get_cache_to_rocm_rixl" {
params = []
result = compact([
RIXL_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rixl-rocm-${RIXL_CACHE_KEY},mode=min" : (RIXL_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rixl-rocm-${RIXL_BRANCH}-ucx-${UCX_BRANCH},mode=min" : ""),
])
}
function "get_cache_to_rocm_rocshmem" {
params = []
result = compact([
ROCSHMEM_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocshmem-rocm-${ROCSHMEM_CACHE_KEY},mode=min" : (ROCSHMEM_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:rocshmem-rocm-${ROCSHMEM_BRANCH},mode=min" : ""),
])
}
function "get_cache_to_rocm_deepep" {
params = []
result = compact([
DEEPEP_CACHE_KEY != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:deepep-rocm-${DEEPEP_CACHE_KEY},mode=min" : (DEEPEP_BRANCH != "" ? "type=registry,ref=${DOCKERHUB_CACHE_REPO}:deepep-rocm-${DEEPEP_BRANCH}-rocshmem-${ROCSHMEM_BRANCH},mode=min" : ""),
])
}
# CI targets
target "_ci-rocm" {
annotations = [
"manifest:vllm.buildkite.build_number=${BUILDKITE_BUILD_NUMBER}",
"manifest:vllm.buildkite.build_id=${BUILDKITE_BUILD_ID}",
]
args = {
ARG_PYTORCH_ROCM_ARCH = PYTORCH_ROCM_ARCH
CI_BASE_IMAGE = CI_BASE_IMAGE
max_jobs = CI_MAX_JOBS
}
}
target "test-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm", "_labels"]
target = "test"
cache-from = get_cache_from_rocm()
cache-to = get_cache_to_rocm()
tags = compact([
IMAGE_TAG,
IMAGE_TAG_LATEST,
])
output = ["type=registry"]
}
# Cache-only target for the source-scoped ROCm native build stage.
# This persists the csrc-build stage in the registry cache even though the
# final test image only consumes it indirectly while packaging the wheel.
target "csrc-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm"]
target = "csrc-build"
cache-from = get_cache_from_rocm_csrc()
cache-to = get_cache_to_rocm_csrc()
output = ["type=cacheonly"]
}
# Keep wheel export on the same CI graph as the test image build so the
# shared build_vllm/export_vllm stages resolve identically within one bake
# invocation. Without this, export-wheel-rocm uses the plain local target
# args while test-rocm-ci uses CI-only args, which can lead to separate
# cache lineages and inconsistent export_vllm results.
target "export-wheel-rocm" {
inherits = ["_common-rocm", "_ci-rocm"]
target = "export_vllm"
cache-from = get_cache_from_rocm()
cache-to = get_cache_to_rocm()
output = ["type=local,dest=./wheel-export"]
}
# Artifact-only vLLM build. GPU test jobs consume this artifact on top of
# ci_base, avoiding a per-commit multi-GB image push/pull.
group "test-rocm-ci-with-artifacts" {
targets = ["csrc-rocm-ci", "export-wheel-rocm"]
}
# Full test image + wheel export. Kept for fallback/debugging when a pushed
# per-commit image is useful.
group "test-rocm-ci-with-wheel" {
targets = ["csrc-rocm-ci", "test-rocm-ci", "export-wheel-rocm"]
}
# Image tags for the ci_base build. ci-bake-rocm.sh rewrites CI_BASE_IMAGE_TAG
# to the primary tag for this build. Non-nightly builds use a commit-scoped tag
# and also publish a content tag for reuse. NIGHTLY=1 builds on the stable branch
# can additionally set CI_BASE_IMAGE_TAG_STABLE to refresh rocm/vllm-dev:ci_base.
variable "CI_BASE_IMAGE_TAG" {
default = "rocm/vllm-dev:ci_base"
}
variable "CI_BASE_IMAGE_TAG_CONTENT" {
default = ""
}
variable "CI_BASE_IMAGE_TAG_STABLE" {
default = ""
}
# Cache-only targets for upstream dependency stages. These persist each stage
# in the registry cache keyed by its upstream commit hash. When ci_base rebuilds
# (e.g., requirements change), these stages are cache hits if their upstream
# pins haven't changed -- saving ~35min of compilation.
target "rixl-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm"]
target = "build_rixl"
cache-from = get_cache_from_rocm_deps()
cache-to = get_cache_to_rocm_rixl()
output = ["type=cacheonly"]
}
target "rocshmem-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm"]
target = "build_rocshmem"
cache-from = get_cache_from_rocm_deps()
cache-to = get_cache_to_rocm_rocshmem()
output = ["type=cacheonly"]
}
target "deepep-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm"]
target = "build_deepep"
cache-from = get_cache_from_rocm_deps()
cache-to = get_cache_to_rocm_deepep()
output = ["type=cacheonly"]
}
# Builds only the ci_base stage (RIXL, DeepEP, torchcodec, etc.)
# Invoked by the ensure-ci-base step when the content hash of ci_base-affecting
# files drifts from the remote image label. Per-PR builds then pull the result
# as CI_BASE_IMAGE instead of rebuilding those slow layers on every commit.
# Uses inline cache metadata on the ci_base image itself instead of exporting a
# separate registry cache artifact.
target "ci-base-rocm-ci" {
inherits = ["_common-rocm", "_ci-rocm", "_labels"]
target = "ci_base"
cache-from = concat(
compact([
CI_BASE_IMAGE_TAG != "" ? "type=registry,ref=${CI_BASE_IMAGE_TAG}" : "",
CI_BASE_IMAGE_TAG_CONTENT != "" ? "type=registry,ref=${CI_BASE_IMAGE_TAG_CONTENT}" : "",
CI_BASE_IMAGE_TAG_STABLE != "" ? "type=registry,ref=${CI_BASE_IMAGE_TAG_STABLE}" : "",
]),
# Import upstream dependency caches so RIXL/ROCShmem/DeepEP stages
# are cache hits even when ci_base itself needs rebuilding.
get_cache_from_rocm_deps(),
)
cache-to = ["type=inline"]
tags = compact([CI_BASE_IMAGE_TAG, CI_BASE_IMAGE_TAG_CONTENT, CI_BASE_IMAGE_TAG_STABLE])
output = ["type=registry"]
}
# Group for ci_base builds -- exports dependency stage caches alongside the
# ci_base image so future rebuilds can reuse them independently.
group "ci-base-rocm-ci-with-deps" {
targets = ["rixl-rocm-ci", "rocshmem-rocm-ci", "deepep-rocm-ci", "ci-base-rocm-ci"]
}
+143
View File
@@ -0,0 +1,143 @@
# docker-bake-rocm.hcl - vLLM ROCm Docker build configuration
#
# This file lives in the vLLM repo at docker/docker-bake-rocm.hcl
# Equivalent of docker-bake.hcl for ROCm builds.
#
# Usage:
# docker buildx bake -f docker/docker-bake-rocm.hcl # Build test (default)
# docker buildx bake -f docker/docker-bake-rocm.hcl final-rocm # Build final image
# docker buildx bake -f docker/docker-bake-rocm.hcl --print # Show resolved config
#
# CI usage (with the vLLM-owned CI overlay):
# docker buildx bake -f docker/docker-bake-rocm.hcl -f docker/ci-rocm.hcl test-rocm-ci
variable "MAX_JOBS" {
# Empty string lets the Dockerfile fall back to $(nproc) via
# MAX_JOBS="${MAX_JOBS:-$(nproc)}" in each RUN step, which uses all
# available cores on whatever machine the build runs on.
# Override with --set '*.args.max_jobs=8' for local builds on small machines.
default = ""
}
variable "PYTORCH_ROCM_ARCH" {
default = "gfx90a;gfx942;gfx950"
}
variable "COMMIT" {
default = ""
}
# Content hash of ci_base-affecting files. Computed by ci-bake-rocm.sh and
# embedded as a label so future builds can compare without rebuilding.
variable "CI_BASE_CONTENT_HASH" {
default = ""
}
# REMOTE_VLLM=0: use local source via Docker build context (ONBUILD COPY ./ vllm/)
# REMOTE_VLLM=1: clone from GitHub at VLLM_BRANCH (standalone builds without local source)
variable "REMOTE_VLLM" {
default = "0"
}
variable "VLLM_BRANCH" {
default = "main"
}
# CI_BASE_IMAGE: pre-built ci_base image for per-PR test builds.
# Defaults to the local "ci_base" stage for standalone/local builds.
# CI overrides this to "rocm/vllm-dev:ci_base" via environment variable.
variable "CI_BASE_IMAGE" {
default = "rocm/vllm-dev:ci_base"
}
# Upstream dependency commit pins. Plain local bake builds use the Dockerfile
# ARG defaults. ci-bake-rocm.sh resolves those defaults (plus any env
# overrides) and writes a small HCL override before invoking CI targets.
variable "RIXL_BRANCH" {
default = ""
}
variable "UCX_BRANCH" {
default = ""
}
variable "ROCSHMEM_BRANCH" {
default = ""
}
variable "DEEPEP_BRANCH" {
default = ""
}
group "default" {
targets = ["test-rocm"]
}
target "_common-rocm" {
dockerfile = "docker/Dockerfile.rocm"
context = "."
args = {
max_jobs = MAX_JOBS
ARG_PYTORCH_ROCM_ARCH = PYTORCH_ROCM_ARCH
REMOTE_VLLM = REMOTE_VLLM
VLLM_BRANCH = VLLM_BRANCH
CI_BASE_IMAGE = CI_BASE_IMAGE
}
}
target "_labels" {
labels = {
"org.opencontainers.image.source" = "https://github.com/vllm-project/vllm"
"org.opencontainers.image.vendor" = "vLLM"
"org.opencontainers.image.title" = "vLLM ROCm"
"org.opencontainers.image.description" = "vLLM: A high-throughput and memory-efficient inference and serving engine for LLMs (ROCm)"
"org.opencontainers.image.licenses" = "Apache-2.0"
"org.opencontainers.image.revision" = COMMIT
}
annotations = [
"manifest:org.opencontainers.image.revision=${COMMIT}",
]
}
target "test-rocm" {
inherits = ["_common-rocm", "_labels"]
target = "test"
tags = ["rocm/vllm:test"]
output = ["type=docker"]
}
# CI base image target - builds only the ci_base stage (RIXL, DeepEP,
# torchcodec, requirements, etc.). Used by the weekly scheduled build and
# the auto-rebuild trigger when requirements change in a PR.
target "ci-base-rocm" {
inherits = ["_common-rocm", "_labels"]
target = "ci_base"
labels = {
"vllm.ci_base.content_hash" = CI_BASE_CONTENT_HASH
}
tags = ["rocm/vllm-dev:ci_base"]
output = ["type=docker"]
}
# Wheel export target - extracts the built vLLM wheel + test workspace
# to local disk. Used by CI to upload the wheel as a Buildkite artifact
# so test jobs can assemble images locally from ci_base + wheel instead
# of pulling the full large image from Docker Hub.
#
# Usage:
# docker buildx bake -f docker/docker-bake-rocm.hcl export-wheel-rocm
# # Creates ./wheel-export/*.whl, ./wheel-export/requirements/, etc.
#
# After a full bake build, BuildKit cache makes this nearly instant.
target "export-wheel-rocm" {
inherits = ["_common-rocm"]
target = "export_vllm"
output = ["type=local,dest=./wheel-export"]
}
target "final-rocm" {
inherits = ["_common-rocm", "_labels"]
target = "final"
tags = ["rocm/vllm:latest"]
output = ["type=docker"]
}
+1 -1
View File
@@ -68,7 +68,7 @@
"default": "true"
},
"FLASHINFER_VERSION": {
"default": "0.6.11.post2"
"default": "0.6.12"
},
"GDRCOPY_CUDA_VERSION": {
"default": "12.8"
+8 -2
View File
@@ -246,6 +246,12 @@ Every image listed in "image_files" is added to the request in the listed order
The "image" shorthand accepts the same values as "image_files". The "image_url" field accepts either an OpenAI-style object with a "url" field or a URL string.
By default, image references are sent to the serving endpoint as provided, with local image paths converted to `file://` URLs.
If the benchmark client should load local and HTTP(S) images before sending requests, pass `--custom-ensure-client-side-data` to encode them as base64 data URLs on the client side.
Existing `data:image/...` URLs are already self-contained and are kept unchanged.
```bash
# need a model with vision capability here
vllm serve Qwen/Qwen2-VL-7B-Instruct
@@ -253,13 +259,13 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
```bash
# run benchmarking script
vllm bench serve--save-result --save-detailed \
vllm bench serve --save-result --save-detailed \
--backend openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name custom_image \
--dataset-path <path-to-your-image-data-jsonl> \
--allowed-local-media-path /path/to/image/folder
--custom-ensure-client-side-data
```
Note that we need to use the `openai-chat` backend and `/v1/chat/completions` endpoint for multimodal inputs.
+1 -2
View File
@@ -35,8 +35,7 @@ Traces can be visualized using <https://ui.perfetto.dev/>.
!!! tip
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
`export VLLM_RPC_TIMEOUT=1800000`
The engine client waits for this flush to complete without timing out, so simply allow the stop call to run to completion.
### Example commands and usage
+12 -16
View File
@@ -17,6 +17,7 @@ The encoder CUDA Graph system uses a **budget-based capture/replay** strategy, m
* [EncoderCudaGraphManager][vllm.v1.worker.encoder_cudagraph.EncoderCudaGraphManager]: orchestrates capture, replay, greedy packing, and data-parallel execution for encoder CUDA Graphs.
* [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph]: a runtime-checkable protocol that models implement to opt-in to encoder CUDA Graphs.
* [EncoderItemSpec][vllm.v1.worker.encoder_cudagraph_defs.EncoderItemSpec]: describes a single encoder input item (image or video) with its input size and output token count.
* [BudgetGraphMetadata][vllm.v1.worker.encoder_cudagraph.BudgetGraphMetadata]: holds the captured CUDA Graph and its associated I/O buffers for a single token budget level.
### Budget-based graph capture
@@ -30,8 +31,7 @@ class BudgetGraphMetadata:
max_batch_size: int
max_frames_per_batch: int
graph: torch.cuda.CUDAGraph
input_buffer: torch.Tensor # e.g. pixel_values
metadata_buffers: dict[str, torch.Tensor] # e.g. embeddings, seq metadata
input_buffers: dict[str, torch.Tensor] # e.g. pixel_values, embeddings, seq metadata
output_buffer: torch.Tensor # encoder hidden states
```
@@ -43,8 +43,8 @@ When a batch of images arrives, the manager sorts images by output token count (
For each graph replay:
1. Zero the pre-allocated `input_buffer`, then copy input tensors (e.g., `pixel_values`) into it.
2. Zero `metadata_buffers`, then slice-copy precomputed values (e.g., rotary embeddings, sequence metadata).
1. Call `prepare_encoder_cudagraph_replay_buffers()` to compute buffer values (including `pixel_values` and precomputed metadata) from actual batch inputs.
2. Zero the pre-allocated `input_buffers`, then slice-copy the replay values into them.
3. Replay the CUDA Graph.
4. Clone outputs from `output_buffer` (cloning is necessary since the buffer is reused across replays).
@@ -65,19 +65,15 @@ Following <https://github.com/vllm-project/vllm/pull/35963> (ViT full CUDA graph
Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGraph][vllm.model_executor.models.interfaces.SupportsEncoderCudaGraph] protocol. This protocol encapsulates all model-specific logic so that the manager remains model-agnostic. The protocol defines the following methods:
* `get_encoder_cudagraph_config()` — returns static configuration (supported modalities, input key, buffer keys, output hidden size).
* `get_encoder_cudagraph_config()` — returns static configuration (supported modalities, buffer keys, output hidden size, padding logics, max frames per video).
* `get_encoder_cudagraph_budget_range(vllm_config)` — returns `(min_budget, max_budget)` for auto-inference of token budgets.
* `get_encoder_cudagraph_num_items(mm_kwargs)` — returns the number of items (e.g. images) in the batch.
* `get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)` — returns per-item output token counts, used for greedy packing.
* `get_encoder_cudagraph_per_item_input_sizes(mm_kwargs)` — returns per-item input sizes (e.g. patch counts), used for DP load balancing.
* `get_encoder_cudagraph_item_specs(mm_kwargs)` — returns `list[EncoderItemSpec]` describing each item with its input size and output token count. Replaces the former three separate methods (`get_num_items`, `get_per_item_output_tokens`, `get_per_item_input_sizes`).
* `select_encoder_cudagraph_items(mm_kwargs, indices)` — extracts a sub-batch of items by index, used during greedy packing and DP sharding.
* `prepare_encoder_cudagraph_capture_inputs(...)` — creates dummy inputs for graph capture.
* `prepare_encoder_cudagraph_replay_buffers(...)` — computes new buffer values from actual batch inputs before replay.
* `encoder_cudagraph_forward(...)` — forward pass using precomputed buffers (called during capture and replay).
* `encoder_eager_forward(...)` — fallback eager forward when no graph fits.
* `get_input_modality(...)` - return the modality of the inputs.
* `get_max_frames_per_video()` - return model-specific max frames per video.
* `postprocess_encoder_output(...)` - post process encoder output, directly call scatter_output_slices by default
* `prepare_encoder_cudagraph_capture_inputs(...)` — creates dummy inputs for graph capture. Returns `EncoderCudaGraphCaptureInputs` with a single `values: dict[str, torch.Tensor]` that contains all buffers to be recorded into the graph.
* `prepare_encoder_cudagraph_replay_buffers(mm_kwargs, max_batch_size, max_frames_per_batch)` — computes buffer values from actual batch inputs. Returns `EncoderCudaGraphReplayBuffers` with a `values` dict whose keys match `buffer_keys` in the config.
* `encoder_cudagraph_forward(inputs: dict[str, torch.Tensor])` — forward pass accepting only fixed-shaped input tensors (the captured `values` dict). Called during both capture and replay. The `pixel_values` tensor is included in `inputs` alongside metadata buffers.
* `encoder_eager_forward(mm_kwargs)` — fallback eager forward when no graph fits.
* `postprocess_encoder_output(...)` — post-process encoder output, delegates to `scatter_output_slices` by default.
!!! note
The `SupportsEncoderCudaGraph` protocol is designed to be model-agnostic. New vision encoder models can opt-in by implementing the protocol methods without modifying the manager.
@@ -103,7 +99,7 @@ Three fields in `CompilationConfig` control encoder CUDA Graphs:
* `cudagraph_mm_encoder` (`bool`, default `False`) — enable CUDA Graph capture for multimodal encoder. When enabled, captures the full encoder forward as a CUDA Graph for each token budget level.
* `encoder_cudagraph_token_budgets` (`list[int]`, default `[]`) — token budget levels for capture. If empty (default), auto-inferred from model architecture as power-of-2 levels. User-provided values override auto-inference.
* `encoder_cudagraph_max_vision_items_per_batch` (`int`, default `0`) — maximum number of images/videos per batch during capture. If 0 (default), auto-inferred as `max_budget // min_budget`.
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `None`) — maximum number of video frames per batch during capture. If `None` (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * max_frames_per_video` (`max_frames_per_video` is a model-specific value according to its `processing_info`). If we limit the video count per prompt to `0`, it will also be set to `0` (i.e., fall back to image-only mode).
* `encoder_cudagraph_max_frames_per_batch` (`int`, default `None`) — maximum number of video frames per batch during capture. If `None` (default), auto-inferred as `encoder_cudagraph_max_vision_items_per_batch * max_frames_per_video` (`max_frames_per_video` is a model-specific value from `EncoderCudaGraphConfig`, computed by `get_max_frames_per_video()` on the model). If we limit the video count per prompt to `0`, it will also be set to `0` (i.e., fall back to image-only mode).
## Usage guide
+32 -2
View File
@@ -100,14 +100,44 @@ For further details on renderer APIs, please refer to [this page](renderer.md).
- `/version` - Version information
- `/load` - Server load metrics
## Sleep Mode APIs
## Server in development mode
When using the flag VLLM_SERVER_DEV_MODE=1, you enable development endpoints.
**SECURITY WARNING: These endpoints should NOT be used in production!**
### Cache Management APIs
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
- `/reset_encoder_cache` - Reset encoder cache (can disrupt service)
### Weight Transfer APIs (RL Training)
For further details on Weight Transfer, please refer to [this page](../../training/weight_transfer/README.md).
- `/pause` - Pause generation (causes denial of service)
- `/resume` - Resume generation
- `/is_paused` - Check if generation is paused
- `/init_weight_transfer_engine` - Initialize weight transfer engine for RLHF
- `/update_weights` - Update model weights (can alter model behavior)
- `/get_world_size` - Get distributed world size
### Collective RPC
- `/collective_rpc` - Execute arbitrary RPC methods on the engine (extremely dangerous)
### Server info
- `/server_info` - Get detailed server configuration
### Sleep Mode APIs
For further details on sleep mode, please refer to [this page](../../features/sleep_mode.md).
- `/sleep` - Put engine to sleep (causes denial of service)
- `/wake_up` - Wake engine from sleep
- `/is_sleeping` - Check if engine is sleeping
- `/collective_rpc` - Execute arbitrary RPC methods on the engine (extremely dangerous)
## Chat Template
+1
View File
@@ -156,5 +156,6 @@ from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
engine = WeightTransferEngineFactory.create_engine(
config=weight_transfer_config,
parallel_config=parallel_config,
model=model,
)
```
+2 -2
View File
@@ -9,8 +9,8 @@ torchaudio==2.11.0
# These must be updated alongside torch
torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.11.post2
flashinfer-cubin==0.6.11.post2
flashinfer-python==0.6.12
flashinfer-cubin==0.6.12
apache-tvm-ffi==0.1.9
tilelang==0.1.9
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
+1
View File
@@ -5622,6 +5622,7 @@ dependencies = [
"expect-test",
"futures",
"half",
"indexmap 2.13.0",
"itertools 0.14.0",
"llm-multimodal",
"minijinja",
+2 -1
View File
@@ -43,6 +43,7 @@ half = { version = "2.7.1", features = ["bytemuck"] }
hex = "0.4.3"
hf-hub = { version = "0.5.0", features = ["tokio"] }
http-body = "1.0.1"
indexmap = "2.13.0"
itertools = "0.14.0"
libc = "0.2.177"
llm-multimodal = { git = "https://github.com/vllm-project/llm-multimodal", rev = "5b558989844d1c7af3e43d0f604069ffd9c06320" }
@@ -69,7 +70,7 @@ rustc-hash = "1.1.0"
serde = { version = "1.0.228", features = ["derive"] }
serde-json-fmt = "0.1.0"
serde_default = "0.2.0"
serde_json = { version = "1.0.145", features = ["arbitrary_precision", "preserve_order"] }
serde_json = { version = "1.0.145", features = ["preserve_order"] }
serde_repr = "0.1.20"
serde_tuple = "1.1.3"
serde_with = "3.18.0"
+1
View File
@@ -10,6 +10,7 @@ asynk-strim-attr.workspace = true
easy-ext.workspace = true
futures.workspace = true
half.workspace = true
indexmap.workspace = true
itertools.workspace = true
llm-multimodal.workspace = true
minijinja.workspace = true
+1
View File
@@ -189,6 +189,7 @@ impl ChatLlm {
cache_salt: request.cache_salt,
add_special_tokens: request.add_special_tokens,
data_parallel_rank: request.data_parallel_rank,
lora_request: request.lora_request,
};
let decoded_stream = self.text.generate(text_request).await?.map_err(Error::from).boxed();
+2 -2
View File
@@ -6,10 +6,10 @@ use std::convert::Infallible;
use std::fmt;
use std::str::FromStr;
use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};
/// Specify which reasoning or tool-call parser implementation to use.
#[derive(Debug, Clone, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum ParserSelection {
/// Use model-based auto-detection.
#[default]
+2 -2
View File
@@ -5,7 +5,7 @@ use std::str::FromStr;
use minijinja::machinery::ast::{Expr, ForLoop, Set, Stmt};
use minijinja::machinery::{WhitespaceConfig, parse};
use minijinja::syntax::SyntaxConfig;
use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};
/// Chat template content format.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
@@ -18,7 +18,7 @@ pub enum ChatTemplateContentFormat {
}
/// Configurable chat-template content format selection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum ChatTemplateContentFormatOption {
/// Detect the format from the template source.
#[default]
+33 -7
View File
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use serde::Serialize;
use serde_json::Value;
use serde_json::Value as JsonValue;
use thiserror_ext::AsReport as _;
use tracing::{info, trace, warn};
use vllm_text::Prompt;
@@ -13,6 +13,7 @@ use self::format::{
ChatTemplateContentFormat, ChatTemplateContentFormatOption as ContentFormatOption,
};
use self::template::{CompiledChatTemplate, TemplateContext};
use self::value::{TemplateValue, to_template_value};
use super::{ChatRenderer, RenderedPrompt};
use crate::error::Result;
use crate::request::{ChatContent, ChatContentPart, ChatMessage, ChatRequest};
@@ -24,6 +25,7 @@ mod error;
mod format;
mod template;
mod tojson;
mod value;
pub use template::{load_chat_template, resolve_chat_template};
@@ -38,7 +40,7 @@ pub struct MultimodalRenderInfo {
/// state.
pub struct HfChatRenderer {
default_template: Option<CompiledChatTemplate>,
default_template_kwargs: HashMap<String, Value>,
default_template_kwargs: HashMap<String, JsonValue>,
content_format: ContentFormatOption,
special_tokens: Option<HfSpecialTokens>,
multimodal: Option<MultimodalRenderInfo>,
@@ -48,7 +50,7 @@ impl HfChatRenderer {
/// Create a renderer from the given template string.
pub fn new(
template: Option<String>,
default_template_kwargs: HashMap<String, Value>,
default_template_kwargs: HashMap<String, JsonValue>,
content_format: ContentFormatOption,
) -> Result<Self> {
Ok(Self {
@@ -245,7 +247,7 @@ struct TemplateToolCall {
#[derive(Debug, Serialize)]
struct TemplateToolFunction {
name: String,
arguments: Value,
arguments: TemplateValue,
}
#[derive(Debug, Serialize)]
@@ -259,7 +261,7 @@ pub(super) struct TemplateTool {
struct TemplateToolDefinition {
name: String,
description: Option<String>,
parameters: Value,
parameters: TemplateValue,
strict: Option<bool>,
}
@@ -345,13 +347,14 @@ fn to_template_tool_calls(
let mut tool_calls = Vec::new();
for tool_call in content.tool_calls() {
let arguments = serde_json::from_str::<Value>(&tool_call.arguments).map_err(|error| {
let arguments = serde_json::from_str(&tool_call.arguments).map_err(|error| {
Error::ChatTemplate(format!(
"assistant tool call `{}` has invalid JSON arguments: {}",
tool_call.id,
error.as_report()
))
})?;
let arguments = to_template_value(arguments);
tool_calls.push(TemplateToolCall {
id: tool_call.id.clone(),
@@ -434,7 +437,7 @@ fn to_template_tools(tools: &[ChatTool]) -> Vec<TemplateTool> {
function: TemplateToolDefinition {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
parameters: to_template_value(tool.parameters.clone()),
strict: tool.strict,
},
})
@@ -909,6 +912,29 @@ mod tests {
assert_eq!(rendered, "get_weather|Paris|call_1|Sunny");
}
#[test]
fn chat_template_tool_call_argument_items_method_is_not_shadowed_by_field() {
let request = sample_request(vec![ChatMessage::assistant_blocks(vec![
AssistantContentBlock::ToolCall(crate::AssistantToolCall {
id: "call_1".to_string(),
name: "add".to_string(),
arguments: r#"{"items":"operands","x":2,"y":1.0}"#.to_string(),
}),
])]);
let rendered = render(
Some(
"{%- set arguments = messages[0].tool_calls[0].function.arguments -%}
{%- for key, value in arguments.items() -%}{{ key }}={{ value }};{%- endfor -%}
|{{ arguments['items'] }}",
),
&request,
)
.unwrap();
assert_eq!(rendered, "items=operands;x=2;y=1.0;|operands");
}
#[test]
fn qwen35_template_renders_prefilled_reasoning_start_when_thinking_enabled() {
let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]);
+18 -2
View File
@@ -208,11 +208,27 @@ mod tests {
}
#[test]
fn tojson_preserves_arbitrary_precision_number_spelling() {
fn tojson_uses_standard_serde_json_number_spelling() {
let payload = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
let rendered = render("{{ payload|tojson }}", payload);
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.00}");
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
// `arbitrary_precision` feature, otherwise the following test
// `serialized_json_numbers_do_not_leak_serde_private_representation` will fail.
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.0}");
}
#[test]
fn serialized_json_numbers_do_not_leak_serde_private_representation() {
let payload: serde_json::Value = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
let rendered = render("{{ payload }}", payload);
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
// `arbitrary_precision` feature, otherwise this will fail.
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
assert!(!rendered.contains("$serde_json::private::Number"));
assert_eq!(rendered, r#"{"x": 2, "y": 1.0}"#);
}
#[test]
+77
View File
@@ -0,0 +1,77 @@
use std::sync::Arc;
use indexmap::IndexMap;
use minijinja::value::{Enumerator, Object, ObjectExt, ObjectRepr};
use minijinja::{Error as TemplateError, ErrorKind as TemplateErrorKind, State};
use serde::Serialize;
use serde_json::Value as JsonValue;
/// A wrapper around `minijinja::Value` that can be constructed with `to_template_value` and used
/// as a value in the chat template.
#[derive(Debug, Serialize)]
#[serde(transparent)]
pub(super) struct TemplateValue(minijinja::Value);
pub(super) fn to_template_value(value: JsonValue) -> TemplateValue {
TemplateValue(match value {
JsonValue::Array(values) => values
.into_iter()
.map(to_template_value)
.map(|value| value.0)
.collect::<minijinja::Value>(),
JsonValue::Object(values) => minijinja::Value::from_object(TemplateMap(
values
.into_iter()
.map(|(key, value)| (key, to_template_value(value).0))
.collect(),
)),
// For primitive values, directly convert them to `minijinja::Value` using `from_serialize`.
value => minijinja::Value::from_serialize(value),
})
}
/// A custom map type that always returns `UnknownMethod` for method calls, so that pycompat can
/// always handle dict methods through the unknown-method callback.
///
/// Use `IndexMap` to preserve the original key order when iterating.
///
/// MiniJinja's default map can resolve a same-named field before Python dict methods. HF templates
/// commonly call `dict.items()`, which would fail if the map had an `items` field.
/// See issue: https://github.com/mitsuhiko/minijinja/issues/903
#[derive(Debug)]
struct TemplateMap(IndexMap<String, minijinja::Value>);
impl Object for TemplateMap {
fn repr(self: &Arc<Self>) -> ObjectRepr {
ObjectRepr::Map
}
fn get_value(self: &Arc<Self>, key: &minijinja::Value) -> Option<minijinja::Value> {
self.0.get(key.as_str()?).cloned()
}
fn get_value_by_str(self: &Arc<Self>, key: &str) -> Option<minijinja::Value> {
self.0.get(key).cloned()
}
fn enumerate(self: &Arc<Self>) -> Enumerator {
self.mapped_rev_enumerator(|this| {
Box::new(this.0.keys().map(|key| minijinja::Value::from(key.as_str())))
})
}
fn enumerator_len(self: &Arc<Self>) -> Option<usize> {
Some(self.0.len())
}
fn call_method(
self: &Arc<Self>,
_state: &State<'_, '_>,
_method: &str,
_args: &[minijinja::Value],
) -> std::result::Result<minijinja::Value, TemplateError> {
// Always return `UnknownMethod` for method calls,
// so that pycompat can handle dict methods through the unknown-method callback.
Err(TemplateError::from(TemplateErrorKind::UnknownMethod))
}
}
+2 -2
View File
@@ -1,10 +1,10 @@
use std::fmt;
use std::str::FromStr;
use serde_with::DeserializeFromStr;
use serde_with::{DeserializeFromStr, SerializeDisplay};
/// Specify which chat renderer implementation to use.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr, SerializeDisplay)]
pub enum RendererSelection {
/// Use model-based auto-detection.
#[default]

Some files were not shown because too many files have changed in this diff Show More