TensorRT-LLMs/tests/unittest/others/test_kv_cache_transceiver.py
QI JUN cc81028547
[TRTLLM-8812][chore] Limit the scope of pybind based CacheTransceiverConfig (#8558)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
2025-10-23 10:32:09 -04:00

204 lines
7.8 KiB
Python

import time
import pytest
import torch
import tensorrt_llm
import tensorrt_llm.bindings
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.distributed import MPIDist
from tensorrt_llm._torch.pyexecutor.kv_cache_transceiver import \
create_kv_cache_transceiver
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
LlmRequestState)
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams
AttentionTypeCpp = tensorrt_llm.bindings.internal.batch_manager.AttentionType
LlmRequestType = tensorrt_llm.bindings.internal.batch_manager.LlmRequestType
DataType = tensorrt_llm.bindings.DataType
def create_kv_cache_manager(mapping, dtype):
return KVCacheManager(
trtllm.KvCacheConfig(
max_tokens=256,
enable_block_reuse=False,
),
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=1,
num_kv_heads=1,
head_dim=1,
tokens_per_block=8,
max_seq_len=256,
max_batch_size=1,
mapping=mapping,
dtype=dtype)
def fill_kv_cache_buffer(kv_cache_manager):
with torch.no_grad():
buffer = kv_cache_manager.get_buffers(0)
random_values = torch.rand(buffer.shape,
dtype=torch.float32,
device=buffer.device)
buffer.copy_(random_values)
@pytest.fixture(scope="function")
def ctx_gen_kv_cache_dtype(request):
if request.param == "ctx_fp8_gen_fp8":
return DataType.FP8, DataType.FP8
elif request.param == "ctx_fp16_gen_fp16":
return DataType.HALF, DataType.HALF
elif request.param == "ctx_bf16_gen_bf16":
return DataType.BF16, DataType.BF16
else:
raise ValueError(f"Invalid config: {request.param}")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"ctx_gen_kv_cache_dtype",
["ctx_fp8_gen_fp8", "ctx_fp16_gen_fp16", "ctx_bf16_gen_bf16"],
ids=["ctx_fp8_gen_fp8", "ctx_fp16_gen_fp16", "ctx_bf16_gen_bf16"],
indirect=True)
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"])
def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
attention_type, backend):
# Init kv_cache manager and cache transceiver
mapping = Mapping(world_size=1, rank=0)
ctx_kv_cache_dtype, gen_kv_cache_dtype = ctx_gen_kv_cache_dtype
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
max_tokens_in_buffer=512)
dist = MPIDist(mapping=mapping)
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)
kv_cache_transceiver_gen = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_gen, attention_type,
cache_transceiver_config)
fill_kv_cache_buffer(kv_cache_manager_ctx)
# init ctx request
sampling_params = SamplingParams()
ctx_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id,
ctx_request.prompt_len, 1,
ctx_request)
# send ctx request
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)
# init gen request
gen_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
context_phase_params=ctx_request.context_phase_params)
kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id,
gen_request.prompt_len, 1,
gen_request)
# send gen request
kv_cache_transceiver_gen.request_and_receive_async(gen_request)
kv_cache_transceiver_ctx.check_context_transfer_status(1)
kv_cache_transceiver_gen.check_gen_transfer_status(1)
assert torch.equal(
kv_cache_manager_gen.get_buffers(0),
kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values"
@pytest.mark.timeout(120)
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
def test_cancel_request_in_transmission(attention_type):
# Init kv_cache manager and cache transceiver
mapping = Mapping(world_size=1, rank=0)
dist = MPIDist(mapping=mapping)
ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
max_tokens_in_buffer=512)
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)
kv_cache_transceiver_gen = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_gen, attention_type,
cache_transceiver_config)
fill_kv_cache_buffer(kv_cache_manager_ctx)
# init ctx request
sampling_params = SamplingParams()
ctx_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id,
ctx_request.prompt_len, 1,
ctx_request)
# send ctx request
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)
# wait for ctx request to be sent
time.sleep(2)
# cancel ctx request
is_cancelled = kv_cache_transceiver_ctx.cancel_request(ctx_request)
assert is_cancelled
# init gen request
gen_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
context_phase_params=ctx_request.context_phase_params)
kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id,
gen_request.prompt_len, 1,
gen_request)
# send gen request
kv_cache_transceiver_gen.request_and_receive_async(gen_request)
# Block the main thread due to the async operation
time.sleep(2)
assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR