feat: Add support for disaggregation with pp with pytorch backend (#6369)

Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
Signed-off-by: raayandhar <rdhar@nvidia.com>
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: pcastonguay <55748270+pcastonguay@users.noreply.github.com>
Co-authored-by: raayandhar <rdhar@nvidia.com>
Co-authored-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
pcastonguay 2025-07-30 09:42:13 -04:00 committed by GitHub
parent a2514d93fc
commit e7ae5e2824
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 496 additions and 21 deletions

View File

@ -840,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size()) if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
{ {
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers"); TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
return false; return false;
} }
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();

View File

@ -71,7 +71,10 @@ def clear_folder(folder_path):
if os.path.isdir(item_path) and not os.path.islink(item_path): if os.path.isdir(item_path) and not os.path.islink(item_path):
rmtree(item_path) rmtree(item_path)
else: else:
os.remove(item_path) try:
os.remove(item_path)
except (OSError, IOError) as e:
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
def sysconfig_scheme(override_vars=None): def sysconfig_scheme(override_vars=None):

View File

@ -96,13 +96,13 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
attention_type: AttentionTypeCpp, attention_type: AttentionTypeCpp,
cache_transceiver_config: CacheTransceiverConfig): cache_transceiver_config: CacheTransceiverConfig):
world_config = mapping_to_world_config(mapping) world_config = mapping_to_world_config(mapping)
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
head_dim = kv_cache_manager.head_dim head_dim = kv_cache_manager.head_dim
tokens_per_block = kv_cache_manager.tokens_per_block tokens_per_block = kv_cache_manager.tokens_per_block
dtype = kv_cache_manager.dtype dtype = kv_cache_manager.dtype
self.impl = CacheTransceiverCpp(kv_cache_manager.impl, self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
num_kv_heads_per_layer, head_dim, total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config, dtype, tokens_per_block, world_config, dtype,
attention_type, attention_type,
cache_transceiver_config) cache_transceiver_config)

View File

@ -122,6 +122,7 @@ class BatchState:
@dataclasses.dataclass @dataclasses.dataclass
class BatchStatePP(BatchState): class BatchStatePP(BatchState):
microbatch_id: int = -1 microbatch_id: int = -1
scheduled_ctx_reqs: list[LlmRequest] = None
class PyExecutor: class PyExecutor:
@ -641,6 +642,7 @@ class PyExecutor:
return False return False
def _executor_loop_pp(self): def _executor_loop_pp(self):
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
torch.cuda.set_device(self.device_id) torch.cuda.set_device(self.device_id)
microbatch_id = 0 microbatch_id = 0
with self._profiler() as profile_step: with self._profiler() as profile_step:
@ -654,6 +656,9 @@ class PyExecutor:
if self.should_stop_processing: if self.should_stop_processing:
break break
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
if self.enable_iter_perf_stats: if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats( iter_stats = self._get_init_iter_stats(
len(new_requests), len(new_requests),
@ -662,9 +667,23 @@ class PyExecutor:
self._pad_attention_dp_dummy_request() self._pad_attention_dp_dummy_request()
scheduled_batch, _, _ = self._schedule() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)
self.num_scheduled_requests = scheduled_batch.batch_size self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug( logger.debug(
f'has {len(self.active_requests)} active_request, ' f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'scheduled {len(scheduled_batch.context_requests)} context requests and '
@ -677,7 +696,7 @@ class PyExecutor:
can_queue = 0 not in tp_batch_sizes can_queue = 0 not in tp_batch_sizes
else: else:
can_queue = scheduled_batch.batch_size > 0 can_queue = scheduled_batch.batch_size > 0
if not can_queue: if not can_queue and not self.kv_cache_transceiver:
assert len(self.inflight_req_ids) > 0, ( assert len(self.inflight_req_ids) > 0, (
"fail to schedule any pending request, probably run out of resource" "fail to schedule any pending request, probably run out of resource"
) )
@ -686,8 +705,28 @@ class PyExecutor:
self.micro_batches[microbatch_id] = None self.micro_batches[microbatch_id] = None
else: else:
self._add_inflight_ids(scheduled_batch) self._add_inflight_ids(scheduled_batch)
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)
self.resource_manager.prepare_resources(scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch)
# The generation requests that are do not have batch_idx,
# needs to be in front of the batch due to the assumptions
# made in model_engine.py::_forward_step. This is only important
# for disaggregated serving. For non-disaggregated serving,
# the generation requests always have batch_idx.
scheduled_batch.generation_requests = sorted( # stable sort
scheduled_batch.generation_requests,
key=lambda req: int(req.py_batch_idx is not None),
)
if self.kv_cache_transceiver:
# Return the first token to the client
self._handle_first_token_response(scheduled_batch)
# Stage 1: Async forward (all ranks) and decoding pass (last rank only) # Stage 1: Async forward (all ranks) and decoding pass (last rank only)
if not self.dist.is_last_pp_rank: if not self.dist.is_last_pp_rank:
sample_state = self._forward_step_inter_pp( sample_state = self._forward_step_inter_pp(
@ -715,6 +754,7 @@ class PyExecutor:
iter_start_time=iter_start_time, iter_start_time=iter_start_time,
iter_stats=iter_stats, iter_stats=iter_stats,
microbatch_id=microbatch_id, microbatch_id=microbatch_id,
scheduled_ctx_reqs=scheduled_batch.context_requests,
) )
self.micro_batches[microbatch_id] = batch_state self.micro_batches[microbatch_id] = batch_state
@ -779,6 +819,11 @@ class PyExecutor:
if previous_batch is not None: if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"): with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
self._update_requests(previous_batch.sample_state) self._update_requests(previous_batch.sample_state)
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
self._send_disagg_ctx_cache(
previous_batch.scheduled_ctx_reqs)
self._handle_canceled_requests() self._handle_canceled_requests()
finished_requests = self._handle_responses() finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
@ -787,6 +832,9 @@ class PyExecutor:
self._remove_inflight_ids(previous_scheduled_batch) self._remove_inflight_ids(previous_scheduled_batch)
self.micro_batches[prev_microbatch_id] = None self.micro_batches[prev_microbatch_id] = None
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
# march forward in microbatch slots # march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches microbatch_id = (microbatch_id + 1) % self.num_micro_batches

View File

@ -155,18 +155,33 @@ class KVCacheManager(BaseResourceManager):
(num_kv_heads + tp_size - 1) // tp_size (num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_local_layers) for _ in range(self.num_local_layers)
] ]
self.total_num_kv_heads_per_layer = [
(num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_layers)
]
else: else:
assert len(num_kv_heads) == self.num_layers assert len(num_kv_heads) == self.num_layers
def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
kv_head: Optional[int]):
if kv_head is not None:
num_kv_heads_per_layer.append(
(kv_head + tp_size - 1) // tp_size)
else:
num_kv_heads_per_layer.append(0)
self.num_kv_heads_per_layer = [] self.num_kv_heads_per_layer = []
if self.num_local_layers > 0: if self.num_local_layers > 0:
for i in self.pp_layers: for i in self.pp_layers:
kv_head = num_kv_heads[i] kv_head = num_kv_heads[i]
if kv_head is not None: append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
self.num_kv_heads_per_layer.append( kv_head)
(kv_head + tp_size - 1) // tp_size)
else: self.total_num_kv_heads_per_layer = []
self.num_kv_heads_per_layer.append(0) for i in range(self.num_layers):
kv_head = num_kv_heads[i]
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
kv_head)
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.head_dim = head_dim self.head_dim = head_dim

View File

@ -735,3 +735,14 @@ class LlmapiAccuracyTestHarness:
logger.set_level("info") logger.set_level("info")
yield yield
logger.set_level(original_level) logger.set_level(original_level)
def get_accuracy_task(dataset_name: str):
try:
task_class = globals()[dataset_name]
if issubclass(task_class, AccuracyTask):
return task_class
else:
raise ValueError(f"Unknown dataset: {dataset_name}.")
except KeyError:
raise ValueError(f"Not registered dataset: {dataset_name}.")

View File

@ -20,9 +20,11 @@ from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs from tensorrt_llm.llmapi.llm_args import LlmArgs
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
from ..trt_test_alternative import popen from ..trt_test_alternative import popen
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)
class Result(GenerationResultBase): class Result(GenerationResultBase):
@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
disaggregated_serving_config_path = os.path.join( disaggregated_serving_config_path = os.path.join(
temp_dir.name, "disaggregated_serving_config.yaml") temp_dir.name, "disaggregated_serving_config.yaml")
if tensor_parallel_size > 1:
print(
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
)
with open(disaggregated_serving_config_path, "w") as f: with open(disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f) yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(temp_dir.name, ctx_server_config_path = os.path.join(temp_dir.name,
@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
trtllm_serve_path = "trtllm-serve" trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers # Common arguments for both servers
common_args = [ common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend", trtllm_serve_path,
"pytorch" model_name,
"--host",
"localhost",
"--backend",
"pytorch",
] ]
gen_tp, gen_pp = gen_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
1)
ctx_tp, ctx_pp = ctx_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
1)
if tensor_parallel_size > 1: ctx_total_gpus = ctx_tp * ctx_pp
common_args.append(f"--tp_size={tensor_parallel_size}") gen_total_gpus = gen_tp * gen_pp
env_ctx = os.environ.copy() env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1" env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join( env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
map(str, range(tensor_parallel_size)))
env_gen = os.environ.copy() env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join( env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size))) map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
ctx_server_args = common_args + [ ctx_server_args = common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path "--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
] ]
gen_server_args = common_args + [ gen_server_args = common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path "--port", "8002", "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
] ]
if "max_num_tokens" in ctx_server_config: if "max_num_tokens" in ctx_server_config:
ctx_server_args.append( ctx_server_args.append(
@ -182,6 +203,56 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
disaggregated_server.wait() disaggregated_server.wait()
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
ctx_tp: int, gen_pp: int, gen_tp: int,
test_set: LlmapiAccuracyTestHarness):
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
pytest.fail(
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
)
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
"tensor_parallel_size": ctx_tp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
model_path) as llm:
task = test_set(model_name)
task.evaluate(llm)
@pytest.mark.timeout(3600) @pytest.mark.timeout(3600)
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@ -315,6 +386,20 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME) task = GSM8K(self.MODEL_NAME)
task.evaluate(llm) task.evaluate(llm)
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_tp_pp_symmetric(self, tp, pp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, get_accuracy_task(testset))
@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, get_accuracy_task(testset))
@pytest.mark.skip_less_device_memory(140000) @pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600) @pytest.mark.timeout(3600)

View File

@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 2
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"

View File

@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 2
pipeline_parallel_size: 1
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"

View File

@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 4
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 4
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"

View File

@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 2
pipeline_parallel_size: 1
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 2
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"

View File

@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 2
pipeline_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 2
pipeline_parallel_size: 2
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"

View File

@ -59,6 +59,16 @@ def get_test_config(test_desc, example_dir, test_root):
"conditional": (2, "conditional": (2,
f"{test_configs_root}/disagg_config_conditional.yaml"), f"{test_configs_root}/disagg_config_conditional.yaml"),
"ngram": (2, f"{test_configs_root}/disagg_config_ngram.yaml"), "ngram": (2, f"{test_configs_root}/disagg_config_ngram.yaml"),
"ctxpp2_genpp2":
(4, f"{test_configs_root}/disagg_config_ctxpp2_genpp2.yaml"),
"ctxtp2_genpp2":
(4, f"{test_configs_root}/disagg_config_ctxtp2_genpp2.yaml"),
"ctxpp2_gentp2":
(4, f"{test_configs_root}/disagg_config_ctxpp2_gentp2.yaml"),
"ctxtp2pp2_gentp2pp2":
(8, f"{test_configs_root}/disagg_config_ctxtp2pp2_gentp2pp2.yaml"),
"ctxpp4_genpp4":
(8, f"{test_configs_root}/disagg_config_ctxpp4_genpp4.yaml"),
"deepseek_v3_lite_fp8_mpi": "deepseek_v3_lite_fp8_mpi":
(4, (4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml" f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml"
@ -540,6 +550,108 @@ def test_disaggregated_ngram(disaggregated_test_root, llm_venv,
cwd=llm_venv.get_working_directory()) cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ctxpp2_genpp2(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ctxpp2_genpp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ctxtp2_genpp2(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ctxtp2_genpp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ctxpp2_gentp2(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ctxpp2_gentp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ctxtp2pp2_gentp2pp2(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
pytest.skip(f"8 GPU test times out currently, skipping")
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ctxtp2pp2_gentp2pp2",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ctxpp4_genpp4(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
pytest.skip(f"8 GPU test times out currently, skipping")
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ctxpp4_genpp4",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@skip_no_hopper @skip_no_hopper
@pytest.mark.skip_less_device(4) @pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],

View File

@ -30,12 +30,23 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp1]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
- test_e2e.py::test_ptp_quickstart_advanced_bs1 - test_e2e.py::test_ptp_quickstart_advanced_bs1
- condition: - condition:
ranges: ranges:

View File

@ -23,6 +23,14 @@ l0_dgx_h200:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2pp2_gentp2pp2[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_genpp4[TinyLlama-1.1B-Chat-v1.0]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] - unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] - unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora - unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora