mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Pytorch PP + attention DP support (#3044)
Signed-off-by: Aurelien Chartier <achartier@nvidia.com>
This commit is contained in:
parent
ec03159e60
commit
3de82c41cd
@ -834,11 +834,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
assert not (
|
||||
model_config.mapping.has_pp()
|
||||
and model_config.mapping.enable_attention_dp
|
||||
), "Pipeline parallelism and attention DP cannot be used together"
|
||||
|
||||
self.model_nextn = 0
|
||||
if model_config.spec_config is not None:
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
|
||||
@ -581,9 +581,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
batch_size = scheduled_requests.batch_size
|
||||
new_batch_size = batch_size
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
graph_batch_size = self.dist.allgather(
|
||||
graph_batch_size = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
graph_batch_size = self._post_allgather_pp(graph_batch_size)
|
||||
all_can_graph = all(graph_batch[0]
|
||||
for graph_batch in graph_batch_size)
|
||||
if all_can_graph:
|
||||
@ -666,9 +665,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = len(batch.generation_requests)
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
all_can_graph_batch = self.dist.allgather(
|
||||
all_can_graph_batch = self.dist.tp_allgather(
|
||||
[can_run_cuda_graph, batch_size])
|
||||
all_can_graph_batch = self._post_allgather_pp(all_can_graph_batch)
|
||||
is_all_gen_only = all(all_can_graph[0]
|
||||
for all_can_graph in all_can_graph_batch)
|
||||
all_batch_size_equal = all(
|
||||
@ -1097,14 +1095,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'mrope_config': mrope_config
|
||||
}
|
||||
|
||||
if self.mapping.has_pp():
|
||||
pipeline_interface = None
|
||||
if self.mapping.pp_rank > 0:
|
||||
pipeline_interface = self.model.create_pipeline_interface(
|
||||
inputs['input_ids'].shape[0])
|
||||
pipeline_interface.recv()
|
||||
inputs['pipeline_interface'] = pipeline_interface
|
||||
|
||||
if spec_metadata is not None:
|
||||
total_draft_lens = sum(draft_lens)
|
||||
if len(draft_tokens) > 0:
|
||||
@ -1124,12 +1114,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# support attention dp
|
||||
if self.enable_attention_dp:
|
||||
if spec_metadata is not None:
|
||||
all_rank_num_tokens = self.dist.allgather([
|
||||
all_rank_num_tokens = self.dist.tp_allgather([
|
||||
attn_metadata.num_tokens, spec_metadata.num_tokens,
|
||||
len(sequence_lengths)
|
||||
])
|
||||
all_rank_num_tokens = self._post_allgather_pp(
|
||||
all_rank_num_tokens)
|
||||
attn_all_rank_num_tokens = [
|
||||
item[0] for item in all_rank_num_tokens
|
||||
]
|
||||
@ -1141,12 +1129,18 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
|
||||
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
|
||||
else:
|
||||
all_rank_num_tokens = self.dist.allgather(
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
attn_metadata.num_tokens)
|
||||
all_rank_num_tokens = self._post_allgather_pp(
|
||||
all_rank_num_tokens)
|
||||
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
|
||||
|
||||
if self.mapping.has_pp():
|
||||
pipeline_interface = None
|
||||
if self.mapping.pp_rank > 0:
|
||||
pipeline_interface = self.model.create_pipeline_interface(
|
||||
inputs['input_ids'].shape[0])
|
||||
pipeline_interface.recv()
|
||||
inputs['pipeline_interface'] = pipeline_interface
|
||||
|
||||
num_generation_tokens = len(generation_requests) + len(
|
||||
extend_requests) + sum(draft_lens)
|
||||
self.iter_states['num_ctx_requests'] = num_ctx_requests
|
||||
@ -1220,14 +1214,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'multi_modal_data': multi_modal_data
|
||||
}
|
||||
|
||||
if self.mapping.has_pp():
|
||||
pipeline_interface = None
|
||||
if self.mapping.pp_rank > 0:
|
||||
pipeline_interface = self.model.create_pipeline_interface(
|
||||
inputs['input_ids'].shape[0])
|
||||
pipeline_interface.recv()
|
||||
inputs['pipeline_interface'] = pipeline_interface
|
||||
|
||||
if spec_metadata is not None:
|
||||
total_draft_lens = sum(draft_lens)
|
||||
spec_metadata.draft_tokens = self.draft_tokens_cuda[:
|
||||
@ -1243,12 +1229,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# support attention dp
|
||||
if self.enable_attention_dp:
|
||||
if spec_metadata is not None:
|
||||
all_rank_num_tokens = self.dist.allgather([
|
||||
all_rank_num_tokens = self.dist.tp_allgather([
|
||||
attn_metadata.num_tokens, spec_metadata.num_tokens,
|
||||
len(sequence_lengths)
|
||||
])
|
||||
all_rank_num_tokens = self._post_allgather_pp(
|
||||
all_rank_num_tokens)
|
||||
attn_all_rank_num_tokens = [
|
||||
item[0] for item in all_rank_num_tokens
|
||||
]
|
||||
@ -1260,11 +1244,18 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
|
||||
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
|
||||
else:
|
||||
all_rank_num_tokens = self.dist.allgather(
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
attn_metadata.num_tokens)
|
||||
all_rank_num_tokens = self._post_allgather_pp(
|
||||
all_rank_num_tokens)
|
||||
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
|
||||
|
||||
if self.mapping.has_pp():
|
||||
pipeline_interface = None
|
||||
if self.mapping.pp_rank > 0:
|
||||
pipeline_interface = self.model.create_pipeline_interface(
|
||||
inputs['input_ids'].shape[0])
|
||||
pipeline_interface.recv()
|
||||
inputs['pipeline_interface'] = pipeline_interface
|
||||
|
||||
return inputs, None
|
||||
|
||||
def _prepare_star_attention_inputs(self,
|
||||
@ -1477,8 +1468,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
attn_metadata.prepare()
|
||||
if self.enable_attention_dp:
|
||||
all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens)
|
||||
all_rank_num_tokens = self._post_allgather_pp(all_rank_num_tokens)
|
||||
all_rank_num_tokens = self.dist.tp_allgather(
|
||||
attn_metadata.num_tokens)
|
||||
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
|
||||
|
||||
return {
|
||||
@ -1657,12 +1648,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
def _post_allgather_pp(self, allgather_result):
|
||||
if self.mapping.has_pp():
|
||||
return [allgather_result[rank] for rank in self.mapping.tp_group]
|
||||
else:
|
||||
return allgather_result
|
||||
|
||||
def _init_userbuffers(self, hidden_size, quant_config, dtype):
|
||||
# No quant, do not allow UB
|
||||
if self.mapping.tp_size <= 1:
|
||||
|
||||
@ -185,7 +185,7 @@ class PyExecutor:
|
||||
self.active_requests = []
|
||||
self.all_ranks_num_active_requests = [
|
||||
0
|
||||
] * self.dist.world_size if self.enable_attention_dp else []
|
||||
] * self.dist.tp_size if self.enable_attention_dp else []
|
||||
self.expected_num_active_requests = 0
|
||||
self.has_context_request = False
|
||||
self.ctx_in_transmission_requests = []
|
||||
@ -490,8 +490,17 @@ class PyExecutor:
|
||||
self._merge_dummy_request(num_dummy_request)
|
||||
scheduled_batch, _, _ = self._schedule()
|
||||
|
||||
if scheduled_batch.batch_size == 0:
|
||||
assert len(self.inflight_req_ids) > 0, (
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(
|
||||
scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
|
||||
if not can_queue:
|
||||
assert self.enable_attention_dp or len(
|
||||
self.inflight_req_ids
|
||||
) > 0, (
|
||||
"fail to schedule any pending request, probably run out of resource"
|
||||
)
|
||||
self.micro_batches[microbatch_id] = None
|
||||
@ -517,6 +526,7 @@ class PyExecutor:
|
||||
prev_microbatch_id = (microbatch_id +
|
||||
1) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
|
||||
# Stage 2: Handle previous batch that only processed forward_step
|
||||
if previous_batch is not None:
|
||||
previous_scheduled_batch, previous_new_tensors_host = previous_batch
|
||||
@ -530,6 +540,8 @@ class PyExecutor:
|
||||
previous_scheduled_batch) # unlock inflight requests
|
||||
microbatch_id = prev_microbatch_id
|
||||
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_end_time = time.time()
|
||||
iter_latency_ms = iter_end_time - iter_start_time
|
||||
@ -575,8 +587,18 @@ class PyExecutor:
|
||||
self._merge_dummy_request(num_dummy_request)
|
||||
|
||||
scheduled_batch, _, _ = self._schedule()
|
||||
if scheduled_batch.batch_size == 0:
|
||||
assert len(self.inflight_req_ids) > 0, (
|
||||
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(
|
||||
scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
|
||||
if not can_queue:
|
||||
assert self.enable_attention_dp or len(
|
||||
self.inflight_req_ids
|
||||
) > 0, (
|
||||
"fail to schedule any pending request, probably run out of resource"
|
||||
)
|
||||
self.micro_batches[microbatch_id] = None
|
||||
@ -657,6 +679,8 @@ class PyExecutor:
|
||||
# march forward in microbatch slots
|
||||
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
|
||||
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_end_time = time.time()
|
||||
iter_latency_ms = iter_end_time - iter_start_time
|
||||
@ -1035,7 +1059,7 @@ class PyExecutor:
|
||||
@nvtx_range("_fetch_adp_new_requests")
|
||||
def _fetch_adp_new_requests(self):
|
||||
total_num_active_requests = sum(self.all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = self.dist.world_size * self.max_num_active_requests
|
||||
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
|
||||
timeout = None if total_num_active_requests == 0 else datetime.timedelta(
|
||||
0)
|
||||
new_requests = []
|
||||
@ -1048,14 +1072,14 @@ class PyExecutor:
|
||||
num_new_requests_all_ranks = len(new_requests)
|
||||
self.expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks +
|
||||
self.dist.world_size - 1) // self.dist.world_size,
|
||||
self.dist.tp_size - 1) // self.dist.tp_size,
|
||||
max(self.all_ranks_num_active_requests),
|
||||
)
|
||||
self.has_context_request = False
|
||||
new_requests_cur_rank = []
|
||||
if new_requests != [] and new_requests[
|
||||
0] != None and self.expected_num_active_requests > self.all_ranks_num_active_requests[
|
||||
self.dist.rank]:
|
||||
self.dist.tp_rank]:
|
||||
# Balance context tokens across ranks
|
||||
HeapVal = namedtuple(
|
||||
'HeapVal',
|
||||
@ -1071,7 +1095,7 @@ class PyExecutor:
|
||||
for idx, val in enumerate(self.all_ranks_num_active_requests)
|
||||
]
|
||||
new_requests_cur_rank = all_ranks_new_requests_heap[
|
||||
self.dist.rank].request_list
|
||||
self.dist.tp_rank].request_list
|
||||
all_ranks_new_requests_heap = [
|
||||
val for val in all_ranks_new_requests_heap
|
||||
if val.num_requests > 0
|
||||
@ -1111,8 +1135,8 @@ class PyExecutor:
|
||||
def _gather_dp_requests_num(self):
|
||||
if self.enable_attention_dp:
|
||||
gather_active_requests = []
|
||||
resonses_list = self.dist.allgather(len(self.active_requests))
|
||||
for num_active_requests in resonses_list:
|
||||
responses_list = self.dist.tp_allgather(len(self.active_requests))
|
||||
for num_active_requests in responses_list:
|
||||
gather_active_requests.append(num_active_requests)
|
||||
self.all_ranks_num_active_requests = gather_active_requests
|
||||
|
||||
|
||||
@ -78,9 +78,9 @@ def test_deepseek(model_name, backend, quant, tp_size, pp_size, ep_size,
|
||||
if mtp_nextn > 0 and getSMVersion() < 100:
|
||||
pytest.skip(f"Only Blackwell MLA kernel can support MTP now")
|
||||
|
||||
if pp_size > 1 and (enable_dp or mtp_nextn > 0):
|
||||
if pp_size > 1 and mtp_nextn > 0:
|
||||
pytest.skip(
|
||||
"Hang issue with DP attention / MTP + PP: https://nvbugspro.nvidia.com/bug/5170160"
|
||||
"PP + MTP is not supported: https://nvbugspro.nvidia.com/bug/5170160"
|
||||
)
|
||||
if pp_size > 2 and enable_cuda_graph and enable_overlap_scheduler:
|
||||
pytest.skip(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user