mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10733][feat] Make TRTLLM MOE the default one for GPTOSS on Blackwell (#11074)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
parent
ef268e2062
commit
4f0c1b2489
@ -28,8 +28,7 @@ There are multiple MOE backends inside TensorRT LLM. Here are the support matrix
|
||||
| B200/GB200/B300/GB300 | MXFP8 | MXFP4 | TRTLLM | Low Latency and Max Throughput |
|
||||
| H200 | BF16 | MXFP4 | TRITON | Low Latency and Max Throughput |
|
||||
|
||||
The default moe backend is `CUTLASS`, so for the best possible perf, one must set the `moe_config.backend` explicitly to run the model.
|
||||
For Blackwell, `CUTLASS` was better for max throughput at first but now we have optimized `TRTLLM` moe to be universally faster. For Hopper, Triton is the faster backend.
|
||||
For Blackwell, the default MoE backend is `TRTLLM`. For Hopper, the default MoE backend is `TRITON`. They are recommended for the best perf. Users don't need to explicitly set `moe_config.backend`.
|
||||
|
||||
## Deployment Steps
|
||||
|
||||
|
||||
@ -55,13 +55,17 @@ def add_llm_args(parser):
|
||||
'VANILLA', 'TRTLLM', 'FLASHINFER',
|
||||
'FLASHINFER_STAR_ATTENTION'
|
||||
])
|
||||
parser.add_argument('--moe_backend',
|
||||
type=str,
|
||||
default='CUTLASS',
|
||||
choices=[
|
||||
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
|
||||
'DEEPGEMM', 'CUTEDSL', 'TRITON'
|
||||
])
|
||||
parser.add_argument(
|
||||
'--moe_backend',
|
||||
type=str,
|
||||
default='AUTO',
|
||||
choices=[
|
||||
'AUTO', 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM',
|
||||
'CUTEDSL', 'TRITON'
|
||||
],
|
||||
help=
|
||||
'MoE backend to use. AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases.'
|
||||
)
|
||||
parser.add_argument('--enable_attention_dp',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
@ -229,6 +229,32 @@ class ModelConfig(Generic[TConfig]):
|
||||
# TODO: should be 'not model_type == ModelType.ENCODER_ONLY'
|
||||
# once ModelType is used in pytorch flow.
|
||||
|
||||
@staticmethod
|
||||
def resolve_moe_backend(moe_backend: str, architecture: str) -> str:
|
||||
"""Resolve AUTO moe_backend to a specific backend based on model architecture.
|
||||
|
||||
Args:
|
||||
moe_backend: The configured moe_backend (may be "AUTO")
|
||||
architecture: The model architecture name (e.g., "GptOssForCausalLM")
|
||||
|
||||
Returns:
|
||||
Resolved backend name (never "AUTO")
|
||||
"""
|
||||
if moe_backend.upper() != "AUTO":
|
||||
return moe_backend
|
||||
|
||||
if architecture == "GptOssForCausalLM":
|
||||
sm_version = get_sm_version()
|
||||
# Select the best performing backend based on SM version
|
||||
if 100 <= sm_version < 120: # Blackwell
|
||||
return "TRTLLM"
|
||||
elif 90 <= sm_version < 100: # Hopper
|
||||
return "TRITON"
|
||||
else:
|
||||
return "CUTLASS" # Fallback to CUTLASS for other SM versions (e.g., SM120)
|
||||
|
||||
return "CUTLASS"
|
||||
|
||||
@staticmethod
|
||||
def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
|
||||
moe_backend):
|
||||
@ -566,7 +592,12 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
quant_config = QuantConfig()
|
||||
layer_quant_config = None
|
||||
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
|
||||
moe_backend = kwargs.get('moe_backend', 'AUTO')
|
||||
# Resolve AUTO to specific backend based on model architecture
|
||||
architecture = pretrained_config.architectures[
|
||||
0] if pretrained_config.architectures else ""
|
||||
moe_backend = cls.resolve_moe_backend(moe_backend, architecture)
|
||||
kwargs['moe_backend'] = moe_backend
|
||||
|
||||
# quantized ckpt in modelopt format
|
||||
if quant_config_file := cached_file(checkpoint_dir,
|
||||
|
||||
@ -444,10 +444,13 @@ class MoeConfig(StrictBaseModel):
|
||||
"""
|
||||
Configuration for MoE.
|
||||
"""
|
||||
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM",
|
||||
"VANILLA",
|
||||
"TRITON"] = Field(default='CUTLASS',
|
||||
description="MoE backend to use.")
|
||||
backend: Literal[
|
||||
"AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", "VANILLA",
|
||||
"TRITON"] = Field(
|
||||
default='AUTO',
|
||||
description="MoE backend to use. "
|
||||
"AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases."
|
||||
)
|
||||
|
||||
max_num_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user