fix: Limit llama4 context length to 8k (#3778)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-04-23 08:55:10 -07:00 committed by GitHub
parent 49262a62a5
commit 0bc520f15e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -846,6 +846,15 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
def infer_max_seq_len(self):
# TODO: increase to support 10M context length. There are two blockers
# right now:
# 1. We need to implement chunked attention.
# 2. CUDA graph warmup will crash when the cached context is that long.
# This only affects the TRTLLM backend; flashinfer is fine. It is
# most likely an issue with the kernel.
return 8192
def load_weights(self, weights: Dict):
new_weights = {}
for key, tensor in weights.items():

View File

@ -16,7 +16,9 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@pytest.mark.parametrize("backend", ["TRTLLM", "FLASHINFER"],
ids=["trtllm", "flashinfer"])
@pytest.mark.parametrize("tp_size", [8], ids=["tp8"])
def test_llama4(model_name, backend, tp_size):
@pytest.mark.parametrize("use_cuda_graph", [True, False],
ids=["enable_graph", "disable_graph"])
def test_llama4(model_name, backend, tp_size, use_cuda_graph):
prompts = [
"The president of the United States is",
]
@ -25,7 +27,8 @@ def test_llama4(model_name, backend, tp_size):
" the head of state and head of government of the",
]
pytorch_config = PyTorchConfig(attn_backend=backend, )
pytorch_config = PyTorchConfig(attn_backend=backend,
use_cuda_graph=use_cuda_graph)
model_dir = str(llm_models_root() / "llama4-models" / model_name)
llm = LLM(