mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][chore] Add unit test for Gemma3 lora (#6560)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
3916dbd98b
commit
87e4e9f468
@ -14,7 +14,6 @@ from ..attention_backend import AttentionMetadata, FlashInferAttentionMetadata
|
||||
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
|
||||
PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..distributed import AllReduceParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
@ -105,9 +104,6 @@ class Gemma3Attention(Attention):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
|
||||
mrope_config: Optional[dict] = None,
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
lora_params: Optional[dict] = None,
|
||||
attention_mask_data: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -121,9 +117,6 @@ class Gemma3Attention(Attention):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=attention_mask,
|
||||
mrope_config=mrope_config,
|
||||
all_reduce_params=all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
attention_window_size=self.attention_window_size,
|
||||
attention_mask_data=attention_mask_data,
|
||||
**kwargs)
|
||||
@ -209,7 +202,6 @@ class Gemma3DecoderLayer(DecoderLayer):
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
attention_mask_data: Optional[torch.Tensor] = None,
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -222,14 +214,14 @@ class Gemma3DecoderLayer(DecoderLayer):
|
||||
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
|
||||
is not None else PredefinedAttentionMask.CAUSAL,
|
||||
attention_mask_data=attention_mask_data,
|
||||
lora_params=lora_params,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
|
||||
hidden_states = self.mlp(hidden_states,
|
||||
lora_params=kwargs.get("lora_params", None))
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -270,7 +262,6 @@ class Gemma3TextModel(DecoderModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
local_attention_mask_data: Optional[torch.Tensor] = None,
|
||||
global_attention_mask_data: Optional[torch.Tensor] = None,
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -291,7 +282,7 @@ class Gemma3TextModel(DecoderModel):
|
||||
attention_mask_data=local_attention_mask_data
|
||||
if decoder_layer.self_attn.is_sliding else
|
||||
global_attention_mask_data,
|
||||
lora_params=lora_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -465,6 +456,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel,
|
||||
inputs_embeds=inputs_embeds,
|
||||
local_attention_mask_data=local_attention_mask_data,
|
||||
global_attention_mask_data=global_attention_mask_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.logits_processor.forward(
|
||||
|
||||
@ -576,6 +576,7 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora
|
||||
examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct]
|
||||
examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1]
|
||||
examples/test_medusa.py::test_qwen_medusa_1gpu[qwen_7b_chat]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
@ -499,6 +500,62 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
assert len(outputs) == 2
|
||||
|
||||
|
||||
def test_gemma3_1b_instruct_multi_lora() -> None:
|
||||
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
|
||||
|
||||
target_modules = ['attn_q', 'attn_k', 'attn_v']
|
||||
|
||||
# Set up temporary directory for LoRA adapters
|
||||
with tempfile.TemporaryDirectory() as lora_dir:
|
||||
print("Creating dummy LoRAs...")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_dir,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto")
|
||||
hf_modules = ["q_proj", "k_proj", "v_proj"]
|
||||
peft_lora_config = PeftLoraConfig(r=8,
|
||||
target_modules=hf_modules,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM")
|
||||
lora_paths = []
|
||||
for i in range(2):
|
||||
lora_model = get_peft_model(model, peft_lora_config)
|
||||
for param in lora_model.parameters():
|
||||
param.data.zero_()
|
||||
lora_path = f"{lora_dir}/lora_{i}"
|
||||
lora_model.save_pretrained(lora_path)
|
||||
lora_paths.append(lora_path)
|
||||
|
||||
trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
|
||||
lora_target_modules=target_modules,
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
)
|
||||
llm = LLM(model_dir,
|
||||
lora_config=trtllm_lora_config,
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
prompts = [
|
||||
"Is it ok to fill diesel in a petrol car?",
|
||||
"What is the capital of France?",
|
||||
]
|
||||
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
|
||||
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
|
||||
lora_requests = [lora_req1, lora_req2]
|
||||
sampling_params = SamplingParams(max_tokens=200)
|
||||
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests)
|
||||
|
||||
assert len(outputs) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lora_rank,max_lora_rank,description",
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user