feat: Support Gemma3-1b-it in Pytorch workflow (#3999)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-05-13 23:02:44 -07:00 committed by GitHub
parent 86ae506b9d
commit 8280c3d4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 412 additions and 9 deletions

View File

@ -47,7 +47,8 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
| Architecture | Model | HuggingFace Example | Modality |
|--------------|-------|---------------------|----------|
| `BertForSequenceClassification` | BERT-based | `textattack/bert-base-uncased-yelp-polarity` | L |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3 `| L |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3` | L |
| `Gemma3ForCausalLM` | Gemma3 | `google/gemma-3-1b-it` | L |
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
| `LlamaForCausalLM` | Llama 3 <br> Llama 3.1 <br> Llama 2 <br> LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |

View File

@ -667,6 +667,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
latent_cache: Optional[torch.Tensor] = None,
q_pe: Optional[torch.Tensor] = None,
mrope_config: Optional[dict] = None,
attention_window_size: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
assert isinstance(
@ -687,7 +688,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
max_sequence_length=metadata.max_seq_len,
max_context_length=min(metadata.max_seq_len - 1,
metadata.max_num_tokens),
attention_window_size=None,
attention_window_size=attention_window_size,
sink_token_length=0,
beam_width=1,
sequence_length=metadata.kv_lens_cuda_runtime,

View File

@ -4,6 +4,7 @@ from .modeling_auto import AutoModelForCausalLM
from .modeling_bert import BertForSequenceClassification
from .modeling_clip import CLIPVisionModel
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_gemma3 import Gemma3ForCausalLM
from .modeling_llama import LlamaForCausalLM
from .modeling_llava_next import LlavaNextModel
from .modeling_mistral import MistralForCausalLM
@ -27,6 +28,7 @@ __all__ = [
"BertForSequenceClassification",
"CLIPVisionModel",
"DeepseekV3ForCausalLM",
"Gemma3ForCausalLM",
"LlamaForCausalLM",
"LlavaNextModel",
"MistralForCausalLM",

View File

@ -0,0 +1,378 @@
import math
from typing import Dict, Optional, Tuple
import torch
from torch import nn
from tqdm import tqdm
from transformers import Gemma3TextConfig
from transformers.activations import ACT2FN
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..modules.attention import Attention, QkNormType
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.linear import Linear, TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
duplicate_kv_weight, register_auto_model)
class Gemma3Attention(Attention):
def __init__(
self,
model_config: ModelConfig[Gemma3TextConfig],
layer_idx: Optional[int] = None,
is_sliding: bool = False,
):
self.is_sliding = is_sliding
config = model_config.pretrained_config
rope_params = RopeParams.from_config(config)
self.attention_window_size = None
if is_sliding:
rope_params.theta = 10000
self.attention_window_size = config.sliding_window
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=rope_params,
)
q_scaling = math.sqrt(config.query_pre_attn_scalar) / math.sqrt(
config.head_dim)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
config=model_config,
qk_norm_type=QkNormType.pre_rope,
q_scaling=q_scaling,
)
self.q_norm = RMSNorm(hidden_size=config.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.k_norm = RMSNorm(hidden_size=config.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def forward(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
attention_window_size = self.attention_window_size or attn_metadata.max_seq_len
return super().forward(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=attention_window_size,
**kwargs)
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
class Gemma3MLP(nn.Module):
def __init__(self, config: Gemma3TextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.dtype = config.torch_dtype
self.gate_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
self.up_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
self.down_proj = Linear(self.intermediate_size,
self.hidden_size,
bias=False,
dtype=self.dtype)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Gemma3DecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[Gemma3TextConfig],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
self.layer_idx = layer_idx
config = model_config.pretrained_config
is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
self.self_attn = Gemma3Attention(
model_config,
layer_idx=layer_idx,
is_sliding=is_sliding,
)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.pre_feedforward_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_feedforward_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(DecoderModel):
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
super().__init__(model_config)
config = self.model_config
self.hidden_size = config.pretrained_config.hidden_size
self.padding_idx = config.pretrained_config.pad_token_id
self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
Gemma3DecoderLayer(
model_config,
layer_idx,
) for layer_idx in range(config.pretrained_config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.pretrained_config.hidden_size,
eps=config.pretrained_config.rms_norm_eps,
dtype=config.pretrained_config.torch_dtype)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds * math.sqrt(self.hidden_size)
hidden_states = inputs_embeds.to(self.dtype)
for decoder_layer in self.layers:
hidden_states = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata)
hidden_states = self.norm(hidden_states)
return hidden_states
@register_auto_model("Gemma3ForCausalLM")
class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
Gemma3TextConfig]):
def __init__(
self,
model_config: ModelConfig[Gemma3TextConfig],
):
super().__init__(Gemma3TextModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pipeline_interface: Optional[PipelineInterface] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
if self._supports_pp and self.pp_size > 1:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pipeline_interface=pipeline_interface,
)
# No need to compute logits for non-last PP ranks
if self.pp_rank < self.pp_size - 1:
return output
else:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)
return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
)
# This is a modified version of the load_weights function in modeling_utils.py with the
# minor change for Gemma3 RMSNorm.
def load_weights(self, weights: Dict):
tp_size = self.model_config.mapping.tp_size
head_dim = getattr(
self.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
def filter_weights(prefix, weights: Dict):
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result
params_map = {
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
}
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if self.config.tie_word_embeddings and name.startswith(
"lm_head"):
continue
# Skip loading weights for embedding and lm_head if LoRA is enabled.
if hasattr(
self.model_config, 'lora_config'
) and self.model_config.lora_config is not None and len(
self.model_config.lora_config.lora_dir) == 1 and (
name == "model.embed_tokens" or name == "lm_head"):
continue
names = name.split('.')
if names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
if new_name in ['k_proj', 'v_proj']:
fw = {
k:
duplicate_kv_weight(
weight=v[:],
head_dim=head_dim,
tensor_parallel_size=tp_size)
if k in ["weight", "bias"] else v
for k, v in fw.items()
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
# Gemma3 RMSNorm uses +1 just like LayerNorm-1P.
if 'norm' in names[-1]:
p.data.copy_(module_weights[n][:] + 1)
else:
p.data.copy_(module_weights[n][:])

View File

@ -42,6 +42,7 @@ class Attention(nn.Module):
dense_bias: Optional[bool] = None,
config: Optional[ModelConfig] = None,
qk_norm_type: QkNormType = QkNormType.none,
q_scaling: float = 1.0,
):
super().__init__()
self.layer_idx = layer_idx
@ -57,6 +58,7 @@ class Attention(nn.Module):
self.pos_embd_params = pos_embd_params
self.qk_norm_type = qk_norm_type
self.dense_bias = dense_bias
self.q_scaling = q_scaling
if dense_bias is None:
self.dense_bias = bias
@ -108,6 +110,7 @@ class Attention(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
self.pos_embd_params = pos_embd_params
@ -138,6 +141,7 @@ class Attention(nn.Module):
if self.enable_rope_fusion else None,
quant_config=self.quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
)
self.support_fused_qkv = self.attn.support_fused_qkv()
@ -183,6 +187,7 @@ class Attention(nn.Module):
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_window_size: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
@ -213,13 +218,15 @@ class Attention(nn.Module):
out_scale = self.o_proj.inv_input_scale
q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)
attn_output = self.attn.forward(
q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config,
attention_window_size=attention_window_size)
hidden_states = attn_output
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params,

View File

@ -1,3 +1,5 @@
google/gemma-3-1b-it:
- accuracy: 22.988
gpt2:
- accuracy: 18.408
- quant_algo: W8A16

View File

@ -252,6 +252,16 @@ class TestMistral7B(LlmapiAccuracyTestHarness):
task.evaluate(llm)
class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "google/gemma-3-1b-it"
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
class TestMixtral8x7B(LlmapiAccuracyTestHarness):
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
MODEL_PATH = f"{llm_models_root()}/Mixtral-8x7B-v0.1"

View File

@ -426,6 +426,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8-cuda_graph=False]
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep4-cuda_graph=True]
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep8-cuda_graph=True]

View File

@ -20,6 +20,7 @@ l0_h100:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modeling -k "modeling_mixtral"
- unittest/_torch/modeling -k "modeling_nemotron"
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=False]