[#8982][feat] AutoDeploy attention dp support (#10728)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2026-01-26 09:43:33 -05:00 committed by GitHub
parent ce556290c9
commit 00f341be49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 28 deletions

View File

@ -351,13 +351,14 @@ def maybe_pad_for_cuda_graph(func):
def _call_func():
return func(self, scheduled_requests, resource_manager, *args, **kwargs)
# check if we use cuda graph and we can run it
if not (self.cuda_graph_used and scheduled_requests.can_run_cuda_graph):
return _call_func()
# check conditions for current rank
can_run_cuda_graph = self.cuda_graph_used and scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
# generate a persistent dummy request right away to ensure we can reserve the necessary
# resources (kv page and slot)
if self.padding_dummy_request is None:
# resources (kv page and slot) the first time we can actually run cuda graph according to
# this rank
if can_run_cuda_graph and self.padding_dummy_request is None:
self.padding_dummy_request = _generate_dummy_request(
resource_manager,
request_id=CUDA_GRAPH_DUMMY_REQUEST_ID,
@ -367,20 +368,48 @@ def maybe_pad_for_cuda_graph(func):
max_beam_width=self.max_beam_width,
)
# check closest cuda graph batch size
closest_cg_bs = _round_up_to_closest(
self.cuda_graph_batch_sizes, scheduled_requests.batch_size
)
# check if we can pad the batch based on the availability of the dummy request
can_pad = self.padding_dummy_request is not None
# check if we need to pad
num_padding = closest_cg_bs - scheduled_requests.batch_size
# in attention DP mode, we check all ranks
if self.enable_attention_dp and self.mapping.tp_size > 1:
assert self.dist is not None, "Distributed object is required for attention DP mode"
all_rank_info = self.dist.tp_allgather([can_run_cuda_graph, can_pad, batch_size])
else:
all_rank_info = [[can_run_cuda_graph, can_pad, batch_size]]
if num_padding <= 0:
# now let's check if we can in principle run cuda graph across all ranks
can_run_cuda_graph_all = all(r_info[0] for r_info in all_rank_info)
if not can_run_cuda_graph_all:
return _call_func()
# check if we have a dummy request to use
if self.padding_dummy_request is None:
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
# get closest cudagraph batch size based on max_batch_size across ALL ranks
# NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
# same closest cudagraph batch size here based on the max batch size across all ranks
max_batch_size = max(r_info[2] for r_info in all_rank_info)
cg_batch_size = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size)
if cg_batch_size is None:
return _call_func()
# let's check if all ranks can pad the batch if they need to
can_pad_all = all(r_info[1] or (r_info[2] == cg_batch_size) for r_info in all_rank_info)
# fall back if we cannot run cudagraph due to padding issues
if not can_pad_all:
return _call_func()
# check actual amount of padding needed
num_padding = cg_batch_size - batch_size
# we should only hit this point for either of these conditions
assert num_padding == 0 or (num_padding > 0 and self.padding_dummy_request is not None), (
"Padding should not be needed or available at this point"
)
# no padding needed on current rank
if num_padding == 0:
return _call_func()
# pad the scheduled requests with the dummy request
@ -411,7 +440,12 @@ class ADEngine(ModelEngine):
return self.cache_seq_interface.device
@classmethod
def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None):
def build_from_config(
cls,
ad_config: LlmArgs,
mapping: Optional[Mapping] = None,
dist: Optional[Distributed] = None,
):
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
max_batch_size = ad_config.max_batch_size
@ -453,6 +487,7 @@ class ADEngine(ModelEngine):
device,
ad_config=ad_config,
mapping=mapping,
dist=dist,
reporting_info=reporting_info,
)
@ -464,6 +499,7 @@ class ADEngine(ModelEngine):
device: DeviceLikeType,
ad_config: Optional[LlmArgs] = None,
mapping: Optional[Mapping] = None,
dist: Optional[Distributed] = None,
reporting_info: ReportingInfo = ReportingInfo(),
) -> None:
"""Initialize the engine with model and sequence information."""
@ -484,7 +520,7 @@ class ADEngine(ModelEngine):
self.iter_states = {}
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
self.enable_attention_dp = False
self.enable_attention_dp = mapping.enable_attention_dp if mapping else False
if ad_config is not None:
self.max_beam_width = ad_config.max_beam_width
@ -537,6 +573,7 @@ class ADEngine(ModelEngine):
# Reuse _execute_logit_post_processors from PyTorchModelEngine
self.mapping = mapping
self.dist = dist
self._execute_logit_post_processors = types.MethodType(
PyTorchModelEngine._execute_logit_post_processors, self
)
@ -1005,13 +1042,23 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
# initialize process groups
world_size = mpi_world_size()
rank = mpi_rank()
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
enable_attention_dp = ad_config.transforms.get("detect_sharding", {}).get(
"enable_attention_dp", False
)
dist_mapping = Mapping(
rank=rank,
world_size=world_size,
tp_size=world_size,
enable_attention_dp=enable_attention_dp,
)
dist = Distributed.get(dist_mapping)
ad_logger.set_rank(rank)
torch.cuda.set_device(rank)
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
initialize_or_skip(rank, world_size, port)
ad_logger.info(f"{dist_mapping=}, {dist=}, {port=}")
# Setup AutoTuner with distributed state for allreduce autotuning
AutoTuner.get().setup_distributed_state(dist_mapping)
@ -1030,7 +1077,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
)
# initialize model engine
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping, dist=dist)
spec_config = ad_config.speculative_config
if spec_config is not None and not (

View File

@ -151,6 +151,11 @@ class ShardingTransformConfig(TransformConfig):
process_grid: Dict[ShardingDim, int] = Field(default_factory=dict)
enable_attention_dp: bool = Field(
default=False,
description="When True, skip TP sharding as attention data parallelism is enabled.",
)
def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool:
init_process_grid_from_config(self)
if sources is None:
@ -738,8 +743,9 @@ class Sharding(BaseTransform):
f"Using allreduce strategy: {config.allreduce_strategy.name}, dist backend: {config.dist_backend}"
)
if world_size < 2:
ad_logger.info("Skipping sharding for single device")
if world_size < 2 or config.enable_attention_dp:
reason = "single device" if world_size < 2 else "attention DP enabled"
ad_logger.info(f"Skipping sharding: {reason}")
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

View File

@ -81,6 +81,24 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("world_size", [2, 4])
def test_attention_dp(self, world_size):
"""Test attention data parallelism mode where TP sharding is disabled."""
kwargs = self.get_default_kwargs(enable_chunked_prefill=True)
# Enable attention DP - this disables TP sharding
kwargs["transforms"]["detect_sharding"] = {"enable_attention_dp": True}
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
world_size=world_size,
**kwargs) as llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
class TestNemotronH(LlmapiAccuracyTestHarness):
MODEL_NAME = "nvidia/Nemotron-H-8B-Base-8K"

View File

@ -333,3 +333,4 @@ l0_dgx_h100:
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8]
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_attention_dp[4]

View File

@ -1,4 +1,3 @@
from types import SimpleNamespace
from typing import List, Optional, Type
import pytest
@ -9,6 +8,7 @@ from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
class TransformerLikeModelwithFakeCachePool(nn.Module):
@ -196,19 +196,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
# No-chunk: whole prompt in one request
req_full = _DummyRequest(tokens=tokens, begin=0, size=len(tokens), seq_slot=0)
scheduled_full = SimpleNamespace(context_requests=[req_full], generation_requests=[])
logits_full_last = engine.forward(scheduled_full, resource_manager)["logits"][-1]
scheduled_requests = ScheduledRequests()
scheduled_requests.context_requests.append(req_full)
logits_full_last = engine.forward(scheduled_requests, resource_manager)["logits"][-1]
# Chunked: split into two context chunks
split = len(tokens) // 2
req_part1 = _DummyRequest(tokens=tokens, begin=0, size=split, seq_slot=0)
req_part2 = _DummyRequest(tokens=tokens, begin=split, size=len(tokens) - split, seq_slot=0)
scheduled_part1 = SimpleNamespace(context_requests=[req_part1], generation_requests=[])
scheduled_part2 = SimpleNamespace(context_requests=[req_part2], generation_requests=[])
scheduled_requests_part1 = ScheduledRequests()
scheduled_requests_part1.context_requests.append(req_part1)
scheduled_requests_part2 = ScheduledRequests()
scheduled_requests_part2.context_requests.append(req_part2)
# Run first chunk (ignored output), then compare second chunk logits to full
_ = engine.forward(scheduled_part1, resource_manager)
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"][-1]
_ = engine.forward(scheduled_requests_part1, resource_manager)
logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1]
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)