chore: clean code of PyExecutor (#6445)

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-07-30 14:11:43 +08:00 committed by GitHub
parent d6eed1b624
commit 1f39a11af0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -701,11 +701,8 @@ class PyExecutor:
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(
scheduled_batch, batch_outputs['logits'])
self._execute_guided_decoder(
scheduled_batch, batch_outputs['logits'])
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
@ -844,6 +841,11 @@ class PyExecutor:
f'{len(scheduled_batch.generation_requests)} generation requests')
return scheduled_batch, iter_stats
def _execute_guided_decoder(self, scheduled_batch, logits):
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch, logits)
def _executor_loop(self):
torch.cuda.set_device(self.device_id)
with self._profiler() as profile_step:
@ -879,11 +881,8 @@ class PyExecutor:
scheduled_batch, self.resource_manager)
batch_outputs = self._forward_step(scheduled_batch)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])
self._execute_guided_decoder(scheduled_batch,
batch_outputs['logits'])
sample_state = self._sample_async(scheduled_batch,
batch_outputs)
@ -891,11 +890,9 @@ class PyExecutor:
self._update_request_states(scheduled_batch)
self._update_requests(sample_state)
ctx_transmission_reqs = self._send_disagg_ctx_cache(
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []
if self.kv_cache_transceiver:
ctx_transmission_reqs = self._send_disagg_ctx_cache(
scheduled_batch.context_requests)
# For context only req in transmission, we reset the state since sampler might have changed it
for req in ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
@ -997,10 +994,8 @@ class PyExecutor:
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)
if self.guided_decoder is not None:
self.guided_decoder.build(scheduled_batch)
self.guided_decoder.execute(scheduled_batch,
batch_outputs['logits'])
self._execute_guided_decoder(scheduled_batch,
batch_outputs['logits'])
sample_state = self._sample_async(scheduled_batch,
batch_outputs)