mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-10053][feat] AutoDeploy: Add Super v3 config file, improve test runtime (#10397)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
This commit is contained in:
parent
225d3a9001
commit
e98c27ee4f
52
examples/auto_deploy/super_v3.yaml
Normal file
52
examples/auto_deploy/super_v3.yaml
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
runtime: trtllm
|
||||||
|
compile_backend: torch-cudagraph
|
||||||
|
max_batch_size: 384
|
||||||
|
max_seq_len: 65536 # tunable
|
||||||
|
enable_chunked_prefill: true
|
||||||
|
attn_backend: flashinfer
|
||||||
|
model_factory: AutoModelForCausalLM
|
||||||
|
skip_loading_weights: false
|
||||||
|
free_mem_ratio: 0.9
|
||||||
|
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
|
||||||
|
kv_cache_config:
|
||||||
|
# disable kv_cache reuse since not supported for hybrid/ssm models
|
||||||
|
enable_block_reuse: false
|
||||||
|
transforms:
|
||||||
|
detect_sharding:
|
||||||
|
allreduce_strategy: SYMM_MEM
|
||||||
|
sharding_dims: ['ep', 'bmm']
|
||||||
|
manual_config:
|
||||||
|
head_dim: 128
|
||||||
|
tp_plan:
|
||||||
|
# mamba SSM layer
|
||||||
|
"in_proj": "mamba"
|
||||||
|
"out_proj": "rowwise"
|
||||||
|
# attention layer
|
||||||
|
"q_proj": "colwise"
|
||||||
|
"k_proj": "colwise"
|
||||||
|
"v_proj": "colwise"
|
||||||
|
"o_proj": "rowwise"
|
||||||
|
# NOTE: consider not sharding shared experts and/or
|
||||||
|
# latent projections at all, keeping them replicated.
|
||||||
|
# To do so, comment out the corresponding entries.
|
||||||
|
# moe layer: SHARED experts
|
||||||
|
"up_proj": "colwise"
|
||||||
|
"down_proj": "rowwise"
|
||||||
|
# MoLE: latent projections: simple shard
|
||||||
|
"fc1_latent_proj": "gather"
|
||||||
|
"fc2_latent_proj": "gather"
|
||||||
|
multi_stream_moe:
|
||||||
|
stage: compile
|
||||||
|
enabled: false
|
||||||
|
# tunable mamba cache dtype
|
||||||
|
# --> use float32 for accuracy and default (null) for speed
|
||||||
|
insert_cached_ssm_attention:
|
||||||
|
cache_config:
|
||||||
|
# mamba_dtype: float32
|
||||||
|
mamba_dtype: null
|
||||||
|
gather_logits_before_lm_head:
|
||||||
|
# TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
|
||||||
|
enabled: true
|
||||||
|
fuse_mamba_a_log:
|
||||||
|
stage: post_load_fusion
|
||||||
|
enabled: true
|
||||||
@ -234,8 +234,16 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
|||||||
|
|
||||||
|
|
||||||
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
||||||
|
"""Accuracy regression tests for Nemotron Super V3.
|
||||||
|
|
||||||
|
Runs the model via AutoDeploy and verifies benchmark performance on MMLU and GSM8K
|
||||||
|
"""
|
||||||
|
|
||||||
MODEL_NAME = "nvidia/Nemotron-Super-V3"
|
MODEL_NAME = "nvidia/Nemotron-Super-V3"
|
||||||
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev"
|
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev"
|
||||||
|
# Set minimum possible seq len + small buffer, for test speed & memory usage
|
||||||
|
MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN,
|
||||||
|
GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN)
|
||||||
|
|
||||||
def get_default_kwargs(self):
|
def get_default_kwargs(self):
|
||||||
return {
|
return {
|
||||||
@ -243,10 +251,10 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
|
|||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"skip_loading_weights": False,
|
"skip_loading_weights": False,
|
||||||
"compile_backend": "torch-cudagraph",
|
"compile_backend": "torch-cudagraph",
|
||||||
"free_mem_ratio": 0.5, # maybe we can increase
|
"free_mem_ratio": 0.9,
|
||||||
"max_batch_size": 128,
|
"max_batch_size": 128,
|
||||||
"max_seq_len": 8192,
|
"max_seq_len": self.MAX_SEQ_LEN,
|
||||||
"max_num_tokens": 8192,
|
"max_num_tokens": self.MAX_SEQ_LEN,
|
||||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||||
"transforms": {
|
"transforms": {
|
||||||
"detect_sharding": {
|
"detect_sharding": {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user