mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 12:12:39 +08:00
fix sm check of kv reuse and chunked context
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
parent
886437db3a
commit
b782b6ed68
@ -9,7 +9,7 @@ import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm._utils import get_sm_family, get_sm_version
|
||||
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -282,9 +282,9 @@ def create_py_executor(
|
||||
)
|
||||
|
||||
if executor_config.kv_cache_config.enable_block_reuse and not (
|
||||
get_sm_version() >= 90 and get_sm_version() <= 100):
|
||||
get_sm_version() == 90 or get_sm_family() == 100):
|
||||
logger.warning(
|
||||
f"KV cache reuse for MLA can only be enabled on SM90/SM100, "
|
||||
f"KV cache reuse for MLA can only be enabled on SM90/SM100f, "
|
||||
f"disable enable_block_reuse for SM{get_sm_version()}")
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
|
||||
@ -297,10 +297,10 @@ def create_py_executor(
|
||||
f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}"
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
if executor_config.enable_chunked_context and not (get_sm_version()
|
||||
if executor_config.enable_chunked_context and not (get_sm_family()
|
||||
== 100):
|
||||
logger.warning(
|
||||
"Chunked Prefill for MLA can only be enabled on SM100, "
|
||||
"Chunked Prefill for MLA can only be enabled on SM100f, "
|
||||
f"disable enable_block_reuse for SM{get_sm_version()}")
|
||||
executor_config.enable_chunked_context = False
|
||||
model_engine.attn_runtime_features.chunked_prefill = False
|
||||
|
||||
@ -676,6 +676,14 @@ def get_sm_version():
|
||||
return prop.major * 10 + prop.minor
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_sm_family():
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 100 or sm_version == 103:
|
||||
return 100
|
||||
return sm_version
|
||||
|
||||
|
||||
def is_trace_enabled(env_var: str):
|
||||
value = os.environ.get(env_var, "-1")
|
||||
if value == "ALL":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user