mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6406] feat: Enable guided decoding with overlap scheduler (#6000)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
44c70c88f9
commit
21efb50068
@ -54,8 +54,8 @@ void logitsBitmask(std::vector<torch::Tensor> const& logits, std::vector<torch::
|
|||||||
bitmaskPtrsHost[i] = reinterpret_cast<uint64_t>(bitmask[i].data_ptr());
|
bitmaskPtrsHost[i] = reinterpret_cast<uint64_t>(bitmask[i].data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA);
|
auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
|
||||||
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA);
|
auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA, /*non_blocking=*/true);
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream();
|
auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream();
|
||||||
|
|
||||||
|
|||||||
@ -15,4 +15,4 @@
|
|||||||
| KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | |
|
| KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | |
|
||||||
| Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
|
| Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
|
||||||
| Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | |
|
| Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | |
|
||||||
| Guided Decoding | No | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
|
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
|
||||||
|
|||||||
@ -7,12 +7,9 @@ from tensorrt_llm.llmapi import GuidedDecodingParams
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
# Specify the guided decoding backend; xgrammar is supported currently.
|
# Specify the guided decoding backend; xgrammar and llguidance are supported currently.
|
||||||
llm = LLM(
|
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
guided_decoding_backend='xgrammar')
|
||||||
guided_decoding_backend='xgrammar',
|
|
||||||
disable_overlap_scheduler=True # Not supported by xgrammar mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# An example from json-mode-eval
|
# An example from json-mode-eval
|
||||||
schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'
|
schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from ..model_config import ModelConfig
|
|||||||
from ..speculative import get_spec_decoder
|
from ..speculative import get_spec_decoder
|
||||||
from .config import PyTorchConfig
|
from .config import PyTorchConfig
|
||||||
from .config_utils import is_mla, is_nemotron_hybrid
|
from .config_utils import is_mla, is_nemotron_hybrid
|
||||||
|
from .guided_decoder import GuidedDecoder
|
||||||
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
|
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
|
||||||
from .llm_request import ExecutorResponse
|
from .llm_request import ExecutorResponse
|
||||||
from .model_engine import PyTorchModelEngine
|
from .model_engine import PyTorchModelEngine
|
||||||
@ -414,19 +415,12 @@ def create_py_executor_instance(
|
|||||||
start_worker,
|
start_worker,
|
||||||
sampler,
|
sampler,
|
||||||
drafter,
|
drafter,
|
||||||
|
guided_decoder: Optional[GuidedDecoder] = None,
|
||||||
lora_config: Optional[LoraConfig] = None,
|
lora_config: Optional[LoraConfig] = None,
|
||||||
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
|
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
|
||||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||||
|
|
||||||
spec_config = model_engine.spec_config
|
spec_config = model_engine.spec_config
|
||||||
if mapping.is_last_pp_rank(
|
|
||||||
) and executor_config.guided_decoding_config is not None:
|
|
||||||
if spec_config is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Guided decoding is not supported with speculative decoding.")
|
|
||||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
|
||||||
raise ValueError(
|
|
||||||
"Guided decoding is not supported with overlap scheduler.")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
|
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
|
||||||
@ -543,6 +537,7 @@ def create_py_executor_instance(
|
|||||||
if spec_config is not None else 0,
|
if spec_config is not None else 0,
|
||||||
kv_cache_transceiver=kv_cache_transceiver,
|
kv_cache_transceiver=kv_cache_transceiver,
|
||||||
draft_model_engine=draft_model_engine,
|
draft_model_engine=draft_model_engine,
|
||||||
|
guided_decoder=guided_decoder,
|
||||||
start_worker=start_worker,
|
start_worker=start_worker,
|
||||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
|
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,11 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..._utils import nvtx_range
|
||||||
from ...bindings.executor import GuidedDecodingConfig
|
from ...bindings.executor import GuidedDecodingConfig
|
||||||
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
|
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
|
||||||
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
|
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
|
||||||
from .scheduler import ScheduledRequests
|
from .scheduler import ScheduledRequests
|
||||||
from .seq_slot_manager import SeqSlotManager
|
|
||||||
|
|
||||||
|
|
||||||
class GuidedDecoder:
|
class GuidedDecoder:
|
||||||
@ -49,12 +49,12 @@ class GuidedDecoder:
|
|||||||
def bitmask_size(self) -> int:
|
def bitmask_size(self) -> int:
|
||||||
return math.ceil(self.vocab_size_padded / 32)
|
return math.ceil(self.vocab_size_padded / 32)
|
||||||
|
|
||||||
def build(self, scheduled_requests: ScheduledRequests,
|
@nvtx_range("GuidedDecoder.build")
|
||||||
resource_manager: SeqSlotManager) -> None:
|
def build(self, scheduled_requests: ScheduledRequests) -> None:
|
||||||
for llm_req in scheduled_requests.all_requests():
|
for llm_req in scheduled_requests.all_requests():
|
||||||
if llm_req.guided_decoding_params is None:
|
if llm_req.guided_decoding_params is None:
|
||||||
continue
|
continue
|
||||||
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
|
slot = llm_req.py_seq_slot
|
||||||
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
|
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
|
||||||
self.grammar_matchers[
|
self.grammar_matchers[
|
||||||
slot] = self.grammar_matcher_factory.create(
|
slot] = self.grammar_matcher_factory.create(
|
||||||
@ -75,8 +75,9 @@ class GuidedDecoder:
|
|||||||
self.bitmask[slot].copy_(self.bitmask_host[slot],
|
self.bitmask[slot].copy_(self.bitmask_host[slot],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
|
@nvtx_range("GuidedDecoder.execute")
|
||||||
def execute(self, scheduled_requests: ScheduledRequests,
|
def execute(self, scheduled_requests: ScheduledRequests,
|
||||||
logits: torch.Tensor, resource_manager: SeqSlotManager) -> None:
|
logits: torch.Tensor) -> None:
|
||||||
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
|
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
|
||||||
scheduled_requests.generation_requests)
|
scheduled_requests.generation_requests)
|
||||||
torch.cuda.current_stream().wait_stream(self._stream)
|
torch.cuda.current_stream().wait_stream(self._stream)
|
||||||
@ -88,8 +89,7 @@ class GuidedDecoder:
|
|||||||
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
|
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
|
||||||
continue
|
continue
|
||||||
batched_logits.append(logits[i])
|
batched_logits.append(logits[i])
|
||||||
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
|
batched_bitmask.append(self.bitmask[llm_req.py_seq_slot])
|
||||||
batched_bitmask.append(self.bitmask[slot])
|
|
||||||
|
|
||||||
if len(batched_logits) > 0:
|
if len(batched_logits) > 0:
|
||||||
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
|
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
|
|||||||
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
|
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
|
||||||
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
||||||
torch_dtype_to_str, trace_func)
|
torch_dtype_to_str, trace_func)
|
||||||
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
|
|
||||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
|
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
|
||||||
@ -53,7 +52,6 @@ from ..utils import (get_model_extra_attrs, set_torch_compiling,
|
|||||||
from .config import LoadFormat, PyTorchConfig
|
from .config import LoadFormat, PyTorchConfig
|
||||||
from .config_utils import is_mla
|
from .config_utils import is_mla
|
||||||
from .cuda_graph_runner import DecodingCUDAGraphRunner
|
from .cuda_graph_runner import DecodingCUDAGraphRunner
|
||||||
from .guided_decoder import GuidedDecoder
|
|
||||||
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
|
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
|
||||||
from .resource_manager import (BaseResourceManager, KVCacheManager,
|
from .resource_manager import (BaseResourceManager, KVCacheManager,
|
||||||
ResourceManager, ResourceManagerType)
|
ResourceManager, ResourceManagerType)
|
||||||
@ -258,7 +256,6 @@ class PyTorchModelEngine(ModelEngine):
|
|||||||
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
||||||
dist: Optional[MPIDist] = None,
|
dist: Optional[MPIDist] = None,
|
||||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||||
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
|
|
||||||
lora_config: Optional[LoraConfig] = None,
|
lora_config: Optional[LoraConfig] = None,
|
||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
):
|
):
|
||||||
@ -313,13 +310,6 @@ class PyTorchModelEngine(ModelEngine):
|
|||||||
self.dtype = self.model.config.torch_dtype
|
self.dtype = self.model.config.torch_dtype
|
||||||
self._init_model_capacity()
|
self._init_model_capacity()
|
||||||
|
|
||||||
self.guided_decoder: Optional[GuidedDecoder] = None
|
|
||||||
if self.mapping.is_last_pp_rank(
|
|
||||||
) and guided_decoding_config is not None:
|
|
||||||
self.guided_decoder = GuidedDecoder(guided_decoding_config,
|
|
||||||
self.batch_size,
|
|
||||||
self.model.vocab_size_padded)
|
|
||||||
|
|
||||||
self._torch_compile_backend = None
|
self._torch_compile_backend = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -2091,18 +2081,6 @@ class PyTorchModelEngine(ModelEngine):
|
|||||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||||
outputs = maybe_graph.run(inputs)
|
outputs = maybe_graph.run(inputs)
|
||||||
|
|
||||||
# Note: To overlap the CPU and GPU computation as much as possible,
|
|
||||||
# guided_decoder.build should be called immediately after the launch of the single step;
|
|
||||||
# while guided_decoder.execute should be called right before the samplings.
|
|
||||||
# We can insert other CPU computation between them in the future.
|
|
||||||
if self.mapping.is_last_pp_rank(
|
|
||||||
) and self.guided_decoder is not None:
|
|
||||||
seq_slot_manager = resource_manager.get_resource_manager(
|
|
||||||
ResourceManagerType.SEQ_SLOT_MANAGER)
|
|
||||||
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
|
|
||||||
self.guided_decoder.execute(scheduled_requests,
|
|
||||||
outputs['logits'], seq_slot_manager)
|
|
||||||
|
|
||||||
self._execute_logit_post_processors(scheduled_requests, outputs)
|
self._execute_logit_post_processors(scheduled_requests, outputs)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from tensorrt_llm.logger import logger
|
|||||||
|
|
||||||
from ..distributed import Distributed
|
from ..distributed import Distributed
|
||||||
from ..speculative.drafter import Drafter
|
from ..speculative.drafter import Drafter
|
||||||
|
from .guided_decoder import GuidedDecoder
|
||||||
from .kv_cache_transceiver import KvCacheTransceiver
|
from .kv_cache_transceiver import KvCacheTransceiver
|
||||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||||
LlmResponse, executor_request_to_llm_request)
|
LlmResponse, executor_request_to_llm_request)
|
||||||
@ -204,6 +205,7 @@ class PyExecutor:
|
|||||||
max_draft_len: int = 0,
|
max_draft_len: int = 0,
|
||||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||||
draft_model_engine: Optional[ModelEngine] = None,
|
draft_model_engine: Optional[ModelEngine] = None,
|
||||||
|
guided_decoder: Optional[GuidedDecoder] = None,
|
||||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||||
start_worker: bool = True):
|
start_worker: bool = True):
|
||||||
super(PyExecutor, self).__init__()
|
super(PyExecutor, self).__init__()
|
||||||
@ -225,6 +227,7 @@ class PyExecutor:
|
|||||||
self.enable_attention_dp = model_engine.enable_attention_dp
|
self.enable_attention_dp = model_engine.enable_attention_dp
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.drafter = drafter
|
self.drafter = drafter
|
||||||
|
self.guided_decoder = guided_decoder
|
||||||
self.dist = dist
|
self.dist = dist
|
||||||
self.disable_overlap_scheduler = disable_overlap_scheduler
|
self.disable_overlap_scheduler = disable_overlap_scheduler
|
||||||
|
|
||||||
@ -801,6 +804,12 @@ class PyExecutor:
|
|||||||
if self._need_return_logits(scheduled_batch):
|
if self._need_return_logits(scheduled_batch):
|
||||||
logits_host = batch_outputs["logits"].to(
|
logits_host = batch_outputs["logits"].to(
|
||||||
"cpu", non_blocking=True)
|
"cpu", non_blocking=True)
|
||||||
|
|
||||||
|
if self.guided_decoder is not None:
|
||||||
|
self.guided_decoder.build(scheduled_batch)
|
||||||
|
self.guided_decoder.execute(
|
||||||
|
scheduled_batch, batch_outputs['logits'])
|
||||||
|
|
||||||
sample_state = self._sample_async(
|
sample_state = self._sample_async(
|
||||||
scheduled_batch, batch_outputs)
|
scheduled_batch, batch_outputs)
|
||||||
sample_state.host.logits = logits_host
|
sample_state.host.logits = logits_host
|
||||||
@ -978,6 +987,11 @@ class PyExecutor:
|
|||||||
|
|
||||||
batch_outputs = self._forward_step(scheduled_batch)
|
batch_outputs = self._forward_step(scheduled_batch)
|
||||||
|
|
||||||
|
if self.guided_decoder is not None:
|
||||||
|
self.guided_decoder.build(scheduled_batch)
|
||||||
|
self.guided_decoder.execute(scheduled_batch,
|
||||||
|
batch_outputs['logits'])
|
||||||
|
|
||||||
sample_state = self._sample_async(scheduled_batch,
|
sample_state = self._sample_async(scheduled_batch,
|
||||||
batch_outputs)
|
batch_outputs)
|
||||||
|
|
||||||
@ -1126,6 +1140,14 @@ class PyExecutor:
|
|||||||
batch_outputs = self._forward_step(scheduled_batch,
|
batch_outputs = self._forward_step(scheduled_batch,
|
||||||
previous_tensors_device)
|
previous_tensors_device)
|
||||||
|
|
||||||
|
if self.previous_batch is not None:
|
||||||
|
self._update_requests(self.previous_batch.sample_state)
|
||||||
|
|
||||||
|
if self.guided_decoder is not None:
|
||||||
|
self.guided_decoder.build(scheduled_batch)
|
||||||
|
self.guided_decoder.execute(scheduled_batch,
|
||||||
|
batch_outputs['logits'])
|
||||||
|
|
||||||
sample_state = self._sample_async(scheduled_batch,
|
sample_state = self._sample_async(scheduled_batch,
|
||||||
batch_outputs)
|
batch_outputs)
|
||||||
assert sample_state is not None, "Sampling failed"
|
assert sample_state is not None, "Sampling failed"
|
||||||
@ -1159,8 +1181,6 @@ class PyExecutor:
|
|||||||
self._terminate_ctx_finished_requests()
|
self._terminate_ctx_finished_requests()
|
||||||
|
|
||||||
def _process_previous_batch(self):
|
def _process_previous_batch(self):
|
||||||
self._update_requests(self.previous_batch.sample_state)
|
|
||||||
|
|
||||||
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
|
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
|
||||||
for req in self.previous_batch.ctx_transmission_reqs:
|
for req in self.previous_batch.ctx_transmission_reqs:
|
||||||
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
|
|||||||
create_py_executor_instance, instantiate_sampler, is_mla)
|
create_py_executor_instance, instantiate_sampler, is_mla)
|
||||||
from .config import PyTorchConfig
|
from .config import PyTorchConfig
|
||||||
from .config_utils import is_mla
|
from .config_utils import is_mla
|
||||||
|
from .guided_decoder import GuidedDecoder
|
||||||
from .model_engine import PyTorchModelEngine
|
from .model_engine import PyTorchModelEngine
|
||||||
from .py_executor import PyExecutor
|
from .py_executor import PyExecutor
|
||||||
|
|
||||||
@ -237,7 +238,6 @@ def create_py_executor(
|
|||||||
attn_runtime_features=attn_runtime_features,
|
attn_runtime_features=attn_runtime_features,
|
||||||
dist=dist,
|
dist=dist,
|
||||||
spec_config=spec_config,
|
spec_config=spec_config,
|
||||||
guided_decoding_config=executor_config.guided_decoding_config,
|
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
checkpoint_loader=executor_config.checkpoint_loader,
|
checkpoint_loader=executor_config.checkpoint_loader,
|
||||||
)
|
)
|
||||||
@ -344,6 +344,17 @@ def create_py_executor(
|
|||||||
sampler = instantiate_sampler(model_engine, executor_config,
|
sampler = instantiate_sampler(model_engine, executor_config,
|
||||||
pytorch_backend_config, mapping)
|
pytorch_backend_config, mapping)
|
||||||
|
|
||||||
|
guided_decoder: Optional[GuidedDecoder] = None
|
||||||
|
if executor_config.guided_decoding_config is not None:
|
||||||
|
if spec_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Guided decoding is not supported with speculative decoding.")
|
||||||
|
if mapping.is_last_pp_rank():
|
||||||
|
guided_decoder = GuidedDecoder(
|
||||||
|
executor_config.guided_decoding_config,
|
||||||
|
executor_config.max_batch_size,
|
||||||
|
model_engine.model.vocab_size_padded)
|
||||||
|
|
||||||
resources = {}
|
resources = {}
|
||||||
estimating_kv_cache = False
|
estimating_kv_cache = False
|
||||||
kv_cache_creator = None
|
kv_cache_creator = None
|
||||||
@ -388,6 +399,7 @@ def create_py_executor(
|
|||||||
start_worker=False,
|
start_worker=False,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
drafter=drafter,
|
drafter=drafter,
|
||||||
|
guided_decoder=guided_decoder,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||||
)
|
)
|
||||||
@ -430,6 +442,7 @@ def create_py_executor(
|
|||||||
start_worker=False,
|
start_worker=False,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
drafter=drafter,
|
drafter=drafter,
|
||||||
|
guided_decoder=guided_decoder,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
garbage_collection_gen0_threshold=
|
garbage_collection_gen0_threshold=
|
||||||
garbage_collection_gen0_threshold,
|
garbage_collection_gen0_threshold,
|
||||||
|
|||||||
@ -287,7 +287,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
|||||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||||
llm = LLM(self.MODEL_PATH,
|
llm = LLM(self.MODEL_PATH,
|
||||||
guided_decoding_backend=backend,
|
guided_decoding_backend=backend,
|
||||||
disable_overlap_scheduler=True,
|
|
||||||
cuda_graph_config=CudaGraphConfig())
|
cuda_graph_config=CudaGraphConfig())
|
||||||
with llm:
|
with llm:
|
||||||
task = JsonModeEval(self.MODEL_NAME)
|
task = JsonModeEval(self.MODEL_NAME)
|
||||||
@ -300,7 +299,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
|||||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||||
with LLM(self.MODEL_PATH,
|
with LLM(self.MODEL_PATH,
|
||||||
guided_decoding_backend=backend,
|
guided_decoding_backend=backend,
|
||||||
disable_overlap_scheduler=True,
|
|
||||||
cuda_graph_config=CudaGraphConfig(),
|
cuda_graph_config=CudaGraphConfig(),
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
pipeline_parallel_size=2) as llm:
|
pipeline_parallel_size=2) as llm:
|
||||||
|
|||||||
@ -23,10 +23,7 @@ def temp_extra_llm_api_options_file(request):
|
|||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||||
try:
|
try:
|
||||||
extra_llm_api_options_dict = {
|
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
|
||||||
"guided_decoding_backend": "xgrammar",
|
|
||||||
"disable_overlap_scheduler": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(temp_file_path, 'w') as f:
|
with open(temp_file_path, 'w') as f:
|
||||||
yaml.dump(extra_llm_api_options_dict, f)
|
yaml.dump(extra_llm_api_options_dict, f)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user