mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5974][feat] Support disaggregated serving in TRTLLM Sampler (#5328)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
c5ae3272b9
commit
205c97a4ae
87
tensorrt_llm/_torch/pyexecutor/finish_reason.py
Normal file
87
tensorrt_llm/_torch/pyexecutor/finish_reason.py
Normal file
@ -0,0 +1,87 @@
|
||||
from tensorrt_llm.bindings.executor import FinishReason
|
||||
|
||||
|
||||
class FinishedState:
|
||||
# State flags
|
||||
FINISHED_EOS = 1 << 0
|
||||
FINISHED_STOP_WORDS = 1 << 1
|
||||
FINISHED_MAX_LENGTH = 1 << 2
|
||||
FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH
|
||||
SKIP_DECODING = 1 << 3
|
||||
|
||||
def __init__(self, state=0):
|
||||
self._state = state
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return cls(0)
|
||||
|
||||
@classmethod
|
||||
def finished(cls):
|
||||
return cls(cls.FINISHED)
|
||||
|
||||
@classmethod
|
||||
def skip_decoding(cls):
|
||||
return cls(cls.SKIP_DECODING)
|
||||
|
||||
@classmethod
|
||||
def finished_eos(cls):
|
||||
return cls(cls.FINISHED_EOS)
|
||||
|
||||
@classmethod
|
||||
def finished_max_length(cls):
|
||||
return cls(cls.FINISHED_MAX_LENGTH)
|
||||
|
||||
@classmethod
|
||||
def finished_stop_words(cls):
|
||||
return cls(cls.FINISHED_STOP_WORDS)
|
||||
|
||||
def set_finished_eos(self):
|
||||
self._state |= self.FINISHED_EOS
|
||||
|
||||
@property
|
||||
def is_finished_eos(self):
|
||||
return self._any_bit_set(self.FINISHED_EOS)
|
||||
|
||||
def set_finished_stop_words(self):
|
||||
self._state |= self.FINISHED_STOP_WORDS
|
||||
|
||||
@property
|
||||
def is_finished_stop_words(self):
|
||||
return self._any_bit_set(self.FINISHED_STOP_WORDS)
|
||||
|
||||
def set_finished_max_length(self):
|
||||
self._state |= self.FINISHED_MAX_LENGTH
|
||||
|
||||
@property
|
||||
def is_finished_max_length(self):
|
||||
return self._any_bit_set(self.FINISHED_MAX_LENGTH)
|
||||
|
||||
def set_finished(self):
|
||||
self._state |= self.FINISHED
|
||||
|
||||
@property
|
||||
def is_finished(self):
|
||||
return self._any_bit_set(self.FINISHED)
|
||||
|
||||
def set_skip_decoding(self):
|
||||
self._state |= self.SKIP_DECODING
|
||||
|
||||
@property
|
||||
def is_skip_decoding(self):
|
||||
return self._any_bit_set(self.SKIP_DECODING)
|
||||
|
||||
def to_finish_reason(self):
|
||||
if self.is_finished_eos:
|
||||
return FinishReason.END_ID
|
||||
if self.is_finished_stop_words:
|
||||
return FinishReason.STOP_WORDS
|
||||
if self.is_finished_max_length:
|
||||
return FinishReason.LENGTH
|
||||
return FinishReason.NOT_FINISHED
|
||||
|
||||
def to_underlying(self):
|
||||
return self._state
|
||||
|
||||
def _any_bit_set(self, bits):
|
||||
return (self._state & bits) != 0
|
||||
@ -1535,6 +1535,9 @@ class PyExecutor:
|
||||
self.resource_manager.resource_managers[
|
||||
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
|
||||
disagg_gen_init_to_prepare)
|
||||
self.resource_manager.resource_managers[
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources(
|
||||
disagg_gen_init_to_prepare)
|
||||
|
||||
# Trigger KV cache exchange for new disagg_gen_init_requests
|
||||
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)
|
||||
|
||||
@ -23,6 +23,7 @@ from tensorrt_llm.bindings.internal.runtime import (BufferManager, DecoderState,
|
||||
from tensorrt_llm.executor.result import Logprob
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .finish_reason import FinishedState
|
||||
from .llm_request import LlmRequest, LlmRequestState
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
@ -648,6 +649,7 @@ class TRTLLMSampler(Sampler):
|
||||
for beam in range(beam_width):
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
|
||||
beam].item()
|
||||
seq_len = seq_len + 1 if self.is_trt_overlap else seq_len
|
||||
num_new_tokens[beam] = min(
|
||||
num_generated_tokens,
|
||||
seq_len - request.get_num_tokens(beam))
|
||||
@ -678,9 +680,10 @@ class TRTLLMSampler(Sampler):
|
||||
state.host.cum_log_probs[seq_slot * beam_width +
|
||||
beam].item())
|
||||
|
||||
finish_reason = finish_reasons_host[seq_slot * beam_width +
|
||||
beam].item()
|
||||
request.set_finished_reason(FinishReason(finish_reason), beam)
|
||||
finish_reason = FinishedState(
|
||||
finish_reasons_host[seq_slot * beam_width +
|
||||
beam].item()).to_finish_reason()
|
||||
request.set_finished_reason(finish_reason, beam)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
request.py_result.append_log_probs([log_probs], cum_log_probs)
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
enable_trtllm_sampler: True
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
enable_trtllm_sampler: True
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -50,6 +50,8 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
|
||||
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
|
||||
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
|
||||
"trtllm_sampler":
|
||||
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
|
||||
"load_balance":
|
||||
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
|
||||
"cache_aware_balance":
|
||||
@ -179,7 +181,7 @@ def run_disaggregated_test(example_dir,
|
||||
poll_procs=[workers_proc, server_proc])
|
||||
|
||||
# Run the chat completion endpoint test only for TinyLlama
|
||||
if test_desc == "overlap":
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
chat_client_cmd = client_cmd + [
|
||||
'-e', 'chat', '-o', 'output_chat.json'
|
||||
]
|
||||
@ -198,7 +200,7 @@ def run_disaggregated_test(example_dir,
|
||||
not_expected_strings = ["Berlin Berlin"]
|
||||
|
||||
output_files = ['output.json', 'output_streaming.json']
|
||||
if test_desc == "overlap":
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
# Disable streaming chat completion for overlap test
|
||||
# due to bug
|
||||
output_files.extend(['output_chat.json'])
|
||||
@ -420,6 +422,26 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"trtllm_sampler",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user