This commit is contained in:
Tzu-Ling Kan 2026-01-13 21:17:40 +08:00 committed by GitHub
commit d4247d2d61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 19 deletions

View File

@ -31,7 +31,7 @@ from ..executor import (DetokenizedGenerationResultBase, GenerationExecutor,
GenerationResult, IterationResult, LoRARequest, GenerationResult, IterationResult, LoRARequest,
PostprocWorkerConfig, PromptAdapterRequest) PostprocWorkerConfig, PromptAdapterRequest)
from ..executor.postproc_worker import PostprocParams from ..executor.postproc_worker import PostprocParams
from ..executor.utils import (create_mpi_comm_session, from ..executor.utils import (RequestError, create_mpi_comm_session,
get_spawn_proxy_process_env) get_spawn_proxy_process_env)
from ..inputs import (PromptInputs, create_input_processor, from ..inputs import (PromptInputs, create_input_processor,
create_input_processor_with_hash, get_cache_salt_id, create_input_processor_with_hash, get_cache_salt_id,
@ -686,7 +686,7 @@ class BaseLLM:
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only: if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
max_num_tokens = self.args.max_num_tokens max_num_tokens = self.args.max_num_tokens
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens: if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
raise ValueError( raise RequestError(
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed " f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
f"max_num_tokens ({max_num_tokens})") f"max_num_tokens ({max_num_tokens})")
return return

View File

@ -2393,7 +2393,8 @@ def test_llm_chunked_prefill():
enable_chunked_prefill=False, enable_chunked_prefill=False,
fast_build=True) fast_build=True)
with pytest.raises(ValueError): # max_num_tokens validation now raises RequestError consistently
with pytest.raises(RequestError):
output = llm.generate_async( output = llm.generate_async(
"A " * build_config.max_num_tokens, "A " * build_config.max_num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
@ -2436,13 +2437,9 @@ def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
) )
prompt = 'A ' * 65 # the minimum max_num_tokens is 64 prompt = 'A ' * 65 # the minimum max_num_tokens is 64
if pytorch_backend: # Both backends now consistently raise RequestError for max_num_tokens validation
# pytorch backend will raise ValueError for max_num_tokens with pytest.raises(RequestError):
with pytest.raises(ValueError): llm.generate(prompt)
llm.generate(prompt)
else:
with pytest.raises(RequestError):
llm.generate(prompt)
def test_llm_capture_request_error(): def test_llm_capture_request_error():

View File

@ -8,7 +8,7 @@ import pytest
from tensorrt_llm import LLM from tensorrt_llm import LLM
from tensorrt_llm.disaggregated_params import DisaggregatedParams from tensorrt_llm.disaggregated_params import DisaggregatedParams
from tensorrt_llm.executor import GenerationExecutorWorker from tensorrt_llm.executor import GenerationExecutorWorker, RequestError
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
@ -830,10 +830,13 @@ class TestLlmError:
kv_cache_config=global_kvcache_config, kv_cache_config=global_kvcache_config,
max_num_tokens=100) max_num_tokens=100)
with pytest.raises(ValueError, try:
match="should not exceed max_num_tokens"): with pytest.raises(RequestError,
ids = [random.randint(10, 100) for _ in range(101)] match="should not exceed max_num_tokens"):
llm.generate([ids]) ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
finally:
llm.shutdown()
class FailingExecutorWorker(GenerationExecutorWorker): class FailingExecutorWorker(GenerationExecutorWorker):
@ -962,10 +965,13 @@ class TestLlmError:
kv_cache_config=global_kvcache_config, kv_cache_config=global_kvcache_config,
max_num_tokens=100) max_num_tokens=100)
with pytest.raises(ValueError, try:
match="should not exceed max_num_tokens"): with pytest.raises(RequestError,
ids = [random.randint(10, 100) for _ in range(101)] match="should not exceed max_num_tokens"):
llm.generate([ids]) ids = [random.randint(10, 100) for _ in range(101)]
llm.generate([ids])
finally:
llm.shutdown()
@skip_ray @skip_ray