mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Refactoring to avoid circular import when importing torch models (#6720)
Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
c9fe07ede6
commit
7ab8112450
@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.executor import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -33,6 +33,7 @@ import sys
|
||||
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
|
||||
import xgrammar # noqa
|
||||
|
||||
import tensorrt_llm._torch.models as torch_models
|
||||
import tensorrt_llm.functional as functional
|
||||
import tensorrt_llm.math_utils as math_utils
|
||||
import tensorrt_llm.models as models
|
||||
@ -82,6 +83,7 @@ __all__ = [
|
||||
'default_trtnet',
|
||||
'precision',
|
||||
'net_guard',
|
||||
'torch_models',
|
||||
'Network',
|
||||
'Mapping',
|
||||
'MnnvlMemory',
|
||||
|
||||
@ -22,7 +22,7 @@ from ...executor.request import LoRARequest
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...lora_manager import LoraConfig
|
||||
from ...lora_helper import LoraConfig
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..model_config import ModelConfig
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm import logger
|
||||
import tensorrt_llm.logger as trtllm_logger
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
|
||||
@ -743,7 +743,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
if int(name.split(".")[0]) not in expert_ids:
|
||||
continue
|
||||
weight_name = name.replace("weight_scale_inv", "weight")
|
||||
logger.debug(f"Resmoothing {weight_name}")
|
||||
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
|
||||
weight = weights[weight_name][:]
|
||||
scale = weights[name][:]
|
||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||
|
||||
@ -13,9 +13,9 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
||||
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
|
||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
load_torch_lora)
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules)
|
||||
from tensorrt_llm.lora_manager import load_torch_lora
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..model_config import ModelConfig
|
||||
|
||||
@ -27,7 +27,8 @@ from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
||||
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
|
||||
MultimodalRuntimeData)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraModelConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
|
||||
|
||||
@ -13,7 +13,7 @@ from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
|
||||
@ -10,7 +10,8 @@ import torch
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
|
||||
|
||||
@ -36,7 +36,7 @@ from .bindings import KVCacheType
|
||||
from .functional import PositionEmbeddingType
|
||||
from .graph_rewriting import optimize
|
||||
from .logger import logger
|
||||
from .lora_manager import LoraConfig
|
||||
from .lora_helper import LoraConfig
|
||||
from .models import PretrainedConfig, PretrainedModel
|
||||
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
|
||||
from .network import Network, net_guard
|
||||
|
||||
@ -31,7 +31,8 @@ from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.builder import BuildConfig, Engine, build
|
||||
from tensorrt_llm.logger import logger, severity_map
|
||||
from tensorrt_llm.lora_manager import LoraConfig, LoraManager
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager
|
||||
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
|
||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
# isort: off
|
||||
# needed before trying to import bindings to load tensorrt_libs
|
||||
import tensorrt as trt # noqa
|
||||
# isort: on
|
||||
|
||||
from tensorrt_llm.bindings import executor as tllme
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from tensorrt_llm.logger import logger, set_level
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
|
||||
from .._utils import mpi_world_size
|
||||
from ..bindings import executor as tllm
|
||||
|
||||
@ -24,7 +24,8 @@ from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
||||
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
||||
clear_sched_affinity, print_colored_debug,
|
||||
print_traceback_on_error)
|
||||
from ..lora_manager import LoraConfig, LoraManager
|
||||
from ..lora_helper import LoraConfig
|
||||
from ..lora_manager import LoraManager
|
||||
from ..metrics import RequestEventTiming
|
||||
from ..prompt_adapter_manager import PromptAdapterManager
|
||||
from ..runtime import ModelConfig
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Any, List, Optional
|
||||
import filelock
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import BuildConfig
|
||||
from tensorrt_llm.builder import BuildConfig
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@ from pydantic import PrivateAttr, field_validator, model_validator
|
||||
from strenum import StrEnum
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules)
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules)
|
||||
|
||||
from .._utils import mpi_rank
|
||||
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
|
||||
|
||||
101
tensorrt_llm/lora_helper.py
Normal file
101
tensorrt_llm/lora_helper.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ._utils import DictConversion
|
||||
|
||||
|
||||
def get_missing_qkv_modules_from_lora_modules(
|
||||
lora_target_modules: List[str]) -> List[str]:
|
||||
"""Get missing QKV modules from LoRA target modules.
|
||||
|
||||
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
|
||||
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
|
||||
so we use zero tensor to fill the missing ones.
|
||||
"""
|
||||
missing_qkv_modules = []
|
||||
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
|
||||
for lora_module in ["attn_q", "attn_k", "attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
if any(x in lora_target_modules
|
||||
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
||||
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
return missing_qkv_modules
|
||||
|
||||
|
||||
def get_default_trtllm_modules_to_hf_modules():
|
||||
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
|
||||
return {
|
||||
"attn_q": "q_proj",
|
||||
"attn_k": "k_proj",
|
||||
"attn_v": "v_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
"mlp_gate": "up_proj",
|
||||
"mlp_gate_up": "gate_up_proj",
|
||||
"moe_h_to_4h": "w1",
|
||||
"moe_4h_to_h": "w2",
|
||||
"moe_gate": "w3",
|
||||
"moe_router": "gate",
|
||||
}
|
||||
|
||||
|
||||
def use_lora(
|
||||
model,
|
||||
lora_config: "LoraConfig",
|
||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Use LoRA with the given model and configuration.
|
||||
|
||||
This function is a wrapper that delegates to the appropriate loading function
|
||||
based on the LoRA checkpoint source.
|
||||
"""
|
||||
if lora_config.lora_ckpt_source == "nemo":
|
||||
from .lora_manager import load_nemo_lora
|
||||
load_nemo_lora(model, lora_config)
|
||||
elif lora_config.lora_ckpt_source == "hf":
|
||||
from .lora_manager import load_hf_lora
|
||||
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(DictConversion):
|
||||
lora_dir: List[str] = field(default_factory=list)
|
||||
lora_ckpt_source: str = "hf"
|
||||
max_lora_rank: int = 64
|
||||
lora_target_modules: List[str] = field(default_factory=list)
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||
max_loras: Optional[int] = None
|
||||
max_cpu_loras: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lora_ckpt_source in [
|
||||
"hf", "nemo"
|
||||
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
return get_missing_qkv_modules_from_lora_modules(
|
||||
self.lora_target_modules)
|
||||
@ -5,7 +5,7 @@ import re
|
||||
import tarfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||
@ -16,8 +16,13 @@ import yaml
|
||||
|
||||
from tensorrt_llm.bindings import internal as tb_internal
|
||||
|
||||
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
||||
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
||||
from .layers.linear import ColumnLinear
|
||||
from .lora_helper import (
|
||||
LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
get_missing_qkv_modules_from_lora_modules,
|
||||
)
|
||||
from .mapping import Mapping
|
||||
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
|
||||
|
||||
@ -232,26 +237,6 @@ def norm_dora_magnitude(
|
||||
return norm_m
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(DictConversion):
|
||||
lora_dir: List[str] = field(default_factory=list)
|
||||
lora_ckpt_source: str = "hf"
|
||||
max_lora_rank: int = 64
|
||||
lora_target_modules: List[str] = field(default_factory=list)
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||
max_loras: int | None = None
|
||||
max_cpu_loras: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lora_ckpt_source in ["hf", "nemo"], (
|
||||
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraModelConfig:
|
||||
lora_target_modules: list[str]
|
||||
@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
|
||||
lora_config.lora_target_modules = lora_loader.lora_target_modules
|
||||
|
||||
|
||||
def get_default_trtllm_modules_to_hf_modules():
|
||||
return {
|
||||
"attn_q": "q_proj",
|
||||
"attn_k": "k_proj",
|
||||
"attn_v": "v_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
"mlp_gate": "up_proj",
|
||||
"mlp_gate_up": "gate_up_proj",
|
||||
"moe_h_to_4h": "w1",
|
||||
"moe_4h_to_h": "w2",
|
||||
"moe_gate": "w3",
|
||||
"moe_router": "gate",
|
||||
}
|
||||
|
||||
|
||||
def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
"""This is a shortned version of load_hf_lora that is used for torch models.
|
||||
|
||||
@ -628,19 +596,6 @@ def load_hf_lora(
|
||||
).to(torch_dtype)
|
||||
|
||||
|
||||
def use_lora(
|
||||
model,
|
||||
lora_config: LoraConfig,
|
||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if lora_config.lora_ckpt_source == "nemo":
|
||||
load_nemo_lora(model, lora_config)
|
||||
elif lora_config.lora_ckpt_source == "hf":
|
||||
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||
|
||||
|
||||
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
|
||||
"""Unpack model config and weights from a NeMo .nemo archive file.
|
||||
|
||||
@ -762,21 +717,8 @@ class LoraManager(object):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_missing_qkv_modules(lora_target_modules):
|
||||
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
|
||||
# all disabled at the same time.
|
||||
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
|
||||
# to fill the missing ones.
|
||||
missing_qkv_modules = []
|
||||
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
|
||||
for lora_module in ["attn_q", "attn_k", "attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
||||
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
return missing_qkv_modules
|
||||
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
|
||||
return get_missing_qkv_modules_from_lora_modules(lora_target_modules)
|
||||
|
||||
@property
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
|
||||
@ -36,9 +36,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
|
||||
LanguageAdapterConfig, LayerNorm, LoraParams,
|
||||
PromptTuningEmbedding, RmsNorm)
|
||||
# yapf: enable
|
||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
use_lora)
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
use_lora)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
|
||||
from tensorrt_llm.module import Module, ModuleList
|
||||
|
||||
@ -28,7 +28,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType,
|
||||
from ...layers import (Attention, AttentionMaskType, AttentionParams,
|
||||
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
|
||||
LoraParams, PositionEmbeddingType, RmsNorm)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||
|
||||
@ -21,7 +21,7 @@ from ...functional import (Tensor, is_gated_activation, non_gated_version, recv,
|
||||
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, GatedMLP, LayerNorm, MoeConfig,
|
||||
PositionEmbeddingType)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ...quantization import QuantMode
|
||||
|
||||
@ -18,7 +18,7 @@ from ..._utils import pad_vocab_size
|
||||
from ...functional import Tensor, recv, send
|
||||
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||
|
||||
@ -25,7 +25,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor,
|
||||
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, FusedGatedMLP, GatedMLP,
|
||||
PositionEmbeddingType, RmsNorm)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ...quantization.functional import fused_layernorm
|
||||
|
||||
@ -32,9 +32,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
|
||||
ColumnLinear, Embedding, FusedGatedMLP,
|
||||
GatedMLP, GroupNorm, KeyValueCacheParams,
|
||||
LayerNorm, LoraParams, RmsNorm)
|
||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
use_lora)
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
use_lora)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
|
||||
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig
|
||||
|
||||
@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size
|
||||
from ...functional import Tensor
|
||||
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, LayerNorm)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||
|
||||
@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType, Tensor
|
||||
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
|
||||
BlockSparseAttnParams, ColumnLinear, Embedding,
|
||||
LayerNorm, MoeConfig, RmsNorm)
|
||||
from ...lora_manager import LoraConfig, use_lora
|
||||
from ...lora_helper import LoraConfig, use_lora
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
|
||||
|
||||
@ -26,8 +26,8 @@ from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, GatedMLP, RmsNorm, SharedMoE)
|
||||
from ...layers.moe import MOEWeightWrapper
|
||||
from ...logger import logger
|
||||
from ...lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules, use_lora)
|
||||
from ...lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules, use_lora)
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ...quantization import QuantAlgo
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .lora_manager import LoraConfig
|
||||
from .lora_helper import LoraConfig
|
||||
from .mapping import Mapping
|
||||
from .plugin.plugin import PluginConfig
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
||||
from tensorrt_llm.bindings import executor as tllm
|
||||
from tensorrt_llm.bindings.internal.batch_manager import \
|
||||
PeftTaskNotCachedException
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
|
||||
DataType = tensorrt_llm.bindings.DataType
|
||||
LoraModule = tensorrt_llm.bindings.LoraModule
|
||||
|
||||
@ -42,7 +42,7 @@ from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig,
|
||||
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer,
|
||||
load_hf_tokenizer)
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
|
||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
|
||||
|
||||
@ -12,7 +12,7 @@ from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm.executor import GenerationExecutorProxy
|
||||
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models import PretrainedConfig
|
||||
from tensorrt_llm.models.llama.model import LLaMAForCausalLM
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from .test_llm import tinyllama_logits_processor_test_harness
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
|
||||
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
|
||||
from .test_llm import _test_llm_capture_request_error
|
||||
|
||||
@ -25,7 +25,7 @@ from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
|
||||
skip_gpu_memory_less_than_80gb,
|
||||
skip_gpu_memory_less_than_138gb)
|
||||
from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
Loading…
Reference in New Issue
Block a user