mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
add disagg_request_id in OpenAI protocol
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
207cce4ba5
commit
ce44fabe8e
@ -582,7 +582,7 @@ void initRequestBindings(nb::module_& m)
|
||||
};
|
||||
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
|
||||
{
|
||||
if (state.size() != 34)
|
||||
if (state.size() != 35)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ public:
|
||||
allottedTimeMs, //
|
||||
contextPhaseParams, //
|
||||
cacheSaltID, //
|
||||
arrivalTime, //
|
||||
arrivalTime //
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
@ -531,11 +531,11 @@ void initRequestBindings(pybind11::module_& m)
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
|
||||
};
|
||||
auto requestSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 34)
|
||||
if (state.size() != 35)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
@ -640,8 +640,9 @@ void initRequestBindings(pybind11::module_& m)
|
||||
py::arg("guided_decoding_params") = py::none(),
|
||||
py::arg("language_adapter_uid") = py::none(),
|
||||
py::arg("allotted_time_ms") = py::none(),
|
||||
py::arg("cache_salt_id") = py::none()
|
||||
) // clang-format on
|
||||
py::arg("cache_salt_id") = py::none(),
|
||||
py::arg("disagg_request_id") = py::none()
|
||||
) // clang-format on
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -688,6 +689,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_property(
|
||||
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
|
||||
.def(py::pickle(requestGetstate, requestSetstate));
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
|
||||
@ -347,7 +347,7 @@ class DemoGenerationExecutor(GenerationExecutor):
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
# set request id if necessary
|
||||
client_id = self._get_next_client_id()
|
||||
client_id = request.id if request.id is not None else self._get_next_client_id()
|
||||
if request.id is None:
|
||||
request.set_id(client_id)
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ class DisaggregatedParams:
|
||||
ctx_request_id (int): The context request id
|
||||
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
|
||||
draft_tokens (List[int]): The draft tokens of the generation request
|
||||
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
|
||||
as underlying request id.
|
||||
|
||||
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
|
||||
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
|
||||
|
||||
@ -147,7 +147,7 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
) -> UCompletionRequest:
|
||||
ctx_request = copy.deepcopy(request)
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only", ctx_request_id=disagg_request_id
|
||||
request_type="context_only", disagg_request_id=disagg_request_id
|
||||
)
|
||||
ctx_request.stream = False
|
||||
ctx_request.stream_options = None
|
||||
|
||||
@ -37,7 +37,6 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
|
||||
from tensorrt_llm.llmapi.disagg_utils import MIN_GLOBAL_ID
|
||||
|
||||
|
||||
def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
|
||||
@ -118,6 +117,7 @@ class DisaggregatedParams(OpenAIBaseModel):
|
||||
ctx_request_id: Optional[int] = None
|
||||
encoded_opaque_state: Optional[str] = None
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
disagg_request_id: Optional[int] = None
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1001,18 +1001,14 @@ def to_disaggregated_params(
|
||||
ctx_request_id=tllm_disagg_params.ctx_request_id,
|
||||
encoded_opaque_state=encode_opaque_state(
|
||||
tllm_disagg_params.opaque_state),
|
||||
draft_tokens=tllm_disagg_params.draft_tokens)
|
||||
draft_tokens=tllm_disagg_params.draft_tokens,
|
||||
disagg_request_id=tllm_disagg_params.disagg_request_id)
|
||||
|
||||
|
||||
def to_llm_disaggregated_params(
|
||||
disaggregated_params: DisaggregatedParams) -> LlmDisaggregatedParams:
|
||||
if disaggregated_params is None:
|
||||
return None
|
||||
disagg_request_id = None
|
||||
# If ctx_request_id is greater than or equal to MIN_GLOBAL_ID, use it as disagg_request_id
|
||||
# then both the ctx and gen requests will use it as underlying request id.
|
||||
if disaggregated_params.ctx_request_id is not None and disaggregated_params.ctx_request_id >= MIN_GLOBAL_ID:
|
||||
disagg_request_id = disaggregated_params.ctx_request_id
|
||||
return LlmDisaggregatedParams(
|
||||
request_type=disaggregated_params.request_type,
|
||||
first_gen_tokens=disaggregated_params.first_gen_tokens,
|
||||
@ -1020,7 +1016,7 @@ def to_llm_disaggregated_params(
|
||||
opaque_state=decode_opaque_state(
|
||||
disaggregated_params.encoded_opaque_state),
|
||||
draft_tokens=disaggregated_params.draft_tokens,
|
||||
disagg_request_id=disagg_request_id)
|
||||
disagg_request_id=disaggregated_params.disagg_request_id)
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user