mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] Support custom callable proposer backend for speculative decoding (#39487)
Signed-off-by: 524031910363 <hyzhyzsh@sjtu.edu.cn> Signed-off-by: CynicDora <hyzhyzsh@sjtu.edu.cn>
This commit is contained in:
@@ -15,6 +15,7 @@ vLLM supports a variety of methods of speculative decoding. Model-based methods
|
||||
- [Multi-Layer Perceptron](mlp.md)
|
||||
- [N-Gram](n_gram.md)
|
||||
- [Suffix Decoding](suffix.md)
|
||||
- [Custom Proposer Backend (Experimental)](#custom-proposer-backend-experimental)
|
||||
|
||||
## Method Selection at a Glance
|
||||
|
||||
@@ -30,11 +31,22 @@ depend on your model family, traffic pattern, hardware, and sampling settings.
|
||||
| MLP speculator | Medium to high gain | Medium gain | Good when compatible MLP speculators are available. |
|
||||
| N-gram | Low to medium gain | Medium gain | Lightweight and easy to enable. |
|
||||
| Suffix decoding | Low to medium gain | Medium gain | No extra draft model; dynamic speculation depth. |
|
||||
| Custom Proposer | Varies | Varies | Bring your own proposer class (experimental). |
|
||||
|
||||
For reproducible measurements in your environment, use
|
||||
[`examples/features/speculative_decoding/spec_decode_offline.py`](../../../examples/features/speculative_decoding/spec_decode_offline.py)
|
||||
or the [benchmark CLI guide](../../benchmarking/cli.md).
|
||||
|
||||
## Custom Proposer Backend (Experimental)
|
||||
|
||||
You can plug in your own custom proposer class for speculative decoding by setting the method to `custom_class` and providing the full module path to your class.
|
||||
Your custom class must accept a `VllmConfig` upon instantiation and implement a `propose` method.
|
||||
|
||||
**Example configuration:**
|
||||
|
||||
- `speculative_config.method = "custom_class"`
|
||||
- `speculative_config.model = "your_module.YourCustomProposerClass"`
|
||||
|
||||
## `--speculative-config` schema
|
||||
|
||||
Use `--speculative-config` to pass speculative decoding settings as a JSON
|
||||
|
||||
Executable
+121
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Integration test for custom proposer class in speculative decoding.
|
||||
|
||||
Usage:
|
||||
.venv/bin/python test_custom_proposer.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
MODEL_ID = "facebook/opt-125m"
|
||||
NUM_SPEC_TOKENS = 5
|
||||
|
||||
|
||||
class DummyDraftProposer:
|
||||
"""Custom proposer class that repeats the last token of each sequence.
|
||||
|
||||
This demonstrates the class-based custom proposer interface.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
"""Initialize the custom proposer.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM configuration containing model and speculative settings.
|
||||
"""
|
||||
self.num_speculative_tokens = (
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
print(
|
||||
f"[DummyDraftProposer.__init__] num_speculative_tokens="
|
||||
f"{self.num_speculative_tokens}, max_model_len={self.max_model_len}"
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_tokens_no_spec: int,
|
||||
token_ids_cpu: torch.Tensor,
|
||||
slot_mappings: torch.Tensor | None = None,
|
||||
) -> list[list[int]]:
|
||||
"""Generate draft tokens by repeating the last token of each sequence.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Recently sampled token IDs per request.
|
||||
num_tokens_no_spec: Number of non-speculative tokens per request.
|
||||
token_ids_cpu: Full token IDs tensor on CPU.
|
||||
slot_mappings: Slot mapping for KV cache (optional).
|
||||
|
||||
Returns:
|
||||
List of draft token sequences for each request.
|
||||
"""
|
||||
# Cross-process flag to prove this method was executed
|
||||
with open("proposer_called.flag", "w") as f:
|
||||
f.write("called")
|
||||
|
||||
batch_size = len(sampled_token_ids)
|
||||
last_tokens = [seq[-1] for seq in sampled_token_ids]
|
||||
drafts = [[t] * self.num_speculative_tokens for t in last_tokens]
|
||||
print(
|
||||
f"[DummyDraftProposer.propose] batch_size={batch_size}, "
|
||||
f"num_speculative_tokens={self.num_speculative_tokens}, "
|
||||
f"drafts_shape={len(drafts)}x{len(drafts[0])}"
|
||||
)
|
||||
return drafts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Custom Proposer Backend Integration Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Cleanup any leftover flag from previous failed runs
|
||||
if os.path.exists("proposer_called.flag"):
|
||||
os.remove("proposer_called.flag")
|
||||
|
||||
llm = LLM(
|
||||
model=MODEL_ID,
|
||||
speculative_config={
|
||||
"model": f"{__name__}.DummyDraftProposer",
|
||||
"num_speculative_tokens": NUM_SPEC_TOKENS,
|
||||
},
|
||||
gpu_memory_utilization=0.4,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=32,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
print(f"\nRunning generate with {len(prompts)} prompt(s)...\n")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Generated text: {generated!r}")
|
||||
print("-" * 60)
|
||||
|
||||
# Verify the custom proposer's propose() was actually called across processes
|
||||
assert os.path.exists("proposer_called.flag"), (
|
||||
"The custom proposer's propose() method was never called!"
|
||||
)
|
||||
os.remove("proposer_called.flag")
|
||||
|
||||
print("✓ Custom proposer was actively used during generation!")
|
||||
print("Test completed successfully.")
|
||||
@@ -98,7 +98,6 @@ def mypy(
|
||||
|
||||
|
||||
def main():
|
||||
ci = sys.argv[1] == "1"
|
||||
python_version = sys.argv[2]
|
||||
file_groups = group_files(sys.argv[3:])
|
||||
|
||||
@@ -107,7 +106,7 @@ def main():
|
||||
|
||||
returncode = 0
|
||||
for file_group, changed_files in file_groups.items():
|
||||
follow_imports = None if ci and file_group == "" else "skip"
|
||||
follow_imports = None if file_group == "" else "skip"
|
||||
if changed_files:
|
||||
returncode |= mypy(
|
||||
changed_files, python_version, follow_imports, file_group
|
||||
|
||||
@@ -62,6 +62,7 @@ SpeculativeMethod = Literal[
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
"custom_class",
|
||||
EagleModelTypes,
|
||||
NgramGPUTypes,
|
||||
]
|
||||
@@ -511,7 +512,16 @@ class SpeculativeConfig:
|
||||
# default.
|
||||
|
||||
# infer method from user args
|
||||
if self.method is None:
|
||||
# Check if the model field contains a custom module path (e.g., 'pkg.Mod')
|
||||
if (
|
||||
self.model is not None
|
||||
and "." in self.model
|
||||
and not self.model.startswith(("http://", "https://", "file://"))
|
||||
and "/" not in self.model # not a HuggingFace repo (org/model)
|
||||
):
|
||||
# Treat as a custom class path
|
||||
self.method = "custom_class"
|
||||
elif self.method is None:
|
||||
if self.model in ("ngram", "[ngram]"):
|
||||
self.method = "ngram"
|
||||
else:
|
||||
@@ -545,6 +555,14 @@ class SpeculativeConfig:
|
||||
self.model = "suffix"
|
||||
elif self.method == "extract_hidden_states":
|
||||
self.model = "extract_hidden_states"
|
||||
elif self.method == "custom_class":
|
||||
# method was set explicitly, but model should already contain the
|
||||
# custom module path. If not, this is a configuration error.
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"method='custom_class' requires 'model' to contain the "
|
||||
"custom proposer module path (e.g., 'my_module.MyProposer')."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
@@ -588,6 +606,18 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
elif self.method == "custom_class":
|
||||
# Custom class proposer does not need a draft model.
|
||||
# It will dynamically load the user-provided class at runtime.
|
||||
logger.warning_once(
|
||||
"Using a custom class-based proposer backend. This is an "
|
||||
"experimental feature and the proposer interface is subject to "
|
||||
"breaking changes in future vLLM releases."
|
||||
)
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "extract_hidden_states":
|
||||
from vllm.transformers_utils.configs.extract_hidden_states import (
|
||||
ExtractHiddenStatesConfig,
|
||||
@@ -1063,7 +1093,13 @@ class SpeculativeConfig:
|
||||
method = self.method
|
||||
model = (
|
||||
None
|
||||
if method in ("ngram", "suffix", "extract_hidden_states")
|
||||
if method
|
||||
in (
|
||||
"ngram",
|
||||
"suffix",
|
||||
"extract_hidden_states",
|
||||
"custom_class",
|
||||
)
|
||||
else self.draft_model_config.model
|
||||
)
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
|
||||
@@ -1613,11 +1613,6 @@ class EngineArgs:
|
||||
) -> SpeculativeConfig | None:
|
||||
"""Initializes and returns a SpeculativeConfig object based on
|
||||
`speculative_config`.
|
||||
|
||||
This function utilizes `speculative_config` to create a
|
||||
SpeculativeConfig object. The `speculative_config` can either be
|
||||
provided as a JSON string input via CLI arguments or directly as a
|
||||
dictionary from the engine.
|
||||
"""
|
||||
if self.speculative_config is None:
|
||||
return None
|
||||
|
||||
Executable
+73
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_custom_proposer(vllm_config: VllmConfig):
|
||||
"""Load and instantiate a user-provided proposer class.
|
||||
|
||||
The class path is read from ``speculative_config.model``
|
||||
(e.g., ``"my_module.MyCustomProposer"``). The class is
|
||||
imported, instantiated with *vllm_config*, and returned
|
||||
directly so the caller can use it without any wrapper.
|
||||
|
||||
The returned object must expose a callable ``propose`` method.
|
||||
"""
|
||||
assert vllm_config.speculative_config is not None
|
||||
spec_config = vllm_config.speculative_config
|
||||
|
||||
backend = spec_config.model
|
||||
assert backend is not None
|
||||
|
||||
if "." not in backend:
|
||||
raise ValueError(
|
||||
f"Invalid custom proposer module path '{backend}'. "
|
||||
"It must be a full module path (e.g., 'module.MyProposerClass')."
|
||||
)
|
||||
|
||||
module_path, class_name = backend.rsplit(".", 1)
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Cannot import module '{module_path}' for custom proposer '{backend}': {e}"
|
||||
) from e
|
||||
|
||||
user_class = getattr(module, class_name, None)
|
||||
if user_class is None:
|
||||
raise AttributeError(
|
||||
f"Module '{module_path}' has no attribute '{class_name}' "
|
||||
f"(speculative_config.model='{backend}')"
|
||||
)
|
||||
|
||||
try:
|
||||
instance = user_class(vllm_config)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to instantiate custom proposer class '{backend}': {e}. "
|
||||
"The class constructor must accept VllmConfig as argument."
|
||||
) from e
|
||||
|
||||
if not hasattr(instance, "propose"):
|
||||
raise AttributeError(
|
||||
f"Custom proposer class '{backend}' must have a 'propose' method."
|
||||
)
|
||||
if not callable(instance.propose):
|
||||
raise AttributeError(
|
||||
f"Custom proposer class '{backend}' has a 'propose' attribute "
|
||||
"but it is not callable."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded custom proposer class '%s' with num_speculative_tokens=%d",
|
||||
backend,
|
||||
spec_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
return instance
|
||||
@@ -169,6 +169,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.custom_class_proposer import create_custom_proposer
|
||||
from vllm.v1.spec_decode.dflash import DFlashProposer
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
@@ -531,7 +532,11 @@ class GPUModelRunner(
|
||||
| ExtractHiddenStatesProposer
|
||||
| Gemma4Proposer
|
||||
)
|
||||
if self.speculative_config.method == "ngram":
|
||||
if self.speculative_config.method == "custom_class":
|
||||
self.drafter = create_custom_proposer( # type: ignore[assignment]
|
||||
self.vllm_config
|
||||
)
|
||||
elif self.speculative_config.method == "ngram":
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
@@ -4614,6 +4619,14 @@ class GPUModelRunner(
|
||||
self.input_batch.token_ids_cpu,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
elif spec_config.method == "custom_class":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
draft_token_ids = cast(Any, self.drafter).propose(
|
||||
sampled_token_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
elif spec_config.use_ngram_gpu():
|
||||
assert isinstance(self.drafter, NgramProposerGPU)
|
||||
(
|
||||
@@ -4885,7 +4898,8 @@ class GPUModelRunner(
|
||||
)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info_once("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if hasattr(self.drafter, "load_model"):
|
||||
self.drafter.load_model(self.model)
|
||||
if (
|
||||
hasattr(self.drafter, "model")
|
||||
and is_mixture_of_experts(self.drafter.model)
|
||||
|
||||
Reference in New Issue
Block a user