mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8070][test] add generation logits case for llama3 (#7759)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
parent
e0423bfaab
commit
26d50eb539
@ -3270,3 +3270,84 @@ def test_multi_nodes_eval(llm_venv, model_path, tp_size, pp_size, ep_size,
|
||||
if os.environ.get("SLURM_PROCID", '0') == '0':
|
||||
mmlu_accuracy = get_mmlu_accuracy(output)
|
||||
assert mmlu_accuracy > mmlu_threshold, f"MMLU accuracy {mmlu_accuracy} is less than threshold {mmlu_threshold}"
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("return_generation_logits", [True, False])
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
("llama-3.1-model/Llama-3.1-8B-Instruct"),
|
||||
pytest.param("llama-3.3-models/Llama-3.3-70B-Instruct",
|
||||
marks=pytest.mark.skip_less_device(8)),
|
||||
])
|
||||
def test_llmapi_generation_logits(llm_venv, model_path,
|
||||
return_generation_logits):
|
||||
"""
|
||||
RCCA: https://nvbugspro.nvidia.com/bug/5501805
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
seq_len, max_tokens = 131072, 100000
|
||||
if return_generation_logits:
|
||||
# use short seq_len and max_tokens for testing when return_generation_logits is True
|
||||
seq_len, max_tokens = 1024, 1000
|
||||
tp_size = 8 if "70B" in model_path else 1
|
||||
# Model parameters
|
||||
params = {
|
||||
"cuda_graph_config": {
|
||||
"batch_sizes": [512]
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"guided_decoding_backend": "xgrammar",
|
||||
"kv_cache_config": {
|
||||
"cross_kv_cache_fraction": None,
|
||||
"enable_block_reuse": False,
|
||||
"free_gpu_memory_fraction": 0.9,
|
||||
"max_attention_window": None
|
||||
},
|
||||
"max_seq_len": seq_len,
|
||||
"tensor_parallel_size": tp_size,
|
||||
}
|
||||
|
||||
# Sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
return_context_logits=False,
|
||||
return_generation_logits=return_generation_logits,
|
||||
)
|
||||
|
||||
# Test prompt (token IDs)
|
||||
prompt = [
|
||||
128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790,
|
||||
220, 2366, 18, 198, 15724, 2696, 25, 220, 2545, 17907, 220, 2366, 20,
|
||||
271, 67, 10319, 7422, 389, 128009, 128006, 882, 128007, 271, 3923, 374,
|
||||
701, 836, 30, 128009, 128006, 78191, 128007, 271
|
||||
]
|
||||
|
||||
async def async_generation_test():
|
||||
"""Async generation test function"""
|
||||
model_path_full = f"{llm_models_root()}/{model_path}"
|
||||
llm = LLM(**params, model=model_path_full, tokenizer=model_path_full)
|
||||
|
||||
try:
|
||||
outputs = []
|
||||
async for output in llm.generate_async(
|
||||
prompt,
|
||||
sampling_params,
|
||||
streaming=True,
|
||||
):
|
||||
outputs.append(output)
|
||||
print(f"Generated: {output}")
|
||||
|
||||
# Verify that we got some output
|
||||
assert len(outputs) > 0, "No output generated"
|
||||
print(f"Successfully generated {len(outputs)} streaming outputs")
|
||||
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
# Run the async test
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(async_generation_test())
|
||||
|
||||
@ -173,3 +173,7 @@ test_e2e.py::test_ptp_quickstart_advanced[Nemotron4_4B-BF16-nemotron/Minitron-4B
|
||||
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-H-8B-Nemotron-H-8B-Base-8K]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus[DeepSeek-V3-671B-FP8-DeepSeek-V3-0324]
|
||||
test_e2e.py::test_trtllm_benchmark_serving[gpt_oss/gpt-oss-20b]
|
||||
test_e2e.py::test_llmapi_generation_logits[llama-3.1-model/Llama-3.1-8B-Instruct-True]
|
||||
test_e2e.py::test_llmapi_generation_logits[llama-3.1-model/Llama-3.1-8B-Instruct-False]
|
||||
test_e2e.py::test_llmapi_generation_logits[llama-3.3-models/Llama-3.3-70B-Instruct-True]
|
||||
test_e2e.py::test_llmapi_generation_logits[llama-3.3-models/Llama-3.3-70B-Instruct-False]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user