From 3de82c41cdcb4b553fe742ecf2a36a7f451734bc Mon Sep 17 00:00:00 2001 From: Aurelien Chartier Date: Thu, 27 Mar 2025 09:11:19 -0700 Subject: [PATCH] Pytorch PP + attention DP support (#3044) Signed-off-by: Aurelien Chartier --- .../_torch/models/modeling_deepseekv3.py | 5 -- .../_torch/pyexecutor/model_engine.py | 65 +++++++------------ tensorrt_llm/_torch/pyexecutor/py_executor.py | 46 +++++++++---- .../multi_gpu_modeling/test_deepseek.py | 4 +- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a131a97598..9c43e07fd3 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -834,11 +834,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, hidden_size=model_config.pretrained_config.hidden_size, vocab_size=model_config.pretrained_config.vocab_size) - assert not ( - model_config.mapping.has_pp() - and model_config.mapping.enable_attention_dp - ), "Pipeline parallelism and attention DP cannot be used together" - self.model_nextn = 0 if model_config.spec_config is not None: model_nextn = model_config.spec_config.num_nextn_predict_layers diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index c2a9faeaf1..265d0dcdfb 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -581,9 +581,8 @@ class PyTorchModelEngine(ModelEngine): batch_size = scheduled_requests.batch_size new_batch_size = batch_size if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: - graph_batch_size = self.dist.allgather( + graph_batch_size = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) - graph_batch_size = self._post_allgather_pp(graph_batch_size) all_can_graph = all(graph_batch[0] for graph_batch in graph_batch_size) if all_can_graph: @@ -666,9 +665,8 @@ class PyTorchModelEngine(ModelEngine): can_run_cuda_graph = batch.can_run_cuda_graph batch_size = len(batch.generation_requests) if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: - all_can_graph_batch = self.dist.allgather( + all_can_graph_batch = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) - all_can_graph_batch = self._post_allgather_pp(all_can_graph_batch) is_all_gen_only = all(all_can_graph[0] for all_can_graph in all_can_graph_batch) all_batch_size_equal = all( @@ -1097,14 +1095,6 @@ class PyTorchModelEngine(ModelEngine): 'mrope_config': mrope_config } - if self.mapping.has_pp(): - pipeline_interface = None - if self.mapping.pp_rank > 0: - pipeline_interface = self.model.create_pipeline_interface( - inputs['input_ids'].shape[0]) - pipeline_interface.recv() - inputs['pipeline_interface'] = pipeline_interface - if spec_metadata is not None: total_draft_lens = sum(draft_lens) if len(draft_tokens) > 0: @@ -1124,12 +1114,10 @@ class PyTorchModelEngine(ModelEngine): # support attention dp if self.enable_attention_dp: if spec_metadata is not None: - all_rank_num_tokens = self.dist.allgather([ + all_rank_num_tokens = self.dist.tp_allgather([ attn_metadata.num_tokens, spec_metadata.num_tokens, len(sequence_lengths) ]) - all_rank_num_tokens = self._post_allgather_pp( - all_rank_num_tokens) attn_all_rank_num_tokens = [ item[0] for item in all_rank_num_tokens ] @@ -1141,12 +1129,18 @@ class PyTorchModelEngine(ModelEngine): spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens spec_metadata.all_rank_num_seqs = all_rank_num_seqs else: - all_rank_num_tokens = self.dist.allgather( + all_rank_num_tokens = self.dist.tp_allgather( attn_metadata.num_tokens) - all_rank_num_tokens = self._post_allgather_pp( - all_rank_num_tokens) attn_metadata.all_rank_num_tokens = all_rank_num_tokens + if self.mapping.has_pp(): + pipeline_interface = None + if self.mapping.pp_rank > 0: + pipeline_interface = self.model.create_pipeline_interface( + inputs['input_ids'].shape[0]) + pipeline_interface.recv() + inputs['pipeline_interface'] = pipeline_interface + num_generation_tokens = len(generation_requests) + len( extend_requests) + sum(draft_lens) self.iter_states['num_ctx_requests'] = num_ctx_requests @@ -1220,14 +1214,6 @@ class PyTorchModelEngine(ModelEngine): 'multi_modal_data': multi_modal_data } - if self.mapping.has_pp(): - pipeline_interface = None - if self.mapping.pp_rank > 0: - pipeline_interface = self.model.create_pipeline_interface( - inputs['input_ids'].shape[0]) - pipeline_interface.recv() - inputs['pipeline_interface'] = pipeline_interface - if spec_metadata is not None: total_draft_lens = sum(draft_lens) spec_metadata.draft_tokens = self.draft_tokens_cuda[: @@ -1243,12 +1229,10 @@ class PyTorchModelEngine(ModelEngine): # support attention dp if self.enable_attention_dp: if spec_metadata is not None: - all_rank_num_tokens = self.dist.allgather([ + all_rank_num_tokens = self.dist.tp_allgather([ attn_metadata.num_tokens, spec_metadata.num_tokens, len(sequence_lengths) ]) - all_rank_num_tokens = self._post_allgather_pp( - all_rank_num_tokens) attn_all_rank_num_tokens = [ item[0] for item in all_rank_num_tokens ] @@ -1260,11 +1244,18 @@ class PyTorchModelEngine(ModelEngine): spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens spec_metadata.all_rank_num_seqs = all_rank_num_seqs else: - all_rank_num_tokens = self.dist.allgather( + all_rank_num_tokens = self.dist.tp_allgather( attn_metadata.num_tokens) - all_rank_num_tokens = self._post_allgather_pp( - all_rank_num_tokens) attn_metadata.all_rank_num_tokens = all_rank_num_tokens + + if self.mapping.has_pp(): + pipeline_interface = None + if self.mapping.pp_rank > 0: + pipeline_interface = self.model.create_pipeline_interface( + inputs['input_ids'].shape[0]) + pipeline_interface.recv() + inputs['pipeline_interface'] = pipeline_interface + return inputs, None def _prepare_star_attention_inputs(self, @@ -1477,8 +1468,8 @@ class PyTorchModelEngine(ModelEngine): attn_metadata.prepare() if self.enable_attention_dp: - all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens) - all_rank_num_tokens = self._post_allgather_pp(all_rank_num_tokens) + all_rank_num_tokens = self.dist.tp_allgather( + attn_metadata.num_tokens) attn_metadata.all_rank_num_tokens = all_rank_num_tokens return { @@ -1657,12 +1648,6 @@ class PyTorchModelEngine(ModelEngine): else: return {'hidden_states': hidden_states} - def _post_allgather_pp(self, allgather_result): - if self.mapping.has_pp(): - return [allgather_result[rank] for rank in self.mapping.tp_group] - else: - return allgather_result - def _init_userbuffers(self, hidden_size, quant_config, dtype): # No quant, do not allow UB if self.mapping.tp_size <= 1: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index cb9ccdecc1..c64af5fd11 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -185,7 +185,7 @@ class PyExecutor: self.active_requests = [] self.all_ranks_num_active_requests = [ 0 - ] * self.dist.world_size if self.enable_attention_dp else [] + ] * self.dist.tp_size if self.enable_attention_dp else [] self.expected_num_active_requests = 0 self.has_context_request = False self.ctx_in_transmission_requests = [] @@ -490,8 +490,17 @@ class PyExecutor: self._merge_dummy_request(num_dummy_request) scheduled_batch, _, _ = self._schedule() - if scheduled_batch.batch_size == 0: - assert len(self.inflight_req_ids) > 0, ( + if self.enable_attention_dp: + tp_batch_sizes = self.dist.tp_allgather( + scheduled_batch.batch_size) + can_queue = 0 not in tp_batch_sizes + else: + can_queue = scheduled_batch.batch_size > 0 + + if not can_queue: + assert self.enable_attention_dp or len( + self.inflight_req_ids + ) > 0, ( "fail to schedule any pending request, probably run out of resource" ) self.micro_batches[microbatch_id] = None @@ -517,6 +526,7 @@ class PyExecutor: prev_microbatch_id = (microbatch_id + 1) % self.num_micro_batches previous_batch = self.micro_batches[prev_microbatch_id] + # Stage 2: Handle previous batch that only processed forward_step if previous_batch is not None: previous_scheduled_batch, previous_new_tensors_host = previous_batch @@ -530,6 +540,8 @@ class PyExecutor: previous_scheduled_batch) # unlock inflight requests microbatch_id = prev_microbatch_id + self._gather_dp_requests_num() + if self.enable_iter_perf_stats: iter_end_time = time.time() iter_latency_ms = iter_end_time - iter_start_time @@ -575,8 +587,18 @@ class PyExecutor: self._merge_dummy_request(num_dummy_request) scheduled_batch, _, _ = self._schedule() - if scheduled_batch.batch_size == 0: - assert len(self.inflight_req_ids) > 0, ( + + if self.enable_attention_dp: + tp_batch_sizes = self.dist.tp_allgather( + scheduled_batch.batch_size) + can_queue = 0 not in tp_batch_sizes + else: + can_queue = scheduled_batch.batch_size > 0 + + if not can_queue: + assert self.enable_attention_dp or len( + self.inflight_req_ids + ) > 0, ( "fail to schedule any pending request, probably run out of resource" ) self.micro_batches[microbatch_id] = None @@ -657,6 +679,8 @@ class PyExecutor: # march forward in microbatch slots microbatch_id = (microbatch_id + 1) % self.num_micro_batches + self._gather_dp_requests_num() + if self.enable_iter_perf_stats: iter_end_time = time.time() iter_latency_ms = iter_end_time - iter_start_time @@ -1035,7 +1059,7 @@ class PyExecutor: @nvtx_range("_fetch_adp_new_requests") def _fetch_adp_new_requests(self): total_num_active_requests = sum(self.all_ranks_num_active_requests) - total_max_num_active_requests = self.dist.world_size * self.max_num_active_requests + total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests timeout = None if total_num_active_requests == 0 else datetime.timedelta( 0) new_requests = [] @@ -1048,14 +1072,14 @@ class PyExecutor: num_new_requests_all_ranks = len(new_requests) self.expected_num_active_requests = max( (total_num_active_requests + num_new_requests_all_ranks + - self.dist.world_size - 1) // self.dist.world_size, + self.dist.tp_size - 1) // self.dist.tp_size, max(self.all_ranks_num_active_requests), ) self.has_context_request = False new_requests_cur_rank = [] if new_requests != [] and new_requests[ 0] != None and self.expected_num_active_requests > self.all_ranks_num_active_requests[ - self.dist.rank]: + self.dist.tp_rank]: # Balance context tokens across ranks HeapVal = namedtuple( 'HeapVal', @@ -1071,7 +1095,7 @@ class PyExecutor: for idx, val in enumerate(self.all_ranks_num_active_requests) ] new_requests_cur_rank = all_ranks_new_requests_heap[ - self.dist.rank].request_list + self.dist.tp_rank].request_list all_ranks_new_requests_heap = [ val for val in all_ranks_new_requests_heap if val.num_requests > 0 @@ -1111,8 +1135,8 @@ class PyExecutor: def _gather_dp_requests_num(self): if self.enable_attention_dp: gather_active_requests = [] - resonses_list = self.dist.allgather(len(self.active_requests)) - for num_active_requests in resonses_list: + responses_list = self.dist.tp_allgather(len(self.active_requests)) + for num_active_requests in responses_list: gather_active_requests.append(num_active_requests) self.all_ranks_num_active_requests = gather_active_requests diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py index ed593b2b0a..6344c082c0 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py @@ -78,9 +78,9 @@ def test_deepseek(model_name, backend, quant, tp_size, pp_size, ep_size, if mtp_nextn > 0 and getSMVersion() < 100: pytest.skip(f"Only Blackwell MLA kernel can support MTP now") - if pp_size > 1 and (enable_dp or mtp_nextn > 0): + if pp_size > 1 and mtp_nextn > 0: pytest.skip( - "Hang issue with DP attention / MTP + PP: https://nvbugspro.nvidia.com/bug/5170160" + "PP + MTP is not supported: https://nvbugspro.nvidia.com/bug/5170160" ) if pp_size > 2 and enable_cuda_graph and enable_overlap_scheduler: pytest.skip(