fix ci failures

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-01-05 21:52:40 -08:00
parent 4ad2427e02
commit c71c8922be
9 changed files with 18 additions and 21 deletions

View File

@ -929,8 +929,8 @@ void initRequestBindings(nb::module_& m)
{
throw std::runtime_error("Invalid Request state!");
}
new (&response) tle::Response(
nb::cast<SizeType32>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<SizeType32>(state[2]));
new (&response)
tle::Response(nb::cast<IdType>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<IdType>(state[2]));
};
nb::class_<tle::Response>(m, "Response")

View File

@ -870,7 +870,7 @@ void initRequestBindings(pybind11::module_& m)
throw std::runtime_error("Invalid Request state!");
}
return std::make_unique<tle::Response>(
state[0].cast<SizeType32>(), state[1].cast<tle::Result>(), state[2].cast<SizeType32>());
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
};
py::class_<tle::Response>(m, "Response")

View File

@ -347,7 +347,7 @@ class DemoGenerationExecutor(GenerationExecutor):
def submit(self, request: GenerationRequest) -> GenerationResult:
# set request id if necessary
client_id = request.id if request.id is not None else self._get_next_client_id()
client_id = self._get_client_id(request)
if request.id is None:
request.set_id(client_id)

View File

@ -442,8 +442,6 @@ class BaseWorker(GenerationExecutor):
== "context_and_generation"
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
request_type = request.disaggregated_params.get_request_type()
if request.disaggregated_params.ctx_request_id is not None and request.disaggregated_params.ctx_request_id > 0:
client_id = request.disaggregated_params.ctx_request_id
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
@ -617,8 +615,7 @@ class BaseWorker(GenerationExecutor):
"To fix this, ensure that the llm.generate(...) method is "
"guarded with the `if __name__ == '__main__':` block.")
client_id = request.id if request.id is not None else self._get_next_client_id(
)
client_id = self._get_client_id(request)
if request.id is None:
request.set_id(client_id)

View File

@ -213,7 +213,12 @@ class GenerationExecutor(ABC):
return futures
def _get_next_client_id(self):
def _get_client_id(self, request: GenerationRequest) -> int:
if request.id is not None:
return request.id
if request.disaggregated_params and isinstance(
request.disaggregated_params.ctx_request_id, int):
return request.disaggregated_params.ctx_request_id
self._last_client_id = get_local_request_id(self._last_client_id)
return self._last_client_id

View File

@ -342,11 +342,8 @@ class GenerationExecutorProxy(GenerationExecutor):
self._start_dispatch_threads()
if request.disaggregated_params is not None \
and request.disaggregated_params.ctx_request_id is not None:
request.set_id(request.disaggregated_params.ctx_request_id)
else:
request.set_id(self._get_next_client_id())
if request.id is None:
request.set_id(self._get_client_id(request))
logprob_params = self._get_logprob_params(request)
result = GenerationResult(

View File

@ -233,7 +233,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
Low-level API to the executor. Return a "future" GenerationResult
which can be waited. Forwards the request to the workers through RPC.
"""
request.set_id(self._get_next_client_id())
if request.id is None:
request.set_id(self._get_client_id(request))
logprob_params = self._get_logprob_params(request)
with nvtx_range_debug("rpc_submit"):

View File

@ -80,7 +80,8 @@ class RpcExecutorMixin:
atexit.register(self.shutdown)
def submit(self, request: GenerationRequest) -> GenerationResult:
request.set_id(self._get_next_client_id())
if request.id is None:
request.set_id(self._get_client_id(request))
logprob_params = self._get_logprob_params(request)
# submit is a fire-and-forget operation, don't need to wait for response

View File

@ -124,7 +124,7 @@ class OpenAIDisaggregatedService(OpenAIService):
ctx_req, server=reserved_ctx_server, hooks=hooks
)
await self._verify_ctx_response(ctx_response)
gen_req = self._get_gen_request(request, ctx_response, request_id)
gen_req = self._get_gen_request(request, ctx_response)
if ctx_response is None or self._need_gen(ctx_response):
return await self._gen_client.send_request(
gen_req, server=reserved_gen_server, hooks=hooks
@ -157,12 +157,8 @@ class OpenAIDisaggregatedService(OpenAIService):
self,
request: UCompletionRequest,
ctx_response: UCompletionResponse,
ctx_request_id: int,
) -> UCompletionRequest:
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
assert (
request.disaggregated_params.ctx_request_id == ctx_request_id or ctx_request_id is None
)
request.disaggregated_params.request_type = "generation_only"
# Replace the string prompt with prompt_tokens_ids
if isinstance(request, CompletionRequest):