mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[nvbugs/5345391] fix: chunked prefill + overlap scheduling (#5761)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
2e21e3421f
commit
fd94d3cbf5
@ -1083,11 +1083,8 @@ class PyExecutor:
|
||||
scheduled_batch.context_requests
|
||||
) if self.kv_cache_transceiver else []
|
||||
|
||||
has_previous_batch = self.previous_batch is not None
|
||||
if has_previous_batch:
|
||||
previous_batch_size = self.previous_batch.sample_state.scheduled_requests.batch_size
|
||||
if previous_batch_size > 0: # first previous batch size is 0
|
||||
self._process_previous_batch()
|
||||
if self.previous_batch is not None:
|
||||
self._process_previous_batch()
|
||||
self.previous_batch: Optional[BatchState] = None
|
||||
|
||||
# Separate chunked requests so we can handle them in _update_requests w/o relying on the request state.
|
||||
|
||||
@ -690,7 +690,9 @@ class TRTLLMSampler(Sampler):
|
||||
assert isinstance(state, SampleStateTRTLLM)
|
||||
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert scheduled_requests.batch_size > 0
|
||||
if scheduled_requests.batch_size == 0:
|
||||
return
|
||||
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests)
|
||||
sampler_event = state.sampler_event
|
||||
|
||||
|
||||
@ -61,14 +61,10 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
|
||||
def test_chunked_prefill(self, attn_backend):
|
||||
pytorch_config = dict(
|
||||
attn_backend=attn_backend,
|
||||
# https://nvbugspro.nvidia.com/bug/5345391
|
||||
disable_overlap_scheduler=True)
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
attn_backend=attn_backend,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=512,
|
||||
**pytorch_config)
|
||||
max_num_tokens=512)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user