fix sm check of kv reuse and chunked context

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-06 14:22:50 +08:00
parent 886437db3a
commit b782b6ed68
2 changed files with 13 additions and 5 deletions

View File

@ -9,7 +9,7 @@ import torch
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm._utils import get_sm_family, get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.logger import logger
@ -282,9 +282,9 @@ def create_py_executor(
)
if executor_config.kv_cache_config.enable_block_reuse and not (
get_sm_version() >= 90 and get_sm_version() <= 100):
get_sm_version() == 90 or get_sm_family() == 100):
logger.warning(
f"KV cache reuse for MLA can only be enabled on SM90/SM100, "
f"KV cache reuse for MLA can only be enabled on SM90/SM100f, "
f"disable enable_block_reuse for SM{get_sm_version()}")
executor_config.kv_cache_config.enable_block_reuse = False
@ -297,10 +297,10 @@ def create_py_executor(
f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}"
)
executor_config.kv_cache_config.enable_block_reuse = False
if executor_config.enable_chunked_context and not (get_sm_version()
if executor_config.enable_chunked_context and not (get_sm_family()
== 100):
logger.warning(
"Chunked Prefill for MLA can only be enabled on SM100, "
"Chunked Prefill for MLA can only be enabled on SM100f, "
f"disable enable_block_reuse for SM{get_sm_version()}")
executor_config.enable_chunked_context = False
model_engine.attn_runtime_features.chunked_prefill = False

View File

@ -676,6 +676,14 @@ def get_sm_version():
return prop.major * 10 + prop.minor
@lru_cache(maxsize=1)
def get_sm_family():
sm_version = get_sm_version()
if sm_version == 100 or sm_version == 103:
return 100
return sm_version
def is_trace_enabled(env_var: str):
value = os.environ.get(env_var, "-1")
if value == "ALL":