This commit is contained in:
Tzu-Ling Kan 2026-01-13 21:17:40 +08:00 committed by GitHub
commit d4247d2d61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 19 deletions

View File

@ -31,7 +31,7 @@ from ..executor import (DetokenizedGenerationResultBase, GenerationExecutor,
GenerationResult, IterationResult, LoRARequest,
PostprocWorkerConfig, PromptAdapterRequest)
from ..executor.postproc_worker import PostprocParams
from ..executor.utils import (create_mpi_comm_session,
from ..executor.utils import (RequestError, create_mpi_comm_session,
get_spawn_proxy_process_env)
from ..inputs import (PromptInputs, create_input_processor,
create_input_processor_with_hash, get_cache_salt_id,
@ -686,7 +686,7 @@ class BaseLLM:
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
max_num_tokens = self.args.max_num_tokens
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
raise ValueError(
raise RequestError(
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
f"max_num_tokens ({max_num_tokens})")
return

View File

@ -2393,7 +2393,8 @@ def test_llm_chunked_prefill():
enable_chunked_prefill=False,
fast_build=True)
with pytest.raises(ValueError):
# max_num_tokens validation now raises RequestError consistently
with pytest.raises(RequestError):
output = llm.generate_async(
"A " * build_config.max_num_tokens,
sampling_params=sampling_params,
@ -2436,13 +2437,9 @@ def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
)
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
if pytorch_backend:
# pytorch backend will raise ValueError for max_num_tokens
with pytest.raises(ValueError):
llm.generate(prompt)
else:
with pytest.raises(RequestError):
llm.generate(prompt)
# Both backends now consistently raise RequestError for max_num_tokens validation
with pytest.raises(RequestError):
llm.generate(prompt)
def test_llm_capture_request_error():

View File

@ -8,7 +8,7 @@ import pytest
from tensorrt_llm import LLM
from tensorrt_llm.disaggregated_params import DisaggregatedParams
from tensorrt_llm.executor import GenerationExecutorWorker
from tensorrt_llm.executor import GenerationExecutorWorker, RequestError
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
@ -830,10 +830,13 @@ class TestLlmError:
kv_cache_config=global_kvcache_config,
max_num_tokens=100)
with pytest.raises(ValueError,
match="should not exceed max_num_tokens"):
ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
try:
with pytest.raises(RequestError,
match="should not exceed max_num_tokens"):
ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
finally:
llm.shutdown()
class FailingExecutorWorker(GenerationExecutorWorker):
@ -962,10 +965,13 @@ class TestLlmError:
kv_cache_config=global_kvcache_config,
max_num_tokens=100)
with pytest.raises(ValueError,
match="should not exceed max_num_tokens"):
ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
try:
with pytest.raises(RequestError,
match="should not exceed max_num_tokens"):
ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
finally:
llm.shutdown()
@skip_ray