[nvbugs/5345391] fix: chunked prefill + overlap scheduling (#5761)

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-07-09 17:59:45 +02:00 committed by GitHub
parent 2e21e3421f
commit fd94d3cbf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 12 deletions

View File

@ -1083,11 +1083,8 @@ class PyExecutor:
scheduled_batch.context_requests
) if self.kv_cache_transceiver else []
has_previous_batch = self.previous_batch is not None
if has_previous_batch:
previous_batch_size = self.previous_batch.sample_state.scheduled_requests.batch_size
if previous_batch_size > 0: # first previous batch size is 0
self._process_previous_batch()
if self.previous_batch is not None:
self._process_previous_batch()
self.previous_batch: Optional[BatchState] = None
# Separate chunked requests so we can handle them in _update_requests w/o relying on the request state.

View File

@ -690,7 +690,9 @@ class TRTLLMSampler(Sampler):
assert isinstance(state, SampleStateTRTLLM)
scheduled_requests = state.scheduled_requests
assert scheduled_requests.batch_size > 0
if scheduled_requests.batch_size == 0:
return
beam_width = self.beam_width(scheduled_requests.all_requests)
sampler_event = state.sampler_event

View File

@ -61,14 +61,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(32000)
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
def test_chunked_prefill(self, attn_backend):
pytorch_config = dict(
attn_backend=attn_backend,
# https://nvbugspro.nvidia.com/bug/5345391
disable_overlap_scheduler=True)
llm = LLM(self.MODEL_PATH,
attn_backend=attn_backend,
enable_chunked_prefill=True,
max_num_tokens=512,
**pytorch_config)
max_num_tokens=512)
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)