Pytorch PP + attention DP support (#3044)

Signed-off-by: Aurelien Chartier <achartier@nvidia.com>
This commit is contained in:
Aurelien Chartier 2025-03-27 09:11:19 -07:00 committed by GitHub
parent ec03159e60
commit 3de82c41cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 58 deletions

View File

@ -834,11 +834,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
assert not (
model_config.mapping.has_pp()
and model_config.mapping.enable_attention_dp
), "Pipeline parallelism and attention DP cannot be used together"
self.model_nextn = 0
if model_config.spec_config is not None:
model_nextn = model_config.spec_config.num_nextn_predict_layers

View File

@ -581,9 +581,8 @@ class PyTorchModelEngine(ModelEngine):
batch_size = scheduled_requests.batch_size
new_batch_size = batch_size
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.allgather(
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
graph_batch_size = self._post_allgather_pp(graph_batch_size)
all_can_graph = all(graph_batch[0]
for graph_batch in graph_batch_size)
if all_can_graph:
@ -666,9 +665,8 @@ class PyTorchModelEngine(ModelEngine):
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = len(batch.generation_requests)
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
all_can_graph_batch = self.dist.allgather(
all_can_graph_batch = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
all_can_graph_batch = self._post_allgather_pp(all_can_graph_batch)
is_all_gen_only = all(all_can_graph[0]
for all_can_graph in all_can_graph_batch)
all_batch_size_equal = all(
@ -1097,14 +1095,6 @@ class PyTorchModelEngine(ModelEngine):
'mrope_config': mrope_config
}
if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
if spec_metadata is not None:
total_draft_lens = sum(draft_lens)
if len(draft_tokens) > 0:
@ -1124,12 +1114,10 @@ class PyTorchModelEngine(ModelEngine):
# support attention dp
if self.enable_attention_dp:
if spec_metadata is not None:
all_rank_num_tokens = self.dist.allgather([
all_rank_num_tokens = self.dist.tp_allgather([
attn_metadata.num_tokens, spec_metadata.num_tokens,
len(sequence_lengths)
])
all_rank_num_tokens = self._post_allgather_pp(
all_rank_num_tokens)
attn_all_rank_num_tokens = [
item[0] for item in all_rank_num_tokens
]
@ -1141,12 +1129,18 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
else:
all_rank_num_tokens = self.dist.allgather(
all_rank_num_tokens = self.dist.tp_allgather(
attn_metadata.num_tokens)
all_rank_num_tokens = self._post_allgather_pp(
all_rank_num_tokens)
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
num_generation_tokens = len(generation_requests) + len(
extend_requests) + sum(draft_lens)
self.iter_states['num_ctx_requests'] = num_ctx_requests
@ -1220,14 +1214,6 @@ class PyTorchModelEngine(ModelEngine):
'multi_modal_data': multi_modal_data
}
if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
if spec_metadata is not None:
total_draft_lens = sum(draft_lens)
spec_metadata.draft_tokens = self.draft_tokens_cuda[:
@ -1243,12 +1229,10 @@ class PyTorchModelEngine(ModelEngine):
# support attention dp
if self.enable_attention_dp:
if spec_metadata is not None:
all_rank_num_tokens = self.dist.allgather([
all_rank_num_tokens = self.dist.tp_allgather([
attn_metadata.num_tokens, spec_metadata.num_tokens,
len(sequence_lengths)
])
all_rank_num_tokens = self._post_allgather_pp(
all_rank_num_tokens)
attn_all_rank_num_tokens = [
item[0] for item in all_rank_num_tokens
]
@ -1260,11 +1244,18 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
else:
all_rank_num_tokens = self.dist.allgather(
all_rank_num_tokens = self.dist.tp_allgather(
attn_metadata.num_tokens)
all_rank_num_tokens = self._post_allgather_pp(
all_rank_num_tokens)
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
return inputs, None
def _prepare_star_attention_inputs(self,
@ -1477,8 +1468,8 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata.prepare()
if self.enable_attention_dp:
all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens)
all_rank_num_tokens = self._post_allgather_pp(all_rank_num_tokens)
all_rank_num_tokens = self.dist.tp_allgather(
attn_metadata.num_tokens)
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
return {
@ -1657,12 +1648,6 @@ class PyTorchModelEngine(ModelEngine):
else:
return {'hidden_states': hidden_states}
def _post_allgather_pp(self, allgather_result):
if self.mapping.has_pp():
return [allgather_result[rank] for rank in self.mapping.tp_group]
else:
return allgather_result
def _init_userbuffers(self, hidden_size, quant_config, dtype):
# No quant, do not allow UB
if self.mapping.tp_size <= 1:

View File

@ -185,7 +185,7 @@ class PyExecutor:
self.active_requests = []
self.all_ranks_num_active_requests = [
0
] * self.dist.world_size if self.enable_attention_dp else []
] * self.dist.tp_size if self.enable_attention_dp else []
self.expected_num_active_requests = 0
self.has_context_request = False
self.ctx_in_transmission_requests = []
@ -490,8 +490,17 @@ class PyExecutor:
self._merge_dummy_request(num_dummy_request)
scheduled_batch, _, _ = self._schedule()
if scheduled_batch.batch_size == 0:
assert len(self.inflight_req_ids) > 0, (
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(
scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
if not can_queue:
assert self.enable_attention_dp or len(
self.inflight_req_ids
) > 0, (
"fail to schedule any pending request, probably run out of resource"
)
self.micro_batches[microbatch_id] = None
@ -517,6 +526,7 @@ class PyExecutor:
prev_microbatch_id = (microbatch_id +
1) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
# Stage 2: Handle previous batch that only processed forward_step
if previous_batch is not None:
previous_scheduled_batch, previous_new_tensors_host = previous_batch
@ -530,6 +540,8 @@ class PyExecutor:
previous_scheduled_batch) # unlock inflight requests
microbatch_id = prev_microbatch_id
self._gather_dp_requests_num()
if self.enable_iter_perf_stats:
iter_end_time = time.time()
iter_latency_ms = iter_end_time - iter_start_time
@ -575,8 +587,18 @@ class PyExecutor:
self._merge_dummy_request(num_dummy_request)
scheduled_batch, _, _ = self._schedule()
if scheduled_batch.batch_size == 0:
assert len(self.inflight_req_ids) > 0, (
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(
scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
if not can_queue:
assert self.enable_attention_dp or len(
self.inflight_req_ids
) > 0, (
"fail to schedule any pending request, probably run out of resource"
)
self.micro_batches[microbatch_id] = None
@ -657,6 +679,8 @@ class PyExecutor:
# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
self._gather_dp_requests_num()
if self.enable_iter_perf_stats:
iter_end_time = time.time()
iter_latency_ms = iter_end_time - iter_start_time
@ -1035,7 +1059,7 @@ class PyExecutor:
@nvtx_range("_fetch_adp_new_requests")
def _fetch_adp_new_requests(self):
total_num_active_requests = sum(self.all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.world_size * self.max_num_active_requests
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
timeout = None if total_num_active_requests == 0 else datetime.timedelta(
0)
new_requests = []
@ -1048,14 +1072,14 @@ class PyExecutor:
num_new_requests_all_ranks = len(new_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.world_size - 1) // self.dist.world_size,
self.dist.tp_size - 1) // self.dist.tp_size,
max(self.all_ranks_num_active_requests),
)
self.has_context_request = False
new_requests_cur_rank = []
if new_requests != [] and new_requests[
0] != None and self.expected_num_active_requests > self.all_ranks_num_active_requests[
self.dist.rank]:
self.dist.tp_rank]:
# Balance context tokens across ranks
HeapVal = namedtuple(
'HeapVal',
@ -1071,7 +1095,7 @@ class PyExecutor:
for idx, val in enumerate(self.all_ranks_num_active_requests)
]
new_requests_cur_rank = all_ranks_new_requests_heap[
self.dist.rank].request_list
self.dist.tp_rank].request_list
all_ranks_new_requests_heap = [
val for val in all_ranks_new_requests_heap
if val.num_requests > 0
@ -1111,8 +1135,8 @@ class PyExecutor:
def _gather_dp_requests_num(self):
if self.enable_attention_dp:
gather_active_requests = []
resonses_list = self.dist.allgather(len(self.active_requests))
for num_active_requests in resonses_list:
responses_list = self.dist.tp_allgather(len(self.active_requests))
for num_active_requests in responses_list:
gather_active_requests.append(num_active_requests)
self.all_ranks_num_active_requests = gather_active_requests

View File

@ -78,9 +78,9 @@ def test_deepseek(model_name, backend, quant, tp_size, pp_size, ep_size,
if mtp_nextn > 0 and getSMVersion() < 100:
pytest.skip(f"Only Blackwell MLA kernel can support MTP now")
if pp_size > 1 and (enable_dp or mtp_nextn > 0):
if pp_size > 1 and mtp_nextn > 0:
pytest.skip(
"Hang issue with DP attention / MTP + PP: https://nvbugspro.nvidia.com/bug/5170160"
"PP + MTP is not supported: https://nvbugspro.nvidia.com/bug/5170160"
)
if pp_size > 2 and enable_cuda_graph and enable_overlap_scheduler:
pytest.skip(