mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-21 10:15:46 +08:00
[TRTLLM-10305][feat] Support customized seq len larger than model config (#10600)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
4f86c5f5ce
commit
722978b837
@ -1155,6 +1155,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
release_gc()
|
||||
|
||||
def _init_max_seq_len(self):
|
||||
# Allow user to override the inferred max_seq_len with a warning.
|
||||
allow_long_max_model_len = os.getenv(
|
||||
"TLLM_ALLOW_LONG_MAX_MODEL_LEN",
|
||||
"0").lower() in ["1", "true", "yes", "y"]
|
||||
|
||||
# For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models
|
||||
if hasattr(self.model, 'infer_max_seq_len'):
|
||||
inferred_max_seq_len = self.model.infer_max_seq_len()
|
||||
@ -1166,15 +1171,20 @@ class PyTorchModelEngine(ModelEngine):
|
||||
f"max_seq_len is not specified, using inferred value {inferred_max_seq_len}"
|
||||
)
|
||||
self.max_seq_len = inferred_max_seq_len
|
||||
|
||||
elif inferred_max_seq_len < self.max_seq_len:
|
||||
# NOTE: py_executor_creator makes sure that the executor uses this
|
||||
# smaller value as its max_seq_len too.
|
||||
logger.warning(
|
||||
f"Specified {self.max_seq_len=} is larger than what the model can support "
|
||||
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
|
||||
)
|
||||
self.max_seq_len = inferred_max_seq_len
|
||||
if allow_long_max_model_len:
|
||||
logger.warning(
|
||||
f"User specified max_seq_len is larger than the config in the model config file "
|
||||
f"({inferred_max_seq_len}). Setting max_seq_len to user's specified value {self.max_seq_len}. "
|
||||
)
|
||||
else:
|
||||
# NOTE: py_executor_creator makes sure that the executor uses this
|
||||
# smaller value as its max_seq_len too.
|
||||
logger.warning(
|
||||
f"Specified {self.max_seq_len=} is larger than what the model can support "
|
||||
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
|
||||
)
|
||||
self.max_seq_len = inferred_max_seq_len
|
||||
|
||||
def _infer_max_seq_len_from_config(self) -> int:
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user