mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[nvbug/5374773] chore: Add a runtime flag to enable fail fast when attn window is too large to fit at least one sequence in KV cache (#5974)
Signed-off-by: moraxu <mguzek@nvidia.com>
This commit is contained in:
parent
c35c78ff58
commit
08d57123f9
@ -1484,7 +1484,8 @@ public:
|
||||
std::optional<GuidedDecodingConfig> guidedDecodingConfig = std::nullopt,
|
||||
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs = std::nullopt,
|
||||
std::optional<CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt,
|
||||
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false);
|
||||
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false,
|
||||
bool failFastOnAttentionWindowTooLarge = false);
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -1519,6 +1520,7 @@ public:
|
||||
[[nodiscard]] bool getPromptTableOffloading() const;
|
||||
[[nodiscard]] std::optional<CacheTransceiverConfig> getCacheTransceiverConfig() const;
|
||||
[[nodiscard]] bool getEnableTrtOverlap() const;
|
||||
[[nodiscard]] bool getFailFastOnAttentionWindowTooLarge() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType32 maxBeamWidth);
|
||||
void setMaxBatchSize(SizeType32 maxBatchSize);
|
||||
@ -1548,6 +1550,7 @@ public:
|
||||
void setPromptTableOffloading(bool promptTableOffloading);
|
||||
void setCacheTransceiverConfig(CacheTransceiverConfig const& cacheTransceiverConfig);
|
||||
void setEnableTrtOverlap(bool enableTrtOverlap);
|
||||
void setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -1634,6 +1637,10 @@ private:
|
||||
|
||||
/// @brief Controls whether preparation and TRT engine execution should be overlapped.
|
||||
bool mEnableTrtOverlap{false};
|
||||
|
||||
/// @brief Controls whether to fail fast when attention window is too large to fit even a single sequence in the KV
|
||||
/// cache.
|
||||
bool mFailFastOnAttentionWindowTooLarge{false};
|
||||
};
|
||||
|
||||
struct KVCacheCreatedData
|
||||
|
||||
@ -296,7 +296,6 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
|
||||
|
||||
auto const [freePrimaryMemBytes, freeSecondaryMemBytes]
|
||||
= BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig);
|
||||
|
||||
if (mModelConfig.useCrossAttention())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(kvCacheConfig.getCrossKvCacheFraction().has_value(),
|
||||
@ -304,10 +303,11 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
|
||||
auto const crossKvCacheFraction = kvCacheConfig.getCrossKvCacheFraction().value();
|
||||
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF,
|
||||
freePrimaryMemBytes * (1.0f - crossKvCacheFraction),
|
||||
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize);
|
||||
mCrossKvCacheManager
|
||||
= createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS, freePrimaryMemBytes * crossKvCacheFraction,
|
||||
freeSecondaryMemBytes * crossKvCacheFraction, cacheTransPreAllocaSize);
|
||||
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize,
|
||||
executorConfig.getFailFastOnAttentionWindowTooLarge());
|
||||
mCrossKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS,
|
||||
freePrimaryMemBytes * crossKvCacheFraction, freeSecondaryMemBytes * crossKvCacheFraction,
|
||||
cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
|
||||
TLLM_LOG_INFO("This is an Encoder-Decoder model, set %0.1f cross KV cache fraction based on the config.",
|
||||
crossKvCacheFraction);
|
||||
}
|
||||
@ -315,8 +315,8 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!kvCacheConfig.getCrossKvCacheFraction().has_value(),
|
||||
"Do not set crossKvCacheFraction for decoder-only model");
|
||||
mKvCacheManager = createKvCacheManager(
|
||||
kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize);
|
||||
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes,
|
||||
freeSecondaryMemBytes, cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
|
||||
}
|
||||
|
||||
mCacheTransceiver
|
||||
@ -550,7 +550,8 @@ void TrtGptModelInflightBatching::reshapeKvTensors(OffsetTableDimensions const&
|
||||
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
|
||||
|
||||
std::pair<BlocksPerWindow, std::vector<SizeType32>>
|
||||
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWindow const& blocksPerWindow)
|
||||
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(
|
||||
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge)
|
||||
{
|
||||
// At this point, we can only validate that the cheapest sequence in terms of kv-cache resources still fits. More
|
||||
// validation is needed on a per-request basis, once the prompt / output lengths and the actual beam width are
|
||||
@ -591,6 +592,16 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi
|
||||
}
|
||||
TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s",
|
||||
common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str());
|
||||
|
||||
if (failFastOnAttentionWindowTooLarge)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Attention window too large to fit even a single sequence in the KV cache. Failing fast rather than "
|
||||
"attempting an adjustment of the window sizes. "
|
||||
"Old: "
|
||||
+ common::vec2str(getMaxAttentionWindowVec()) + ", New: " + common::vec2str(newMaxAttentionWindowVec));
|
||||
}
|
||||
|
||||
setMaxAttentionWindowVec(newMaxAttentionWindowVec);
|
||||
if (getMaxSequenceLen() > getMaxAttentionWindow())
|
||||
{
|
||||
@ -613,7 +624,7 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi
|
||||
|
||||
std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::createKvCacheManager(
|
||||
KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, uint64_t freePrimaryMemBytes,
|
||||
uint64_t freeSecondaryMemBytes, size_t extraCostMemory)
|
||||
uint64_t freeSecondaryMemBytes, size_t extraCostMemory, bool const failFastOnAttentionWindowTooLarge)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
bool isCrossAttention = kvCacheType == KvCacheType::kCROSS;
|
||||
@ -657,7 +668,8 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c
|
||||
// and user also didn't provide maxAttentionWindow, which leads it to be equal to maxSeqLen
|
||||
if (kvCacheType == KvCacheType::kSELF)
|
||||
{
|
||||
std::tie(blocksPerWindow, maxAttentionWindowVec) = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow);
|
||||
std::tie(blocksPerWindow, maxAttentionWindowVec)
|
||||
= clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow, failFastOnAttentionWindowTooLarge);
|
||||
}
|
||||
|
||||
kv_cache_manager::TempAttentionWindowInputs tempAttentionWindowInputs;
|
||||
|
||||
@ -280,7 +280,8 @@ private:
|
||||
void createBuffers(executor::DecodingConfig const& decodingConfig,
|
||||
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs);
|
||||
std::unique_ptr<KVCacheManager> createKvCacheManager(KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType,
|
||||
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory);
|
||||
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory,
|
||||
bool const failFastOnAttentionWindowTooLarge = false);
|
||||
void createRnnStateManager();
|
||||
void createCustomAllReduceWorkspace();
|
||||
void createRuntimePerfKnobsTensor(executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
|
||||
@ -378,9 +379,11 @@ private:
|
||||
/// window.
|
||||
///
|
||||
/// @param blocksPerWindow map of window size to number of blocks.
|
||||
/// @param failFastOnAttentionWindowTooLarge if true, the function will report a runtime error if the attention
|
||||
/// window is too large to fit even a single sequence in the KV cache.
|
||||
/// @return pair of new blocks per window and new maxAttentionWindowVec
|
||||
[[nodiscard]] std::pair<BlocksPerWindow, std::vector<SizeType32>> clampWindowSizesToFitAtLeastOneSequence(
|
||||
BlocksPerWindow const& blocksPerWindow);
|
||||
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge = false);
|
||||
|
||||
/// @brief Change the speculative decoding mode.
|
||||
void changeSpecDecMode(ScheduledRequests const& scheduledRequests);
|
||||
|
||||
@ -34,7 +34,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
|
||||
std::optional<SpeculativeDecodingConfig> specDecConfig, std::optional<GuidedDecodingConfig> guidedDecodingConfig,
|
||||
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs,
|
||||
std::optional<CacheTransceiverConfig> cacheTransceiverConfig, bool gatherGenerationLogits,
|
||||
bool promptTableOffloading, bool enableTrtOverlap)
|
||||
bool promptTableOffloading, bool enableTrtOverlap, bool failFastOnAttentionWindowTooLarge)
|
||||
: mMaxBeamWidth(maxBeamWidth)
|
||||
, mSchedulerConfig(std::move(schedulerConfig))
|
||||
, mKvCacheConfig(std::move(kvCacheConfig))
|
||||
@ -63,6 +63,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
|
||||
, mGatherGenerationLogits(gatherGenerationLogits)
|
||||
, mPromptTableOffloading(promptTableOffloading)
|
||||
, mEnableTrtOverlap(enableTrtOverlap)
|
||||
, mFailFastOnAttentionWindowTooLarge(failFastOnAttentionWindowTooLarge)
|
||||
{
|
||||
TLLM_CHECK(iterStatsMaxIterations >= 0);
|
||||
TLLM_CHECK(requestStatsMaxIterations >= 0);
|
||||
@ -222,6 +223,11 @@ bool ExecutorConfig::getEnableTrtOverlap() const
|
||||
return mEnableTrtOverlap;
|
||||
}
|
||||
|
||||
bool ExecutorConfig::getFailFastOnAttentionWindowTooLarge() const
|
||||
{
|
||||
return mFailFastOnAttentionWindowTooLarge;
|
||||
}
|
||||
|
||||
// setters
|
||||
|
||||
void ExecutorConfig::setMaxBeamWidth(SizeType32 maxBeamWidth)
|
||||
@ -371,4 +377,9 @@ void ExecutorConfig::setEnableTrtOverlap(bool enableTrtOverlap)
|
||||
mEnableTrtOverlap = enableTrtOverlap;
|
||||
}
|
||||
|
||||
void ExecutorConfig::setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge)
|
||||
{
|
||||
mFailFastOnAttentionWindowTooLarge = failFastOnAttentionWindowTooLarge;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -459,7 +459,7 @@ void initConfigBindings(pybind11::module_& m)
|
||||
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
|
||||
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
|
||||
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
|
||||
c.getPromptTableOffloading(), c.getEnableTrtOverlap());
|
||||
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
|
||||
auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__"));
|
||||
return pickle_tuple;
|
||||
};
|
||||
@ -472,7 +472,7 @@ void initConfigBindings(pybind11::module_& m)
|
||||
|
||||
// Restore C++ data
|
||||
auto cpp_states = state[0].cast<py::tuple>();
|
||||
if (cpp_states.size() != 28)
|
||||
if (cpp_states.size() != 29)
|
||||
{
|
||||
throw std::runtime_error("Invalid cpp_states!");
|
||||
}
|
||||
@ -505,7 +505,8 @@ void initConfigBindings(pybind11::module_& m)
|
||||
cpp_states[24].cast<std::optional<tle::CacheTransceiverConfig>>(), // CacheTransceiverConfig
|
||||
cpp_states[25].cast<bool>(), // GatherGenerationLogits
|
||||
cpp_states[26].cast<bool>(), // PromptTableOffloading
|
||||
cpp_states[27].cast<bool>() // EnableTrtOverlap
|
||||
cpp_states[27].cast<bool>(), // EnableTrtOverlap
|
||||
cpp_states[28].cast<bool>() // FailFastOnAttentionWindowTooLarge
|
||||
);
|
||||
|
||||
auto py_state = state[1].cast<py::dict>();
|
||||
@ -542,7 +543,8 @@ void initConfigBindings(pybind11::module_& m)
|
||||
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
|
||||
bool, // GatherGenerationLogits
|
||||
bool, // PromptTableOffloading
|
||||
bool // EnableTrtOverlap
|
||||
bool, // EnableTrtOverlap
|
||||
bool // FailFastOnAttentionWindowTooLarge
|
||||
>(),
|
||||
py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"),
|
||||
py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"),
|
||||
@ -563,7 +565,7 @@ void initConfigBindings(pybind11::module_& m)
|
||||
py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(),
|
||||
py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(),
|
||||
py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false,
|
||||
py::arg("enable_trt_overlap") = false)
|
||||
py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false)
|
||||
.def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
|
||||
.def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
|
||||
.def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
|
||||
@ -613,6 +615,9 @@ void initConfigBindings(pybind11::module_& m)
|
||||
&tle::ExecutorConfig::setPromptTableOffloading)
|
||||
.def_property(
|
||||
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
|
||||
.def_property("fail_fast_on_attention_window_too_large",
|
||||
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
|
||||
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
|
||||
.def(py::pickle(executorConfigGetState, executorConfigSetState));
|
||||
}
|
||||
|
||||
|
||||
@ -106,6 +106,13 @@ def parse_arguments(args=None):
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Run several 10 iterations to profile the inference latencies.")
|
||||
parser.add_argument(
|
||||
'--fail_fast_on_attention_window_too_large',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
'Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache.'
|
||||
)
|
||||
|
||||
parser = add_common_args(parser)
|
||||
|
||||
@ -455,6 +462,8 @@ def main(args):
|
||||
gpu_weights_percent=args.gpu_weights_percent,
|
||||
max_output_len=args.max_output_len,
|
||||
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc,
|
||||
fail_fast_on_attention_window_too_large=args.
|
||||
fail_fast_on_attention_window_too_large,
|
||||
)
|
||||
if args.medusa_choices is not None:
|
||||
args.medusa_choices = ast.literal_eval(args.medusa_choices)
|
||||
@ -549,6 +558,8 @@ def main(args):
|
||||
eagle_choices=args.eagle_choices,
|
||||
return_all_generated_tokens=args.return_all_generated_tokens,
|
||||
input_token_extra_ids=input_token_extra_ids,
|
||||
fail_fast_on_attention_window_too_large=args.
|
||||
fail_fast_on_attention_window_too_large,
|
||||
language_adapter_uids=args.language_task_uids)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -680,7 +691,9 @@ def main(args):
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=args.
|
||||
return_all_generated_tokens,
|
||||
input_token_extra_ids=input_token_extra_ids)
|
||||
input_token_extra_ids=input_token_extra_ids,
|
||||
fail_fast_on_attention_window_too_large=args.
|
||||
fail_fast_on_attention_window_too_large)
|
||||
torch.cuda.synchronize()
|
||||
tensorrt_llm.profiler.stop("tmp")
|
||||
|
||||
|
||||
@ -84,6 +84,7 @@ def get_llm_args(model: str,
|
||||
num_postprocess_workers: int = 0,
|
||||
trust_remote_code: bool = False,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
fail_fast_on_attention_window_too_large: bool = False,
|
||||
**llm_args_extra_dict: Any):
|
||||
|
||||
if gpus_per_node is None:
|
||||
@ -107,24 +108,44 @@ def get_llm_args(model: str,
|
||||
)
|
||||
|
||||
llm_args = {
|
||||
"model": model,
|
||||
"scheduler_config": scheduler_config,
|
||||
"tokenizer": tokenizer,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"pipeline_parallel_size": pipeline_parallel_size,
|
||||
"moe_expert_parallel_size": moe_expert_parallel_size,
|
||||
"gpus_per_node": gpus_per_node,
|
||||
"trust_remote_code": trust_remote_code,
|
||||
"build_config": build_config,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_num_tokens": max_num_tokens,
|
||||
"max_beam_width": max_beam_width,
|
||||
"max_seq_len": max_seq_len,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"backend": backend if backend == "pytorch" else None,
|
||||
"num_postprocess_workers": num_postprocess_workers,
|
||||
"postprocess_tokenizer_dir": tokenizer or model,
|
||||
"reasoning_parser": reasoning_parser,
|
||||
"model":
|
||||
model,
|
||||
"scheduler_config":
|
||||
scheduler_config,
|
||||
"tokenizer":
|
||||
tokenizer,
|
||||
"tensor_parallel_size":
|
||||
tensor_parallel_size,
|
||||
"pipeline_parallel_size":
|
||||
pipeline_parallel_size,
|
||||
"moe_expert_parallel_size":
|
||||
moe_expert_parallel_size,
|
||||
"gpus_per_node":
|
||||
gpus_per_node,
|
||||
"trust_remote_code":
|
||||
trust_remote_code,
|
||||
"build_config":
|
||||
build_config,
|
||||
"max_batch_size":
|
||||
max_batch_size,
|
||||
"max_num_tokens":
|
||||
max_num_tokens,
|
||||
"max_beam_width":
|
||||
max_beam_width,
|
||||
"max_seq_len":
|
||||
max_seq_len,
|
||||
"kv_cache_config":
|
||||
kv_cache_config,
|
||||
"backend":
|
||||
backend if backend == "pytorch" else None,
|
||||
"num_postprocess_workers":
|
||||
num_postprocess_workers,
|
||||
"postprocess_tokenizer_dir":
|
||||
tokenizer or model,
|
||||
"reasoning_parser":
|
||||
reasoning_parser,
|
||||
"fail_fast_on_attention_window_too_large":
|
||||
fail_fast_on_attention_window_too_large,
|
||||
}
|
||||
|
||||
return llm_args, llm_args_extra_dict
|
||||
@ -249,16 +270,23 @@ def launch_server(host: str,
|
||||
default=None,
|
||||
help="Server role. Specify this value only if running in disaggregated mode."
|
||||
)
|
||||
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
log_level: str, backend: str, max_beam_width: int,
|
||||
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
|
||||
tp_size: int, pp_size: int, ep_size: Optional[int],
|
||||
cluster_size: Optional[int], gpus_per_node: Optional[int],
|
||||
kv_cache_free_gpu_memory_fraction: float,
|
||||
num_postprocess_workers: int, trust_remote_code: bool,
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str],
|
||||
server_role: Optional[str]):
|
||||
@click.option(
|
||||
"--fail_fast_on_attention_window_too_large",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=
|
||||
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
|
||||
)
|
||||
def serve(
|
||||
model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
|
||||
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
|
||||
ep_size: Optional[int], cluster_size: Optional[int],
|
||||
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
|
||||
num_postprocess_workers: int, trust_remote_code: bool,
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str], server_role: Optional[str],
|
||||
fail_fast_on_attention_window_too_large: bool):
|
||||
"""Running an OpenAI API compatible server
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
@ -281,7 +309,9 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,
|
||||
num_postprocess_workers=num_postprocess_workers,
|
||||
trust_remote_code=trust_remote_code,
|
||||
reasoning_parser=reasoning_parser)
|
||||
reasoning_parser=reasoning_parser,
|
||||
fail_fast_on_attention_window_too_large=
|
||||
fail_fast_on_attention_window_too_large)
|
||||
|
||||
llm_args_extra_dict = {}
|
||||
if extra_llm_api_options is not None:
|
||||
|
||||
@ -779,7 +779,9 @@ class _TrtLLM(BaseLLM):
|
||||
or tllm.BatchingType.INFLIGHT,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
gather_generation_logits=self.args.gather_generation_logits)
|
||||
gather_generation_logits=self.args.gather_generation_logits,
|
||||
fail_fast_on_attention_window_too_large=getattr(
|
||||
self.args, 'fail_fast_on_attention_window_too_large', False))
|
||||
|
||||
# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
|
||||
if max_seq_len is not None:
|
||||
@ -920,7 +922,9 @@ class _TorchLLM(BaseLLM):
|
||||
or tllm.BatchingType.INFLIGHT,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
gather_generation_logits=self.args.gather_generation_logits)
|
||||
gather_generation_logits=self.args.gather_generation_logits,
|
||||
fail_fast_on_attention_window_too_large=getattr(
|
||||
self.args, 'fail_fast_on_attention_window_too_large', False))
|
||||
|
||||
if self.args.kv_cache_config is not None:
|
||||
self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
|
||||
|
||||
@ -998,6 +998,12 @@ class BaseLlmArgs(BaseModel):
|
||||
description="The format to load the model.",
|
||||
json_schema_extra={"type": "Literal['auto', 'dummy']"})
|
||||
|
||||
fail_fast_on_attention_window_too_large: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"Fail fast when attention window is too large to fit even a single sequence in the KV cache."
|
||||
)
|
||||
|
||||
# LoRA arguments
|
||||
enable_lora: bool = Field(default=False, description="Enable LoRA.")
|
||||
|
||||
|
||||
@ -646,6 +646,7 @@ class ModelRunner(ModelRunnerMixin):
|
||||
gpu_weights_percent: float = 1,
|
||||
enable_context_fmha_fp32_acc: Optional[bool] = None,
|
||||
multi_block_mode: Optional[bool] = None,
|
||||
fail_fast_on_attention_window_too_large: bool = False,
|
||||
) -> 'ModelRunner':
|
||||
"""
|
||||
Create a ModelRunner instance from an engine directory.
|
||||
@ -667,6 +668,9 @@ class ModelRunner(ModelRunnerMixin):
|
||||
Stream to use.
|
||||
multi_block_mode (bool):
|
||||
Whether to distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel.
|
||||
fail_fast_on_attention_window_too_large (bool):
|
||||
Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache.
|
||||
Note: This parameter is only applicable to C++ runtime (ModelRunnerCpp).
|
||||
Returns:
|
||||
ModelRunner: An instance of ModelRunner.
|
||||
"""
|
||||
|
||||
@ -124,6 +124,7 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
gather_generation_logits: bool = False,
|
||||
use_variable_beam_width_search: bool = False,
|
||||
mm_embedding_offloading: bool = False,
|
||||
fail_fast_on_attention_window_too_large: bool = False,
|
||||
) -> 'ModelRunnerCpp':
|
||||
"""
|
||||
Create a ModelRunnerCpp instance from an engine directory.
|
||||
@ -197,6 +198,8 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
The mode to run the model-runner, Leader mode by default.
|
||||
gather_generation_logits (bool):
|
||||
Enable gathering generation logits.
|
||||
fail_fast_on_attention_window_too_large (bool):
|
||||
Whether to fail fast if the attention window(s) are too large to fit even a single sequence in the KVCache.
|
||||
Returns:
|
||||
ModelRunnerCpp: An instance of ModelRunnerCpp.
|
||||
"""
|
||||
@ -398,6 +401,7 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
trtllm_config.enable_chunked_context = enable_chunked_context
|
||||
trtllm_config.extended_runtime_perf_knob_config = extended_runtime_perf_knob_config
|
||||
trtllm_config.mm_embedding_offloading = mm_embedding_offloading
|
||||
trtllm_config.fail_fast_on_attention_window_too_large = fail_fast_on_attention_window_too_large
|
||||
if is_orchestrator_mode:
|
||||
communication_mode = trtllm.CommunicationMode.ORCHESTRATOR
|
||||
path = str(Path(__file__).parent.parent / 'bin' / 'executorWorker')
|
||||
|
||||
@ -53,6 +53,10 @@ methods:
|
||||
reasoning_parser:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
# Runtime behavior
|
||||
fail_fast_on_attention_window_too_large:
|
||||
annotation: bool
|
||||
default: false
|
||||
garbage_collection_gen0_threshold:
|
||||
annotation: int
|
||||
default: 20000
|
||||
|
||||
Loading…
Reference in New Issue
Block a user