mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5510879][fix] Fix pytorch & TRT-python flows fused LoRA adapter modules weight split with TP>1 (#8063)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
a1ed03fe8a
commit
fac47e2826
@ -13,6 +13,7 @@ from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range
|
||||
@ -32,7 +33,7 @@ BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
|
||||
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
|
||||
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
|
||||
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
|
||||
ModelConfig = tensorrt_llm.bindings.ModelConfig
|
||||
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
|
||||
DataType = tensorrt_llm.bindings.DataType
|
||||
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
|
||||
RequestList = list[LlmRequest]
|
||||
@ -160,7 +161,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
max_num_tokens: int = 8192,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
model_config: Optional[ModelConfigCpp] = None,
|
||||
max_beam_width: int = 1,
|
||||
is_draft: bool = False,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
@ -371,7 +372,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
|
||||
@classmethod
|
||||
def from_model_config(cls,
|
||||
model_config: ModelConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
kv_cache_config: KvCacheConfigCpp,
|
||||
mapping: Mapping,
|
||||
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
|
||||
@ -772,7 +773,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
window_size_to_layers: Dict[int, List[int]],
|
||||
max_attention_window_vec: List[int],
|
||||
kv_cache_config: KvCacheConfigCpp,
|
||||
model_config: ModelConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
pool_memory_bytes: int,
|
||||
kv_factor: int,
|
||||
dtype: DataType,
|
||||
@ -887,7 +888,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
def calculate_max_num_blocks_from_cpp(
|
||||
self,
|
||||
kv_cache_config: KvCacheConfigCpp,
|
||||
model_config: ModelConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
|
||||
"""
|
||||
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
|
||||
@ -1133,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager):
|
||||
def __init__(self,
|
||||
peft_cache_config: PeftCacheConfig,
|
||||
lora_config: LoraConfig,
|
||||
model_config: ModelConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
world_config: WorldConfig | None = None):
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
@ -1169,7 +1170,20 @@ class PeftCacheManager(BaseResourceManager):
|
||||
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
|
||||
binding_to_str_dtype(model_config.data_type),
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
self._lora_manager = LoraManager()
|
||||
mapping = Mapping(
|
||||
world_size=world_config.size,
|
||||
rank=world_config.rank,
|
||||
tp_size=world_config.tensor_parallelism,
|
||||
pp_size=world_config.pipeline_parallelism,
|
||||
gpus_per_node=world_config.gpus_per_node,
|
||||
)
|
||||
self._lora_manager = LoraManager(
|
||||
mapping=mapping,
|
||||
model_config=ModelConfigPython.from_model_config_cpp(model_config),
|
||||
cpp_peft_cache_manager=self.impl)
|
||||
|
||||
def get_lora_manager(self) -> LoraManager:
|
||||
return self._lora_manager
|
||||
|
||||
def add_request_peft(self, request: LlmRequest):
|
||||
if request.lora_task_id is not None:
|
||||
@ -1183,7 +1197,6 @@ class PeftCacheManager(BaseResourceManager):
|
||||
self._lora_manager.load_from_ckpt(
|
||||
[request.py_lora_path],
|
||||
model_config=self._lora_model_config,
|
||||
runtime_mapping=None,
|
||||
uids=[request.lora_task_id],
|
||||
ckpt_source=self._lora_config.lora_ckpt_source)
|
||||
request.lora_weights = self._lora_manager.cpp_lora_weights[
|
||||
|
||||
@ -42,7 +42,7 @@ import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
|
||||
from tensorrt_llm.bindings import DataType, GptJsonConfig
|
||||
from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -198,6 +198,10 @@ _binding_dtype_bits = {
|
||||
}
|
||||
|
||||
|
||||
def binding_layer_type_to_str(layer_type: LayerType) -> str:
|
||||
return layer_type.name.lower()
|
||||
|
||||
|
||||
def binding_to_str_dtype(binding_dtype) -> str:
|
||||
ret = _binding_to_str_dtype.get(binding_dtype)
|
||||
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'
|
||||
|
||||
@ -205,7 +205,10 @@ class BaseWorker(GenerationExecutor):
|
||||
# point in the TRT flow is currently not supported (it's at the CPP
|
||||
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
|
||||
# optimization is not available in TRT-python flow.
|
||||
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
|
||||
self._lora_manager = LoraManager(
|
||||
mapping=engine_config.pretrained_config.mapping,
|
||||
model_config=self._runtime_model_config,
|
||||
cpp_peft_cache_manager=None)
|
||||
if engine_config.build_config.max_prompt_embedding_table_size > 0:
|
||||
self._prompt_adapter_manager = PromptAdapterManager()
|
||||
|
||||
@ -216,8 +219,7 @@ class BaseWorker(GenerationExecutor):
|
||||
ResourceManagerType
|
||||
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.PEFT_CACHE_MANAGER)
|
||||
self._lora_manager = LoraManager(
|
||||
cpp_peft_cache_manager=peft_cache_manager.impl)
|
||||
self._lora_manager = peft_cache_manager.get_lora_manager()
|
||||
lora_model_config = self.engine.model_engine.lora_model_config
|
||||
assert lora_model_config is not None
|
||||
self._lora_model_config = lora_model_config
|
||||
@ -302,7 +304,6 @@ class BaseWorker(GenerationExecutor):
|
||||
[lora_request.path],
|
||||
model_config=self._runtime_model_config if
|
||||
self._runtime_model_config is not None else self._lora_model_config,
|
||||
runtime_mapping=None,
|
||||
uids=[adapter_id],
|
||||
ckpt_source=lora_request.ckpt_source)
|
||||
return adapter_id in newly_loaded_uids
|
||||
|
||||
@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules():
|
||||
"attn_q": "q_proj",
|
||||
"attn_k": "k_proj",
|
||||
"attn_v": "v_proj",
|
||||
"attn_qkv": "qkv_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -660,11 +661,17 @@ class LoraManager(object):
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
|
||||
self,
|
||||
*,
|
||||
mapping: Mapping,
|
||||
model_config: "ModelConfig",
|
||||
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
mapping (Mapping): Parallelism related information.
|
||||
model_config (ModelConfig): model configuration python class instance.
|
||||
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
|
||||
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
|
||||
the adapter is already loaded in the LoRA CPU cache.
|
||||
@ -704,6 +711,8 @@ class LoraManager(object):
|
||||
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
|
||||
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
|
||||
self.lora_target_modules: List[str] = []
|
||||
self._mapping = mapping
|
||||
self._model_config = model_config
|
||||
self._cpp_peft_cache_manager = cpp_peft_cache_manager
|
||||
|
||||
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
|
||||
@ -730,7 +739,6 @@ class LoraManager(object):
|
||||
self,
|
||||
model_dirs_or_files: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
ckpt_source: str = "hf",
|
||||
) -> List[str]:
|
||||
@ -743,7 +751,6 @@ class LoraManager(object):
|
||||
return self.load_from_hf(
|
||||
model_dirs=model_dirs_or_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids,
|
||||
)
|
||||
elif ckpt_source == "nemo":
|
||||
@ -754,7 +761,6 @@ class LoraManager(object):
|
||||
return self.load_from_nemo(
|
||||
model_files=nemo_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids,
|
||||
)
|
||||
else:
|
||||
@ -764,7 +770,6 @@ class LoraManager(object):
|
||||
self,
|
||||
model_files: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""Returns the adapter UIDs that were loaded by this call.
|
||||
@ -772,11 +777,6 @@ class LoraManager(object):
|
||||
Note that when an adapter was already loaded before this call, it would not be
|
||||
included in the returned list of UIDs.
|
||||
"""
|
||||
if runtime_mapping is None:
|
||||
runtime_mapping = Mapping()
|
||||
tp_size = runtime_mapping.tp_size
|
||||
tp_rank = runtime_mapping.tp_rank
|
||||
|
||||
if uids is None:
|
||||
uids = [self._generate_uid() for _ in range(len(model_files))]
|
||||
assert len(uids) == len(model_files)
|
||||
@ -829,10 +829,6 @@ class LoraManager(object):
|
||||
|
||||
t_in = all_lora_weights[layer_idx]["in"]
|
||||
t_out = all_lora_weights[layer_idx]["out"]
|
||||
assert t_out.shape[0] % tp_size == 0
|
||||
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
else:
|
||||
t_in = None
|
||||
t_out = None
|
||||
@ -882,7 +878,6 @@ class LoraManager(object):
|
||||
self,
|
||||
model_dirs: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
component: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
@ -939,11 +934,6 @@ class LoraManager(object):
|
||||
...
|
||||
|
||||
"""
|
||||
if runtime_mapping is None:
|
||||
runtime_mapping = Mapping()
|
||||
tp_size = runtime_mapping.tp_size
|
||||
tp_rank = runtime_mapping.tp_rank
|
||||
|
||||
if uids is None:
|
||||
uids = [self._generate_uid() for _ in range(len(model_dirs))]
|
||||
assert len(uids) == len(model_dirs)
|
||||
@ -983,6 +973,70 @@ class LoraManager(object):
|
||||
lora_model[key] = value
|
||||
return lora_model
|
||||
|
||||
def interleave_fused_lora_weights_for_tp(
|
||||
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Interleaves fused LoRA modules weights for TP.
|
||||
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
|
||||
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
|
||||
where N=TP size.
|
||||
""" # noqa: D205
|
||||
assert weight.shape[rank_dim] == sum(part_sizes)
|
||||
|
||||
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
|
||||
weight_parts = [
|
||||
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
|
||||
for i in range(len(part_sizes))
|
||||
]
|
||||
for i in range(len(part_sizes)):
|
||||
assert weight_parts[i].shape[rank_dim] % tp_size == 0
|
||||
|
||||
# Split each part into tp_size chunks.
|
||||
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
|
||||
# where N is TP size, for attn_qkv.
|
||||
weight_parts_tp_weights = [
|
||||
torch.split(
|
||||
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
|
||||
)
|
||||
for i in range(len(part_sizes))
|
||||
]
|
||||
|
||||
# Interleave the parts across TP ranks and flatten the list of lists into a single list.
|
||||
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
|
||||
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
|
||||
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))
|
||||
|
||||
def prepare_fused_lora_modules_for_tp(
|
||||
lora_module: str, t_out: torch.Tensor, rank_dim: int
|
||||
) -> torch.Tensor:
|
||||
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
|
||||
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
|
||||
|
||||
See interleave_fused_lora_weights_for_tp for more details.
|
||||
""" # noqa: D205
|
||||
tp_size = self._mapping.tp_size
|
||||
if tp_size == 1:
|
||||
return t_out
|
||||
part_sizes = []
|
||||
if lora_module == "mlp_gate_up":
|
||||
assert t_out.shape[rank_dim] % 2 == 0
|
||||
half_size = t_out.shape[rank_dim] // 2
|
||||
part_sizes = [half_size, half_size]
|
||||
elif lora_module == "attn_qkv":
|
||||
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
|
||||
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
|
||||
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
|
||||
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
|
||||
part_sizes = [q_size, kv_size, kv_size]
|
||||
|
||||
if part_sizes:
|
||||
interleaved_parts = interleave_fused_lora_weights_for_tp(
|
||||
t_out, rank_dim, tp_size, part_sizes
|
||||
)
|
||||
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
|
||||
t_out = torch.cat(interleaved_parts, dim=rank_dim)
|
||||
return t_out
|
||||
|
||||
def load_from_model_dir(uid, model_dir, hf_config):
|
||||
if uid not in self._cpp_lora_weights:
|
||||
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
|
||||
@ -1060,36 +1114,9 @@ class LoraManager(object):
|
||||
t_mag = module_weights.get("magnitude", None)
|
||||
|
||||
is_dora = t_mag is not None
|
||||
|
||||
if lora_module in ["moe_router", "mlp_router"]:
|
||||
pass
|
||||
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
|
||||
pass
|
||||
elif lora_module in [
|
||||
"attn_dense",
|
||||
"cross_attn_dense",
|
||||
"mlp_4h_to_h",
|
||||
"moe_4h_to_h",
|
||||
]:
|
||||
# split by row
|
||||
dim = 2 if has_expert_indices else 1
|
||||
assert t_in.shape[dim] % tp_size == 0
|
||||
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
else:
|
||||
# split by column
|
||||
dim = 1 if has_expert_indices else 0
|
||||
assert t_out.shape[dim] % tp_size == 0
|
||||
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
if dim == 0 and is_dora and t_mag is not None:
|
||||
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
|
||||
rank_dim = 1 if has_expert_indices else 0
|
||||
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)
|
||||
|
||||
effective_rank = t_in.shape[rank_dim]
|
||||
|
||||
t_in = t_in.cuda().contiguous()
|
||||
|
||||
@ -174,12 +174,14 @@ class EncDecModelRunner:
|
||||
|
||||
# encoder lora manager setup
|
||||
if self.encoder_model_config.lora_plugin:
|
||||
self.encoder_lora_manager = LoraManager()
|
||||
self.encoder_lora_manager = LoraManager(
|
||||
mapping=self.encoder_runtime_mapping,
|
||||
model_config=self.encoder_model_config,
|
||||
)
|
||||
# TODO: this is only for bart
|
||||
self.encoder_lora_manager.load_from_hf(
|
||||
model_dirs=lora_dir,
|
||||
model_config=self.encoder_model_config,
|
||||
runtime_mapping=self.encoder_runtime_mapping,
|
||||
component='encoder',
|
||||
)
|
||||
else:
|
||||
@ -197,12 +199,14 @@ class EncDecModelRunner:
|
||||
|
||||
# decoder lora manager setup
|
||||
if self.decoder_model_config.lora_plugin:
|
||||
self.decoder_lora_manager = LoraManager()
|
||||
self.decoder_lora_manager = LoraManager(
|
||||
mapping=self.decoder_runtime_mapping,
|
||||
model_config=self.decoder_model_config,
|
||||
)
|
||||
# TODO: this is only for bart
|
||||
self.decoder_lora_manager.load_from_hf(
|
||||
model_dirs=lora_dir,
|
||||
model_config=self.decoder_model_config,
|
||||
runtime_mapping=self.decoder_runtime_mapping,
|
||||
component='decoder',
|
||||
)
|
||||
else:
|
||||
|
||||
@ -40,7 +40,8 @@ from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \
|
||||
PoolsKVCacheManager
|
||||
from tensorrt_llm.runtime.redrafter_utils import *
|
||||
|
||||
from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
|
||||
from .._utils import (binding_layer_type_to_str, binding_to_str_dtype,
|
||||
pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
|
||||
trt_dtype_to_torch)
|
||||
from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free
|
||||
from ..layers import LanguageAdapterConfig
|
||||
@ -653,6 +654,41 @@ class ModelConfig:
|
||||
# language adapter
|
||||
language_adapter_config: Optional[LanguageAdapterConfig] = None
|
||||
|
||||
@classmethod
|
||||
def from_model_config_cpp(cls, model_config_cpp) -> 'ModelConfig':
|
||||
"""Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance.
|
||||
|
||||
Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython
|
||||
won't have all of its fields initialized.
|
||||
"""
|
||||
return cls(
|
||||
max_batch_size=model_config_cpp.max_batch_size,
|
||||
max_beam_width=model_config_cpp.max_beam_width,
|
||||
vocab_size=model_config_cpp.vocab_size,
|
||||
num_layers=model_config_cpp.num_layers(),
|
||||
num_heads=model_config_cpp.num_heads,
|
||||
num_kv_heads=model_config_cpp.num_kv_heads(0),
|
||||
hidden_size=model_config_cpp.hidden_size,
|
||||
remove_input_padding=model_config_cpp.use_packed_input,
|
||||
kv_cache_type=model_config_cpp.kv_cache_type,
|
||||
cross_attention=model_config_cpp.use_cross_attention,
|
||||
head_size=model_config_cpp.head_size,
|
||||
max_prompt_embedding_table_size=model_config_cpp.
|
||||
max_prompt_embedding_table_size,
|
||||
quant_mode=QuantMode(model_config_cpp.quant_mode.value),
|
||||
gather_context_logits=model_config_cpp.compute_context_logits,
|
||||
gather_generation_logits=model_config_cpp.compute_generation_logits,
|
||||
gpt_attention_plugin=model_config_cpp.use_gpt_attention_plugin,
|
||||
dtype=binding_to_str_dtype(model_config_cpp.data_type),
|
||||
num_kv_heads_per_layer=model_config_cpp.num_kv_heads_per_layer,
|
||||
tokens_per_block=model_config_cpp.tokens_per_block,
|
||||
lora_plugin=model_config_cpp.use_lora_plugin,
|
||||
layer_types=[
|
||||
binding_layer_type_to_str(lt)
|
||||
for lt in model_config_cpp.layer_types
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingConfig:
|
||||
|
||||
@ -611,11 +611,11 @@ class ModelRunner(ModelRunnerMixin):
|
||||
session.runtime._set_weight_streaming(gpu_weights_percent)
|
||||
|
||||
if session.use_lora_plugin:
|
||||
lora_manager = LoraManager()
|
||||
lora_manager = LoraManager(mapping=runtime_mapping,
|
||||
model_config=model_config)
|
||||
if lora_dir is not None:
|
||||
lora_manager.load_from_ckpt(lora_dir,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
ckpt_source=lora_ckpt_source)
|
||||
else:
|
||||
lora_manager = None
|
||||
@ -720,11 +720,11 @@ class ModelRunner(ModelRunnerMixin):
|
||||
debug_mode=debug_mode,
|
||||
stream=stream)
|
||||
if session.use_lora_plugin:
|
||||
lora_manager = LoraManager()
|
||||
lora_manager = LoraManager(mapping=runtime_mapping,
|
||||
model_config=model_config)
|
||||
if lora_dir is not None:
|
||||
lora_manager.load_from_ckpt(lora_dir,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
ckpt_source=lora_ckpt_source)
|
||||
else:
|
||||
lora_manager = None
|
||||
|
||||
@ -32,8 +32,9 @@ from ..builder import EngineConfig
|
||||
from ..layers import MropeParams
|
||||
from ..logger import logger
|
||||
from ..mapping import Mapping
|
||||
from .generation import (LogitsProcessor, LoraManager, SamplingConfig,
|
||||
StoppingCriteria)
|
||||
from .generation import LogitsProcessor, LoraManager
|
||||
from .generation import ModelConfig as ModelConfigPython
|
||||
from .generation import SamplingConfig, StoppingCriteria
|
||||
from .model_runner import ModelRunnerMixin, _engine_config_to_model_config
|
||||
|
||||
_bindings_dtype_to_torch_dtype_dict = {
|
||||
@ -277,7 +278,11 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
|
||||
engine_config = EngineConfig.from_json_file(f"{engine_dir}/config.json")
|
||||
if model_config.use_lora_plugin and rank == 0:
|
||||
lora_manager = LoraManager()
|
||||
mapping = _world_config_to_mapping(world_config)
|
||||
lora_manager = LoraManager(
|
||||
mapping=mapping,
|
||||
model_config=ModelConfigPython.from_model_config_cpp(
|
||||
model_config))
|
||||
if lora_dir is None:
|
||||
config_lora_dir = engine_config.build_config.lora_config.lora_dir
|
||||
if len(config_lora_dir) > 0:
|
||||
@ -292,7 +297,6 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
# For Executor, only rank 0 can enqueue requests, and should hold all lora weights
|
||||
lora_manager.load_from_ckpt(lora_dir,
|
||||
model_config=runtime_model_config,
|
||||
runtime_mapping=None,
|
||||
ckpt_source=lora_ckpt_source)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import tarfile
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import OrderedDict, Type
|
||||
from typing import List, OrderedDict, Type
|
||||
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
@ -11,10 +11,62 @@ from utils.util import duplicate_list_to_length, flatten_list, similar
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi.llm import BaseLLM
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
|
||||
_RU_LORA_ADAPTER_PROMPTS = [
|
||||
"Назови главную площадь в центре Москвы.",
|
||||
"Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".",
|
||||
"Что означает выражение \"водить за нос\"? Объясни в двух словах.",
|
||||
]
|
||||
|
||||
|
||||
def _generate_phi3_response_lora_fused_modules(llm_class: Type[BaseLLM],
|
||||
prompts: List[str],
|
||||
**extra_llm_kwargs) -> List[str]:
|
||||
"""Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter.
|
||||
The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules.
|
||||
Returns the generated texts.
|
||||
""" # noqa: D205
|
||||
hf_model_dir = f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct"
|
||||
hf_lora_dir = f"{llm_models_root()}/lora/phi/Phi-3-mini-4k-instruct-ru-lora"
|
||||
|
||||
lora_req = LoRARequest("ru-lora", 0, hf_lora_dir)
|
||||
sampling_params = SamplingParams(max_tokens=20)
|
||||
|
||||
lora_config = LoraConfig(lora_dir=[hf_lora_dir],
|
||||
max_lora_rank=16,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
|
||||
lora_requests = [lora_req] * len(prompts)
|
||||
with llm_class(hf_model_dir, lora_config=lora_config,
|
||||
**extra_llm_kwargs) as llm:
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests)
|
||||
|
||||
return [output.outputs[0].text for output in outputs]
|
||||
|
||||
|
||||
def check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
|
||||
llm_class: Type[BaseLLM], **extra_llm_kwargs) -> None:
|
||||
"""Tests the output with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter with TP=2 is identical to
|
||||
the output with TP=1.
|
||||
That LoRA adapter has fused attention QKV and fused MLP gate up proj modules.
|
||||
""" # noqa: D205
|
||||
extra_llm_kwargs["tensor_parallel_size"] = 1
|
||||
outputs_tp1 = _generate_phi3_response_lora_fused_modules(
|
||||
llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs)
|
||||
|
||||
extra_llm_kwargs["tensor_parallel_size"] = 2
|
||||
outputs_tp2 = _generate_phi3_response_lora_fused_modules(
|
||||
llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs)
|
||||
|
||||
assert outputs_tp1 == outputs_tp2
|
||||
|
||||
|
||||
def check_llama_7b_multi_unique_lora_adapters_from_request(
|
||||
lora_adapter_count_per_call: list[int], repeat_calls: int,
|
||||
lora_adapter_count_per_call: List[int], repeat_calls: int,
|
||||
repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs):
|
||||
"""Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests
|
||||
repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID.
|
||||
|
||||
@ -5,7 +5,7 @@ from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
|
||||
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1
|
||||
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
|
||||
from .test_llm import _test_llm_capture_request_error
|
||||
from utils.util import skip_ray
|
||||
@ -62,6 +62,15 @@ def test_llama_7b_multi_lora_tp2():
|
||||
cuda_graph_config=None)
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
|
||||
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
|
||||
LLM,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5560921")
|
||||
@skip_ray
|
||||
@pytest.mark.gpu2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user