mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6854][feat] Enable guided decoding with disagg serving (#6704)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
1cf669496a
commit
aee828d98a
@ -15,4 +15,4 @@
|
||||
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
|
||||
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
|
||||
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
|
||||
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
|
||||
| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
|
||||
|
||||
@ -39,7 +39,7 @@ class GuidedDecoder:
|
||||
guided_decoding_config, vocab_size_padded)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"invalid guided decoding backend: {self.guided_decoding_backend}"
|
||||
f"Invalid guided decoding backend: {self.guided_decoding_backend}"
|
||||
)
|
||||
logger.info(
|
||||
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
|
||||
@ -71,7 +71,7 @@ class GuidedDecoder:
|
||||
def bitmask_size(self) -> int:
|
||||
return math.ceil(self.vocab_size_padded / 32)
|
||||
|
||||
def _is_matcher_init(self, llm_req: LlmRequest) -> bool:
|
||||
def _require_matcher_init(self, llm_req: LlmRequest) -> bool:
|
||||
if llm_req.guided_decoding_params is None:
|
||||
return False
|
||||
if llm_req.py_is_draft:
|
||||
@ -79,7 +79,7 @@ class GuidedDecoder:
|
||||
# The request is in the last chunk of a context forward step.
|
||||
return llm_req.is_context_init_state and llm_req.is_last_context_chunk
|
||||
|
||||
def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool:
|
||||
def _require_matcher_advance(self, llm_req: LlmRequest) -> bool:
|
||||
if llm_req.guided_decoding_params is None:
|
||||
return False
|
||||
if llm_req.py_is_draft:
|
||||
@ -102,12 +102,17 @@ class GuidedDecoder:
|
||||
self.num_advanced_tokens[slot] = 0
|
||||
self.num_guided_tokens[slot] = 0
|
||||
|
||||
if self._is_matcher_init(llm_req):
|
||||
matcher_init: bool = self._require_matcher_init(llm_req)
|
||||
matcher_advance: bool = self._require_matcher_advance(llm_req)
|
||||
if not (matcher_init or matcher_advance):
|
||||
continue
|
||||
|
||||
if matcher_init:
|
||||
matcher = self.grammar_matcher_factory.create(
|
||||
llm_req.guided_decoding_params)
|
||||
self.grammar_matchers[slot] = matcher
|
||||
|
||||
elif self._is_matcher_in_progress(llm_req):
|
||||
if matcher_advance:
|
||||
matcher = self.grammar_matchers[slot]
|
||||
# The last new token must be acceptable unless the matcher is terminated in a drafting loop.
|
||||
if llm_req.py_is_draft and (matcher.is_terminated()
|
||||
@ -127,9 +132,6 @@ class GuidedDecoder:
|
||||
f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
|
||||
)
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
self.num_advanced_tokens[slot] += 1
|
||||
if not matcher.is_terminated():
|
||||
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
|
||||
@ -244,3 +246,19 @@ class GuidedDecoder:
|
||||
# Reset the drafting states.
|
||||
self.num_advanced_draft_tokens[slot] = 0
|
||||
self.is_draft_terminated[slot] = False
|
||||
|
||||
@nvtx_range("GuidedDecoder.init_disagg_gen_requests")
|
||||
def init_disagg_gen_requests(self,
|
||||
scheduled_requests: ScheduledRequests) -> None:
|
||||
"""Initialize the grammar matchers for disagg gen requests.
|
||||
"""
|
||||
for llm_req in scheduled_requests.generation_requests:
|
||||
if llm_req.guided_decoding_params is None:
|
||||
continue
|
||||
assert not llm_req.py_is_draft
|
||||
slot: int = llm_req.py_seq_slot
|
||||
if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1:
|
||||
# The request is in the first generation forward step at the disagg gen instance.
|
||||
self.grammar_matchers[
|
||||
slot] = self.grammar_matcher_factory.create(
|
||||
llm_req.guided_decoding_params)
|
||||
|
||||
@ -329,7 +329,9 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_return_generation_logits = return_generation_logits
|
||||
self.py_return_logits_device_memory = return_logits_device_memory
|
||||
self.py_is_draft = is_draft
|
||||
# The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive).
|
||||
self.py_seq_slot = seq_slot
|
||||
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
|
||||
self.py_target_seq_slot = target_seq_slot
|
||||
|
||||
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
|
||||
|
||||
@ -749,6 +749,9 @@ class PyExecutor:
|
||||
if self._need_return_logits(scheduled_batch):
|
||||
logits_host = batch_outputs["logits"].to(
|
||||
"cpu", non_blocking=True)
|
||||
if self.kv_cache_transceiver and self.guided_decoder:
|
||||
self.guided_decoder.init_disagg_gen_requests(
|
||||
scheduled_batch)
|
||||
self._execute_guided_decoder(
|
||||
scheduled_batch, batch_outputs['logits'])
|
||||
|
||||
@ -939,6 +942,10 @@ class PyExecutor:
|
||||
self._handle_first_token_response(scheduled_batch)
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
if self.kv_cache_transceiver and self.guided_decoder:
|
||||
self.guided_decoder.init_disagg_gen_requests(
|
||||
scheduled_batch)
|
||||
if self.drafter is not None and self.use_spec_decode:
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.rollback_rejected_tokens(
|
||||
@ -1063,6 +1070,9 @@ class PyExecutor:
|
||||
if self.previous_batch is not None:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.kv_cache_transceiver and self.guided_decoder:
|
||||
self.guided_decoder.init_disagg_gen_requests(
|
||||
scheduled_batch)
|
||||
self._execute_guided_decoder(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
|
||||
import concurrent
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@ -19,12 +20,13 @@ import yaml
|
||||
from tensorrt_llm.executor.result import GenerationResultBase
|
||||
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_args import LlmArgs
|
||||
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
|
||||
|
||||
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
|
||||
skip_pre_hopper)
|
||||
from ..trt_test_alternative import popen
|
||||
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
|
||||
get_accuracy_task)
|
||||
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
|
||||
LlmapiAccuracyTestHarness, get_accuracy_task)
|
||||
|
||||
|
||||
class Result(GenerationResultBase):
|
||||
@ -43,7 +45,7 @@ class Result(GenerationResultBase):
|
||||
return self
|
||||
|
||||
|
||||
DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async'])
|
||||
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
|
||||
|
||||
|
||||
class MyThreadPoolExecutor(ThreadPoolExecutor):
|
||||
@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
|
||||
|
||||
def send_request(prompt: str, sampling_params: SamplingParams,
|
||||
streaming: bool):
|
||||
response = client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
stream=streaming,
|
||||
**({
|
||||
"max_tokens": sampling_params.max_tokens,
|
||||
"temperature": sampling_params.temperature,
|
||||
"top_p": sampling_params.top_p,
|
||||
"stop": sampling_params.stop,
|
||||
"seed": sampling_params.seed
|
||||
} if sampling_params else {}))
|
||||
kwargs = {}
|
||||
if sampling_params is not None:
|
||||
kwargs.update(max_tokens=sampling_params.max_tokens,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
stop=sampling_params.stop,
|
||||
seed=sampling_params.seed)
|
||||
if (guided_decoding_params :=
|
||||
sampling_params.guided_decoding) is not None:
|
||||
extra_body = {}
|
||||
if (schema := guided_decoding_params.json) is not None:
|
||||
extra_body.update(response_format={
|
||||
"type": "json",
|
||||
"schema": json.loads(schema)
|
||||
})
|
||||
elif guided_decoding_params.json_object:
|
||||
extra_body.update(
|
||||
response_format={"type": "json_object"})
|
||||
else:
|
||||
# TODO: Support other guided decoding types
|
||||
raise ValueError(
|
||||
f"Unsupported guided decoding params: {guided_decoding_params}."
|
||||
)
|
||||
kwargs.update(extra_body=extra_body)
|
||||
|
||||
response = client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
stream=streaming,
|
||||
**kwargs)
|
||||
result = Result(id=0,
|
||||
sampling_params=sampling_params,
|
||||
outputs=[
|
||||
@ -192,8 +212,10 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
|
||||
thread_pool.futures.append(future)
|
||||
return future
|
||||
|
||||
tokenizer = load_hf_tokenizer(model_name)
|
||||
|
||||
try:
|
||||
yield DuckLLM(args, generate_async)
|
||||
yield DuckLLM(args, tokenizer, generate_async)
|
||||
finally:
|
||||
ctx_server.terminate()
|
||||
gen_server.terminate()
|
||||
@ -394,6 +416,95 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
|
||||
def test_guided_decoding(self, backend: str, mocker):
|
||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||
ctx_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"guided_decoding_backend": backend,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"guided_decoding_backend": backend,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
|
||||
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
|
||||
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
|
||||
speculative_decoding_config = {
|
||||
"decoding_type": "Eagle",
|
||||
"max_draft_len": 3,
|
||||
"speculative_model_dir":
|
||||
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
"eagle3_one_model": False
|
||||
}
|
||||
|
||||
ctx_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"speculative_config": speculative_decoding_config,
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.8,
|
||||
},
|
||||
"guided_decoding_backend": backend,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"disable_overlap_scheduler": True,
|
||||
"speculative_config": speculative_decoding_config,
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.8,
|
||||
},
|
||||
"guided_decoding_backend": backend,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
|
||||
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
|
||||
|
||||
@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
|
||||
@ -520,6 +524,10 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]
|
||||
|
||||
@ -42,6 +42,8 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user