diff --git a/docs/source/torch/features/feature_combination_matrix.md b/docs/source/torch/features/feature_combination_matrix.md index 6990c61e18..f25d4bc487 100644 --- a/docs/source/torch/features/feature_combination_matrix.md +++ b/docs/source/torch/features/feature_combination_matrix.md @@ -15,4 +15,4 @@ | KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | | | Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | | Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | | -| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- | +| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- | diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index fa95a0a7a1..664f1aaca8 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -39,7 +39,7 @@ class GuidedDecoder: guided_decoding_config, vocab_size_padded) else: raise ValueError( - f"invalid guided decoding backend: {self.guided_decoding_backend}" + f"Invalid guided decoding backend: {self.guided_decoding_backend}" ) logger.info( f"Guided decoder initialized with backend: {self.guided_decoding_backend}" @@ -71,7 +71,7 @@ class GuidedDecoder: def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) - def _is_matcher_init(self, llm_req: LlmRequest) -> bool: + def _require_matcher_init(self, llm_req: LlmRequest) -> bool: if llm_req.guided_decoding_params is None: return False if llm_req.py_is_draft: @@ -79,7 +79,7 @@ class GuidedDecoder: # The request is in the last chunk of a context forward step. return llm_req.is_context_init_state and llm_req.is_last_context_chunk - def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool: + def _require_matcher_advance(self, llm_req: LlmRequest) -> bool: if llm_req.guided_decoding_params is None: return False if llm_req.py_is_draft: @@ -102,12 +102,17 @@ class GuidedDecoder: self.num_advanced_tokens[slot] = 0 self.num_guided_tokens[slot] = 0 - if self._is_matcher_init(llm_req): + matcher_init: bool = self._require_matcher_init(llm_req) + matcher_advance: bool = self._require_matcher_advance(llm_req) + if not (matcher_init or matcher_advance): + continue + + if matcher_init: matcher = self.grammar_matcher_factory.create( llm_req.guided_decoding_params) self.grammar_matchers[slot] = matcher - elif self._is_matcher_in_progress(llm_req): + if matcher_advance: matcher = self.grammar_matchers[slot] # The last new token must be acceptable unless the matcher is terminated in a drafting loop. if llm_req.py_is_draft and (matcher.is_terminated() @@ -127,9 +132,6 @@ class GuidedDecoder: f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}." ) - else: - continue - self.num_advanced_tokens[slot] += 1 if not matcher.is_terminated(): matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0) @@ -244,3 +246,19 @@ class GuidedDecoder: # Reset the drafting states. self.num_advanced_draft_tokens[slot] = 0 self.is_draft_terminated[slot] = False + + @nvtx_range("GuidedDecoder.init_disagg_gen_requests") + def init_disagg_gen_requests(self, + scheduled_requests: ScheduledRequests) -> None: + """Initialize the grammar matchers for disagg gen requests. + """ + for llm_req in scheduled_requests.generation_requests: + if llm_req.guided_decoding_params is None: + continue + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1: + # The request is in the first generation forward step at the disagg gen instance. + self.grammar_matchers[ + slot] = self.grammar_matcher_factory.create( + llm_req.guided_decoding_params) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index bd72fcc53e..a068327b6d 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -329,7 +329,9 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.py_return_generation_logits = return_generation_logits self.py_return_logits_device_memory = return_logits_device_memory self.py_is_draft = is_draft + # The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive). self.py_seq_slot = seq_slot + # If the request is a draft request, target_seq_slot is the sequence slot ID of its target request. self.py_target_seq_slot = target_seq_slot # TODO: remove this when use DynamicDecodeOp in pytorch flow. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ff341ebcb7..af0e113e26 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -749,6 +749,9 @@ class PyExecutor: if self._need_return_logits(scheduled_batch): logits_host = batch_outputs["logits"].to( "cpu", non_blocking=True) + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) self._execute_guided_decoder( scheduled_batch, batch_outputs['logits']) @@ -939,6 +942,10 @@ class PyExecutor: self._handle_first_token_response(scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) + + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) if self.drafter is not None and self.use_spec_decode: if self.guided_decoder is not None: self.guided_decoder.rollback_rejected_tokens( @@ -1063,6 +1070,9 @@ class PyExecutor: if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) + if self.kv_cache_transceiver and self.guided_decoder: + self.guided_decoder.init_disagg_gen_requests( + scheduled_batch) self._execute_guided_decoder(scheduled_batch, batch_outputs['logits']) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index f35675ddd6..e0801302eb 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -4,6 +4,7 @@ # Please take a look at the existing test_llm_api_pytorch.py file for reference. import concurrent import contextlib +import json import os import tempfile import time @@ -19,12 +20,13 @@ import yaml from tensorrt_llm.executor.result import GenerationResultBase from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids, skip_pre_hopper) from ..trt_test_alternative import popen -from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness, - get_accuracy_task) +from .accuracy_core import (GSM8K, MMLU, JsonModeEval, + LlmapiAccuracyTestHarness, get_accuracy_task) class Result(GenerationResultBase): @@ -43,7 +45,7 @@ class Result(GenerationResultBase): return self -DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async']) +DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async']) class MyThreadPoolExecutor(ThreadPoolExecutor): @@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], def send_request(prompt: str, sampling_params: SamplingParams, streaming: bool): - response = client.completions.create( - model=model_name, - prompt=prompt, - stream=streaming, - **({ - "max_tokens": sampling_params.max_tokens, - "temperature": sampling_params.temperature, - "top_p": sampling_params.top_p, - "stop": sampling_params.stop, - "seed": sampling_params.seed - } if sampling_params else {})) + kwargs = {} + if sampling_params is not None: + kwargs.update(max_tokens=sampling_params.max_tokens, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + stop=sampling_params.stop, + seed=sampling_params.seed) + if (guided_decoding_params := + sampling_params.guided_decoding) is not None: + extra_body = {} + if (schema := guided_decoding_params.json) is not None: + extra_body.update(response_format={ + "type": "json", + "schema": json.loads(schema) + }) + elif guided_decoding_params.json_object: + extra_body.update( + response_format={"type": "json_object"}) + else: + # TODO: Support other guided decoding types + raise ValueError( + f"Unsupported guided decoding params: {guided_decoding_params}." + ) + kwargs.update(extra_body=extra_body) + + response = client.completions.create(model=model_name, + prompt=prompt, + stream=streaming, + **kwargs) result = Result(id=0, sampling_params=sampling_params, outputs=[ @@ -192,8 +212,10 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], thread_pool.futures.append(future) return future + tokenizer = load_hf_tokenizer(model_name) + try: - yield DuckLLM(args, generate_async) + yield DuckLLM(args, tokenizer, generate_async) finally: ctx_server.terminate() gen_server.terminate() @@ -394,6 +416,95 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + ctx_server_config = { + "disable_overlap_scheduler": True, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "default" + } + } + gen_server_config = { + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "default" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_eagle3(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + speculative_decoding_config = { + "decoding_type": "Eagle", + "max_draft_len": 3, + "speculative_model_dir": + f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", + "eagle3_one_model": False + } + + ctx_server_config = { + "disable_overlap_scheduler": True, + "speculative_config": speculative_decoding_config, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.8, + }, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "default" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "speculative_config": speculative_decoding_config, + "kv_cache_config": { + "free_gpu_memory_fraction": 0.8, + }, + "guided_decoding_backend": backend, + "cache_transceiver_config": { + "backend": "default" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)], ids=["tp1pp2", "tp2pp1", "tp2pp2"]) diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 3abdb1dbb5..4f707478e7 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True] @@ -520,6 +524,10 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 3a8e6aa9c9..ee9426dd75 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -42,6 +42,8 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]