mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5661741][feat] Add 250K-token NVFP4 MoE + PDL regression tests (#10911)
Signed-off-by: yingguo-trt <244492186+yingguo-trt@users.noreply.github.com>
This commit is contained in:
parent
2d8245d125
commit
c8f1745a6e
@ -12,12 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
|
||||
|
||||
@ -3151,6 +3153,207 @@ class TestKimiK2(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_device_memory(183000)
|
||||
@pytest.mark.timeout(14400)
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*Calling super.*encode.*add_special_tokens.*:UserWarning")
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning"
|
||||
)
|
||||
def test_nvfp4_longseq_trtllm_moe_stress(self, mocker):
|
||||
"""
|
||||
Long-sequence MoE stress test with PDL enabled.
|
||||
RCCA: https://nvbugspro.nvidia.com/bug/5661741
|
||||
"""
|
||||
patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"})
|
||||
model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4"
|
||||
target_len = 250000
|
||||
kv_cache_config = KvCacheConfig(
|
||||
dtype="fp8",
|
||||
free_gpu_memory_fraction=0.75,
|
||||
enable_block_reuse=True,
|
||||
enable_partial_reuse=False,
|
||||
event_buffer_max_size=1024,
|
||||
)
|
||||
|
||||
with LLM(
|
||||
model_path,
|
||||
tensor_parallel_size=8,
|
||||
moe_expert_parallel_size=4,
|
||||
moe_config=MoeConfig(backend="TRTLLM"),
|
||||
enable_chunked_prefill=True,
|
||||
trust_remote_code=True,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=8192,
|
||||
max_seq_len=262144,
|
||||
max_batch_size=32,
|
||||
enable_attention_dp=True,
|
||||
) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
# Build long token sequences from dataset
|
||||
tokenizer = llm.tokenizer
|
||||
dataset_path = f"{llm_models_root()}/datasets/Crystalcareai/Code-feedback-sharegpt-renamed"
|
||||
dataset = load_dataset(dataset_path, split="train[:2000]")
|
||||
long_token_list = []
|
||||
for row in dataset:
|
||||
msg = row["messages"][0]["value"]
|
||||
tokens = tokenizer.encode(msg, add_special_tokens=False)
|
||||
if not tokens:
|
||||
continue
|
||||
repeat = target_len // len(tokens) + 1
|
||||
long_tokens = (tokens * repeat)[:target_len]
|
||||
long_token_list.append(long_tokens)
|
||||
assert len(long_token_list) > 0, "No valid samples found"
|
||||
|
||||
samples_per_batch = 8
|
||||
sampling_params_greedy = SamplingParams(max_tokens=8)
|
||||
sampling_params_sampling = SamplingParams(max_tokens=8,
|
||||
temperature=0.8,
|
||||
top_p=0.95)
|
||||
|
||||
num_samples = len(long_token_list)
|
||||
max_batch_count = 15
|
||||
|
||||
for batch_idx in range(max_batch_count):
|
||||
start_idx = (batch_idx * samples_per_batch) % num_samples
|
||||
indices = [(start_idx + i) % num_samples
|
||||
for i in range(samples_per_batch)]
|
||||
batch_inputs = [long_token_list[i] for i in indices]
|
||||
|
||||
for output in llm.generate(
|
||||
batch_inputs, sampling_params=sampling_params_greedy):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
assert len(token_ids) > 0
|
||||
assert not all(tid == 0 for tid in token_ids)
|
||||
|
||||
for output in llm.generate(
|
||||
batch_inputs, sampling_params=sampling_params_sampling):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
assert len(token_ids) > 0
|
||||
assert not all(tid == 0 for tid in token_ids)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_device_memory(183000)
|
||||
@pytest.mark.timeout(14400)
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*Calling super.*encode.*add_special_tokens.*:UserWarning")
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:.*configuration is not supported by the fused routing kernel.*:UserWarning"
|
||||
)
|
||||
def test_nvfp4_longseq_trtllm_moe_async_cancel(self, mocker):
|
||||
"""
|
||||
Long-sequence MoE async streaming test with cancellation.
|
||||
RCCA: https://nvbugspro.nvidia.com/bug/5661741
|
||||
"""
|
||||
patch_mpi_pool_session_for_env(mocker, {"TRTLLM_ENABLE_PDL": "1"})
|
||||
model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4"
|
||||
target_len = 250000
|
||||
kv_cache_config = KvCacheConfig(
|
||||
dtype="fp8",
|
||||
free_gpu_memory_fraction=0.75,
|
||||
enable_block_reuse=True,
|
||||
enable_partial_reuse=False,
|
||||
event_buffer_max_size=1024,
|
||||
)
|
||||
|
||||
with LLM(
|
||||
model_path,
|
||||
tensor_parallel_size=8,
|
||||
moe_expert_parallel_size=4,
|
||||
moe_config=MoeConfig(backend="TRTLLM"),
|
||||
enable_chunked_prefill=True,
|
||||
trust_remote_code=True,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=4096,
|
||||
max_seq_len=262144,
|
||||
max_batch_size=8,
|
||||
enable_attention_dp=True,
|
||||
) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
# Build long token sequences from dataset
|
||||
tokenizer = llm.tokenizer
|
||||
dataset_path = f"{llm_models_root()}/datasets/Crystalcareai/Code-feedback-sharegpt-renamed"
|
||||
dataset = load_dataset(dataset_path, split="train[:2000]")
|
||||
long_token_list = []
|
||||
for row in dataset:
|
||||
msg = row["messages"][0]["value"]
|
||||
tokens = tokenizer.encode(msg, add_special_tokens=False)
|
||||
if not tokens:
|
||||
continue
|
||||
repeat = target_len // len(tokens) + 1
|
||||
long_tokens = (tokens * repeat)[:target_len]
|
||||
long_token_list.append(long_tokens)
|
||||
assert len(long_token_list) > 0, "No valid samples found"
|
||||
|
||||
async_batch_size = 6
|
||||
num_async_batches = 3
|
||||
cancel_ratio = 0.5
|
||||
num_samples = len(long_token_list)
|
||||
|
||||
async def handle_one_request(async_gen, should_cancel):
|
||||
chunks_received = 0
|
||||
max_chunks_before_cancel = 5
|
||||
try:
|
||||
async for chunk in async_gen:
|
||||
chunks_received += 1
|
||||
if chunk.outputs:
|
||||
token_ids = chunk.outputs[0].token_ids
|
||||
assert len(token_ids) > 0
|
||||
assert not all(tid == 0 for tid in token_ids)
|
||||
if should_cancel and chunks_received >= max_chunks_before_cancel:
|
||||
break
|
||||
except Exception:
|
||||
if not should_cancel:
|
||||
raise
|
||||
|
||||
async def run_streaming_with_cancellation():
|
||||
for async_batch_idx in range(num_async_batches):
|
||||
start_idx = (async_batch_idx *
|
||||
async_batch_size) % num_samples
|
||||
indices = [(start_idx + i) % num_samples
|
||||
for i in range(async_batch_size)]
|
||||
batch_inputs = [long_token_list[i] for i in indices]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=50,
|
||||
temperature=0.8,
|
||||
top_p=0.95)
|
||||
async_results = [
|
||||
llm.generate_async(inp,
|
||||
sampling_params=sampling_params,
|
||||
streaming=True)
|
||||
for inp in batch_inputs
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
handle_one_request(
|
||||
gen, idx < async_batch_size * cancel_ratio))
|
||||
for idx, gen in enumerate(async_results)
|
||||
]
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(*tasks), timeout=300)
|
||||
|
||||
asyncio.run(run_streaming_with_cancellation())
|
||||
|
||||
# Verify LLM still works after cancellations (bug 5661741 symptom check)
|
||||
verify_batch_size = 4
|
||||
verify_inputs = [
|
||||
long_token_list[i % num_samples]
|
||||
for i in range(verify_batch_size)
|
||||
]
|
||||
verify_params = SamplingParams(max_tokens=16)
|
||||
|
||||
for output in llm.generate(verify_inputs,
|
||||
sampling_params=verify_params):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
assert len(token_ids) > 0
|
||||
assert not all(tid == 0 for tid in token_ids)
|
||||
|
||||
|
||||
class TestMinitron4BBaseInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-Mini-4B-Instruct"
|
||||
|
||||
@ -3,3 +3,5 @@ stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_360
|
||||
stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-deepseek_r1_v2_fp4_stress]
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_stress]
|
||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_nvfp4_longseq_trtllm_moe_stress
|
||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_nvfp4_longseq_trtllm_moe_async_cancel
|
||||
|
||||
Loading…
Reference in New Issue
Block a user