TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Netanel Haber 3c52ac098f
feat: allocate minimal blocks per window size (#3028)
* implement variable window attention by breaking the block manager into window block managers per window size

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* revert isCyclic to be true if the min attention window is reached, not per window size

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* add explanatory comment to mCyclicThreshold

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* load correct gemma config

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* don't shadow inputLength in addSequence - it should remain the function scope input length between window size loop iterations

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix KVCacheManagerVariableWindowAttentionWithReuseTest for multiple window block managers

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* if TYPE_CHECKING

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* set temp_attention_window_inputs to None explicitly

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* set temp_attention_window_inputs to None explicitly

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* pass dtype as well

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* test_gemma variable sliding window attention

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* allot a fraction of primary/secondaryBlocks to different window size heaps, depending on the window size's total contribution to the kvcache size (i.e., including all layers)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* remove || mEnableBlockReuse which erroneously triggers beamsearch code for cyclic variable attention window code

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* turn off request delaying for MaxUtil

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* make comments better

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* windowSizesTotalSum using std::accumulate

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix error handling of forwardAsync - forwardAsync catch-all catch cleanup code that runs terminateRequest can also fail and must be caught

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix comments

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* remove assert that kills disagg tests, since it isn't necessary

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix corrupted expression: 'isNewTask && (peftCacheManager ?' -> '(isNewTask && peftCacheManager) ?' which caused boolean algebra. Main is correct

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* add Gemma3 to SUPPORTED_HF_ARCHITECTURES

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* support Gemma3

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix kvfactor field for deepseek

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix comment

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix gemma-3 entries in testlist to include vswa

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* only quantize gemma2 VSWA

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

remove misleading comment

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix test_gemma

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* in sendRequestInfo, fromOldAllocatedBlockIds->fromOldAllocatedBlockIds, like in main

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

* fix: disable KV cache reuse if using attention sink (#3021)

* fix: disable KV cache reuse if using attention sink

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: disable KV cache reuse if sink bubble

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* add comment

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-17 16:04:57 +08:00

177 lines
7.1 KiB
C++

/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
#include <variant>
namespace tensorrt_llm::batch_manager
{
namespace kv_cache_manager
{
class BaseKVCacheManager;
}
class BasePeftCacheManager;
} // namespace tensorrt_llm::batch_manager
namespace tensorrt_llm::batch_manager
{
using tensorrt_llm::runtime::SizeType32;
using common::OptionalRef;
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
/// or even pause previously started requests.
class BaseCapacityScheduler
{
public:
explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: mNoScheduleUntilState(noScheduleUntilState)
, mNoScheduleAfterState(noScheduleAfterState)
{
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept
{
return mNoScheduleUntilState;
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept
{
return mNoScheduleAfterState;
}
private:
/// The state until/after which the scheduler should not schedule requests
LlmRequestState mNoScheduleUntilState;
LlmRequestState mNoScheduleAfterState;
};
/// @brief Schedule up to maxNumRequests requests
class MaxRequestsScheduler : public BaseCapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
/// to update for this current iteration, and a map of requests to pause
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
/// @details Try reserving resources to advance requests by one step,
/// may pause previously started requests.
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager,
RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
protected:
template <bool StaticBatchScheduling>
[[nodiscard]] std::tuple<RequestVector, RequestVector> impl(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
};
/// @brief Schedule requests using the STATIC_BATCH policy
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
{
public:
StaticBatchScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
};
class CapacityScheduler : public Algorithm
{
public:
constexpr static auto name{"CapacityScheduler"};
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, std::optional<bool> manyMicroBatches = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
/**
* @brief Schedules requests following the selected policy.
*
* @param kvCacheManager Required in MaxUtilizationScheduler (as a ref) and in GuaranteedNoEvictScheduler and
* StaticBatchScheduler (as a const ref).
* @param crossKvCacheManager Optional used in GuaranteedNoEvictScheduler and StaticBatchScheduler.
* @param peftCacheManager Optional used in MaxUtilizationScheduler, GuaranteedNoEvictScheduler and
* StaticBatchScheduler.
* @param activeRequests
* @return std::tuple<RequestVector, RequestVector, RequestVector>, fittingRequests, fittingDisaggInitRequests and
* pausedRequests respectively.
*/
[[nodiscard]] std::tuple<RequestVector, RequestVector, RequestVector> operator()(RequestList const& activeRequests,
OptionalRef<kv_cache_manager::BaseKVCacheManager> kvCacheManager = std::nullopt,
OptionalRef<BasePeftCacheManager const> peftCacheManager = std::nullopt,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager = std::nullopt) const;
private:
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
StaticBatchScheduler>
mScheduler;
};
} // namespace tensorrt_llm::batch_manager