[None][chore] Add unit test for Gemma3 lora (#6560)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-08-04 01:56:57 -07:00 committed by GitHub
parent 3916dbd98b
commit 87e4e9f468
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 12 deletions

View File

@ -14,7 +14,6 @@ from ..attention_backend import AttentionMetadata, FlashInferAttentionMetadata
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
@ -105,9 +104,6 @@ class Gemma3Attention(Attention):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@ -121,9 +117,6 @@ class Gemma3Attention(Attention):
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=self.attention_window_size,
attention_mask_data=attention_mask_data,
**kwargs)
@ -209,7 +202,6 @@ class Gemma3DecoderLayer(DecoderLayer):
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
@ -222,14 +214,14 @@ class Gemma3DecoderLayer(DecoderLayer):
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
is not None else PredefinedAttentionMask.CAUSAL,
attention_mask_data=attention_mask_data,
lora_params=lora_params,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
hidden_states = self.mlp(hidden_states,
lora_params=kwargs.get("lora_params", None))
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
@ -270,7 +262,6 @@ class Gemma3TextModel(DecoderModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
local_attention_mask_data: Optional[torch.Tensor] = None,
global_attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -291,7 +282,7 @@ class Gemma3TextModel(DecoderModel):
attention_mask_data=local_attention_mask_data
if decoder_layer.self_attn.is_sliding else
global_attention_mask_data,
lora_params=lora_params,
**kwargs,
)
hidden_states = self.norm(hidden_states)
@ -465,6 +456,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
inputs_embeds=inputs_embeds,
local_attention_mask_data=local_attention_mask_data,
global_attention_mask_data=global_attention_mask_data,
**kwargs,
)
return self.logits_processor.forward(

View File

@ -576,6 +576,7 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora
examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct]
examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1]
examples/test_medusa.py::test_qwen_medusa_1gpu[qwen_7b_chat]

View File

@ -1,6 +1,7 @@
import pytest
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.sampling_params import SamplingParams
@ -499,6 +500,62 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
assert len(outputs) == 2
def test_gemma3_1b_instruct_multi_lora() -> None:
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
target_modules = ['attn_q', 'attn_k', 'attn_v']
# Set up temporary directory for LoRA adapters
with tempfile.TemporaryDirectory() as lora_dir:
print("Creating dummy LoRAs...")
model = AutoModelForCausalLM.from_pretrained(model_dir,
torch_dtype=torch.bfloat16,
device_map="auto")
hf_modules = ["q_proj", "k_proj", "v_proj"]
peft_lora_config = PeftLoraConfig(r=8,
target_modules=hf_modules,
bias="none",
task_type="CAUSAL_LM")
lora_paths = []
for i in range(2):
lora_model = get_peft_model(model, peft_lora_config)
for param in lora_model.parameters():
param.data.zero_()
lora_path = f"{lora_dir}/lora_{i}"
lora_model.save_pretrained(lora_path)
lora_paths.append(lora_path)
trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
lora_target_modules=target_modules,
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
)
llm = LLM(model_dir,
lora_config=trtllm_lora_config,
kv_cache_config=kv_cache_config)
prompts = [
"Is it ok to fill diesel in a petrol car?",
"What is the capital of France?",
]
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
lora_requests = [lora_req1, lora_req2]
sampling_params = SamplingParams(max_tokens=200)
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests)
assert len(outputs) == 2
@pytest.mark.parametrize(
"lora_rank,max_lora_rank,description",
[