mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
204 lines
7.8 KiB
Python
204 lines
7.8 KiB
Python
import time
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.bindings
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
from tensorrt_llm._torch.distributed import MPIDist
|
|
from tensorrt_llm._torch.pyexecutor.kv_cache_transceiver import \
|
|
create_kv_cache_transceiver
|
|
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
|
|
LlmRequestState)
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
|
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.sampling_params import SamplingParams
|
|
|
|
AttentionTypeCpp = tensorrt_llm.bindings.internal.batch_manager.AttentionType
|
|
LlmRequestType = tensorrt_llm.bindings.internal.batch_manager.LlmRequestType
|
|
DataType = tensorrt_llm.bindings.DataType
|
|
|
|
|
|
def create_kv_cache_manager(mapping, dtype):
|
|
return KVCacheManager(
|
|
trtllm.KvCacheConfig(
|
|
max_tokens=256,
|
|
enable_block_reuse=False,
|
|
),
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers=1,
|
|
num_kv_heads=1,
|
|
head_dim=1,
|
|
tokens_per_block=8,
|
|
max_seq_len=256,
|
|
max_batch_size=1,
|
|
mapping=mapping,
|
|
dtype=dtype)
|
|
|
|
|
|
def fill_kv_cache_buffer(kv_cache_manager):
|
|
with torch.no_grad():
|
|
buffer = kv_cache_manager.get_buffers(0)
|
|
random_values = torch.rand(buffer.shape,
|
|
dtype=torch.float32,
|
|
device=buffer.device)
|
|
buffer.copy_(random_values)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def ctx_gen_kv_cache_dtype(request):
|
|
if request.param == "ctx_fp8_gen_fp8":
|
|
return DataType.FP8, DataType.FP8
|
|
elif request.param == "ctx_fp16_gen_fp16":
|
|
return DataType.HALF, DataType.HALF
|
|
elif request.param == "ctx_bf16_gen_bf16":
|
|
return DataType.BF16, DataType.BF16
|
|
else:
|
|
raise ValueError(f"Invalid config: {request.param}")
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
@pytest.mark.parametrize(
|
|
"ctx_gen_kv_cache_dtype",
|
|
["ctx_fp8_gen_fp8", "ctx_fp16_gen_fp16", "ctx_bf16_gen_bf16"],
|
|
ids=["ctx_fp8_gen_fp8", "ctx_fp16_gen_fp16", "ctx_bf16_gen_bf16"],
|
|
indirect=True)
|
|
@pytest.mark.parametrize("attention_type",
|
|
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
|
|
ids=["mha", "mla"])
|
|
@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"])
|
|
def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
|
|
attention_type, backend):
|
|
# Init kv_cache manager and cache transceiver
|
|
mapping = Mapping(world_size=1, rank=0)
|
|
ctx_kv_cache_dtype, gen_kv_cache_dtype = ctx_gen_kv_cache_dtype
|
|
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
|
|
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
|
|
|
|
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
|
|
max_tokens_in_buffer=512)
|
|
dist = MPIDist(mapping=mapping)
|
|
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
|
|
mapping, dist, kv_cache_manager_ctx, attention_type,
|
|
cache_transceiver_config)
|
|
|
|
kv_cache_transceiver_gen = create_kv_cache_transceiver(
|
|
mapping, dist, kv_cache_manager_gen, attention_type,
|
|
cache_transceiver_config)
|
|
|
|
fill_kv_cache_buffer(kv_cache_manager_ctx)
|
|
|
|
# init ctx request
|
|
sampling_params = SamplingParams()
|
|
ctx_request = LlmRequest(
|
|
request_id=0,
|
|
max_new_tokens=1,
|
|
input_tokens=list(range(256)),
|
|
sampling_config=tensorrt_llm.bindings.SamplingConfig(
|
|
sampling_params._get_sampling_config()),
|
|
is_streaming=False,
|
|
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
|
|
|
|
kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id,
|
|
ctx_request.prompt_len, 1,
|
|
ctx_request)
|
|
# send ctx request
|
|
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)
|
|
|
|
# init gen request
|
|
gen_request = LlmRequest(
|
|
request_id=0,
|
|
max_new_tokens=1,
|
|
input_tokens=list(range(256)),
|
|
sampling_config=tensorrt_llm.bindings.SamplingConfig(
|
|
sampling_params._get_sampling_config()),
|
|
is_streaming=False,
|
|
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
|
|
context_phase_params=ctx_request.context_phase_params)
|
|
|
|
kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id,
|
|
gen_request.prompt_len, 1,
|
|
gen_request)
|
|
# send gen request
|
|
kv_cache_transceiver_gen.request_and_receive_async(gen_request)
|
|
|
|
kv_cache_transceiver_ctx.check_context_transfer_status(1)
|
|
kv_cache_transceiver_gen.check_gen_transfer_status(1)
|
|
|
|
assert torch.equal(
|
|
kv_cache_manager_gen.get_buffers(0),
|
|
kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values"
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
@pytest.mark.parametrize("attention_type",
|
|
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
|
|
ids=["mha", "mla"])
|
|
def test_cancel_request_in_transmission(attention_type):
|
|
# Init kv_cache manager and cache transceiver
|
|
mapping = Mapping(world_size=1, rank=0)
|
|
dist = MPIDist(mapping=mapping)
|
|
ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF
|
|
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
|
|
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
|
|
|
|
cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
|
|
max_tokens_in_buffer=512)
|
|
|
|
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
|
|
mapping, dist, kv_cache_manager_ctx, attention_type,
|
|
cache_transceiver_config)
|
|
|
|
kv_cache_transceiver_gen = create_kv_cache_transceiver(
|
|
mapping, dist, kv_cache_manager_gen, attention_type,
|
|
cache_transceiver_config)
|
|
|
|
fill_kv_cache_buffer(kv_cache_manager_ctx)
|
|
|
|
# init ctx request
|
|
sampling_params = SamplingParams()
|
|
ctx_request = LlmRequest(
|
|
request_id=0,
|
|
max_new_tokens=1,
|
|
input_tokens=list(range(256)),
|
|
sampling_config=tensorrt_llm.bindings.SamplingConfig(
|
|
sampling_params._get_sampling_config()),
|
|
is_streaming=False,
|
|
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
|
|
|
|
kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id,
|
|
ctx_request.prompt_len, 1,
|
|
ctx_request)
|
|
# send ctx request
|
|
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)
|
|
|
|
# wait for ctx request to be sent
|
|
time.sleep(2)
|
|
|
|
# cancel ctx request
|
|
is_cancelled = kv_cache_transceiver_ctx.cancel_request(ctx_request)
|
|
assert is_cancelled
|
|
|
|
# init gen request
|
|
gen_request = LlmRequest(
|
|
request_id=0,
|
|
max_new_tokens=1,
|
|
input_tokens=list(range(256)),
|
|
sampling_config=tensorrt_llm.bindings.SamplingConfig(
|
|
sampling_params._get_sampling_config()),
|
|
is_streaming=False,
|
|
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
|
|
context_phase_params=ctx_request.context_phase_params)
|
|
|
|
kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id,
|
|
gen_request.prompt_len, 1,
|
|
gen_request)
|
|
# send gen request
|
|
kv_cache_transceiver_gen.request_and_receive_async(gen_request)
|
|
|
|
# Block the main thread due to the async operation
|
|
time.sleep(2)
|
|
assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR
|