mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-6825][fix] Update lora for phi4-mm (#6817)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
c5036cb536
commit
07c711eb1f
@ -611,23 +611,21 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
@staticmethod
|
||||
def lora_config(model_dir: str):
|
||||
_lora_config = LoraConfig(
|
||||
lora_dir=[
|
||||
f"{model_dir}/vision-lora",
|
||||
f"{model_dir}/speech-lora",
|
||||
],
|
||||
lora_target_modules=[
|
||||
"attn_qkv",
|
||||
"attn_dense",
|
||||
"mlp_h_to_4h",
|
||||
"mlp_gate_up",
|
||||
"mlp_4h_to_h",
|
||||
],
|
||||
trtllm_modules_to_hf_modules={
|
||||
"attn_qkv": "qkv_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_up_proj",
|
||||
"mlp_gate_up": "gate_up_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
},
|
||||
max_lora_rank=320, # Max rank for Phi4MM.
|
||||
swap_gate_up_proj_lora_b_weight=
|
||||
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
|
||||
)
|
||||
return _lora_config
|
||||
|
||||
|
||||
@ -514,7 +514,8 @@ def create_py_executor_instance(
|
||||
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
|
||||
model_engine.set_lora_model_config(
|
||||
lora_config.lora_target_modules,
|
||||
lora_config.trtllm_modules_to_hf_modules)
|
||||
lora_config.trtllm_modules_to_hf_modules,
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
|
||||
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
|
||||
|
||||
|
||||
@ -468,13 +468,16 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def runtime_draft_len(self):
|
||||
return self.max_draft_len if self.enable_spec_decode else 0
|
||||
|
||||
def set_lora_model_config(self, lora_target_modules: list[str],
|
||||
trtllm_modules_to_hf_modules: dict[str, str]):
|
||||
def set_lora_model_config(self,
|
||||
lora_target_modules: list[str],
|
||||
trtllm_modules_to_hf_modules: dict[str, str],
|
||||
swap_gate_up_proj_lora_b_weight: bool = True):
|
||||
self.lora_model_config = LoraModelConfig(
|
||||
lora_target_modules=lora_target_modules,
|
||||
trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules,
|
||||
hidden_size=self.model.config.hidden_size,
|
||||
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
|
||||
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
|
||||
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
|
||||
|
||||
@property
|
||||
def use_mrope(self):
|
||||
|
||||
@ -1040,7 +1040,8 @@ class PeftCacheManager(BaseResourceManager):
|
||||
self._lora_model_config = LoraModelConfig(
|
||||
lora_config.lora_target_modules,
|
||||
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
|
||||
binding_to_str_dtype(model_config.data_type))
|
||||
binding_to_str_dtype(model_config.data_type),
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
self._lora_manager = LoraManager()
|
||||
|
||||
def add_request_peft(self, request: LlmRequest):
|
||||
|
||||
@ -88,6 +88,7 @@ class LoraConfig(DictConversion):
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||
max_loras: Optional[int] = None
|
||||
max_cpu_loras: Optional[int] = None
|
||||
swap_gate_up_proj_lora_b_weight: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lora_ckpt_source in [
|
||||
|
||||
@ -243,6 +243,7 @@ class LoraModelConfig:
|
||||
trtllm_modules_to_hf_modules: dict[str, str]
|
||||
hidden_size: int
|
||||
dtype: str
|
||||
swap_gate_up_proj_lora_b_weight: bool = True
|
||||
|
||||
|
||||
class HfLoraLoader:
|
||||
@ -968,16 +969,17 @@ class LoraManager(object):
|
||||
)
|
||||
hf_modules = set(hf_modules_to_trtllm_modules.keys())
|
||||
|
||||
def preprocess_lora_weights(lora_model):
|
||||
def preprocess_lora_weights(lora_model, model_config):
|
||||
# Swap weights of gate_up_proj
|
||||
for key, value in lora_model.items():
|
||||
if "gate_up_proj.lora_B.weight" in key:
|
||||
original_weights = value.contiguous().clone()
|
||||
half_split = original_weights.shape[0] // 2
|
||||
first_half = original_weights[:half_split, :]
|
||||
second_half = original_weights[half_split:, :]
|
||||
value = torch.cat((second_half, first_half), dim=0)
|
||||
lora_model[key] = value
|
||||
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
|
||||
for key, value in lora_model.items():
|
||||
if "gate_up_proj.lora_B.weight" in key:
|
||||
original_weights = value.contiguous().clone()
|
||||
half_split = original_weights.shape[0] // 2
|
||||
first_half = original_weights[:half_split, :]
|
||||
second_half = original_weights[half_split:, :]
|
||||
value = torch.cat((second_half, first_half), dim=0)
|
||||
lora_model[key] = value
|
||||
return lora_model
|
||||
|
||||
def load_from_model_dir(uid, model_dir, hf_config):
|
||||
@ -989,7 +991,7 @@ class LoraManager(object):
|
||||
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
|
||||
if lora_model is None:
|
||||
raise ValueError(f"Failed to load adapter_model from {model_dir}")
|
||||
lora_model = preprocess_lora_weights(lora_model)
|
||||
lora_model = preprocess_lora_weights(lora_model, model_config)
|
||||
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
|
||||
rank = int(hf_config["r"])
|
||||
rs_lora = bool(hf_config.get("use_rslora", False))
|
||||
|
||||
@ -191,15 +191,17 @@ def get_model_yaml_config(model_label: str,
|
||||
}
|
||||
if 'phi_4_multimodal_instruct' in model_label:
|
||||
lora_config['lora_config']['lora_target_modules'] = [
|
||||
"attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h"
|
||||
"attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h"
|
||||
]
|
||||
lora_config['lora_config']['trtllm_modules_to_hf_modules'] = {
|
||||
"attn_qkv": "qkv_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_up_proj",
|
||||
"mlp_gate_up": "gate_up_proj",
|
||||
"mlp_4h_to_h": "down_proj"
|
||||
}
|
||||
lora_config['lora_config']['max_lora_rank'] = 320
|
||||
lora_config['lora_config'][
|
||||
'swap_gate_up_proj_lora_b_weight'] = False
|
||||
base_config.update(lora_config)
|
||||
|
||||
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())
|
||||
|
||||
@ -2404,15 +2404,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
|
||||
}
|
||||
expected_keywords = {
|
||||
"image": [
|
||||
["image", "depicts", "mountain", "half", "rock"],
|
||||
["road", "car", "lane", "traffic", "bus"],
|
||||
["object", "mountain", "weather", "clear", "clouds"],
|
||||
["traffic", "road", "vehicles", "cars", "bus"],
|
||||
],
|
||||
"audio": [
|
||||
["what", "is", "the", "traffic", "sign", "in", "image"],
|
||||
["what", "is", "shown", "in", "this", "image"],
|
||||
],
|
||||
"image_audio": [
|
||||
["image", "depicts", "Grand", "rock", "scene"],
|
||||
["image", "depicts", "scenic", "famous", "landmark"],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user