mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugspro.nvidia.com/bug/5247148][fix] Attention DP with overlap scheduler (#3975)
* fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * update multigpu list Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix namings Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --------- Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
07479eabc7
commit
dbcfd78ea8
@ -517,6 +517,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
|
||||
"tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py",
|
||||
"tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py",
|
||||
"tensorrt_llm/_torch/custom_ops/userbuffers_custom_ops.py",
|
||||
"tensorrt_llm/_torch/pyexecutor/model_engine.py",
|
||||
"tensorrt_llm/_torch/pyexecutor/py_executor.py",
|
||||
"tensorrt_llm/_torch/models/modeling_deepseekv3.py",
|
||||
"tensorrt_llm/_torch/models/modeling_llama.py",
|
||||
|
||||
@ -50,7 +50,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_draft_tokens = self.draft_tokens
|
||||
self.py_last_draft_tokens = None
|
||||
self.py_decoding_iter = 0
|
||||
self.is_dummy = False
|
||||
self.is_attention_dp_dummy = False
|
||||
self.is_cuda_graph_dummy = False
|
||||
|
||||
|
||||
def convert_wordlist(word_list) -> List[List[int]]:
|
||||
|
||||
@ -651,6 +651,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
[max_req_id + i + 1 for i in range(padding_size)],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=max_num_draft_tokens)
|
||||
for req in generation_requests:
|
||||
req.is_cuda_graph_dummy = True
|
||||
scheduled_requests.generation_requests.extend(generation_requests)
|
||||
return generation_requests
|
||||
|
||||
@ -1022,9 +1024,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# (1) new_tokens_device is None, which means overlap scheduler is disabled; or
|
||||
# (2) request.py_batch_idx is None, which means the request has no previous batch.
|
||||
# the second condition includes dummy generation requests created for CUDA graph padding.
|
||||
# the dummy generation requests should be at the end of generation_requests.
|
||||
# these dummy generation requests should be at the end of generation_requests.
|
||||
# skip adding their input_ids so that new_tokens_device can be aligned to the correct positions.
|
||||
if not request.is_dummy:
|
||||
if not request.is_cuda_graph_dummy:
|
||||
input_ids.append(request.get_last_tokens(0))
|
||||
past_seen_token_num = request.max_beam_num_tokens - 1
|
||||
else:
|
||||
|
||||
@ -1274,7 +1274,6 @@ class PyExecutor:
|
||||
if _is_executor_request(req_item):
|
||||
req_id, exe_req = req_item
|
||||
req = executor_request_to_llm_request(req_id, exe_req)
|
||||
req.is_dummy = False
|
||||
self.active_requests.append(req)
|
||||
elif _is_cancel_request(req_item):
|
||||
self.canceled_req_ids.insert(req_item)
|
||||
@ -1290,18 +1289,18 @@ class PyExecutor:
|
||||
if self.has_context_request else self.max_draft_tokens,
|
||||
)
|
||||
for llm_request in llm_request_list:
|
||||
llm_request.is_dummy = True
|
||||
llm_request.is_attention_dp_dummy = True
|
||||
self.active_requests += llm_request_list
|
||||
|
||||
def _finish_dummy_request(self, scheduled_requests: ScheduledRequests):
|
||||
for req in scheduled_requests.context_requests:
|
||||
if req.is_dummy:
|
||||
if req.is_attention_dp_dummy:
|
||||
req.state = LlmRequestState.GENERATION_COMPLETE
|
||||
for req in scheduled_requests.generation_requests:
|
||||
if req.is_dummy:
|
||||
if req.is_attention_dp_dummy:
|
||||
req.state = LlmRequestState.GENERATION_COMPLETE
|
||||
for req in self.active_requests[:]:
|
||||
if req.is_dummy:
|
||||
if req.is_attention_dp_dummy:
|
||||
self.inflight_req_ids.erase(req.request_id)
|
||||
self._terminate_request(req)
|
||||
self.active_requests.remove(req)
|
||||
@ -1399,7 +1398,6 @@ class PyExecutor:
|
||||
req.ctx_blocks = ctx_blocks
|
||||
req.ctx_position_blocks = position_blocks
|
||||
req.query_id = query_token_ids
|
||||
req.is_dummy = False
|
||||
self.active_requests.append(req)
|
||||
elif _is_cancel_request(req_item):
|
||||
self.canceled_req_ids.insert(req_item)
|
||||
@ -1573,7 +1571,7 @@ class PyExecutor:
|
||||
request.move_to_next_context_chunk()
|
||||
if request.get_context_remaining_length() == 0:
|
||||
request.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
if request.is_dummy:
|
||||
if request.is_attention_dp_dummy:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
|
||||
def _update_request_states_star_attention(
|
||||
@ -1719,7 +1717,6 @@ class PyExecutor:
|
||||
draft_batch.context_requests.append(new_request)
|
||||
|
||||
new_request.py_stop_words_list = request.py_stop_words_list
|
||||
new_request.is_dummy = False
|
||||
|
||||
return draft_batch, req_id_to_num_rejected_tokens
|
||||
|
||||
@ -1911,7 +1908,7 @@ class PyExecutor:
|
||||
for request in self.active_requests:
|
||||
req_id = request.py_request_id
|
||||
# no responses for dummy request, and finish it
|
||||
if request.is_dummy:
|
||||
if request.is_attention_dp_dummy:
|
||||
requests_to_terminate.append(request)
|
||||
continue
|
||||
|
||||
|
||||
@ -323,7 +323,6 @@ class KVCacheManager(BaseResourceManager):
|
||||
req.py_prompt_len = req.prompt_len
|
||||
if max_num_draft_tokens > 0:
|
||||
req.py_draft_tokens = [0] * max_num_draft_tokens
|
||||
req.is_dummy = True
|
||||
requests.append(req)
|
||||
return requests
|
||||
|
||||
|
||||
@ -348,6 +348,24 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_device_not_contain(["H100"])
|
||||
def test_fp8_block_scales_cuda_graph_padding(self):
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
pytorch_config = PyTorchConfig(enable_overlap_scheduler=True,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_max_batch_size=512,
|
||||
cuda_graph_padding_enabled=True)
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_device_not_contain(["H100"])
|
||||
@parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler",
|
||||
|
||||
@ -42,6 +42,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-cuda_graph]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-overlap_scheduler]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
|
||||
|
||||
@ -465,7 +465,3 @@ full:B200/perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float1
|
||||
accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head SKIP (https://nvbugs/5231769)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5236456)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5236456)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5247148)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5247148)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5247148)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5247148)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user