[TRTLLM-6854][feat] Enable guided decoding with disagg serving (#6704)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-08-08 12:10:36 +08:00 committed by GitHub
parent 1cf669496a
commit aee828d98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 175 additions and 24 deletions

View File

@ -15,4 +15,4 @@
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |

View File

@ -39,7 +39,7 @@ class GuidedDecoder:
guided_decoding_config, vocab_size_padded)
else:
raise ValueError(
f"invalid guided decoding backend: {self.guided_decoding_backend}"
f"Invalid guided decoding backend: {self.guided_decoding_backend}"
)
logger.info(
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
@ -71,7 +71,7 @@ class GuidedDecoder:
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)
def _is_matcher_init(self, llm_req: LlmRequest) -> bool:
def _require_matcher_init(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
@ -79,7 +79,7 @@ class GuidedDecoder:
# The request is in the last chunk of a context forward step.
return llm_req.is_context_init_state and llm_req.is_last_context_chunk
def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool:
def _require_matcher_advance(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
@ -102,12 +102,17 @@ class GuidedDecoder:
self.num_advanced_tokens[slot] = 0
self.num_guided_tokens[slot] = 0
if self._is_matcher_init(llm_req):
matcher_init: bool = self._require_matcher_init(llm_req)
matcher_advance: bool = self._require_matcher_advance(llm_req)
if not (matcher_init or matcher_advance):
continue
if matcher_init:
matcher = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)
self.grammar_matchers[slot] = matcher
elif self._is_matcher_in_progress(llm_req):
if matcher_advance:
matcher = self.grammar_matchers[slot]
# The last new token must be acceptable unless the matcher is terminated in a drafting loop.
if llm_req.py_is_draft and (matcher.is_terminated()
@ -127,9 +132,6 @@ class GuidedDecoder:
f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
)
else:
continue
self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
@ -244,3 +246,19 @@ class GuidedDecoder:
# Reset the drafting states.
self.num_advanced_draft_tokens[slot] = 0
self.is_draft_terminated[slot] = False
@nvtx_range("GuidedDecoder.init_disagg_gen_requests")
def init_disagg_gen_requests(self,
scheduled_requests: ScheduledRequests) -> None:
"""Initialize the grammar matchers for disagg gen requests.
"""
for llm_req in scheduled_requests.generation_requests:
if llm_req.guided_decoding_params is None:
continue
assert not llm_req.py_is_draft
slot: int = llm_req.py_seq_slot
if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1:
# The request is in the first generation forward step at the disagg gen instance.
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)

View File

@ -329,7 +329,9 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft
# The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive).
self.py_seq_slot = seq_slot
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
self.py_target_seq_slot = target_seq_slot
# TODO: remove this when use DynamicDecodeOp in pytorch flow.

View File

@ -749,6 +749,9 @@ class PyExecutor:
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
self._execute_guided_decoder(
scheduled_batch, batch_outputs['logits'])
@ -939,6 +942,10 @@ class PyExecutor:
self._handle_first_token_response(scheduled_batch)
self.resource_manager.prepare_resources(scheduled_batch)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
if self.drafter is not None and self.use_spec_decode:
if self.guided_decoder is not None:
self.guided_decoder.rollback_rejected_tokens(
@ -1063,6 +1070,9 @@ class PyExecutor:
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
self._execute_guided_decoder(scheduled_batch,
batch_outputs['logits'])

View File

@ -4,6 +4,7 @@
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
import concurrent
import contextlib
import json
import os
import tempfile
import time
@ -19,12 +20,13 @@ import yaml
from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
LlmapiAccuracyTestHarness, get_accuracy_task)
class Result(GenerationResultBase):
@ -43,7 +45,7 @@ class Result(GenerationResultBase):
return self
DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async'])
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
class MyThreadPoolExecutor(ThreadPoolExecutor):
@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
def send_request(prompt: str, sampling_params: SamplingParams,
streaming: bool):
response = client.completions.create(
model=model_name,
prompt=prompt,
stream=streaming,
**({
"max_tokens": sampling_params.max_tokens,
"temperature": sampling_params.temperature,
"top_p": sampling_params.top_p,
"stop": sampling_params.stop,
"seed": sampling_params.seed
} if sampling_params else {}))
kwargs = {}
if sampling_params is not None:
kwargs.update(max_tokens=sampling_params.max_tokens,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
stop=sampling_params.stop,
seed=sampling_params.seed)
if (guided_decoding_params :=
sampling_params.guided_decoding) is not None:
extra_body = {}
if (schema := guided_decoding_params.json) is not None:
extra_body.update(response_format={
"type": "json",
"schema": json.loads(schema)
})
elif guided_decoding_params.json_object:
extra_body.update(
response_format={"type": "json_object"})
else:
# TODO: Support other guided decoding types
raise ValueError(
f"Unsupported guided decoding params: {guided_decoding_params}."
)
kwargs.update(extra_body=extra_body)
response = client.completions.create(model=model_name,
prompt=prompt,
stream=streaming,
**kwargs)
result = Result(id=0,
sampling_params=sampling_params,
outputs=[
@ -192,8 +212,10 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
thread_pool.futures.append(future)
return future
tokenizer = load_hf_tokenizer(model_name)
try:
yield DuckLLM(args, generate_async)
yield DuckLLM(args, tokenizer, generate_async)
finally:
ctx_server.terminate()
gen_server.terminate()
@ -394,6 +416,95 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
ctx_server_config = {
"disable_overlap_scheduler": True,
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 3,
"speculative_model_dir":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": False
}
ctx_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])

View File

@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
@ -520,6 +524,10 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]

View File

@ -42,6 +42,8 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]