diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 37a2b21ff9..5ee79f9572 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -351,13 +351,14 @@ def maybe_pad_for_cuda_graph(func): def _call_func(): return func(self, scheduled_requests, resource_manager, *args, **kwargs) - # check if we use cuda graph and we can run it - if not (self.cuda_graph_used and scheduled_requests.can_run_cuda_graph): - return _call_func() + # check conditions for current rank + can_run_cuda_graph = self.cuda_graph_used and scheduled_requests.can_run_cuda_graph + batch_size = scheduled_requests.batch_size # generate a persistent dummy request right away to ensure we can reserve the necessary - # resources (kv page and slot) - if self.padding_dummy_request is None: + # resources (kv page and slot) the first time we can actually run cuda graph according to + # this rank + if can_run_cuda_graph and self.padding_dummy_request is None: self.padding_dummy_request = _generate_dummy_request( resource_manager, request_id=CUDA_GRAPH_DUMMY_REQUEST_ID, @@ -367,20 +368,48 @@ def maybe_pad_for_cuda_graph(func): max_beam_width=self.max_beam_width, ) - # check closest cuda graph batch size - closest_cg_bs = _round_up_to_closest( - self.cuda_graph_batch_sizes, scheduled_requests.batch_size - ) + # check if we can pad the batch based on the availability of the dummy request + can_pad = self.padding_dummy_request is not None - # check if we need to pad - num_padding = closest_cg_bs - scheduled_requests.batch_size + # in attention DP mode, we check all ranks + if self.enable_attention_dp and self.mapping.tp_size > 1: + assert self.dist is not None, "Distributed object is required for attention DP mode" + all_rank_info = self.dist.tp_allgather([can_run_cuda_graph, can_pad, batch_size]) + else: + all_rank_info = [[can_run_cuda_graph, can_pad, batch_size]] - if num_padding <= 0: + # now let's check if we can in principle run cuda graph across all ranks + can_run_cuda_graph_all = all(r_info[0] for r_info in all_rank_info) + + if not can_run_cuda_graph_all: return _call_func() - # check if we have a dummy request to use - if self.padding_dummy_request is None: - ad_logger.info("No CUDA graph padding possible due to missing dummy request.") + # get closest cudagraph batch size based on max_batch_size across ALL ranks + # NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the + # same closest cudagraph batch size here based on the max batch size across all ranks + max_batch_size = max(r_info[2] for r_info in all_rank_info) + cg_batch_size = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size) + + if cg_batch_size is None: + return _call_func() + + # let's check if all ranks can pad the batch if they need to + can_pad_all = all(r_info[1] or (r_info[2] == cg_batch_size) for r_info in all_rank_info) + + # fall back if we cannot run cudagraph due to padding issues + if not can_pad_all: + return _call_func() + + # check actual amount of padding needed + num_padding = cg_batch_size - batch_size + + # we should only hit this point for either of these conditions + assert num_padding == 0 or (num_padding > 0 and self.padding_dummy_request is not None), ( + "Padding should not be needed or available at this point" + ) + + # no padding needed on current rank + if num_padding == 0: return _call_func() # pad the scheduled requests with the dummy request @@ -411,7 +440,12 @@ class ADEngine(ModelEngine): return self.cache_seq_interface.device @classmethod - def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None): + def build_from_config( + cls, + ad_config: LlmArgs, + mapping: Optional[Mapping] = None, + dist: Optional[Distributed] = None, + ): """Build the ADEngine using the LlmArgs that gets passed through from the LLM.""" max_batch_size = ad_config.max_batch_size @@ -453,6 +487,7 @@ class ADEngine(ModelEngine): device, ad_config=ad_config, mapping=mapping, + dist=dist, reporting_info=reporting_info, ) @@ -464,6 +499,7 @@ class ADEngine(ModelEngine): device: DeviceLikeType, ad_config: Optional[LlmArgs] = None, mapping: Optional[Mapping] = None, + dist: Optional[Distributed] = None, reporting_info: ReportingInfo = ReportingInfo(), ) -> None: """Initialize the engine with model and sequence information.""" @@ -484,7 +520,7 @@ class ADEngine(ModelEngine): self.iter_states = {} # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... - self.enable_attention_dp = False + self.enable_attention_dp = mapping.enable_attention_dp if mapping else False if ad_config is not None: self.max_beam_width = ad_config.max_beam_width @@ -537,6 +573,7 @@ class ADEngine(ModelEngine): # Reuse _execute_logit_post_processors from PyTorchModelEngine self.mapping = mapping + self.dist = dist self._execute_logit_post_processors = types.MethodType( PyTorchModelEngine._execute_logit_post_processors, self ) @@ -1005,13 +1042,23 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer # initialize process groups world_size = mpi_world_size() rank = mpi_rank() - dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size) + enable_attention_dp = ad_config.transforms.get("detect_sharding", {}).get( + "enable_attention_dp", False + ) + dist_mapping = Mapping( + rank=rank, + world_size=world_size, + tp_size=world_size, + enable_attention_dp=enable_attention_dp, + ) dist = Distributed.get(dist_mapping) ad_logger.set_rank(rank) torch.cuda.set_device(rank) port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port initialize_or_skip(rank, world_size, port) + ad_logger.info(f"{dist_mapping=}, {dist=}, {port=}") + # Setup AutoTuner with distributed state for allreduce autotuning AutoTuner.get().setup_distributed_state(dist_mapping) @@ -1030,7 +1077,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer ) # initialize model engine - engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping) + engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping, dist=dist) spec_config = ad_config.speculative_config if spec_config is not None and not ( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 2df3ed61b4..e14e71f218 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -151,6 +151,11 @@ class ShardingTransformConfig(TransformConfig): process_grid: Dict[ShardingDim, int] = Field(default_factory=dict) + enable_attention_dp: bool = Field( + default=False, + description="When True, skip TP sharding as attention data parallelism is enabled.", + ) + def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool: init_process_grid_from_config(self) if sources is None: @@ -738,8 +743,9 @@ class Sharding(BaseTransform): f"Using allreduce strategy: {config.allreduce_strategy.name}, dist backend: {config.dist_backend}" ) - if world_size < 2: - ad_logger.info("Skipping sharding for single device") + if world_size < 2 or config.enable_attention_dp: + reason = "single device" if world_size < 2 else "attention DP enabled" + ad_logger.info(f"Skipping sharding: {reason}") return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 7b293abf4c..df673f5bd3 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -81,6 +81,24 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): task = MMLU(self.MODEL_NAME) task.evaluate(llm, sampling_params=sampling_params) + @pytest.mark.skip_less_device_memory(32000) + @pytest.mark.skip_less_device(2) + @pytest.mark.parametrize("world_size", [2, 4]) + def test_attention_dp(self, world_size): + """Test attention data parallelism mode where TP sharding is disabled.""" + kwargs = self.get_default_kwargs(enable_chunked_prefill=True) + # Enable attention DP - this disables TP sharding + kwargs["transforms"]["detect_sharding"] = {"enable_attention_dp": True} + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH, + tokenizer=self.MODEL_PATH, + world_size=world_size, + **kwargs) as llm: + task = CnnDailymail(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + class TestNemotronH(LlmapiAccuracyTestHarness): MODEL_NAME = "nvidia/Nemotron-H-8B-Base-8K" diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 398a1827fb..6d02460179 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -333,3 +333,4 @@ l0_dgx_h100: - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16 - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4] - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8] + - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_attention_dp[4] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index fa38c4476b..63d54aa851 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -1,4 +1,3 @@ -from types import SimpleNamespace from typing import List, Optional, Type import pytest @@ -9,6 +8,7 @@ from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests class TransformerLikeModelwithFakeCachePool(nn.Module): @@ -196,19 +196,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int): # No-chunk: whole prompt in one request req_full = _DummyRequest(tokens=tokens, begin=0, size=len(tokens), seq_slot=0) - scheduled_full = SimpleNamespace(context_requests=[req_full], generation_requests=[]) - logits_full_last = engine.forward(scheduled_full, resource_manager)["logits"][-1] + scheduled_requests = ScheduledRequests() + scheduled_requests.context_requests.append(req_full) + logits_full_last = engine.forward(scheduled_requests, resource_manager)["logits"][-1] # Chunked: split into two context chunks split = len(tokens) // 2 req_part1 = _DummyRequest(tokens=tokens, begin=0, size=split, seq_slot=0) req_part2 = _DummyRequest(tokens=tokens, begin=split, size=len(tokens) - split, seq_slot=0) - scheduled_part1 = SimpleNamespace(context_requests=[req_part1], generation_requests=[]) - scheduled_part2 = SimpleNamespace(context_requests=[req_part2], generation_requests=[]) + scheduled_requests_part1 = ScheduledRequests() + scheduled_requests_part1.context_requests.append(req_part1) + scheduled_requests_part2 = ScheduledRequests() + scheduled_requests_part2.context_requests.append(req_part2) # Run first chunk (ignored output), then compare second chunk logits to full - _ = engine.forward(scheduled_part1, resource_manager) - logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"][-1] + _ = engine.forward(scheduled_requests_part1, resource_manager) + logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1] torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)