[TRTLLM-6825][fix] Update lora for phi4-mm (#6817)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-08-22 10:00:04 +08:00 committed by GitHub
parent c5036cb536
commit 07c711eb1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 34 additions and 26 deletions

View File

@ -611,23 +611,21 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
@staticmethod
def lora_config(model_dir: str):
_lora_config = LoraConfig(
lora_dir=[
f"{model_dir}/vision-lora",
f"{model_dir}/speech-lora",
],
lora_target_modules=[
"attn_qkv",
"attn_dense",
"mlp_h_to_4h",
"mlp_gate_up",
"mlp_4h_to_h",
],
trtllm_modules_to_hf_modules={
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_gate_up": "gate_up_proj",
"mlp_4h_to_h": "down_proj",
},
max_lora_rank=320, # Max rank for Phi4MM.
swap_gate_up_proj_lora_b_weight=
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
)
return _lora_config

View File

@ -514,7 +514,8 @@ def create_py_executor_instance(
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
model_engine.set_lora_model_config(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)
max_num_sequences = executor_config.max_batch_size * mapping.pp_size

View File

@ -468,13 +468,16 @@ class PyTorchModelEngine(ModelEngine):
def runtime_draft_len(self):
return self.max_draft_len if self.enable_spec_decode else 0
def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
def set_lora_model_config(self,
lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str],
swap_gate_up_proj_lora_b_weight: bool = True):
self.lora_model_config = LoraModelConfig(
lora_target_modules=lora_target_modules,
trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules,
hidden_size=self.model.config.hidden_size,
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
@property
def use_mrope(self):

View File

@ -1040,7 +1040,8 @@ class PeftCacheManager(BaseResourceManager):
self._lora_model_config = LoraModelConfig(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
binding_to_str_dtype(model_config.data_type))
binding_to_str_dtype(model_config.data_type),
lora_config.swap_gate_up_proj_lora_b_weight)
self._lora_manager = LoraManager()
def add_request_peft(self, request: LlmRequest):

View File

@ -88,6 +88,7 @@ class LoraConfig(DictConversion):
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: Optional[int] = None
max_cpu_loras: Optional[int] = None
swap_gate_up_proj_lora_b_weight: bool = True
def __post_init__(self):
assert self.lora_ckpt_source in [

View File

@ -243,6 +243,7 @@ class LoraModelConfig:
trtllm_modules_to_hf_modules: dict[str, str]
hidden_size: int
dtype: str
swap_gate_up_proj_lora_b_weight: bool = True
class HfLoraLoader:
@ -968,16 +969,17 @@ class LoraManager(object):
)
hf_modules = set(hf_modules_to_trtllm_modules.keys())
def preprocess_lora_weights(lora_model):
def preprocess_lora_weights(lora_model, model_config):
# Swap weights of gate_up_proj
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
return lora_model
def load_from_model_dir(uid, model_dir, hf_config):
@ -989,7 +991,7 @@ class LoraManager(object):
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
if lora_model is None:
raise ValueError(f"Failed to load adapter_model from {model_dir}")
lora_model = preprocess_lora_weights(lora_model)
lora_model = preprocess_lora_weights(lora_model, model_config)
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
rank = int(hf_config["r"])
rs_lora = bool(hf_config.get("use_rslora", False))

View File

@ -191,15 +191,17 @@ def get_model_yaml_config(model_label: str,
}
if 'phi_4_multimodal_instruct' in model_label:
lora_config['lora_config']['lora_target_modules'] = [
"attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h"
"attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h"
]
lora_config['lora_config']['trtllm_modules_to_hf_modules'] = {
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_gate_up": "gate_up_proj",
"mlp_4h_to_h": "down_proj"
}
lora_config['lora_config']['max_lora_rank'] = 320
lora_config['lora_config'][
'swap_gate_up_proj_lora_b_weight'] = False
base_config.update(lora_config)
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())

View File

@ -2404,15 +2404,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
}
expected_keywords = {
"image": [
["image", "depicts", "mountain", "half", "rock"],
["road", "car", "lane", "traffic", "bus"],
["object", "mountain", "weather", "clear", "clouds"],
["traffic", "road", "vehicles", "cars", "bus"],
],
"audio": [
["what", "is", "the", "traffic", "sign", "in", "image"],
["what", "is", "shown", "in", "this", "image"],
],
"image_audio": [
["image", "depicts", "Grand", "rock", "scene"],
["image", "depicts", "scenic", "famous", "landmark"],
],
}