mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Chore: clean up _gather_dp_requests_num method of PyExecutor (#4571)
clean up _gather_dp_requests_num method of PyExecutor Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
3549b68c1c
commit
1e55d616da
@ -225,9 +225,6 @@ class PyExecutor:
|
||||
# _executor_loop private data
|
||||
self.max_num_active_requests = model_engine.get_max_num_sequences()
|
||||
self.active_requests: List[LlmRequest] = []
|
||||
self.all_ranks_num_active_requests = [
|
||||
0
|
||||
] * self.dist.tp_size if self.enable_attention_dp else []
|
||||
self.expected_num_active_requests = 0
|
||||
self.has_context_request = False
|
||||
self.ctx_in_transmission_requests = []
|
||||
@ -794,7 +791,6 @@ class PyExecutor:
|
||||
|
||||
# march forward in microbatch slots
|
||||
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats and previous_batch is not None:
|
||||
self._process_iter_stats(finished_requests,
|
||||
@ -917,8 +913,6 @@ class PyExecutor:
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
@ -1081,7 +1075,6 @@ class PyExecutor:
|
||||
iter_start_time=iter_start_time,
|
||||
iter_stats=iter_stats,
|
||||
ctx_transmission_reqs=ctx_transmission_reqs)
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
@ -1187,7 +1180,11 @@ class PyExecutor:
|
||||
@nvtx_range("_fetch_new_requests")
|
||||
def _fetch_new_requests(self):
|
||||
if self.enable_attention_dp:
|
||||
total_num_active_requests = sum(self.all_ranks_num_active_requests)
|
||||
all_ranks_num_active_requests = []
|
||||
responses_list = self.dist.tp_allgather(len(self.active_requests))
|
||||
for num_active_requests in responses_list:
|
||||
all_ranks_num_active_requests.append(num_active_requests)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
|
||||
else:
|
||||
total_num_active_requests = len(self.active_requests)
|
||||
@ -1229,13 +1226,13 @@ class PyExecutor:
|
||||
self.expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks +
|
||||
self.dist.tp_size - 1) // self.dist.tp_size,
|
||||
max(self.all_ranks_num_active_requests),
|
||||
max(all_ranks_num_active_requests),
|
||||
)
|
||||
|
||||
self.has_context_request = False
|
||||
new_requests_cur_rank = []
|
||||
if new_requests != [] and new_requests[
|
||||
0] != None and self.expected_num_active_requests > self.all_ranks_num_active_requests[
|
||||
0] != None and self.expected_num_active_requests > all_ranks_num_active_requests[
|
||||
self.dist.tp_rank]:
|
||||
# Balance context tokens across ranks
|
||||
HeapVal = namedtuple(
|
||||
@ -1249,8 +1246,7 @@ class PyExecutor:
|
||||
)
|
||||
all_ranks_new_requests_heap = [
|
||||
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
|
||||
for tp_rank, val in enumerate(
|
||||
self.all_ranks_num_active_requests)
|
||||
for tp_rank, val in enumerate(all_ranks_num_active_requests)
|
||||
]
|
||||
new_requests_cur_rank = all_ranks_new_requests_heap[
|
||||
self.dist.tp_rank].request_list
|
||||
@ -1294,15 +1290,6 @@ class PyExecutor:
|
||||
new_requests_cur_rank = new_requests
|
||||
return new_requests_cur_rank
|
||||
|
||||
@nvtx_range("_gather_dp_requests_num")
|
||||
def _gather_dp_requests_num(self):
|
||||
if self.enable_attention_dp:
|
||||
gather_active_requests = []
|
||||
responses_list = self.dist.tp_allgather(len(self.active_requests))
|
||||
for num_active_requests in responses_list:
|
||||
gather_active_requests.append(num_active_requests)
|
||||
self.all_ranks_num_active_requests = gather_active_requests
|
||||
|
||||
def _add_kv_cache_events(self):
|
||||
kv_cache_manager = self.resource_manager.resource_managers.get(
|
||||
"kv_cache_manager")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user