mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: clean code of PyExecutor (#6445)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
d6eed1b624
commit
1f39a11af0
@ -701,11 +701,8 @@ class PyExecutor:
|
||||
if self._need_return_logits(scheduled_batch):
|
||||
logits_host = batch_outputs["logits"].to(
|
||||
"cpu", non_blocking=True)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(
|
||||
scheduled_batch, batch_outputs['logits'])
|
||||
self._execute_guided_decoder(
|
||||
scheduled_batch, batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
@ -844,6 +841,11 @@ class PyExecutor:
|
||||
f'{len(scheduled_batch.generation_requests)} generation requests')
|
||||
return scheduled_batch, iter_stats
|
||||
|
||||
def _execute_guided_decoder(self, scheduled_batch, logits):
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch, logits)
|
||||
|
||||
def _executor_loop(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
with self._profiler() as profile_step:
|
||||
@ -879,11 +881,8 @@ class PyExecutor:
|
||||
scheduled_batch, self.resource_manager)
|
||||
|
||||
batch_outputs = self._forward_step(scheduled_batch)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
self._execute_guided_decoder(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(scheduled_batch,
|
||||
batch_outputs)
|
||||
@ -891,11 +890,9 @@ class PyExecutor:
|
||||
self._update_request_states(scheduled_batch)
|
||||
self._update_requests(sample_state)
|
||||
|
||||
ctx_transmission_reqs = self._send_disagg_ctx_cache(
|
||||
scheduled_batch.context_requests
|
||||
) if self.kv_cache_transceiver else []
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
ctx_transmission_reqs = self._send_disagg_ctx_cache(
|
||||
scheduled_batch.context_requests)
|
||||
# For context only req in transmission, we reset the state since sampler might have changed it
|
||||
for req in ctx_transmission_reqs:
|
||||
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
@ -997,10 +994,8 @@ class PyExecutor:
|
||||
if self.previous_batch is not None:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
self._execute_guided_decoder(scheduled_batch,
|
||||
batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(scheduled_batch,
|
||||
batch_outputs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user