mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-6406] feat: Enable guided decoding with overlap scheduler (#6000)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
44c70c88f9
commit
21efb50068
@ -54,8 +54,8 @@ void logitsBitmask(std::vector<torch::Tensor> const& logits, std::vector<torch::
|
||||
bitmaskPtrsHost[i] = reinterpret_cast<uint64_t>(bitmask[i].data_ptr());
|
||||
}
|
||||
|
||||
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA);
|
||||
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA);
|
||||
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
|
||||
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream();
|
||||
|
||||
|
||||
@ -15,4 +15,4 @@
|
||||
| KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | |
|
||||
| Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
|
||||
| Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | |
|
||||
| Guided Decoding | No | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
|
||||
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
|
||||
|
||||
@ -7,12 +7,9 @@ from tensorrt_llm.llmapi import GuidedDecodingParams
|
||||
|
||||
def main():
|
||||
|
||||
# Specify the guided decoding backend; xgrammar is supported currently.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
guided_decoding_backend='xgrammar',
|
||||
disable_overlap_scheduler=True # Not supported by xgrammar mode
|
||||
)
|
||||
# Specify the guided decoding backend; xgrammar and llguidance are supported currently.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
guided_decoding_backend='xgrammar')
|
||||
|
||||
# An example from json-mode-eval
|
||||
schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'
|
||||
|
||||
@ -21,6 +21,7 @@ from ..model_config import ModelConfig
|
||||
from ..speculative import get_spec_decoder
|
||||
from .config import PyTorchConfig
|
||||
from .config_utils import is_mla, is_nemotron_hybrid
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
|
||||
from .llm_request import ExecutorResponse
|
||||
from .model_engine import PyTorchModelEngine
|
||||
@ -414,19 +415,12 @@ def create_py_executor_instance(
|
||||
start_worker,
|
||||
sampler,
|
||||
drafter,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
spec_config = model_engine.spec_config
|
||||
if mapping.is_last_pp_rank(
|
||||
) and executor_config.guided_decoding_config is not None:
|
||||
if spec_config is not None:
|
||||
raise ValueError(
|
||||
"Guided decoding is not supported with speculative decoding.")
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
raise ValueError(
|
||||
"Guided decoding is not supported with overlap scheduler.")
|
||||
|
||||
logger.info(
|
||||
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
|
||||
@ -543,6 +537,7 @@ def create_py_executor_instance(
|
||||
if spec_config is not None else 0,
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
draft_model_engine=draft_model_engine,
|
||||
guided_decoder=guided_decoder,
|
||||
start_worker=start_worker,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
|
||||
|
||||
|
||||
@ -3,11 +3,11 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...bindings.executor import GuidedDecodingConfig
|
||||
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
|
||||
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
|
||||
from .scheduler import ScheduledRequests
|
||||
from .seq_slot_manager import SeqSlotManager
|
||||
|
||||
|
||||
class GuidedDecoder:
|
||||
@ -49,12 +49,12 @@ class GuidedDecoder:
|
||||
def bitmask_size(self) -> int:
|
||||
return math.ceil(self.vocab_size_padded / 32)
|
||||
|
||||
def build(self, scheduled_requests: ScheduledRequests,
|
||||
resource_manager: SeqSlotManager) -> None:
|
||||
@nvtx_range("GuidedDecoder.build")
|
||||
def build(self, scheduled_requests: ScheduledRequests) -> None:
|
||||
for llm_req in scheduled_requests.all_requests():
|
||||
if llm_req.guided_decoding_params is None:
|
||||
continue
|
||||
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
|
||||
slot = llm_req.py_seq_slot
|
||||
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
|
||||
self.grammar_matchers[
|
||||
slot] = self.grammar_matcher_factory.create(
|
||||
@ -75,8 +75,9 @@ class GuidedDecoder:
|
||||
self.bitmask[slot].copy_(self.bitmask_host[slot],
|
||||
non_blocking=True)
|
||||
|
||||
@nvtx_range("GuidedDecoder.execute")
|
||||
def execute(self, scheduled_requests: ScheduledRequests,
|
||||
logits: torch.Tensor, resource_manager: SeqSlotManager) -> None:
|
||||
logits: torch.Tensor) -> None:
|
||||
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
|
||||
scheduled_requests.generation_requests)
|
||||
torch.cuda.current_stream().wait_stream(self._stream)
|
||||
@ -88,8 +89,7 @@ class GuidedDecoder:
|
||||
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
|
||||
continue
|
||||
batched_logits.append(logits[i])
|
||||
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
|
||||
batched_bitmask.append(self.bitmask[slot])
|
||||
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])
|
||||
|
||||
if len(batched_logits) > 0:
|
||||
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
|
||||
|
||||
@ -21,7 +21,6 @@ from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
|
||||
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
|
||||
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
||||
torch_dtype_to_str, trace_func)
|
||||
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
|
||||
@ -53,7 +52,6 @@ from ..utils import (get_model_extra_attrs, set_torch_compiling,
|
||||
from .config import LoadFormat, PyTorchConfig
|
||||
from .config_utils import is_mla
|
||||
from .cuda_graph_runner import DecodingCUDAGraphRunner
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
|
||||
from .resource_manager import (BaseResourceManager, KVCacheManager,
|
||||
ResourceManager, ResourceManagerType)
|
||||
@ -258,7 +256,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
||||
dist: Optional[MPIDist] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
is_draft_model: bool = False,
|
||||
):
|
||||
@ -313,13 +310,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.dtype = self.model.config.torch_dtype
|
||||
self._init_model_capacity()
|
||||
|
||||
self.guided_decoder: Optional[GuidedDecoder] = None
|
||||
if self.mapping.is_last_pp_rank(
|
||||
) and guided_decoding_config is not None:
|
||||
self.guided_decoder = GuidedDecoder(guided_decoding_config,
|
||||
self.batch_size,
|
||||
self.model.vocab_size_padded)
|
||||
|
||||
self._torch_compile_backend = None
|
||||
|
||||
try:
|
||||
@ -2091,18 +2081,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
outputs = maybe_graph.run(inputs)
|
||||
|
||||
# Note: To overlap the CPU and GPU computation as much as possible,
|
||||
# guided_decoder.build should be called immediately after the launch of the single step;
|
||||
# while guided_decoder.execute should be called right before the samplings.
|
||||
# We can insert other CPU computation between them in the future.
|
||||
if self.mapping.is_last_pp_rank(
|
||||
) and self.guided_decoder is not None:
|
||||
seq_slot_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER)
|
||||
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
|
||||
self.guided_decoder.execute(scheduled_requests,
|
||||
outputs['logits'], seq_slot_manager)
|
||||
|
||||
self._execute_logit_post_processors(scheduled_requests, outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -31,6 +31,7 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
from ..distributed import Distributed
|
||||
from ..speculative.drafter import Drafter
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse, executor_request_to_llm_request)
|
||||
@ -204,6 +205,7 @@ class PyExecutor:
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
draft_model_engine: Optional[ModelEngine] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True):
|
||||
super(PyExecutor, self).__init__()
|
||||
@ -225,6 +227,7 @@ class PyExecutor:
|
||||
self.enable_attention_dp = model_engine.enable_attention_dp
|
||||
self.sampler = sampler
|
||||
self.drafter = drafter
|
||||
self.guided_decoder = guided_decoder
|
||||
self.dist = dist
|
||||
self.disable_overlap_scheduler = disable_overlap_scheduler
|
||||
|
||||
@ -801,6 +804,12 @@ class PyExecutor:
|
||||
if self._need_return_logits(scheduled_batch):
|
||||
logits_host = batch_outputs["logits"].to(
|
||||
"cpu", non_blocking=True)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(
|
||||
scheduled_batch, batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
sample_state.host.logits = logits_host
|
||||
@ -978,6 +987,11 @@ class PyExecutor:
|
||||
|
||||
batch_outputs = self._forward_step(scheduled_batch)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(scheduled_batch,
|
||||
batch_outputs)
|
||||
|
||||
@ -1126,6 +1140,14 @@ class PyExecutor:
|
||||
batch_outputs = self._forward_step(scheduled_batch,
|
||||
previous_tensors_device)
|
||||
|
||||
if self.previous_batch is not None:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(scheduled_batch,
|
||||
batch_outputs)
|
||||
assert sample_state is not None, "Sampling failed"
|
||||
@ -1159,8 +1181,6 @@ class PyExecutor:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
def _process_previous_batch(self):
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
|
||||
for req in self.previous_batch.ctx_transmission_reqs:
|
||||
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
|
||||
@ -24,6 +24,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
|
||||
create_py_executor_instance, instantiate_sampler, is_mla)
|
||||
from .config import PyTorchConfig
|
||||
from .config_utils import is_mla
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .model_engine import PyTorchModelEngine
|
||||
from .py_executor import PyExecutor
|
||||
|
||||
@ -237,7 +238,6 @@ def create_py_executor(
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
dist=dist,
|
||||
spec_config=spec_config,
|
||||
guided_decoding_config=executor_config.guided_decoding_config,
|
||||
lora_config=lora_config,
|
||||
checkpoint_loader=executor_config.checkpoint_loader,
|
||||
)
|
||||
@ -344,6 +344,17 @@ def create_py_executor(
|
||||
sampler = instantiate_sampler(model_engine, executor_config,
|
||||
pytorch_backend_config, mapping)
|
||||
|
||||
guided_decoder: Optional[GuidedDecoder] = None
|
||||
if executor_config.guided_decoding_config is not None:
|
||||
if spec_config is not None:
|
||||
raise ValueError(
|
||||
"Guided decoding is not supported with speculative decoding.")
|
||||
if mapping.is_last_pp_rank():
|
||||
guided_decoder = GuidedDecoder(
|
||||
executor_config.guided_decoding_config,
|
||||
executor_config.max_batch_size,
|
||||
model_engine.model.vocab_size_padded)
|
||||
|
||||
resources = {}
|
||||
estimating_kv_cache = False
|
||||
kv_cache_creator = None
|
||||
@ -388,6 +399,7 @@ def create_py_executor(
|
||||
start_worker=False,
|
||||
sampler=sampler,
|
||||
drafter=drafter,
|
||||
guided_decoder=guided_decoder,
|
||||
lora_config=lora_config,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
)
|
||||
@ -430,6 +442,7 @@ def create_py_executor(
|
||||
start_worker=False,
|
||||
sampler=sampler,
|
||||
drafter=drafter,
|
||||
guided_decoder=guided_decoder,
|
||||
lora_config=lora_config,
|
||||
garbage_collection_gen0_threshold=
|
||||
garbage_collection_gen0_threshold,
|
||||
|
||||
@ -287,7 +287,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
guided_decoding_backend=backend,
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig())
|
||||
with llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
@ -300,7 +299,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||
with LLM(self.MODEL_PATH,
|
||||
guided_decoding_backend=backend,
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig(),
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2) as llm:
|
||||
|
||||
@ -23,10 +23,7 @@ def temp_extra_llm_api_options_file(request):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
"guided_decoding_backend": "xgrammar",
|
||||
"disable_overlap_scheduler": True,
|
||||
}
|
||||
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user