[TRTLLM-6406] feat: Enable guided decoding with overlap scheduler (#6000)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-07-17 17:46:10 +08:00 committed by GitHub
parent 44c70c88f9
commit 21efb50068
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 53 additions and 55 deletions

View File

@ -54,8 +54,8 @@ void logitsBitmask(std::vector<torch::Tensor> const& logits, std::vector<torch::
bitmaskPtrsHost[i] = reinterpret_cast<uint64_t>(bitmask[i].data_ptr());
}
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA);
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA);
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream();

View File

@ -15,4 +15,4 @@
| KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | No | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |

View File

@ -7,12 +7,9 @@ from tensorrt_llm.llmapi import GuidedDecodingParams
def main():
# Specify the guided decoding backend; xgrammar is supported currently.
llm = LLM(
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
guided_decoding_backend='xgrammar',
disable_overlap_scheduler=True # Not supported by xgrammar mode
)
# Specify the guided decoding backend; xgrammar and llguidance are supported currently.
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
guided_decoding_backend='xgrammar')
# An example from json-mode-eval
schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'

View File

@ -21,6 +21,7 @@ from ..model_config import ModelConfig
from ..speculative import get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid
from .guided_decoder import GuidedDecoder
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .model_engine import PyTorchModelEngine
@ -414,19 +415,12 @@ def create_py_executor_instance(
start_worker,
sampler,
drafter,
guided_decoder: Optional[GuidedDecoder] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
spec_config = model_engine.spec_config
if mapping.is_last_pp_rank(
) and executor_config.guided_decoding_config is not None:
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if not pytorch_backend_config.disable_overlap_scheduler:
raise ValueError(
"Guided decoding is not supported with overlap scheduler.")
logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
@ -543,6 +537,7 @@ def create_py_executor_instance(
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
draft_model_engine=draft_model_engine,
guided_decoder=guided_decoder,
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)

View File

@ -3,11 +3,11 @@ from typing import List, Optional
import torch
from ..._utils import nvtx_range
from ...bindings.executor import GuidedDecodingConfig
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
from .scheduler import ScheduledRequests
from .seq_slot_manager import SeqSlotManager
class GuidedDecoder:
@ -49,12 +49,12 @@ class GuidedDecoder:
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)
def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
@nvtx_range("GuidedDecoder.build")
def build(self, scheduled_requests: ScheduledRequests) -> None:
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
slot = llm_req.py_seq_slot
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
@ -75,8 +75,9 @@ class GuidedDecoder:
self.bitmask[slot].copy_(self.bitmask_host[slot],
non_blocking=True)
@nvtx_range("GuidedDecoder.execute")
def execute(self, scheduled_requests: ScheduledRequests,
logits: torch.Tensor, resource_manager: SeqSlotManager) -> None:
logits: torch.Tensor) -> None:
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
scheduled_requests.generation_requests)
torch.cuda.current_stream().wait_stream(self._stream)
@ -88,8 +89,7 @@ class GuidedDecoder:
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
continue
batched_logits.append(logits[i])
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
batched_bitmask.append(self.bitmask[slot])
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])
if len(batched_logits) > 0:
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)

View File

@ -21,7 +21,6 @@ from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
@ -53,7 +52,6 @@ from ..utils import (get_model_extra_attrs, set_torch_compiling,
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .cuda_graph_runner import DecodingCUDAGraphRunner
from .guided_decoder import GuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
@ -258,7 +256,6 @@ class PyTorchModelEngine(ModelEngine):
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
):
@ -313,13 +310,6 @@ class PyTorchModelEngine(ModelEngine):
self.dtype = self.model.config.torch_dtype
self._init_model_capacity()
self.guided_decoder: Optional[GuidedDecoder] = None
if self.mapping.is_last_pp_rank(
) and guided_decoding_config is not None:
self.guided_decoder = GuidedDecoder(guided_decoding_config,
self.batch_size,
self.model.vocab_size_padded)
self._torch_compile_backend = None
try:
@ -2091,18 +2081,6 @@ class PyTorchModelEngine(ModelEngine):
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = maybe_graph.run(inputs)
# Note: To overlap the CPU and GPU computation as much as possible,
# guided_decoder.build should be called immediately after the launch of the single step;
# while guided_decoder.execute should be called right before the samplings.
# We can insert other CPU computation between them in the future.
if self.mapping.is_last_pp_rank(
) and self.guided_decoder is not None:
seq_slot_manager = resource_manager.get_resource_manager(
ResourceManagerType.SEQ_SLOT_MANAGER)
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
self.guided_decoder.execute(scheduled_requests,
outputs['logits'], seq_slot_manager)
self._execute_logit_post_processors(scheduled_requests, outputs)
return outputs

View File

@ -31,6 +31,7 @@ from tensorrt_llm.logger import logger
from ..distributed import Distributed
from ..speculative.drafter import Drafter
from .guided_decoder import GuidedDecoder
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, executor_request_to_llm_request)
@ -204,6 +205,7 @@ class PyExecutor:
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
draft_model_engine: Optional[ModelEngine] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True):
super(PyExecutor, self).__init__()
@ -225,6 +227,7 @@ class PyExecutor:
self.enable_attention_dp = model_engine.enable_attention_dp
self.sampler = sampler
self.drafter = drafter
self.guided_decoder = guided_decoder
self.dist = dist
self.disable_overlap_scheduler = disable_overlap_scheduler
@ -801,6 +804,12 @@ class PyExecutor:
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(
scheduled_batch, batch_outputs['logits'])
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
sample_state.host.logits = logits_host
@ -978,6 +987,11 @@ class PyExecutor:
batch_outputs = self._forward_step(scheduled_batch)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])
sample_state = self._sample_async(scheduled_batch,
batch_outputs)
@ -1126,6 +1140,14 @@ class PyExecutor:
batch_outputs = self._forward_step(scheduled_batch,
previous_tensors_device)
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])
sample_state = self._sample_async(scheduled_batch,
batch_outputs)
assert sample_state is not None, "Sampling failed"
@ -1159,8 +1181,6 @@ class PyExecutor:
self._terminate_ctx_finished_requests()
def _process_previous_batch(self):
self._update_requests(self.previous_batch.sample_state)
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
for req in self.previous_batch.ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS

View File

@ -24,6 +24,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import GuidedDecoder
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
@ -237,7 +238,6 @@ def create_py_executor(
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
guided_decoding_config=executor_config.guided_decoding_config,
lora_config=lora_config,
checkpoint_loader=executor_config.checkpoint_loader,
)
@ -344,6 +344,17 @@ def create_py_executor(
sampler = instantiate_sampler(model_engine, executor_config,
pytorch_backend_config, mapping)
guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if mapping.is_last_pp_rank():
guided_decoder = GuidedDecoder(
executor_config.guided_decoding_config,
executor_config.max_batch_size,
model_engine.model.vocab_size_padded)
resources = {}
estimating_kv_cache = False
kv_cache_creator = None
@ -388,6 +399,7 @@ def create_py_executor(
start_worker=False,
sampler=sampler,
drafter=drafter,
guided_decoder=guided_decoder,
lora_config=lora_config,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
)
@ -430,6 +442,7 @@ def create_py_executor(
start_worker=False,
sampler=sampler,
drafter=drafter,
guided_decoder=guided_decoder,
lora_config=lora_config,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold,

View File

@ -287,7 +287,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
llm = LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig())
with llm:
task = JsonModeEval(self.MODEL_NAME)
@ -300,7 +299,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
with LLM(self.MODEL_PATH,
guided_decoding_backend=backend,
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(),
tensor_parallel_size=2,
pipeline_parallel_size=2) as llm:

View File

@ -23,10 +23,7 @@ def temp_extra_llm_api_options_file(request):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
"guided_decoding_backend": "xgrammar",
"disable_overlap_scheduler": True,
}
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)