mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: Add support for disaggregation with pp with pytorch backend (#6369)
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Signed-off-by: raayandhar <rdhar@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Signed-off-by: pcastonguay <55748270+pcastonguay@users.noreply.github.com> Co-authored-by: raayandhar <rdhar@nvidia.com> Co-authored-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
a2514d93fc
commit
e7ae5e2824
@ -840,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
|
||||
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
|
||||
{
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
|
||||
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
|
||||
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
|
||||
return false;
|
||||
}
|
||||
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
|
||||
|
||||
@ -71,7 +71,10 @@ def clear_folder(folder_path):
|
||||
if os.path.isdir(item_path) and not os.path.islink(item_path):
|
||||
rmtree(item_path)
|
||||
else:
|
||||
os.remove(item_path)
|
||||
try:
|
||||
os.remove(item_path)
|
||||
except (OSError, IOError) as e:
|
||||
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
def sysconfig_scheme(override_vars=None):
|
||||
|
||||
@ -96,13 +96,13 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
|
||||
attention_type: AttentionTypeCpp,
|
||||
cache_transceiver_config: CacheTransceiverConfig):
|
||||
world_config = mapping_to_world_config(mapping)
|
||||
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
|
||||
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
|
||||
head_dim = kv_cache_manager.head_dim
|
||||
tokens_per_block = kv_cache_manager.tokens_per_block
|
||||
dtype = kv_cache_manager.dtype
|
||||
|
||||
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
|
||||
num_kv_heads_per_layer, head_dim,
|
||||
total_num_kv_heads_per_layer, head_dim,
|
||||
tokens_per_block, world_config, dtype,
|
||||
attention_type,
|
||||
cache_transceiver_config)
|
||||
|
||||
@ -122,6 +122,7 @@ class BatchState:
|
||||
@dataclasses.dataclass
|
||||
class BatchStatePP(BatchState):
|
||||
microbatch_id: int = -1
|
||||
scheduled_ctx_reqs: list[LlmRequest] = None
|
||||
|
||||
|
||||
class PyExecutor:
|
||||
@ -641,6 +642,7 @@ class PyExecutor:
|
||||
return False
|
||||
|
||||
def _executor_loop_pp(self):
|
||||
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
|
||||
torch.cuda.set_device(self.device_id)
|
||||
microbatch_id = 0
|
||||
with self._profiler() as profile_step:
|
||||
@ -654,6 +656,9 @@ class PyExecutor:
|
||||
if self.should_stop_processing:
|
||||
break
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_gen_transfer_status()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats = self._get_init_iter_stats(
|
||||
len(new_requests),
|
||||
@ -662,9 +667,23 @@ class PyExecutor:
|
||||
|
||||
self._pad_attention_dp_dummy_request()
|
||||
|
||||
scheduled_batch, _, _ = self._schedule()
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
|
||||
self._prepare_disagg_gen_init(
|
||||
fitting_disagg_gen_init_requests)
|
||||
|
||||
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
|
||||
logger.warning(
|
||||
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
|
||||
)
|
||||
self.kv_cache_transceiver.check_context_transfer_status(
|
||||
1)
|
||||
|
||||
self.num_scheduled_requests = scheduled_batch.batch_size
|
||||
|
||||
logger.debug(
|
||||
f'has {len(self.active_requests)} active_request, '
|
||||
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
|
||||
@ -677,7 +696,7 @@ class PyExecutor:
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
if not can_queue:
|
||||
if not can_queue and not self.kv_cache_transceiver:
|
||||
assert len(self.inflight_req_ids) > 0, (
|
||||
"fail to schedule any pending request, probably run out of resource"
|
||||
)
|
||||
@ -686,8 +705,28 @@ class PyExecutor:
|
||||
self.micro_batches[microbatch_id] = None
|
||||
else:
|
||||
self._add_inflight_ids(scheduled_batch)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
scheduled_batch)
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
# The generation requests that are do not have batch_idx,
|
||||
# needs to be in front of the batch due to the assumptions
|
||||
# made in model_engine.py::_forward_step. This is only important
|
||||
# for disaggregated serving. For non-disaggregated serving,
|
||||
# the generation requests always have batch_idx.
|
||||
scheduled_batch.generation_requests = sorted( # stable sort
|
||||
scheduled_batch.generation_requests,
|
||||
key=lambda req: int(req.py_batch_idx is not None),
|
||||
)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# Return the first token to the client
|
||||
self._handle_first_token_response(scheduled_batch)
|
||||
|
||||
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
|
||||
if not self.dist.is_last_pp_rank:
|
||||
sample_state = self._forward_step_inter_pp(
|
||||
@ -715,6 +754,7 @@ class PyExecutor:
|
||||
iter_start_time=iter_start_time,
|
||||
iter_stats=iter_stats,
|
||||
microbatch_id=microbatch_id,
|
||||
scheduled_ctx_reqs=scheduled_batch.context_requests,
|
||||
)
|
||||
|
||||
self.micro_batches[microbatch_id] = batch_state
|
||||
@ -779,6 +819,11 @@ class PyExecutor:
|
||||
if previous_batch is not None:
|
||||
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
|
||||
self._update_requests(previous_batch.sample_state)
|
||||
|
||||
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
|
||||
self._send_disagg_ctx_cache(
|
||||
previous_batch.scheduled_ctx_reqs)
|
||||
|
||||
self._handle_canceled_requests()
|
||||
finished_requests = self._handle_responses()
|
||||
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
|
||||
@ -787,6 +832,9 @@ class PyExecutor:
|
||||
self._remove_inflight_ids(previous_scheduled_batch)
|
||||
self.micro_batches[prev_microbatch_id] = None
|
||||
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
# march forward in microbatch slots
|
||||
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
|
||||
|
||||
|
||||
@ -155,18 +155,33 @@ class KVCacheManager(BaseResourceManager):
|
||||
(num_kv_heads + tp_size - 1) // tp_size
|
||||
for _ in range(self.num_local_layers)
|
||||
]
|
||||
self.total_num_kv_heads_per_layer = [
|
||||
(num_kv_heads + tp_size - 1) // tp_size
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
else:
|
||||
assert len(num_kv_heads) == self.num_layers
|
||||
|
||||
def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
|
||||
kv_head: Optional[int]):
|
||||
if kv_head is not None:
|
||||
num_kv_heads_per_layer.append(
|
||||
(kv_head + tp_size - 1) // tp_size)
|
||||
else:
|
||||
num_kv_heads_per_layer.append(0)
|
||||
|
||||
self.num_kv_heads_per_layer = []
|
||||
if self.num_local_layers > 0:
|
||||
for i in self.pp_layers:
|
||||
kv_head = num_kv_heads[i]
|
||||
if kv_head is not None:
|
||||
self.num_kv_heads_per_layer.append(
|
||||
(kv_head + tp_size - 1) // tp_size)
|
||||
else:
|
||||
self.num_kv_heads_per_layer.append(0)
|
||||
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
|
||||
kv_head)
|
||||
|
||||
self.total_num_kv_heads_per_layer = []
|
||||
for i in range(self.num_layers):
|
||||
kv_head = num_kv_heads[i]
|
||||
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
|
||||
kv_head)
|
||||
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
@ -735,3 +735,14 @@ class LlmapiAccuracyTestHarness:
|
||||
logger.set_level("info")
|
||||
yield
|
||||
logger.set_level(original_level)
|
||||
|
||||
|
||||
def get_accuracy_task(dataset_name: str):
|
||||
try:
|
||||
task_class = globals()[dataset_name]
|
||||
if issubclass(task_class, AccuracyTask):
|
||||
return task_class
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}.")
|
||||
except KeyError:
|
||||
raise ValueError(f"Not registered dataset: {dataset_name}.")
|
||||
|
||||
@ -20,9 +20,11 @@ from tensorrt_llm.executor.result import GenerationResultBase
|
||||
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_args import LlmArgs
|
||||
|
||||
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
|
||||
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
|
||||
skip_pre_hopper)
|
||||
from ..trt_test_alternative import popen
|
||||
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
|
||||
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
|
||||
get_accuracy_task)
|
||||
|
||||
|
||||
class Result(GenerationResultBase):
|
||||
@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
disaggregated_serving_config_path = os.path.join(
|
||||
temp_dir.name, "disaggregated_serving_config.yaml")
|
||||
|
||||
if tensor_parallel_size > 1:
|
||||
print(
|
||||
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
|
||||
)
|
||||
|
||||
with open(disaggregated_serving_config_path, "w") as f:
|
||||
yaml.dump(disaggregated_server_config, f)
|
||||
ctx_server_config_path = os.path.join(temp_dir.name,
|
||||
@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
|
||||
trtllm_serve_path = "trtllm-serve"
|
||||
# Common arguments for both servers
|
||||
common_args = [
|
||||
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
|
||||
"pytorch"
|
||||
trtllm_serve_path,
|
||||
model_name,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--backend",
|
||||
"pytorch",
|
||||
]
|
||||
gen_tp, gen_pp = gen_server_config.get(
|
||||
"tensor_parallel_size",
|
||||
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
|
||||
1)
|
||||
ctx_tp, ctx_pp = ctx_server_config.get(
|
||||
"tensor_parallel_size",
|
||||
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
|
||||
1)
|
||||
|
||||
if tensor_parallel_size > 1:
|
||||
common_args.append(f"--tp_size={tensor_parallel_size}")
|
||||
ctx_total_gpus = ctx_tp * ctx_pp
|
||||
gen_total_gpus = gen_tp * gen_pp
|
||||
|
||||
env_ctx = os.environ.copy()
|
||||
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, range(tensor_parallel_size)))
|
||||
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
|
||||
|
||||
env_gen = os.environ.copy()
|
||||
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
|
||||
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
|
||||
ctx_server_args = common_args + [
|
||||
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
|
||||
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
|
||||
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
|
||||
]
|
||||
gen_server_args = common_args + [
|
||||
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
|
||||
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
|
||||
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
|
||||
]
|
||||
if "max_num_tokens" in ctx_server_config:
|
||||
ctx_server_args.append(
|
||||
@ -182,6 +203,56 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
|
||||
disaggregated_server.wait()
|
||||
|
||||
|
||||
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
|
||||
ctx_tp: int, gen_pp: int, gen_tp: int,
|
||||
test_set: LlmapiAccuracyTestHarness):
|
||||
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
|
||||
pytest.fail(
|
||||
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
|
||||
)
|
||||
|
||||
kv_cache_config = {
|
||||
"free_gpu_memory_fraction": 0.5,
|
||||
"enable_block_reuse": False
|
||||
}
|
||||
ctx_server_config = {
|
||||
"pipeline_parallel_size": ctx_pp,
|
||||
"tensor_parallel_size": ctx_tp,
|
||||
"disable_overlap_scheduler": True,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"tensor_parallel_size": gen_tp,
|
||||
"pipeline_parallel_size": gen_pp,
|
||||
"disable_overlap_scheduler": True,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "default"
|
||||
}
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
"backend": "pytorch",
|
||||
"context_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8001"]
|
||||
},
|
||||
"generation_servers": {
|
||||
"num_instances": 1,
|
||||
"urls": ["localhost:8002"]
|
||||
}
|
||||
}
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
model_path) as llm:
|
||||
task = test_set(model_name)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
@pytest.mark.timeout(3600)
|
||||
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@ -315,6 +386,20 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
|
||||
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
def test_tp_pp_symmetric(self, tp, pp, testset):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
|
||||
tp, get_accuracy_task(testset))
|
||||
|
||||
@parametrize_with_ids("ctx_pp", [2, 4])
|
||||
@parametrize_with_ids("gen_tp", [1, 2])
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
|
||||
gen_tp, get_accuracy_task(testset))
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.timeout(3600)
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 2
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 2
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -0,0 +1,36 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 2
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -0,0 +1,36 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 4
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 4
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -0,0 +1,36 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 2
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -0,0 +1,36 @@
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
free_gpu_memory_fraction: 0.2
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 3000
|
||||
max_seq_len: 4096
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 2
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 2
|
||||
max_batch_size: 256
|
||||
max_num_tokens: 4096
|
||||
max_seq_len: 4096
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
backend: default
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
@ -59,6 +59,16 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
"conditional": (2,
|
||||
f"{test_configs_root}/disagg_config_conditional.yaml"),
|
||||
"ngram": (2, f"{test_configs_root}/disagg_config_ngram.yaml"),
|
||||
"ctxpp2_genpp2":
|
||||
(4, f"{test_configs_root}/disagg_config_ctxpp2_genpp2.yaml"),
|
||||
"ctxtp2_genpp2":
|
||||
(4, f"{test_configs_root}/disagg_config_ctxtp2_genpp2.yaml"),
|
||||
"ctxpp2_gentp2":
|
||||
(4, f"{test_configs_root}/disagg_config_ctxpp2_gentp2.yaml"),
|
||||
"ctxtp2pp2_gentp2pp2":
|
||||
(8, f"{test_configs_root}/disagg_config_ctxtp2pp2_gentp2pp2.yaml"),
|
||||
"ctxpp4_genpp4":
|
||||
(8, f"{test_configs_root}/disagg_config_ctxpp4_genpp4.yaml"),
|
||||
"deepseek_v3_lite_fp8_mpi":
|
||||
(4,
|
||||
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml"
|
||||
@ -540,6 +550,108 @@ def test_disaggregated_ngram(disaggregated_test_root, llm_venv,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_ctxpp2_genpp2(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp2_genpp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_ctxtp2_genpp2(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxtp2_genpp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_ctxpp2_gentp2(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp2_gentp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_ctxtp2pp2_gentp2pp2(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
pytest.skip(f"8 GPU test times out currently, skipping")
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxtp2pp2_gentp2pp2",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_ctxpp4_genpp4(disaggregated_test_root, llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
pytest.skip(f"8 GPU test times out currently, skipping")
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"ctxpp4_genpp4",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@skip_no_hopper
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
|
||||
|
||||
@ -30,12 +30,23 @@ l0_dgx_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
|
||||
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp1]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_bs1
|
||||
- condition:
|
||||
ranges:
|
||||
|
||||
@ -23,6 +23,14 @@ l0_dgx_h200:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2pp2_gentp2pp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_genpp4[TinyLlama-1.1B-Chat-v1.0]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
|
||||
|
||||
Loading…
Reference in New Issue
Block a user