mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Refactoring to avoid circular import when importing torch models (#6720)
Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
c9fe07ede6
commit
7ab8112450
@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from tensorrt_llm import LLM
|
from tensorrt_llm import LLM
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from tensorrt_llm.executor.request import LoRARequest
|
from tensorrt_llm.executor.request import LoRARequest
|
||||||
from tensorrt_llm.sampling_params import SamplingParams
|
from tensorrt_llm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from tensorrt_llm import LLM
|
from tensorrt_llm import LLM
|
||||||
from tensorrt_llm.executor import LoRARequest
|
from tensorrt_llm.executor import LoRARequest
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@ -33,6 +33,7 @@ import sys
|
|||||||
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
|
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
|
||||||
import xgrammar # noqa
|
import xgrammar # noqa
|
||||||
|
|
||||||
|
import tensorrt_llm._torch.models as torch_models
|
||||||
import tensorrt_llm.functional as functional
|
import tensorrt_llm.functional as functional
|
||||||
import tensorrt_llm.math_utils as math_utils
|
import tensorrt_llm.math_utils as math_utils
|
||||||
import tensorrt_llm.models as models
|
import tensorrt_llm.models as models
|
||||||
@ -82,6 +83,7 @@ __all__ = [
|
|||||||
'default_trtnet',
|
'default_trtnet',
|
||||||
'precision',
|
'precision',
|
||||||
'net_guard',
|
'net_guard',
|
||||||
|
'torch_models',
|
||||||
'Network',
|
'Network',
|
||||||
'Mapping',
|
'Mapping',
|
||||||
'MnnvlMemory',
|
'MnnvlMemory',
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from ...executor.request import LoRARequest
|
|||||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||||
register_input_processor)
|
register_input_processor)
|
||||||
from ...logger import logger
|
from ...logger import logger
|
||||||
from ...lora_manager import LoraConfig
|
from ...lora_helper import LoraConfig
|
||||||
from ...sampling_params import SamplingParams
|
from ...sampling_params import SamplingParams
|
||||||
from ..attention_backend import AttentionMetadata
|
from ..attention_backend import AttentionMetadata
|
||||||
from ..model_config import ModelConfig
|
from ..model_config import ModelConfig
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tensorrt_llm import logger
|
import tensorrt_llm.logger as trtllm_logger
|
||||||
from tensorrt_llm._utils import get_sm_version
|
from tensorrt_llm._utils import get_sm_version
|
||||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||||
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
|
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
|
||||||
@ -743,7 +743,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
|||||||
if int(name.split(".")[0]) not in expert_ids:
|
if int(name.split(".")[0]) not in expert_ids:
|
||||||
continue
|
continue
|
||||||
weight_name = name.replace("weight_scale_inv", "weight")
|
weight_name = name.replace("weight_scale_inv", "weight")
|
||||||
logger.debug(f"Resmoothing {weight_name}")
|
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
|
||||||
weight = weights[weight_name][:]
|
weight = weights[weight_name][:]
|
||||||
scale = weights[name][:]
|
scale = weights[name][:]
|
||||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||||
|
|||||||
@ -13,9 +13,9 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
|||||||
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
|
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
|
||||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||||
get_default_trtllm_modules_to_hf_modules,
|
get_default_trtllm_modules_to_hf_modules)
|
||||||
load_torch_lora)
|
from tensorrt_llm.lora_manager import load_torch_lora
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
from ..model_config import ModelConfig
|
from ..model_config import ModelConfig
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
|||||||
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
|
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
|
||||||
MultimodalRuntimeData)
|
MultimodalRuntimeData)
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
from tensorrt_llm.lora_manager import LoraModelConfig
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||||
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
|
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from tensorrt_llm._utils import get_sm_version
|
|||||||
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
|
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
|
||||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.quantization import QuantAlgo
|
from tensorrt_llm.quantization import QuantAlgo
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import torch
|
|||||||
import tensorrt_llm
|
import tensorrt_llm
|
||||||
import tensorrt_llm.bindings
|
import tensorrt_llm.bindings
|
||||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||||
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||||
from tensorrt_llm.sampling_params import SamplingParams
|
from tensorrt_llm.sampling_params import SamplingParams
|
||||||
|
|
||||||
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
|
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from .bindings import KVCacheType
|
|||||||
from .functional import PositionEmbeddingType
|
from .functional import PositionEmbeddingType
|
||||||
from .graph_rewriting import optimize
|
from .graph_rewriting import optimize
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
from .lora_manager import LoraConfig
|
from .lora_helper import LoraConfig
|
||||||
from .models import PretrainedConfig, PretrainedModel
|
from .models import PretrainedConfig, PretrainedModel
|
||||||
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
|
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
|
||||||
from .network import Network, net_guard
|
from .network import Network, net_guard
|
||||||
|
|||||||
@ -31,7 +31,8 @@ from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
|
|||||||
from tensorrt_llm.bindings import KVCacheType
|
from tensorrt_llm.bindings import KVCacheType
|
||||||
from tensorrt_llm.builder import BuildConfig, Engine, build
|
from tensorrt_llm.builder import BuildConfig, Engine, build
|
||||||
from tensorrt_llm.logger import logger, severity_map
|
from tensorrt_llm.logger import logger, severity_map
|
||||||
from tensorrt_llm.lora_manager import LoraConfig, LoraManager
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
from tensorrt_llm.lora_manager import LoraManager
|
||||||
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
|
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
|
||||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||||
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
|
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
|
||||||
|
|||||||
@ -1,6 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# isort: off
|
||||||
|
# needed before trying to import bindings to load tensorrt_libs
|
||||||
|
import tensorrt as trt # noqa
|
||||||
|
# isort: on
|
||||||
|
|
||||||
from tensorrt_llm.bindings import executor as tllme
|
from tensorrt_llm.bindings import executor as tllme
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import torch
|
|||||||
|
|
||||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||||
from tensorrt_llm.logger import logger, set_level
|
from tensorrt_llm.logger import logger, set_level
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
|
||||||
from .._utils import mpi_world_size
|
from .._utils import mpi_world_size
|
||||||
from ..bindings import executor as tllm
|
from ..bindings import executor as tllm
|
||||||
|
|||||||
@ -24,7 +24,8 @@ from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
|||||||
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
||||||
clear_sched_affinity, print_colored_debug,
|
clear_sched_affinity, print_colored_debug,
|
||||||
print_traceback_on_error)
|
print_traceback_on_error)
|
||||||
from ..lora_manager import LoraConfig, LoraManager
|
from ..lora_helper import LoraConfig
|
||||||
|
from ..lora_manager import LoraManager
|
||||||
from ..metrics import RequestEventTiming
|
from ..metrics import RequestEventTiming
|
||||||
from ..prompt_adapter_manager import PromptAdapterManager
|
from ..prompt_adapter_manager import PromptAdapterManager
|
||||||
from ..runtime import ModelConfig
|
from ..runtime import ModelConfig
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import Any, List, Optional
|
|||||||
import filelock
|
import filelock
|
||||||
|
|
||||||
import tensorrt_llm
|
import tensorrt_llm
|
||||||
from tensorrt_llm import BuildConfig
|
from tensorrt_llm.builder import BuildConfig
|
||||||
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
|
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
|
|||||||
@ -19,8 +19,8 @@ from pydantic import PrivateAttr, field_validator, model_validator
|
|||||||
from strenum import StrEnum
|
from strenum import StrEnum
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||||
get_default_trtllm_modules_to_hf_modules)
|
get_default_trtllm_modules_to_hf_modules)
|
||||||
|
|
||||||
from .._utils import mpi_rank
|
from .._utils import mpi_rank
|
||||||
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
|
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
|
||||||
|
|||||||
101
tensorrt_llm/lora_helper.py
Normal file
101
tensorrt_llm/lora_helper.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from ._utils import DictConversion
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_qkv_modules_from_lora_modules(
|
||||||
|
lora_target_modules: List[str]) -> List[str]:
|
||||||
|
"""Get missing QKV modules from LoRA target modules.
|
||||||
|
|
||||||
|
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
|
||||||
|
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
|
||||||
|
so we use zero tensor to fill the missing ones.
|
||||||
|
"""
|
||||||
|
missing_qkv_modules = []
|
||||||
|
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
|
||||||
|
for lora_module in ["attn_q", "attn_k", "attn_v"]:
|
||||||
|
if lora_module not in lora_target_modules:
|
||||||
|
missing_qkv_modules.append(lora_module)
|
||||||
|
if any(x in lora_target_modules
|
||||||
|
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
||||||
|
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
|
||||||
|
if lora_module not in lora_target_modules:
|
||||||
|
missing_qkv_modules.append(lora_module)
|
||||||
|
return missing_qkv_modules
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_trtllm_modules_to_hf_modules():
|
||||||
|
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
|
||||||
|
return {
|
||||||
|
"attn_q": "q_proj",
|
||||||
|
"attn_k": "k_proj",
|
||||||
|
"attn_v": "v_proj",
|
||||||
|
"attn_dense": "o_proj",
|
||||||
|
"mlp_h_to_4h": "gate_proj",
|
||||||
|
"mlp_4h_to_h": "down_proj",
|
||||||
|
"mlp_gate": "up_proj",
|
||||||
|
"mlp_gate_up": "gate_up_proj",
|
||||||
|
"moe_h_to_4h": "w1",
|
||||||
|
"moe_4h_to_h": "w2",
|
||||||
|
"moe_gate": "w3",
|
||||||
|
"moe_router": "gate",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def use_lora(
|
||||||
|
model,
|
||||||
|
lora_config: "LoraConfig",
|
||||||
|
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""Use LoRA with the given model and configuration.
|
||||||
|
|
||||||
|
This function is a wrapper that delegates to the appropriate loading function
|
||||||
|
based on the LoRA checkpoint source.
|
||||||
|
"""
|
||||||
|
if lora_config.lora_ckpt_source == "nemo":
|
||||||
|
from .lora_manager import load_nemo_lora
|
||||||
|
load_nemo_lora(model, lora_config)
|
||||||
|
elif lora_config.lora_ckpt_source == "hf":
|
||||||
|
from .lora_manager import load_hf_lora
|
||||||
|
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig(DictConversion):
|
||||||
|
lora_dir: List[str] = field(default_factory=list)
|
||||||
|
lora_ckpt_source: str = "hf"
|
||||||
|
max_lora_rank: int = 64
|
||||||
|
lora_target_modules: List[str] = field(default_factory=list)
|
||||||
|
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||||
|
max_loras: Optional[int] = None
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.lora_ckpt_source in [
|
||||||
|
"hf", "nemo"
|
||||||
|
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def missing_qkv_modules(self) -> List[str]:
|
||||||
|
return get_missing_qkv_modules_from_lora_modules(
|
||||||
|
self.lora_target_modules)
|
||||||
@ -5,7 +5,7 @@ import re
|
|||||||
import tarfile
|
import tarfile
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||||
@ -16,8 +16,13 @@ import yaml
|
|||||||
|
|
||||||
from tensorrt_llm.bindings import internal as tb_internal
|
from tensorrt_llm.bindings import internal as tb_internal
|
||||||
|
|
||||||
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
||||||
from .layers.linear import ColumnLinear
|
from .layers.linear import ColumnLinear
|
||||||
|
from .lora_helper import (
|
||||||
|
LoraConfig,
|
||||||
|
get_default_trtllm_modules_to_hf_modules,
|
||||||
|
get_missing_qkv_modules_from_lora_modules,
|
||||||
|
)
|
||||||
from .mapping import Mapping
|
from .mapping import Mapping
|
||||||
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
|
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
|
||||||
|
|
||||||
@ -232,26 +237,6 @@ def norm_dora_magnitude(
|
|||||||
return norm_m
|
return norm_m
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoraConfig(DictConversion):
|
|
||||||
lora_dir: List[str] = field(default_factory=list)
|
|
||||||
lora_ckpt_source: str = "hf"
|
|
||||||
max_lora_rank: int = 64
|
|
||||||
lora_target_modules: List[str] = field(default_factory=list)
|
|
||||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
|
||||||
max_loras: int | None = None
|
|
||||||
max_cpu_loras: int | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
assert self.lora_ckpt_source in ["hf", "nemo"], (
|
|
||||||
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def missing_qkv_modules(self) -> List[str]:
|
|
||||||
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoraModelConfig:
|
class LoraModelConfig:
|
||||||
lora_target_modules: list[str]
|
lora_target_modules: list[str]
|
||||||
@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
|
|||||||
lora_config.lora_target_modules = lora_loader.lora_target_modules
|
lora_config.lora_target_modules = lora_loader.lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
def get_default_trtllm_modules_to_hf_modules():
|
|
||||||
return {
|
|
||||||
"attn_q": "q_proj",
|
|
||||||
"attn_k": "k_proj",
|
|
||||||
"attn_v": "v_proj",
|
|
||||||
"attn_dense": "o_proj",
|
|
||||||
"mlp_h_to_4h": "gate_proj",
|
|
||||||
"mlp_4h_to_h": "down_proj",
|
|
||||||
"mlp_gate": "up_proj",
|
|
||||||
"mlp_gate_up": "gate_up_proj",
|
|
||||||
"moe_h_to_4h": "w1",
|
|
||||||
"moe_4h_to_h": "w2",
|
|
||||||
"moe_gate": "w3",
|
|
||||||
"moe_router": "gate",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_torch_hf_lora(lora_config: LoraConfig):
|
def load_torch_hf_lora(lora_config: LoraConfig):
|
||||||
"""This is a shortned version of load_hf_lora that is used for torch models.
|
"""This is a shortned version of load_hf_lora that is used for torch models.
|
||||||
|
|
||||||
@ -628,19 +596,6 @@ def load_hf_lora(
|
|||||||
).to(torch_dtype)
|
).to(torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
def use_lora(
|
|
||||||
model,
|
|
||||||
lora_config: LoraConfig,
|
|
||||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
|
||||||
):
|
|
||||||
if lora_config.lora_ckpt_source == "nemo":
|
|
||||||
load_nemo_lora(model, lora_config)
|
|
||||||
elif lora_config.lora_ckpt_source == "hf":
|
|
||||||
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
|
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
|
||||||
"""Unpack model config and weights from a NeMo .nemo archive file.
|
"""Unpack model config and weights from a NeMo .nemo archive file.
|
||||||
|
|
||||||
@ -762,21 +717,8 @@ class LoraManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_missing_qkv_modules(lora_target_modules):
|
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
|
||||||
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
|
return get_missing_qkv_modules_from_lora_modules(lora_target_modules)
|
||||||
# all disabled at the same time.
|
|
||||||
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
|
|
||||||
# to fill the missing ones.
|
|
||||||
missing_qkv_modules = []
|
|
||||||
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
|
|
||||||
for lora_module in ["attn_q", "attn_k", "attn_v"]:
|
|
||||||
if lora_module not in lora_target_modules:
|
|
||||||
missing_qkv_modules.append(lora_module)
|
|
||||||
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
|
||||||
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
|
|
||||||
if lora_module not in lora_target_modules:
|
|
||||||
missing_qkv_modules.append(lora_module)
|
|
||||||
return missing_qkv_modules
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def missing_qkv_modules(self) -> List[str]:
|
def missing_qkv_modules(self) -> List[str]:
|
||||||
|
|||||||
@ -36,9 +36,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
|
|||||||
LanguageAdapterConfig, LayerNorm, LoraParams,
|
LanguageAdapterConfig, LayerNorm, LoraParams,
|
||||||
PromptTuningEmbedding, RmsNorm)
|
PromptTuningEmbedding, RmsNorm)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||||
get_default_trtllm_modules_to_hf_modules,
|
get_default_trtllm_modules_to_hf_modules,
|
||||||
use_lora)
|
use_lora)
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
|
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
|
||||||
from tensorrt_llm.module import Module, ModuleList
|
from tensorrt_llm.module import Module, ModuleList
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType,
|
|||||||
from ...layers import (Attention, AttentionMaskType, AttentionParams,
|
from ...layers import (Attention, AttentionMaskType, AttentionParams,
|
||||||
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
|
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
|
||||||
LoraParams, PositionEmbeddingType, RmsNorm)
|
LoraParams, PositionEmbeddingType, RmsNorm)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from ...functional import (Tensor, is_gated_activation, non_gated_version, recv,
|
|||||||
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
|
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||||
Embedding, GatedMLP, LayerNorm, MoeConfig,
|
Embedding, GatedMLP, LayerNorm, MoeConfig,
|
||||||
PositionEmbeddingType)
|
PositionEmbeddingType)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ...quantization import QuantMode
|
from ...quantization import QuantMode
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from ..._utils import pad_vocab_size
|
|||||||
from ...functional import Tensor, recv, send
|
from ...functional import Tensor, recv, send
|
||||||
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||||
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
|
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor,
|
|||||||
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||||
Embedding, FusedGatedMLP, GatedMLP,
|
Embedding, FusedGatedMLP, GatedMLP,
|
||||||
PositionEmbeddingType, RmsNorm)
|
PositionEmbeddingType, RmsNorm)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ...quantization.functional import fused_layernorm
|
from ...quantization.functional import fused_layernorm
|
||||||
|
|||||||
@ -32,9 +32,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
|
|||||||
ColumnLinear, Embedding, FusedGatedMLP,
|
ColumnLinear, Embedding, FusedGatedMLP,
|
||||||
GatedMLP, GroupNorm, KeyValueCacheParams,
|
GatedMLP, GroupNorm, KeyValueCacheParams,
|
||||||
LayerNorm, LoraParams, RmsNorm)
|
LayerNorm, LoraParams, RmsNorm)
|
||||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||||
get_default_trtllm_modules_to_hf_modules,
|
get_default_trtllm_modules_to_hf_modules,
|
||||||
use_lora)
|
use_lora)
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
|
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
|
||||||
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig
|
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size
|
|||||||
from ...functional import Tensor
|
from ...functional import Tensor
|
||||||
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
|
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
|
||||||
Embedding, LayerNorm)
|
Embedding, LayerNorm)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType, Tensor
|
|||||||
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
|
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
|
||||||
BlockSparseAttnParams, ColumnLinear, Embedding,
|
BlockSparseAttnParams, ColumnLinear, Embedding,
|
||||||
LayerNorm, MoeConfig, RmsNorm)
|
LayerNorm, MoeConfig, RmsNorm)
|
||||||
from ...lora_manager import LoraConfig, use_lora
|
from ...lora_helper import LoraConfig, use_lora
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||||
|
|||||||
@ -26,8 +26,8 @@ from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
|||||||
Embedding, GatedMLP, RmsNorm, SharedMoE)
|
Embedding, GatedMLP, RmsNorm, SharedMoE)
|
||||||
from ...layers.moe import MOEWeightWrapper
|
from ...layers.moe import MOEWeightWrapper
|
||||||
from ...logger import logger
|
from ...logger import logger
|
||||||
from ...lora_manager import (LoraConfig,
|
from ...lora_helper import (LoraConfig,
|
||||||
get_default_trtllm_modules_to_hf_modules, use_lora)
|
get_default_trtllm_modules_to_hf_modules, use_lora)
|
||||||
from ...mapping import Mapping
|
from ...mapping import Mapping
|
||||||
from ...module import Module
|
from ...module import Module
|
||||||
from ...quantization import QuantAlgo
|
from ...quantization import QuantAlgo
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .lora_manager import LoraConfig
|
from .lora_helper import LoraConfig
|
||||||
from .mapping import Mapping
|
from .mapping import Mapping
|
||||||
from .plugin.plugin import PluginConfig
|
from .plugin.plugin import PluginConfig
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
|||||||
from tensorrt_llm.bindings import executor as tllm
|
from tensorrt_llm.bindings import executor as tllm
|
||||||
from tensorrt_llm.bindings.internal.batch_manager import \
|
from tensorrt_llm.bindings.internal.batch_manager import \
|
||||||
PeftTaskNotCachedException
|
PeftTaskNotCachedException
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
|
||||||
DataType = tensorrt_llm.bindings.DataType
|
DataType = tensorrt_llm.bindings.DataType
|
||||||
LoraModule = tensorrt_llm.bindings.LoraModule
|
LoraModule = tensorrt_llm.bindings.LoraModule
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig,
|
|||||||
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer,
|
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer,
|
||||||
load_hf_tokenizer)
|
load_hf_tokenizer)
|
||||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
|
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
|
||||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||||
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
|
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from tensorrt_llm._tensorrt_engine import LLM
|
|||||||
from tensorrt_llm.executor import GenerationExecutorProxy
|
from tensorrt_llm.executor import GenerationExecutorProxy
|
||||||
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
|
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
|
||||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models import PretrainedConfig
|
from tensorrt_llm.models import PretrainedConfig
|
||||||
from tensorrt_llm.models.llama.model import LLaMAForCausalLM
|
from tensorrt_llm.models.llama.model import LLaMAForCausalLM
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
from .test_llm import tinyllama_logits_processor_test_harness
|
from .test_llm import tinyllama_logits_processor_test_harness
|
||||||
from tensorrt_llm import LLM
|
from tensorrt_llm import LLM
|
||||||
from tensorrt_llm.llmapi import KvCacheConfig
|
from tensorrt_llm.llmapi import KvCacheConfig
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
|
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
|
||||||
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
|
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
|
||||||
from .test_llm import _test_llm_capture_request_error
|
from .test_llm import _test_llm_capture_request_error
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
|
|||||||
skip_gpu_memory_less_than_80gb,
|
skip_gpu_memory_less_than_80gb,
|
||||||
skip_gpu_memory_less_than_138gb)
|
skip_gpu_memory_less_than_138gb)
|
||||||
from utils.llm_data import llm_models_root
|
from utils.llm_data import llm_models_root
|
||||||
from tensorrt_llm.lora_manager import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
from tensorrt_llm.executor.request import LoRARequest
|
from tensorrt_llm.executor.request import LoRARequest
|
||||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user