[CPU Backend] CPU top-k and top-p sampling kernels using Triton (#43633)

Signed-off-by: Li, Tianmu <tianmu.li@intel.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Tianmu Li
2026-05-29 00:02:39 -07:00
committed by GitHub
parent 04516eabc8
commit 94d3f4d205
5 changed files with 23 additions and 6 deletions
+6 -1
View File
@@ -62,11 +62,16 @@ steps:
source_file_dependencies:
- vllm/v1/worker/cpu/
- vllm/v1/worker/gpu/
- vllm/v1/sample/ops/topk_topp_triton.py
- vllm/v1/sample/ops/topk_topp_sampler.py
- tests/v1/sample/test_topk_topp_sampler.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 45m "
uv pip install git+https://github.com/triton-lang/triton-cpu.git@270e696d
VLLM_USE_V2_MODEL_RUNNER=1 pytest -x -v -s tests/models/language/generation/test_granite.py -m cpu_model"
VLLM_USE_V2_MODEL_RUNNER=1 pytest -x -v -s tests/models/language/generation/test_granite.py -m cpu_model
# TODO: move to CPU-Kernel Tests once triton-cpu has a pre-built wheel
pytest -x -v -s tests/v1/sample/test_topk_topp_sampler.py::TestTritonTopkTopp"
- label: CPU-Quantization Model Tests
depends_on: []
+2 -1
View File
@@ -5,6 +5,7 @@ import torch
from torch import Generator
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
DEVICE_TYPE = current_platform.device_type
@@ -151,7 +152,7 @@ def test_flashinfer_sampler():
# =============================================================================
@pytest.mark.skipif("cpu" in DEVICE_TYPE, reason="CUDA/XPU not available")
@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available on this platform")
class TestTritonTopkTopp:
"""Tests for the Triton top-k/top-p kernel."""
+4
View File
@@ -408,6 +408,10 @@ class CpuPlatform(Platform):
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.get_num_threads()
@classmethod
def import_kernels(cls) -> None:
if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
+3 -2
View File
@@ -168,7 +168,7 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True)
logits = apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
@@ -310,8 +310,9 @@ def apply_top_k_top_p(
if p is None and k is None:
return logits
# Keep CPU logits on the PyTorch path to avoid invoking Triton kernels.
if current_platform.is_cpu():
if HAS_TRITON:
return apply_top_k_top_p_triton(logits, k, p)
return apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True)
if HAS_TRITON and logits.shape[0] >= 8:
+8 -2
View File
@@ -1038,6 +1038,12 @@ def apply_top_k_top_p_triton(
else:
normal_cdf_to_sigma_table, percentile_to_std_table = tables
# Smaller tiles compile and run faster on CPU; GPU benefits from larger tiles.
if logits.device.type == "cpu":
block_size, block_size_trunc = 256, 128
else:
block_size, block_size_trunc = 8192, 4096
_topk_topp_kernel[(NUM_PROGRAMS,)](
logits,
logits.stride(0),
@@ -1049,8 +1055,8 @@ def apply_top_k_top_p_triton(
BATCH_SIZE=batch_size,
MASK_VALUE=mask_value,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=8192,
BLOCK_SIZE_TRUNC=4096,
BLOCK_SIZE=block_size,
BLOCK_SIZE_TRUNC=block_size_trunc,
TOPK_ENABLED=topk_enabled,
TOPP_ENABLED=topp_enabled,
)