From b782b6ed68efc1b24d018e94e44bf9476cb42c86 Mon Sep 17 00:00:00 2001 From: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com> Date: Wed, 6 Aug 2025 14:22:50 +0800 Subject: [PATCH] fix sm check of kv reuse and chunked context Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py | 10 +++++----- tensorrt_llm/_utils.py | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bcd006be71..a0194bc6db 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -9,7 +9,7 @@ import torch import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType -from tensorrt_llm._utils import get_sm_version +from tensorrt_llm._utils import get_sm_family, get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.logger import logger @@ -282,9 +282,9 @@ def create_py_executor( ) if executor_config.kv_cache_config.enable_block_reuse and not ( - get_sm_version() >= 90 and get_sm_version() <= 100): + get_sm_version() == 90 or get_sm_family() == 100): logger.warning( - f"KV cache reuse for MLA can only be enabled on SM90/SM100, " + f"KV cache reuse for MLA can only be enabled on SM90/SM100f, " f"disable enable_block_reuse for SM{get_sm_version()}") executor_config.kv_cache_config.enable_block_reuse = False @@ -297,10 +297,10 @@ def create_py_executor( f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}" ) executor_config.kv_cache_config.enable_block_reuse = False - if executor_config.enable_chunked_context and not (get_sm_version() + if executor_config.enable_chunked_context and not (get_sm_family() == 100): logger.warning( - "Chunked Prefill for MLA can only be enabled on SM100, " + "Chunked Prefill for MLA can only be enabled on SM100f, " f"disable enable_block_reuse for SM{get_sm_version()}") executor_config.enable_chunked_context = False model_engine.attn_runtime_features.chunked_prefill = False diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 75be272791..0098983c6f 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -676,6 +676,14 @@ def get_sm_version(): return prop.major * 10 + prop.minor +@lru_cache(maxsize=1) +def get_sm_family(): + sm_version = get_sm_version() + if sm_version == 100 or sm_version == 103: + return 100 + return sm_version + + def is_trace_enabled(env_var: str): value = os.environ.get(env_var, "-1") if value == "ALL":