diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 374836dde8..ad35dcfebc 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -701,11 +701,8 @@ class PyExecutor: if self._need_return_logits(scheduled_batch): logits_host = batch_outputs["logits"].to( "cpu", non_blocking=True) - - if self.guided_decoder is not None: - self.guided_decoder.build(scheduled_batch) - self.guided_decoder.execute( - scheduled_batch, batch_outputs['logits']) + self._execute_guided_decoder( + scheduled_batch, batch_outputs['logits']) sample_state = self._sample_async( scheduled_batch, batch_outputs) @@ -844,6 +841,11 @@ class PyExecutor: f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats + def _execute_guided_decoder(self, scheduled_batch, logits): + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute(scheduled_batch, logits) + def _executor_loop(self): torch.cuda.set_device(self.device_id) with self._profiler() as profile_step: @@ -879,11 +881,8 @@ class PyExecutor: scheduled_batch, self.resource_manager) batch_outputs = self._forward_step(scheduled_batch) - - if self.guided_decoder is not None: - self.guided_decoder.build(scheduled_batch) - self.guided_decoder.execute(scheduled_batch, - batch_outputs['logits']) + self._execute_guided_decoder(scheduled_batch, + batch_outputs['logits']) sample_state = self._sample_async(scheduled_batch, batch_outputs) @@ -891,11 +890,9 @@ class PyExecutor: self._update_request_states(scheduled_batch) self._update_requests(sample_state) - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] - if self.kv_cache_transceiver: + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests) # For context only req in transmission, we reset the state since sampler might have changed it for req in ctx_transmission_reqs: req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS @@ -997,10 +994,8 @@ class PyExecutor: if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) - if self.guided_decoder is not None: - self.guided_decoder.build(scheduled_batch) - self.guided_decoder.execute(scheduled_batch, - batch_outputs['logits']) + self._execute_guided_decoder(scheduled_batch, + batch_outputs['logits']) sample_state = self._sample_async(scheduled_batch, batch_outputs)