mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge febc86ced0 into 6df2c8a074
This commit is contained in:
commit
d4247d2d61
@ -31,7 +31,7 @@ from ..executor import (DetokenizedGenerationResultBase, GenerationExecutor,
|
||||
GenerationResult, IterationResult, LoRARequest,
|
||||
PostprocWorkerConfig, PromptAdapterRequest)
|
||||
from ..executor.postproc_worker import PostprocParams
|
||||
from ..executor.utils import (create_mpi_comm_session,
|
||||
from ..executor.utils import (RequestError, create_mpi_comm_session,
|
||||
get_spawn_proxy_process_env)
|
||||
from ..inputs import (PromptInputs, create_input_processor,
|
||||
create_input_processor_with_hash, get_cache_salt_id,
|
||||
@ -686,7 +686,7 @@ class BaseLLM:
|
||||
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
|
||||
max_num_tokens = self.args.max_num_tokens
|
||||
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
|
||||
raise ValueError(
|
||||
raise RequestError(
|
||||
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
|
||||
f"max_num_tokens ({max_num_tokens})")
|
||||
return
|
||||
|
||||
@ -2393,7 +2393,8 @@ def test_llm_chunked_prefill():
|
||||
enable_chunked_prefill=False,
|
||||
fast_build=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# max_num_tokens validation now raises RequestError consistently
|
||||
with pytest.raises(RequestError):
|
||||
output = llm.generate_async(
|
||||
"A " * build_config.max_num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
@ -2436,13 +2437,9 @@ def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
|
||||
)
|
||||
|
||||
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
|
||||
if pytorch_backend:
|
||||
# pytorch backend will raise ValueError for max_num_tokens
|
||||
with pytest.raises(ValueError):
|
||||
llm.generate(prompt)
|
||||
else:
|
||||
with pytest.raises(RequestError):
|
||||
llm.generate(prompt)
|
||||
# Both backends now consistently raise RequestError for max_num_tokens validation
|
||||
with pytest.raises(RequestError):
|
||||
llm.generate(prompt)
|
||||
|
||||
|
||||
def test_llm_capture_request_error():
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.disaggregated_params import DisaggregatedParams
|
||||
from tensorrt_llm.executor import GenerationExecutorWorker
|
||||
from tensorrt_llm.executor import GenerationExecutorWorker, RequestError
|
||||
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
|
||||
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
|
||||
@ -830,10 +830,13 @@ class TestLlmError:
|
||||
kv_cache_config=global_kvcache_config,
|
||||
max_num_tokens=100)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="should not exceed max_num_tokens"):
|
||||
ids = [random.randint(10, 100) for _ in range(101)]
|
||||
llm.generate([ids])
|
||||
try:
|
||||
with pytest.raises(RequestError,
|
||||
match="should not exceed max_num_tokens"):
|
||||
ids = [random.randint(10, 100) for _ in range(101)]
|
||||
llm.generate([ids])
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
class FailingExecutorWorker(GenerationExecutorWorker):
|
||||
@ -962,10 +965,13 @@ class TestLlmError:
|
||||
kv_cache_config=global_kvcache_config,
|
||||
max_num_tokens=100)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="should not exceed max_num_tokens"):
|
||||
ids = [random.randint(10, 100) for _ in range(101)]
|
||||
llm.generate([ids])
|
||||
try:
|
||||
with pytest.raises(RequestError,
|
||||
match="should not exceed max_num_tokens"):
|
||||
ids = [random.randint(10, 100) for _ in range(101)]
|
||||
llm.generate([ids])
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@skip_ray
|
||||
|
||||
Loading…
Reference in New Issue
Block a user