mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5569713][fix] Disable fp8 deep gemm for EXAONE-4.0-32B-FP8 (#8429)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
This commit is contained in:
parent
f256eb9063
commit
0acdecb2c3
@ -5,6 +5,7 @@ from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
@ -54,7 +55,8 @@ class Exaone4Attention(QKNormRoPEAttention):
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[Exaone4Config],
|
||||
layer_idx: Optional[int] = None,
|
||||
fuse_qk_norm_rope: bool = False):
|
||||
fuse_qk_norm_rope: bool = False,
|
||||
disable_deep_gemm: bool = False):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
self.attention_window_size = None
|
||||
@ -88,6 +90,7 @@ class Exaone4Attention(QKNormRoPEAttention):
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -128,9 +131,17 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
|
||||
)
|
||||
|
||||
disable_deep_gemm = False
|
||||
quant_config = getattr(model_config, "quant_config", None)
|
||||
if quant_config is not None:
|
||||
# EXAONE4 fp8 has an illegal memory access issue with deep_gemm.
|
||||
disable_deep_gemm = getattr(quant_config, "quant_algo",
|
||||
None) == QuantAlgo.FP8_BLOCK_SCALES
|
||||
|
||||
self.self_attn = Exaone4Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
)
|
||||
|
||||
self.mlp = GatedMLP(
|
||||
@ -140,6 +151,7 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
|
||||
@ -73,6 +73,7 @@ l0_b200:
|
||||
- unittest/_torch/modeling -k "modeling_llama"
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -51,8 +54,9 @@ EXAONE4_SINGLE_LAYER_CONFIG = {
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "exaone4",
|
||||
"num_attention_heads": 40,
|
||||
"num_hidden_layers":
|
||||
4, #NOTE: For testing, we use 4 instead of 64(all layers)
|
||||
# NOTE: For testing, we use 32 instead of 64(all layers)
|
||||
# Increase from 4 to 32 to trigger the deep_gemm kernel issue
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"pad_token_id": 0,
|
||||
"rms_norm_eps": 1e-05,
|
||||
@ -74,6 +78,15 @@ EXAONE4_SINGLE_LAYER_CONFIG = {
|
||||
"attn_implementation": "flash_attention_2"
|
||||
}
|
||||
|
||||
EXAONE4_FP8_QUANT_CONFIG = {
|
||||
"quantization_config": {
|
||||
"activation_scheme": "dynamic",
|
||||
"modules_to_not_convert": None,
|
||||
"quant_method": "fp8",
|
||||
"weight_block_size": [128, 128]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
@ -390,3 +403,30 @@ class TestEXAONE4(unittest.TestCase):
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@parameterized.expand([None, "FP8"])
|
||||
def test_llm_load(self, quant_algo):
|
||||
|
||||
def dump_config_json(dst_dir, config):
|
||||
if os.path.exists(dst_dir):
|
||||
shutil.rmtree(dst_dir)
|
||||
os.makedirs(dst_dir)
|
||||
|
||||
dst_path = os.path.join(dst_dir, 'config.json')
|
||||
with open(dst_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
|
||||
if quant_algo == "FP8":
|
||||
if getSMVersion() < 89:
|
||||
self.skipTest(
|
||||
"This test is not supported in pre-Ada architecture")
|
||||
|
||||
config_dict.update(EXAONE4_FP8_QUANT_CONFIG)
|
||||
|
||||
tmp_model_dir = f"/tmp/exaone4_llm_load_test_model"
|
||||
dump_config_json(tmp_model_dir, config_dict)
|
||||
try:
|
||||
tensorrt_llm.LLM(model=tmp_model_dir, load_format="dummy")
|
||||
except Exception:
|
||||
raise RuntimeError("Failed to load model.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user