fix: Fix poor generation with FP8 Gemma3 1B checkpoint (#6499)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-07-31 17:18:23 -07:00 committed by GitHub
parent 8cf3faa26a
commit 2eca0d5925
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 10 deletions

View File

@ -158,25 +158,29 @@ class Gemma3Attention(Attention):
class Gemma3MLP(nn.Module):
def __init__(self, config: Gemma3TextConfig):
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.dtype = config.torch_dtype
self.config = model_config.pretrained_config
self.hidden_size = self.config.hidden_size
self.intermediate_size = self.config.intermediate_size
self.dtype = self.config.torch_dtype
self.quant_config = model_config.get_quant_config()
self.gate_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
dtype=self.dtype,
quant_config=self.quant_config)
self.up_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
dtype=self.dtype,
quant_config=self.quant_config)
self.down_proj = Linear(self.intermediate_size,
self.hidden_size,
bias=False,
dtype=self.dtype)
self.act_fn = ACT2FN[config.hidden_activation]
dtype=self.dtype,
quant_config=self.quant_config)
self.act_fn = ACT2FN[self.config.hidden_activation]
@torch.inference_mode()
def forward(self, x):
@ -202,7 +206,7 @@ class Gemma3DecoderLayer(DecoderLayer):
is_sliding=is_sliding,
)
self.mlp = Gemma3MLP(config)
self.mlp = Gemma3MLP(model_config=model_config)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,

View File

@ -1,5 +1,8 @@
google/gemma-3-1b-it:
- accuracy: 22.988
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 22.988
google/gemma-3-27b-it:
- accuracy: 28.90
gpt2:

View File

@ -100,6 +100,11 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
- accuracy: 81.7
google/gemma-2-9b-it:
- accuracy: 73.05
google/gemma-3-1b-it:
- accuracy: 39.0
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 39.0
google/gemma-3-27b-it:
- accuracy: 77.80
Qwen/Qwen2-0.5B-Instruct:

View File

@ -604,6 +604,22 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
def test_fp8_prequantized(self):
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
enable_partial_reuse=False,
dtype="fp8")
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/"
with LLM(prequantized_model_path,
kv_cache_config=kv_cache_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
def test_auto_dtype_vswa(self):
# NOTE: Test with VSWA kv cache config.

View File

@ -191,6 +191,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]