[Feature] Support custom callable proposer backend for speculative decoding (#39487)

Signed-off-by: 524031910363 <hyzhyzsh@sjtu.edu.cn>
Signed-off-by: CynicDora <hyzhyzsh@sjtu.edu.cn>
This commit is contained in:
CynicDora
2026-05-14 00:53:01 +08:00
committed by GitHub
parent e35c0d4c63
commit 256dbcaabf
7 changed files with 261 additions and 11 deletions
@@ -15,6 +15,7 @@ vLLM supports a variety of methods of speculative decoding. Model-based methods
- [Multi-Layer Perceptron](mlp.md)
- [N-Gram](n_gram.md)
- [Suffix Decoding](suffix.md)
- [Custom Proposer Backend (Experimental)](#custom-proposer-backend-experimental)
## Method Selection at a Glance
@@ -30,11 +31,22 @@ depend on your model family, traffic pattern, hardware, and sampling settings.
| MLP speculator | Medium to high gain | Medium gain | Good when compatible MLP speculators are available. |
| N-gram | Low to medium gain | Medium gain | Lightweight and easy to enable. |
| Suffix decoding | Low to medium gain | Medium gain | No extra draft model; dynamic speculation depth. |
| Custom Proposer | Varies | Varies | Bring your own proposer class (experimental). |
For reproducible measurements in your environment, use
[`examples/features/speculative_decoding/spec_decode_offline.py`](../../../examples/features/speculative_decoding/spec_decode_offline.py)
or the [benchmark CLI guide](../../benchmarking/cli.md).
## Custom Proposer Backend (Experimental)
You can plug in your own custom proposer class for speculative decoding by setting the method to `custom_class` and providing the full module path to your class.
Your custom class must accept a `VllmConfig` upon instantiation and implement a `propose` method.
**Example configuration:**
- `speculative_config.method = "custom_class"`
- `speculative_config.model = "your_module.YourCustomProposerClass"`
## `--speculative-config` schema
Use `--speculative-config` to pass speculative decoding settings as a JSON
+121
View File
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration test for custom proposer class in speculative decoding.
Usage:
.venv/bin/python test_custom_proposer.py
"""
import os
import torch
from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
MODEL_ID = "facebook/opt-125m"
NUM_SPEC_TOKENS = 5
class DummyDraftProposer:
"""Custom proposer class that repeats the last token of each sequence.
This demonstrates the class-based custom proposer interface.
"""
def __init__(self, vllm_config: VllmConfig):
"""Initialize the custom proposer.
Args:
vllm_config: vLLM configuration containing model and speculative settings.
"""
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
print(
f"[DummyDraftProposer.__init__] num_speculative_tokens="
f"{self.num_speculative_tokens}, max_model_len={self.max_model_len}"
)
def propose(
self,
sampled_token_ids: list[list[int]],
num_tokens_no_spec: int,
token_ids_cpu: torch.Tensor,
slot_mappings: torch.Tensor | None = None,
) -> list[list[int]]:
"""Generate draft tokens by repeating the last token of each sequence.
Args:
sampled_token_ids: Recently sampled token IDs per request.
num_tokens_no_spec: Number of non-speculative tokens per request.
token_ids_cpu: Full token IDs tensor on CPU.
slot_mappings: Slot mapping for KV cache (optional).
Returns:
List of draft token sequences for each request.
"""
# Cross-process flag to prove this method was executed
with open("proposer_called.flag", "w") as f:
f.write("called")
batch_size = len(sampled_token_ids)
last_tokens = [seq[-1] for seq in sampled_token_ids]
drafts = [[t] * self.num_speculative_tokens for t in last_tokens]
print(
f"[DummyDraftProposer.propose] batch_size={batch_size}, "
f"num_speculative_tokens={self.num_speculative_tokens}, "
f"drafts_shape={len(drafts)}x{len(drafts[0])}"
)
return drafts
if __name__ == "__main__":
print("=" * 60)
print("Custom Proposer Backend Integration Test")
print("=" * 60)
# Cleanup any leftover flag from previous failed runs
if os.path.exists("proposer_called.flag"):
os.remove("proposer_called.flag")
llm = LLM(
model=MODEL_ID,
speculative_config={
"model": f"{__name__}.DummyDraftProposer",
"num_speculative_tokens": NUM_SPEC_TOKENS,
},
gpu_memory_utilization=0.4,
enforce_eager=True,
)
prompts = [
"Hello, my name is",
"The future of AI is",
]
sampling_params = SamplingParams(
max_tokens=32,
temperature=0.0,
)
print(f"\nRunning generate with {len(prompts)} prompt(s)...\n")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated text: {generated!r}")
print("-" * 60)
# Verify the custom proposer's propose() was actually called across processes
assert os.path.exists("proposer_called.flag"), (
"The custom proposer's propose() method was never called!"
)
os.remove("proposer_called.flag")
print("✓ Custom proposer was actively used during generation!")
print("Test completed successfully.")
+1 -2
View File
@@ -98,7 +98,6 @@ def mypy(
def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
@@ -107,7 +106,7 @@ def main():
returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
follow_imports = None if file_group == "" else "skip"
if changed_files:
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
+38 -2
View File
@@ -62,6 +62,7 @@ SpeculativeMethod = Literal[
"mlp_speculator",
"draft_model",
"suffix",
"custom_class",
EagleModelTypes,
NgramGPUTypes,
]
@@ -511,7 +512,16 @@ class SpeculativeConfig:
# default.
# infer method from user args
if self.method is None:
# Check if the model field contains a custom module path (e.g., 'pkg.Mod')
if (
self.model is not None
and "." in self.model
and not self.model.startswith(("http://", "https://", "file://"))
and "/" not in self.model # not a HuggingFace repo (org/model)
):
# Treat as a custom class path
self.method = "custom_class"
elif self.method is None:
if self.model in ("ngram", "[ngram]"):
self.method = "ngram"
else:
@@ -545,6 +555,14 @@ class SpeculativeConfig:
self.model = "suffix"
elif self.method == "extract_hidden_states":
self.model = "extract_hidden_states"
elif self.method == "custom_class":
# method was set explicitly, but model should already contain the
# custom module path. If not, this is a configuration error.
if self.model is None:
raise ValueError(
"method='custom_class' requires 'model' to contain the "
"custom proposer module path (e.g., 'my_module.MyProposer')."
)
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -588,6 +606,18 @@ class SpeculativeConfig:
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
elif self.method == "custom_class":
# Custom class proposer does not need a draft model.
# It will dynamically load the user-provided class at runtime.
logger.warning_once(
"Using a custom class-based proposer backend. This is an "
"experimental feature and the proposer interface is subject to "
"breaking changes in future vLLM releases."
)
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "extract_hidden_states":
from vllm.transformers_utils.configs.extract_hidden_states import (
ExtractHiddenStatesConfig,
@@ -1063,7 +1093,13 @@ class SpeculativeConfig:
method = self.method
model = (
None
if method in ("ngram", "suffix", "extract_hidden_states")
if method
in (
"ngram",
"suffix",
"extract_hidden_states",
"custom_class",
)
else self.draft_model_config.model
)
num_spec_tokens = self.num_speculative_tokens
-5
View File
@@ -1613,11 +1613,6 @@ class EngineArgs:
) -> SpeculativeConfig | None:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
"""
if self.speculative_config is None:
return None
+73
View File
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def create_custom_proposer(vllm_config: VllmConfig):
"""Load and instantiate a user-provided proposer class.
The class path is read from ``speculative_config.model``
(e.g., ``"my_module.MyCustomProposer"``). The class is
imported, instantiated with *vllm_config*, and returned
directly so the caller can use it without any wrapper.
The returned object must expose a callable ``propose`` method.
"""
assert vllm_config.speculative_config is not None
spec_config = vllm_config.speculative_config
backend = spec_config.model
assert backend is not None
if "." not in backend:
raise ValueError(
f"Invalid custom proposer module path '{backend}'. "
"It must be a full module path (e.g., 'module.MyProposerClass')."
)
module_path, class_name = backend.rsplit(".", 1)
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(
f"Cannot import module '{module_path}' for custom proposer '{backend}': {e}"
) from e
user_class = getattr(module, class_name, None)
if user_class is None:
raise AttributeError(
f"Module '{module_path}' has no attribute '{class_name}' "
f"(speculative_config.model='{backend}')"
)
try:
instance = user_class(vllm_config)
except Exception as e:
raise RuntimeError(
f"Failed to instantiate custom proposer class '{backend}': {e}. "
"The class constructor must accept VllmConfig as argument."
) from e
if not hasattr(instance, "propose"):
raise AttributeError(
f"Custom proposer class '{backend}' must have a 'propose' method."
)
if not callable(instance.propose):
raise AttributeError(
f"Custom proposer class '{backend}' has a 'propose' attribute "
"but it is not callable."
)
logger.info(
"Loaded custom proposer class '%s' with num_speculative_tokens=%d",
backend,
spec_config.num_speculative_tokens,
)
return instance
+16 -2
View File
@@ -169,6 +169,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.custom_class_proposer import create_custom_proposer
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
@@ -531,7 +532,11 @@ class GPUModelRunner(
| ExtractHiddenStatesProposer
| Gemma4Proposer
)
if self.speculative_config.method == "ngram":
if self.speculative_config.method == "custom_class":
self.drafter = create_custom_proposer( # type: ignore[assignment]
self.vllm_config
)
elif self.speculative_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
self.drafter = NgramProposer(self.vllm_config)
@@ -4614,6 +4619,14 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
elif spec_config.method == "custom_class":
assert isinstance(sampled_token_ids, list)
draft_token_ids = cast(Any, self.drafter).propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
elif spec_config.use_ngram_gpu():
assert isinstance(self.drafter, NgramProposerGPU)
(
@@ -4885,7 +4898,8 @@ class GPUModelRunner(
)
if hasattr(self, "drafter"):
logger.info_once("Loading drafter model...")
self.drafter.load_model(self.model)
if hasattr(self.drafter, "load_model"):
self.drafter.load_model(self.model)
if (
hasattr(self.drafter, "model")
and is_mixture_of_experts(self.drafter.model)