[https://nvbugs/5510879][fix] Fix pytorch & TRT-python flows fused LoRA adapter modules weight split with TP>1 (#8063)

Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
amitz-nv 2025-10-12 22:29:52 +03:00 committed by GitHub
parent a1ed03fe8a
commit fac47e2826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 229 additions and 78 deletions

View File

@ -13,6 +13,7 @@ from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
from tensorrt_llm.sampling_params import SamplingParams
from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range
@ -32,7 +33,7 @@ BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
ModelConfig = tensorrt_llm.bindings.ModelConfig
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
DataType = tensorrt_llm.bindings.DataType
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
RequestList = list[LlmRequest]
@ -160,7 +161,7 @@ class KVCacheManager(BaseResourceManager):
spec_config: Optional["DecodingBaseConfig"] = None,
layer_mask: Optional[List[bool]] = None,
max_num_tokens: int = 8192,
model_config: Optional[ModelConfig] = None,
model_config: Optional[ModelConfigCpp] = None,
max_beam_width: int = 1,
is_draft: bool = False,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
@ -371,7 +372,7 @@ class KVCacheManager(BaseResourceManager):
@classmethod
def from_model_config(cls,
model_config: ModelConfig,
model_config: ModelConfigCpp,
kv_cache_config: KvCacheConfigCpp,
mapping: Mapping,
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
@ -772,7 +773,7 @@ class KVCacheManager(BaseResourceManager):
window_size_to_layers: Dict[int, List[int]],
max_attention_window_vec: List[int],
kv_cache_config: KvCacheConfigCpp,
model_config: ModelConfig,
model_config: ModelConfigCpp,
pool_memory_bytes: int,
kv_factor: int,
dtype: DataType,
@ -887,7 +888,7 @@ class KVCacheManager(BaseResourceManager):
def calculate_max_num_blocks_from_cpp(
self,
kv_cache_config: KvCacheConfigCpp,
model_config: ModelConfig,
model_config: ModelConfigCpp,
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
"""
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
@ -1133,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager):
def __init__(self,
peft_cache_config: PeftCacheConfig,
lora_config: LoraConfig,
model_config: ModelConfig,
model_config: ModelConfigCpp,
world_config: WorldConfig | None = None):
import tensorrt_llm.bindings as _tb
@ -1169,7 +1170,20 @@ class PeftCacheManager(BaseResourceManager):
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
binding_to_str_dtype(model_config.data_type),
lora_config.swap_gate_up_proj_lora_b_weight)
self._lora_manager = LoraManager()
mapping = Mapping(
world_size=world_config.size,
rank=world_config.rank,
tp_size=world_config.tensor_parallelism,
pp_size=world_config.pipeline_parallelism,
gpus_per_node=world_config.gpus_per_node,
)
self._lora_manager = LoraManager(
mapping=mapping,
model_config=ModelConfigPython.from_model_config_cpp(model_config),
cpp_peft_cache_manager=self.impl)
def get_lora_manager(self) -> LoraManager:
return self._lora_manager
def add_request_peft(self, request: LlmRequest):
if request.lora_task_id is not None:
@ -1183,7 +1197,6 @@ class PeftCacheManager(BaseResourceManager):
self._lora_manager.load_from_ckpt(
[request.py_lora_path],
model_config=self._lora_model_config,
runtime_mapping=None,
uids=[request.lora_task_id],
ckpt_source=self._lora_config.lora_ckpt_source)
request.lora_weights = self._lora_manager.cpp_lora_weights[

View File

@ -42,7 +42,7 @@ import torch
import tensorrt as trt
# isort: on
from tensorrt_llm.bindings import DataType, GptJsonConfig
from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.logger import logger
@ -198,6 +198,10 @@ _binding_dtype_bits = {
}
def binding_layer_type_to_str(layer_type: LayerType) -> str:
return layer_type.name.lower()
def binding_to_str_dtype(binding_dtype) -> str:
ret = _binding_to_str_dtype.get(binding_dtype)
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'

View File

@ -205,7 +205,10 @@ class BaseWorker(GenerationExecutor):
# point in the TRT flow is currently not supported (it's at the CPP
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
# optimization is not available in TRT-python flow.
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
self._lora_manager = LoraManager(
mapping=engine_config.pretrained_config.mapping,
model_config=self._runtime_model_config,
cpp_peft_cache_manager=None)
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()
@ -216,8 +219,7 @@ class BaseWorker(GenerationExecutor):
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
ResourceManagerType.PEFT_CACHE_MANAGER)
self._lora_manager = LoraManager(
cpp_peft_cache_manager=peft_cache_manager.impl)
self._lora_manager = peft_cache_manager.get_lora_manager()
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
@ -302,7 +304,6 @@ class BaseWorker(GenerationExecutor):
[lora_request.path],
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
runtime_mapping=None,
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids

View File

@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules():
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",

View File

@ -1,4 +1,5 @@
import io
import itertools
import json
import logging
import re
@ -660,11 +661,17 @@ class LoraManager(object):
}
def __init__(
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
self,
*,
mapping: Mapping,
model_config: "ModelConfig",
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
):
"""Constructor.
Args:
mapping (Mapping): Parallelism related information.
model_config (ModelConfig): model configuration python class instance.
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
the adapter is already loaded in the LoRA CPU cache.
@ -704,6 +711,8 @@ class LoraManager(object):
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
self.lora_target_modules: List[str] = []
self._mapping = mapping
self._model_config = model_config
self._cpp_peft_cache_manager = cpp_peft_cache_manager
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
@ -730,7 +739,6 @@ class LoraManager(object):
self,
model_dirs_or_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
ckpt_source: str = "hf",
) -> List[str]:
@ -743,7 +751,6 @@ class LoraManager(object):
return self.load_from_hf(
model_dirs=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
elif ckpt_source == "nemo":
@ -754,7 +761,6 @@ class LoraManager(object):
return self.load_from_nemo(
model_files=nemo_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
else:
@ -764,7 +770,6 @@ class LoraManager(object):
self,
model_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
) -> List[str]:
"""Returns the adapter UIDs that were loaded by this call.
@ -772,11 +777,6 @@ class LoraManager(object):
Note that when an adapter was already loaded before this call, it would not be
included in the returned list of UIDs.
"""
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
tp_rank = runtime_mapping.tp_rank
if uids is None:
uids = [self._generate_uid() for _ in range(len(model_files))]
assert len(uids) == len(model_files)
@ -829,10 +829,6 @@ class LoraManager(object):
t_in = all_lora_weights[layer_idx]["in"]
t_out = all_lora_weights[layer_idx]["out"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()
else:
t_in = None
t_out = None
@ -882,7 +878,6 @@ class LoraManager(object):
self,
model_dirs: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
component: Optional[str] = None,
) -> List[str]:
@ -939,11 +934,6 @@ class LoraManager(object):
...
"""
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
tp_rank = runtime_mapping.tp_rank
if uids is None:
uids = [self._generate_uid() for _ in range(len(model_dirs))]
assert len(uids) == len(model_dirs)
@ -983,6 +973,70 @@ class LoraManager(object):
lora_model[key] = value
return lora_model
def interleave_fused_lora_weights_for_tp(
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
) -> List[torch.Tensor]:
"""Interleaves fused LoRA modules weights for TP.
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
where N=TP size.
""" # noqa: D205
assert weight.shape[rank_dim] == sum(part_sizes)
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
weight_parts = [
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
for i in range(len(part_sizes))
]
for i in range(len(part_sizes)):
assert weight_parts[i].shape[rank_dim] % tp_size == 0
# Split each part into tp_size chunks.
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
# where N is TP size, for attn_qkv.
weight_parts_tp_weights = [
torch.split(
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
)
for i in range(len(part_sizes))
]
# Interleave the parts across TP ranks and flatten the list of lists into a single list.
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))
def prepare_fused_lora_modules_for_tp(
lora_module: str, t_out: torch.Tensor, rank_dim: int
) -> torch.Tensor:
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
See interleave_fused_lora_weights_for_tp for more details.
""" # noqa: D205
tp_size = self._mapping.tp_size
if tp_size == 1:
return t_out
part_sizes = []
if lora_module == "mlp_gate_up":
assert t_out.shape[rank_dim] % 2 == 0
half_size = t_out.shape[rank_dim] // 2
part_sizes = [half_size, half_size]
elif lora_module == "attn_qkv":
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
part_sizes = [q_size, kv_size, kv_size]
if part_sizes:
interleaved_parts = interleave_fused_lora_weights_for_tp(
t_out, rank_dim, tp_size, part_sizes
)
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
t_out = torch.cat(interleaved_parts, dim=rank_dim)
return t_out
def load_from_model_dir(uid, model_dir, hf_config):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
@ -1060,36 +1114,9 @@ class LoraManager(object):
t_mag = module_weights.get("magnitude", None)
is_dora = t_mag is not None
if lora_module in ["moe_router", "mlp_router"]:
pass
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
pass
elif lora_module in [
"attn_dense",
"cross_attn_dense",
"mlp_4h_to_h",
"moe_4h_to_h",
]:
# split by row
dim = 2 if has_expert_indices else 1
assert t_in.shape[dim] % tp_size == 0
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
else:
# split by column
dim = 1 if has_expert_indices else 0
assert t_out.shape[dim] % tp_size == 0
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
if dim == 0 and is_dora and t_mag is not None:
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()
rank_dim = 1 if has_expert_indices else 0
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)
effective_rank = t_in.shape[rank_dim]
t_in = t_in.cuda().contiguous()

View File

@ -174,12 +174,14 @@ class EncDecModelRunner:
# encoder lora manager setup
if self.encoder_model_config.lora_plugin:
self.encoder_lora_manager = LoraManager()
self.encoder_lora_manager = LoraManager(
mapping=self.encoder_runtime_mapping,
model_config=self.encoder_model_config,
)
# TODO: this is only for bart
self.encoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.encoder_model_config,
runtime_mapping=self.encoder_runtime_mapping,
component='encoder',
)
else:
@ -197,12 +199,14 @@ class EncDecModelRunner:
# decoder lora manager setup
if self.decoder_model_config.lora_plugin:
self.decoder_lora_manager = LoraManager()
self.decoder_lora_manager = LoraManager(
mapping=self.decoder_runtime_mapping,
model_config=self.decoder_model_config,
)
# TODO: this is only for bart
self.decoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.decoder_model_config,
runtime_mapping=self.decoder_runtime_mapping,
component='decoder',
)
else:

View File

@ -40,7 +40,8 @@ from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \
PoolsKVCacheManager
from tensorrt_llm.runtime.redrafter_utils import *
from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
from .._utils import (binding_layer_type_to_str, binding_to_str_dtype,
pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
trt_dtype_to_torch)
from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free
from ..layers import LanguageAdapterConfig
@ -653,6 +654,41 @@ class ModelConfig:
# language adapter
language_adapter_config: Optional[LanguageAdapterConfig] = None
@classmethod
def from_model_config_cpp(cls, model_config_cpp) -> 'ModelConfig':
"""Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance.
Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython
won't have all of its fields initialized.
"""
return cls(
max_batch_size=model_config_cpp.max_batch_size,
max_beam_width=model_config_cpp.max_beam_width,
vocab_size=model_config_cpp.vocab_size,
num_layers=model_config_cpp.num_layers(),
num_heads=model_config_cpp.num_heads,
num_kv_heads=model_config_cpp.num_kv_heads(0),
hidden_size=model_config_cpp.hidden_size,
remove_input_padding=model_config_cpp.use_packed_input,
kv_cache_type=model_config_cpp.kv_cache_type,
cross_attention=model_config_cpp.use_cross_attention,
head_size=model_config_cpp.head_size,
max_prompt_embedding_table_size=model_config_cpp.
max_prompt_embedding_table_size,
quant_mode=QuantMode(model_config_cpp.quant_mode.value),
gather_context_logits=model_config_cpp.compute_context_logits,
gather_generation_logits=model_config_cpp.compute_generation_logits,
gpt_attention_plugin=model_config_cpp.use_gpt_attention_plugin,
dtype=binding_to_str_dtype(model_config_cpp.data_type),
num_kv_heads_per_layer=model_config_cpp.num_kv_heads_per_layer,
tokens_per_block=model_config_cpp.tokens_per_block,
lora_plugin=model_config_cpp.use_lora_plugin,
layer_types=[
binding_layer_type_to_str(lt)
for lt in model_config_cpp.layer_types
],
)
@dataclass
class SamplingConfig:

View File

@ -611,11 +611,11 @@ class ModelRunner(ModelRunnerMixin):
session.runtime._set_weight_streaming(gpu_weights_percent)
if session.use_lora_plugin:
lora_manager = LoraManager()
lora_manager = LoraManager(mapping=runtime_mapping,
model_config=model_config)
if lora_dir is not None:
lora_manager.load_from_ckpt(lora_dir,
model_config=model_config,
runtime_mapping=runtime_mapping,
ckpt_source=lora_ckpt_source)
else:
lora_manager = None
@ -720,11 +720,11 @@ class ModelRunner(ModelRunnerMixin):
debug_mode=debug_mode,
stream=stream)
if session.use_lora_plugin:
lora_manager = LoraManager()
lora_manager = LoraManager(mapping=runtime_mapping,
model_config=model_config)
if lora_dir is not None:
lora_manager.load_from_ckpt(lora_dir,
model_config=model_config,
runtime_mapping=runtime_mapping,
ckpt_source=lora_ckpt_source)
else:
lora_manager = None

View File

@ -32,8 +32,9 @@ from ..builder import EngineConfig
from ..layers import MropeParams
from ..logger import logger
from ..mapping import Mapping
from .generation import (LogitsProcessor, LoraManager, SamplingConfig,
StoppingCriteria)
from .generation import LogitsProcessor, LoraManager
from .generation import ModelConfig as ModelConfigPython
from .generation import SamplingConfig, StoppingCriteria
from .model_runner import ModelRunnerMixin, _engine_config_to_model_config
_bindings_dtype_to_torch_dtype_dict = {
@ -277,7 +278,11 @@ class ModelRunnerCpp(ModelRunnerMixin):
engine_config = EngineConfig.from_json_file(f"{engine_dir}/config.json")
if model_config.use_lora_plugin and rank == 0:
lora_manager = LoraManager()
mapping = _world_config_to_mapping(world_config)
lora_manager = LoraManager(
mapping=mapping,
model_config=ModelConfigPython.from_model_config_cpp(
model_config))
if lora_dir is None:
config_lora_dir = engine_config.build_config.lora_config.lora_dir
if len(config_lora_dir) > 0:
@ -292,7 +297,6 @@ class ModelRunnerCpp(ModelRunnerMixin):
# For Executor, only rank 0 can enqueue requests, and should hold all lora weights
lora_manager.load_from_ckpt(lora_dir,
model_config=runtime_model_config,
runtime_mapping=None,
ckpt_source=lora_ckpt_source)
else:
raise RuntimeError(

View File

@ -2,7 +2,7 @@ import json
import tarfile
import tempfile
from pathlib import Path
from typing import OrderedDict, Type
from typing import List, OrderedDict, Type
import torch
from utils.llm_data import llm_models_root
@ -11,10 +11,62 @@ from utils.util import duplicate_list_to_length, flatten_list, similar
from tensorrt_llm import SamplingParams
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi.llm import BaseLLM
from tensorrt_llm.lora_helper import LoraConfig
_RU_LORA_ADAPTER_PROMPTS = [
"Назови главную площадь в центре Москвы.",
"Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".",
"Что означает выражение \"водить за нос\"? Объясни в двух словах.",
]
def _generate_phi3_response_lora_fused_modules(llm_class: Type[BaseLLM],
prompts: List[str],
**extra_llm_kwargs) -> List[str]:
"""Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter.
The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules.
Returns the generated texts.
""" # noqa: D205
hf_model_dir = f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct"
hf_lora_dir = f"{llm_models_root()}/lora/phi/Phi-3-mini-4k-instruct-ru-lora"
lora_req = LoRARequest("ru-lora", 0, hf_lora_dir)
sampling_params = SamplingParams(max_tokens=20)
lora_config = LoraConfig(lora_dir=[hf_lora_dir],
max_lora_rank=16,
max_loras=2,
max_cpu_loras=2)
lora_requests = [lora_req] * len(prompts)
with llm_class(hf_model_dir, lora_config=lora_config,
**extra_llm_kwargs) as llm:
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests)
return [output.outputs[0].text for output in outputs]
def check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
llm_class: Type[BaseLLM], **extra_llm_kwargs) -> None:
"""Tests the output with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter with TP=2 is identical to
the output with TP=1.
That LoRA adapter has fused attention QKV and fused MLP gate up proj modules.
""" # noqa: D205
extra_llm_kwargs["tensor_parallel_size"] = 1
outputs_tp1 = _generate_phi3_response_lora_fused_modules(
llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs)
extra_llm_kwargs["tensor_parallel_size"] = 2
outputs_tp2 = _generate_phi3_response_lora_fused_modules(
llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs)
assert outputs_tp1 == outputs_tp2
def check_llama_7b_multi_unique_lora_adapters_from_request(
lora_adapter_count_per_call: list[int], repeat_calls: int,
lora_adapter_count_per_call: List[int], repeat_calls: int,
repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs):
"""Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests
repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID.

View File

@ -5,7 +5,7 @@ from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.lora_helper import LoraConfig
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
from .test_llm import _test_llm_capture_request_error
from utils.util import skip_ray
@ -62,6 +62,15 @@ def test_llama_7b_multi_lora_tp2():
cuda_graph_config=None)
@pytest.mark.gpu2
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
LLM,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
@pytest.mark.skip(reason="https://nvbugs/5560921")
@skip_ray
@pytest.mark.gpu2