[TRTLLM-5974][feat] Support disaggregated serving in TRTLLM Sampler (#5328)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-06-25 17:41:36 +02:00 committed by GitHub
parent c5ae3272b9
commit 205c97a4ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 5 deletions

View File

@ -0,0 +1,87 @@
from tensorrt_llm.bindings.executor import FinishReason
class FinishedState:
# State flags
FINISHED_EOS = 1 << 0
FINISHED_STOP_WORDS = 1 << 1
FINISHED_MAX_LENGTH = 1 << 2
FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH
SKIP_DECODING = 1 << 3
def __init__(self, state=0):
self._state = state
@classmethod
def empty(cls):
return cls(0)
@classmethod
def finished(cls):
return cls(cls.FINISHED)
@classmethod
def skip_decoding(cls):
return cls(cls.SKIP_DECODING)
@classmethod
def finished_eos(cls):
return cls(cls.FINISHED_EOS)
@classmethod
def finished_max_length(cls):
return cls(cls.FINISHED_MAX_LENGTH)
@classmethod
def finished_stop_words(cls):
return cls(cls.FINISHED_STOP_WORDS)
def set_finished_eos(self):
self._state |= self.FINISHED_EOS
@property
def is_finished_eos(self):
return self._any_bit_set(self.FINISHED_EOS)
def set_finished_stop_words(self):
self._state |= self.FINISHED_STOP_WORDS
@property
def is_finished_stop_words(self):
return self._any_bit_set(self.FINISHED_STOP_WORDS)
def set_finished_max_length(self):
self._state |= self.FINISHED_MAX_LENGTH
@property
def is_finished_max_length(self):
return self._any_bit_set(self.FINISHED_MAX_LENGTH)
def set_finished(self):
self._state |= self.FINISHED
@property
def is_finished(self):
return self._any_bit_set(self.FINISHED)
def set_skip_decoding(self):
self._state |= self.SKIP_DECODING
@property
def is_skip_decoding(self):
return self._any_bit_set(self.SKIP_DECODING)
def to_finish_reason(self):
if self.is_finished_eos:
return FinishReason.END_ID
if self.is_finished_stop_words:
return FinishReason.STOP_WORDS
if self.is_finished_max_length:
return FinishReason.LENGTH
return FinishReason.NOT_FINISHED
def to_underlying(self):
return self._state
def _any_bit_set(self, bits):
return (self._state & bits) != 0

View File

@ -1535,6 +1535,9 @@ class PyExecutor:
self.resource_manager.resource_managers[
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
self.resource_manager.resource_managers[
ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
# Trigger KV cache exchange for new disagg_gen_init_requests
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)

View File

@ -23,6 +23,7 @@ from tensorrt_llm.bindings.internal.runtime import (BufferManager, DecoderState,
from tensorrt_llm.executor.result import Logprob
from tensorrt_llm.mapping import Mapping
from .finish_reason import FinishedState
from .llm_request import LlmRequest, LlmRequestState
from .scheduler import ScheduledRequests
@ -648,6 +649,7 @@ class TRTLLMSampler(Sampler):
for beam in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
beam].item()
seq_len = seq_len + 1 if self.is_trt_overlap else seq_len
num_new_tokens[beam] = min(
num_generated_tokens,
seq_len - request.get_num_tokens(beam))
@ -678,9 +680,10 @@ class TRTLLMSampler(Sampler):
state.host.cum_log_probs[seq_slot * beam_width +
beam].item())
finish_reason = finish_reasons_host[seq_slot * beam_width +
beam].item()
request.set_finished_reason(FinishReason(finish_reason), beam)
finish_reason = FinishedState(
finish_reasons_host[seq_slot * beam_width +
beam].item()).to_finish_reason()
request.set_finished_reason(finish_reason, beam)
if request.py_return_log_probs:
request.py_result.append_log_probs([log_probs], cum_log_probs)

View File

@ -0,0 +1,35 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 1
enable_trtllm_sampler: True
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
use_cuda_graph: False
disable_overlap_scheduler: True
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
enable_trtllm_sampler: True
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
use_cuda_graph: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"

View File

@ -50,6 +50,8 @@ def get_test_config(test_desc, example_dir, test_root):
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
"trtllm_sampler":
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
"load_balance":
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
"cache_aware_balance":
@ -179,7 +181,7 @@ def run_disaggregated_test(example_dir,
poll_procs=[workers_proc, server_proc])
# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
@ -198,7 +200,7 @@ def run_disaggregated_test(example_dir,
not_expected_strings = ["Berlin Berlin"]
output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
@ -420,6 +422,26 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,
cwd=llm_venv.get_working_directory())
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"trtllm_sampler",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,