mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Fix poor generation with FP8 Gemma3 1B checkpoint (#6499)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
8cf3faa26a
commit
2eca0d5925
@ -158,25 +158,29 @@ class Gemma3Attention(Attention):
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.dtype = config.torch_dtype
|
||||
self.config = model_config.pretrained_config
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.intermediate_size = self.config.intermediate_size
|
||||
self.dtype = self.config.torch_dtype
|
||||
self.quant_config = model_config.get_quant_config()
|
||||
self.gate_proj = Linear(self.hidden_size,
|
||||
self.intermediate_size,
|
||||
bias=False,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
quant_config=self.quant_config)
|
||||
self.up_proj = Linear(self.hidden_size,
|
||||
self.intermediate_size,
|
||||
bias=False,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
quant_config=self.quant_config)
|
||||
self.down_proj = Linear(self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
dtype=self.dtype)
|
||||
self.act_fn = ACT2FN[config.hidden_activation]
|
||||
dtype=self.dtype,
|
||||
quant_config=self.quant_config)
|
||||
self.act_fn = ACT2FN[self.config.hidden_activation]
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x):
|
||||
@ -202,7 +206,7 @@ class Gemma3DecoderLayer(DecoderLayer):
|
||||
is_sliding=is_sliding,
|
||||
)
|
||||
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.mlp = Gemma3MLP(model_config=model_config)
|
||||
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
google/gemma-3-1b-it:
|
||||
- accuracy: 22.988
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 22.988
|
||||
google/gemma-3-27b-it:
|
||||
- accuracy: 28.90
|
||||
gpt2:
|
||||
|
||||
@ -100,6 +100,11 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
|
||||
- accuracy: 81.7
|
||||
google/gemma-2-9b-it:
|
||||
- accuracy: 73.05
|
||||
google/gemma-3-1b-it:
|
||||
- accuracy: 39.0
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 39.0
|
||||
google/gemma-3-27b-it:
|
||||
- accuracy: 77.80
|
||||
Qwen/Qwen2-0.5B-Instruct:
|
||||
|
||||
@ -604,6 +604,22 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
def test_fp8_prequantized(self):
|
||||
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
dtype="fp8")
|
||||
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/"
|
||||
with LLM(prequantized_model_path,
|
||||
kv_cache_config=kv_cache_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
def test_auto_dtype_vswa(self):
|
||||
# NOTE: Test with VSWA kv cache config.
|
||||
|
||||
@ -191,6 +191,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user