[https://nvbugs/5569713][fix] Disable fp8 deep gemm for EXAONE-4.0-32B-FP8 (#8429)

Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
This commit is contained in:
JunyiXu-nv 2025-10-22 00:37:56 +08:00 committed by GitHub
parent f256eb9063
commit 0acdecb2c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 3 deletions

View File

@ -5,6 +5,7 @@ from torch import nn
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.quantization import QuantAlgo
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
@ -54,7 +55,8 @@ class Exaone4Attention(QKNormRoPEAttention):
def __init__(self,
model_config: ModelConfig[Exaone4Config],
layer_idx: Optional[int] = None,
fuse_qk_norm_rope: bool = False):
fuse_qk_norm_rope: bool = False,
disable_deep_gemm: bool = False):
config = model_config.pretrained_config
self.attention_window_size = None
@ -88,6 +90,7 @@ class Exaone4Attention(QKNormRoPEAttention):
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
disable_deep_gemm=disable_deep_gemm,
)
def forward(
@ -128,9 +131,17 @@ class Exaone4DecoderLayer(DecoderLayer):
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
)
disable_deep_gemm = False
quant_config = getattr(model_config, "quant_config", None)
if quant_config is not None:
# EXAONE4 fp8 has an illegal memory access issue with deep_gemm.
disable_deep_gemm = getattr(quant_config, "quant_algo",
None) == QuantAlgo.FP8_BLOCK_SCALES
self.self_attn = Exaone4Attention(
model_config,
layer_idx=layer_idx,
disable_deep_gemm=disable_deep_gemm,
)
self.mlp = GatedMLP(
@ -140,6 +151,7 @@ class Exaone4DecoderLayer(DecoderLayer):
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
disable_deep_gemm=disable_deep_gemm,
)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,

View File

@ -73,6 +73,7 @@ l0_b200:
- unittest/_torch/modeling -k "modeling_llama"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_gpt_oss"
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
- condition:
ranges:

View File

@ -1,3 +1,6 @@
import json
import os
import shutil
import unittest
from copy import deepcopy
from dataclasses import dataclass
@ -51,8 +54,9 @@ EXAONE4_SINGLE_LAYER_CONFIG = {
"max_position_embeddings": 131072,
"model_type": "exaone4",
"num_attention_heads": 40,
"num_hidden_layers":
4, #NOTE: For testing, we use 4 instead of 64(all layers)
# NOTE: For testing, we use 32 instead of 64(all layers)
# Increase from 4 to 32 to trigger the deep_gemm kernel issue
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
@ -74,6 +78,15 @@ EXAONE4_SINGLE_LAYER_CONFIG = {
"attn_implementation": "flash_attention_2"
}
EXAONE4_FP8_QUANT_CONFIG = {
"quantization_config": {
"activation_scheme": "dynamic",
"modules_to_not_convert": None,
"quant_method": "fp8",
"weight_block_size": [128, 128]
},
}
@dataclass(repr=False)
class Scenario:
@ -390,3 +403,30 @@ class TestEXAONE4(unittest.TestCase):
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()
@parameterized.expand([None, "FP8"])
def test_llm_load(self, quant_algo):
def dump_config_json(dst_dir, config):
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
os.makedirs(dst_dir)
dst_path = os.path.join(dst_dir, 'config.json')
with open(dst_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
if quant_algo == "FP8":
if getSMVersion() < 89:
self.skipTest(
"This test is not supported in pre-Ada architecture")
config_dict.update(EXAONE4_FP8_QUANT_CONFIG)
tmp_model_dir = f"/tmp/exaone4_llm_load_test_model"
dump_config_json(tmp_model_dir, config_dict)
try:
tensorrt_llm.LLM(model=tmp_model_dir, load_format="dummy")
except Exception:
raise RuntimeError("Failed to load model.")