mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[https://nvbugs/5670108][fix] Fix overlap scheduler race condition in… (#10610)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
parent
c6163e2b70
commit
3c8ed19440
@ -606,11 +606,19 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.update_kv_cache_draft_token_location(scheduled_batch,
|
||||
attn_metadata,
|
||||
kv_cache_dtype_byte_size)
|
||||
# rewind kv cache
|
||||
|
||||
# Rewind KV cache for requests with rejected draft tokens.
|
||||
# Skip:
|
||||
# - GENERATION_COMPLETE: finished requests
|
||||
# - CONTEXT_INIT: requests whose state was reset after being paused with KV cache freed.
|
||||
# With overlap scheduler, the scheduler pauses a request and frees KV cache at iteration N,
|
||||
# while the previous batch (N-1) is still trying to update the KV cache after forward pass.
|
||||
for request in scheduled_batch.generation_requests:
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
if request.py_rewind_len > 0:
|
||||
self.rewind_kv_cache(request, request.py_rewind_len)
|
||||
if request.state in (LlmRequestState.GENERATION_COMPLETE,
|
||||
LlmRequestState.CONTEXT_INIT):
|
||||
continue
|
||||
if request.py_rewind_len > 0:
|
||||
self.rewind_kv_cache(request, request.py_rewind_len)
|
||||
|
||||
# For context requests, we store the blocks for reuse.
|
||||
for request in scheduled_batch.context_requests:
|
||||
|
||||
@ -459,7 +459,18 @@ class BenchRunner:
|
||||
extra_llm_api_options: Optional[str] = None,
|
||||
use_mpirun: bool = False,
|
||||
concurrency: Optional[int] = None,
|
||||
num_requests: int = 10):
|
||||
num_requests: int = 10,
|
||||
ep_size: Optional[int] = None,
|
||||
max_batch_size: Optional[int] = None,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
warmup: Optional[int] = None,
|
||||
eos_id: Optional[int] = None,
|
||||
kv_cache_free_gpu_mem_fraction: Optional[float] = None,
|
||||
scheduler_policy: Optional[str] = None,
|
||||
input_mean: int = 128,
|
||||
output_mean: int = 128,
|
||||
input_stdev: int = 0,
|
||||
output_stdev: int = 0):
|
||||
|
||||
llm_models = llm_models_root()
|
||||
assert llm_models is not None
|
||||
@ -486,6 +497,17 @@ class BenchRunner:
|
||||
self.engine_path = None
|
||||
self.concurrency = concurrency
|
||||
self.num_requests = num_requests
|
||||
self.ep_size = ep_size
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.warmup = warmup
|
||||
self.eos_id = eos_id
|
||||
self.kv_cache_free_gpu_mem_fraction = kv_cache_free_gpu_mem_fraction
|
||||
self.scheduler_policy = scheduler_policy
|
||||
self.input_mean = input_mean
|
||||
self.output_mean = output_mean
|
||||
self.input_stdev = input_stdev
|
||||
self.output_stdev = output_stdev
|
||||
|
||||
def __call__(self):
|
||||
self.prepare_dataset()
|
||||
@ -505,17 +527,18 @@ class BenchRunner:
|
||||
f"{self.dataset_path}",
|
||||
"token-norm-dist",
|
||||
"--input-mean",
|
||||
"128",
|
||||
str(self.input_mean),
|
||||
"--output-mean",
|
||||
"128",
|
||||
str(self.output_mean),
|
||||
"--input-stdev",
|
||||
"0",
|
||||
str(self.input_stdev),
|
||||
"--output-stdev",
|
||||
"0",
|
||||
str(self.output_stdev),
|
||||
"--num-requests",
|
||||
str(self.num_requests),
|
||||
]
|
||||
print(f"Running command: {' '.join(command)}")
|
||||
check_call(" ".join(command), shell=True, env=self.llm_venv._new_env)
|
||||
|
||||
def build_engine(self):
|
||||
if self.skip_engine_build:
|
||||
@ -559,11 +582,25 @@ class BenchRunner:
|
||||
benchmark_cmd += " --backend tensorrt"
|
||||
|
||||
if self.extra_llm_api_options:
|
||||
benchmark_cmd += f" --config {self.extra_llm_api_options}"
|
||||
benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}"
|
||||
if self.concurrency:
|
||||
benchmark_cmd += f" --concurrency {self.concurrency}"
|
||||
if self.num_requests:
|
||||
benchmark_cmd += f" --num_requests {self.num_requests}"
|
||||
if self.ep_size is not None:
|
||||
benchmark_cmd += f" --ep {self.ep_size}"
|
||||
if self.max_batch_size is not None:
|
||||
benchmark_cmd += f" --max_batch_size {self.max_batch_size}"
|
||||
if self.max_num_tokens is not None:
|
||||
benchmark_cmd += f" --max_num_tokens {self.max_num_tokens}"
|
||||
if self.warmup is not None:
|
||||
benchmark_cmd += f" --warmup {self.warmup}"
|
||||
if self.eos_id is not None:
|
||||
benchmark_cmd += f" --eos_id {self.eos_id}"
|
||||
if self.kv_cache_free_gpu_mem_fraction is not None:
|
||||
benchmark_cmd += f" --kv_cache_free_gpu_mem_fraction {self.kv_cache_free_gpu_mem_fraction}"
|
||||
if self.scheduler_policy is not None:
|
||||
benchmark_cmd += f" --scheduler_policy {self.scheduler_policy}"
|
||||
|
||||
benchmark_output = check_output(benchmark_cmd,
|
||||
shell=True,
|
||||
@ -2417,6 +2454,92 @@ def test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus(
|
||||
_check_mem_usage(running_log, [50.0, 0, 0, 0], 8)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.skip_less_device(8)
|
||||
def test_deepseek_r1_mtp_bench(llm_root, llm_venv):
|
||||
"""
|
||||
Test DeepSeek-R1 FP4 with MTP speculative decoding using BenchRunner.
|
||||
The goal is to test the bug fix for https://nvbugs/5670108.
|
||||
Average input sequence length: 1k, average output sequence length: 10k.
|
||||
"""
|
||||
model_name = "nvidia/DeepSeek-R1-FP4"
|
||||
model_path = "DeepSeek-R1/DeepSeek-R1-FP4"
|
||||
print(f"Testing {model_name} with MTP speculative decoding.")
|
||||
|
||||
# Create extra_llm_api_options YAML with MTP config
|
||||
extra_config = {
|
||||
"print_iter_log": True,
|
||||
"enable_layerwise_nvtx_marker": False,
|
||||
"disable_overlap_scheduler": False,
|
||||
"enable_iter_perf_stats": True,
|
||||
"enable_chunked_prefill": False,
|
||||
"stream_interval": 20,
|
||||
"scheduler_config": {
|
||||
"capacity_scheduler_policy": "MAX_UTILIZATION",
|
||||
"context_chunking_policy": "FIRST_COME_FIRST_SERVED",
|
||||
},
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.1,
|
||||
"enable_block_reuse": False,
|
||||
"dtype": "fp8",
|
||||
},
|
||||
"enable_attention_dp": True,
|
||||
"moe_config": {
|
||||
"backend": "WIDEEP",
|
||||
},
|
||||
"cuda_graph_config": {
|
||||
"enable_padding": True,
|
||||
"batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512],
|
||||
},
|
||||
"attention_dp_config": {
|
||||
"enable_balance": True,
|
||||
"batching_wait_iters": 10,
|
||||
"timeout_iters": 500,
|
||||
},
|
||||
"speculative_config": {
|
||||
"decoding_type": "MTP",
|
||||
"num_nextn_predict_layers": 1,
|
||||
},
|
||||
}
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
extra_config_path = os.path.join(temp_dir, "deepseek_r1_mtp_config.yaml")
|
||||
with open(extra_config_path, 'w') as f:
|
||||
yaml.dump(extra_config, f)
|
||||
|
||||
try:
|
||||
runner = BenchRunner(
|
||||
llm_root=llm_root,
|
||||
llm_venv=llm_venv,
|
||||
model_name=model_name,
|
||||
model_subdir=model_path,
|
||||
streaming=False,
|
||||
use_pytorch_backend=True,
|
||||
use_mpirun=False,
|
||||
tp_size=8,
|
||||
ep_size=8,
|
||||
concurrency=512,
|
||||
num_requests=512,
|
||||
max_batch_size=512,
|
||||
max_num_tokens=4608,
|
||||
warmup=0,
|
||||
eos_id=1,
|
||||
kv_cache_free_gpu_mem_fraction=0.1,
|
||||
scheduler_policy="max_utilization",
|
||||
extra_llm_api_options=extra_config_path,
|
||||
input_mean=1000,
|
||||
output_mean=10000,
|
||||
input_stdev=0,
|
||||
output_stdev=0,
|
||||
)
|
||||
result = runner()
|
||||
print(f"Benchmark result: {result}")
|
||||
finally:
|
||||
if os.path.exists(extra_config_path):
|
||||
os.remove(extra_config_path)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path,gpu_count", [
|
||||
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B", 8),
|
||||
|
||||
@ -113,6 +113,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_pp4_mtp1] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (60)
|
||||
- test_e2e.py::test_deepseek_r1_mtp_bench TIMEOUT(60) # Cover https://nvbugs/5670108
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user