[https://nvbugs/5277592][fix] fix cuda graph padding for spec decoding (only for 0.20) (#5058)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2025-06-11 02:14:14 +08:00 committed by GitHub
parent b626186241
commit bfa3b59bb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 7 deletions

View File

@ -1114,13 +1114,18 @@ class PyTorchModelEngine(ModelEngine):
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]
# Requests with draft tokens are treated like extend requests.
# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
# requests should be at the end of extend_requests.
extend_requests = []
extend_cuda_graph_dummy_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if len(request.py_draft_tokens
) > 0 or next_draft_tokens_device is not None:
extend_requests.append(request)
if request.is_cuda_graph_dummy:
extend_cuda_graph_dummy_requests.append(request)
else:
extend_requests.append(request)
else:
generation_requests.append(request)
@ -1130,6 +1135,7 @@ class PyTorchModelEngine(ModelEngine):
torch.tensor([mrope_position_deltas],
dtype=torch.int32).to('cuda',
non_blocking=True))
extend_requests = extend_cuda_graph_dummy_requests + extend_requests
if not self._disable_overlap_scheduler and self.is_spec_decode:
spec_dec_mode = self.spec_config.spec_dec_mode

View File

@ -263,6 +263,12 @@ class MTPSampler(TorchSampler):
request.py_decoding_iter += 1
idx += 1
# skip the results of cuda graph dummy requests
if idx == 0:
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
state.scheduled_requests.generation_requests)
idx += num_cuda_graph_dummy_requests
for request in state.scheduled_requests.generation_requests:
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"

View File

@ -568,22 +568,59 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
task.evaluate(llm)
@pytest.mark.skip_device_not_contain(["H100"])
def test_fp8_block_scales_cuda_graph_padding(self):
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
# OOM on H100 with default free_gpu_memory_fraction=0.9
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = PyTorchConfig(disable_overlap_scheduler=False,
use_cuda_graph=True,
cuda_graph_max_batch_size=512,
cuda_graph_padding_enabled=True)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_config)
pytorch_backend_config=pytorch_config,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@parametrize_with_ids("mtp_nextn", [0, 2])
@parametrize_with_ids("attention_dp", [False, True])
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
attention_dp):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = PyTorchConfig(
disable_overlap_scheduler=False,
use_cuda_graph=True,
cuda_graph_padding_enabled=True,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])

View File

@ -101,6 +101,8 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8]

View File

@ -51,7 +51,8 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]