add disagg_request_id in OpenAI protocol

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-01-11 21:13:06 -08:00
parent 207cce4ba5
commit ce44fabe8e
7 changed files with 16 additions and 16 deletions

View File

@ -582,7 +582,7 @@ void initRequestBindings(nb::module_& m)
};
auto requestSetstate = [](tle::Request& self, nb::tuple const& state)
{
if (state.size() != 34)
if (state.size() != 35)
{
throw std::runtime_error("Invalid Request state!");
}

View File

@ -148,7 +148,7 @@ public:
allottedTimeMs, //
contextPhaseParams, //
cacheSaltID, //
arrivalTime, //
arrivalTime //
)
{
}

View File

@ -531,11 +531,11 @@ void initRequestBindings(pybind11::module_& m)
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams(), self.getCacheSaltID());
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
};
auto requestSetstate = [](py::tuple const& state)
{
if (state.size() != 34)
if (state.size() != 35)
{
throw std::runtime_error("Invalid Request state!");
}
@ -640,8 +640,9 @@ void initRequestBindings(pybind11::module_& m)
py::arg("guided_decoding_params") = py::none(),
py::arg("language_adapter_uid") = py::none(),
py::arg("allotted_time_ms") = py::none(),
py::arg("cache_salt_id") = py::none()
) // clang-format on
py::arg("cache_salt_id") = py::none(),
py::arg("disagg_request_id") = py::none()
) // clang-format on
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
@ -688,6 +689,7 @@ void initRequestBindings(pybind11::module_& m)
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_property(
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
.def(py::pickle(requestGetstate, requestSetstate));
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;

View File

@ -347,7 +347,7 @@ class DemoGenerationExecutor(GenerationExecutor):
def submit(self, request: GenerationRequest) -> GenerationResult:
# set request id if necessary
client_id = self._get_next_client_id()
client_id = request.id if request.id is not None else self._get_next_client_id()
if request.id is None:
request.set_id(client_id)

View File

@ -21,6 +21,8 @@ class DisaggregatedParams:
ctx_request_id (int): The context request id
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
draft_tokens (List[int]): The draft tokens of the generation request
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
as underlying request id.
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.

View File

@ -147,7 +147,7 @@ class OpenAIDisaggregatedService(OpenAIService):
) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(
request_type="context_only", ctx_request_id=disagg_request_id
request_type="context_only", disagg_request_id=disagg_request_id
)
ctx_request.stream = False
ctx_request.stream_options = None

View File

@ -37,7 +37,6 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
from tensorrt_llm.llmapi.disagg_utils import MIN_GLOBAL_ID
def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
@ -118,6 +117,7 @@ class DisaggregatedParams(OpenAIBaseModel):
ctx_request_id: Optional[int] = None
encoded_opaque_state: Optional[str] = None
draft_tokens: Optional[List[int]] = None
disagg_request_id: Optional[int] = None
class ErrorResponse(OpenAIBaseModel):
@ -1001,18 +1001,14 @@ def to_disaggregated_params(
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encode_opaque_state(
tllm_disagg_params.opaque_state),
draft_tokens=tllm_disagg_params.draft_tokens)
draft_tokens=tllm_disagg_params.draft_tokens,
disagg_request_id=tllm_disagg_params.disagg_request_id)
def to_llm_disaggregated_params(
disaggregated_params: DisaggregatedParams) -> LlmDisaggregatedParams:
if disaggregated_params is None:
return None
disagg_request_id = None
# If ctx_request_id is greater than or equal to MIN_GLOBAL_ID, use it as disagg_request_id
# then both the ctx and gen requests will use it as underlying request id.
if disaggregated_params.ctx_request_id is not None and disaggregated_params.ctx_request_id >= MIN_GLOBAL_ID:
disagg_request_id = disaggregated_params.ctx_request_id
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
@ -1020,7 +1016,7 @@ def to_llm_disaggregated_params(
opaque_state=decode_opaque_state(
disaggregated_params.encoded_opaque_state),
draft_tokens=disaggregated_params.draft_tokens,
disagg_request_id=disagg_request_id)
disagg_request_id=disaggregated_params.disagg_request_id)
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]