mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Limit llama4 context length to 8k (#3778)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
49262a62a5
commit
0bc520f15e
@ -846,6 +846,15 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
def infer_max_seq_len(self):
|
||||
# TODO: increase to support 10M context length. There are two blockers
|
||||
# right now:
|
||||
# 1. We need to implement chunked attention.
|
||||
# 2. CUDA graph warmup will crash when the cached context is that long.
|
||||
# This only affects the TRTLLM backend; flashinfer is fine. It is
|
||||
# most likely an issue with the kernel.
|
||||
return 8192
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
new_weights = {}
|
||||
for key, tensor in weights.items():
|
||||
|
||||
@ -16,7 +16,9 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
@pytest.mark.parametrize("backend", ["TRTLLM", "FLASHINFER"],
|
||||
ids=["trtllm", "flashinfer"])
|
||||
@pytest.mark.parametrize("tp_size", [8], ids=["tp8"])
|
||||
def test_llama4(model_name, backend, tp_size):
|
||||
@pytest.mark.parametrize("use_cuda_graph", [True, False],
|
||||
ids=["enable_graph", "disable_graph"])
|
||||
def test_llama4(model_name, backend, tp_size, use_cuda_graph):
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
]
|
||||
@ -25,7 +27,8 @@ def test_llama4(model_name, backend, tp_size):
|
||||
" the head of state and head of government of the",
|
||||
]
|
||||
|
||||
pytorch_config = PyTorchConfig(attn_backend=backend, )
|
||||
pytorch_config = PyTorchConfig(attn_backend=backend,
|
||||
use_cuda_graph=use_cuda_graph)
|
||||
model_dir = str(llm_models_root() / "llama4-models" / model_name)
|
||||
|
||||
llm = LLM(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user