mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix ci failures
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
4ad2427e02
commit
c71c8922be
@ -929,8 +929,8 @@ void initRequestBindings(nb::module_& m)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
new (&response) tle::Response(
|
||||
nb::cast<SizeType32>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<SizeType32>(state[2]));
|
||||
new (&response)
|
||||
tle::Response(nb::cast<IdType>(state[0]), nb::cast<tle::Result>(state[1]), nb::cast<IdType>(state[2]));
|
||||
};
|
||||
|
||||
nb::class_<tle::Response>(m, "Response")
|
||||
|
||||
@ -870,7 +870,7 @@ void initRequestBindings(pybind11::module_& m)
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
return std::make_unique<tle::Response>(
|
||||
state[0].cast<SizeType32>(), state[1].cast<tle::Result>(), state[2].cast<SizeType32>());
|
||||
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
|
||||
};
|
||||
|
||||
py::class_<tle::Response>(m, "Response")
|
||||
|
||||
@ -347,7 +347,7 @@ class DemoGenerationExecutor(GenerationExecutor):
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
# set request id if necessary
|
||||
client_id = request.id if request.id is not None else self._get_next_client_id()
|
||||
client_id = self._get_client_id(request)
|
||||
if request.id is None:
|
||||
request.set_id(client_id)
|
||||
|
||||
|
||||
@ -442,8 +442,6 @@ class BaseWorker(GenerationExecutor):
|
||||
== "context_and_generation"
|
||||
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
|
||||
request_type = request.disaggregated_params.get_request_type()
|
||||
if request.disaggregated_params.ctx_request_id is not None and request.disaggregated_params.ctx_request_id > 0:
|
||||
client_id = request.disaggregated_params.ctx_request_id
|
||||
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
|
||||
context_phase_params = request.disaggregated_params.get_context_phase_params(
|
||||
)
|
||||
@ -617,8 +615,7 @@ class BaseWorker(GenerationExecutor):
|
||||
"To fix this, ensure that the llm.generate(...) method is "
|
||||
"guarded with the `if __name__ == '__main__':` block.")
|
||||
|
||||
client_id = request.id if request.id is not None else self._get_next_client_id(
|
||||
)
|
||||
client_id = self._get_client_id(request)
|
||||
if request.id is None:
|
||||
request.set_id(client_id)
|
||||
|
||||
|
||||
@ -213,7 +213,12 @@ class GenerationExecutor(ABC):
|
||||
|
||||
return futures
|
||||
|
||||
def _get_next_client_id(self):
|
||||
def _get_client_id(self, request: GenerationRequest) -> int:
|
||||
if request.id is not None:
|
||||
return request.id
|
||||
if request.disaggregated_params and isinstance(
|
||||
request.disaggregated_params.ctx_request_id, int):
|
||||
return request.disaggregated_params.ctx_request_id
|
||||
self._last_client_id = get_local_request_id(self._last_client_id)
|
||||
return self._last_client_id
|
||||
|
||||
|
||||
@ -342,11 +342,8 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
|
||||
self._start_dispatch_threads()
|
||||
|
||||
if request.disaggregated_params is not None \
|
||||
and request.disaggregated_params.ctx_request_id is not None:
|
||||
request.set_id(request.disaggregated_params.ctx_request_id)
|
||||
else:
|
||||
request.set_id(self._get_next_client_id())
|
||||
if request.id is None:
|
||||
request.set_id(self._get_client_id(request))
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
result = GenerationResult(
|
||||
|
||||
@ -233,7 +233,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
|
||||
Low-level API to the executor. Return a "future" GenerationResult
|
||||
which can be waited. Forwards the request to the workers through RPC.
|
||||
"""
|
||||
request.set_id(self._get_next_client_id())
|
||||
if request.id is None:
|
||||
request.set_id(self._get_client_id(request))
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
with nvtx_range_debug("rpc_submit"):
|
||||
|
||||
@ -80,7 +80,8 @@ class RpcExecutorMixin:
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
request.set_id(self._get_next_client_id())
|
||||
if request.id is None:
|
||||
request.set_id(self._get_client_id(request))
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
# submit is a fire-and-forget operation, don't need to wait for response
|
||||
|
||||
@ -124,7 +124,7 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
ctx_req, server=reserved_ctx_server, hooks=hooks
|
||||
)
|
||||
await self._verify_ctx_response(ctx_response)
|
||||
gen_req = self._get_gen_request(request, ctx_response, request_id)
|
||||
gen_req = self._get_gen_request(request, ctx_response)
|
||||
if ctx_response is None or self._need_gen(ctx_response):
|
||||
return await self._gen_client.send_request(
|
||||
gen_req, server=reserved_gen_server, hooks=hooks
|
||||
@ -157,12 +157,8 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
self,
|
||||
request: UCompletionRequest,
|
||||
ctx_response: UCompletionResponse,
|
||||
ctx_request_id: int,
|
||||
) -> UCompletionRequest:
|
||||
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
|
||||
assert (
|
||||
request.disaggregated_params.ctx_request_id == ctx_request_id or ctx_request_id is None
|
||||
)
|
||||
request.disaggregated_params.request_type = "generation_only"
|
||||
# Replace the string prompt with prompt_tokens_ids
|
||||
if isinstance(request, CompletionRequest):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user