Chore: clean up _gather_dp_requests_num method of PyExecutor (#4571)

clean up _gather_dp_requests_num method of PyExecutor

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-05-23 08:37:39 +08:00 committed by GitHub
parent 3549b68c1c
commit 1e55d616da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -225,9 +225,6 @@ class PyExecutor:
# _executor_loop private data
self.max_num_active_requests = model_engine.get_max_num_sequences()
self.active_requests: List[LlmRequest] = []
self.all_ranks_num_active_requests = [
0
] * self.dist.tp_size if self.enable_attention_dp else []
self.expected_num_active_requests = 0
self.has_context_request = False
self.ctx_in_transmission_requests = []
@ -794,7 +791,6 @@ class PyExecutor:
# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
self._gather_dp_requests_num()
if self.enable_iter_perf_stats and previous_batch is not None:
self._process_iter_stats(finished_requests,
@ -917,8 +913,6 @@ class PyExecutor:
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
self._gather_dp_requests_num()
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
@ -1081,7 +1075,6 @@ class PyExecutor:
iter_start_time=iter_start_time,
iter_stats=iter_stats,
ctx_transmission_reqs=ctx_transmission_reqs)
self._gather_dp_requests_num()
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
@ -1187,7 +1180,11 @@ class PyExecutor:
@nvtx_range("_fetch_new_requests")
def _fetch_new_requests(self):
if self.enable_attention_dp:
total_num_active_requests = sum(self.all_ranks_num_active_requests)
all_ranks_num_active_requests = []
responses_list = self.dist.tp_allgather(len(self.active_requests))
for num_active_requests in responses_list:
all_ranks_num_active_requests.append(num_active_requests)
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
else:
total_num_active_requests = len(self.active_requests)
@ -1229,13 +1226,13 @@ class PyExecutor:
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(self.all_ranks_num_active_requests),
max(all_ranks_num_active_requests),
)
self.has_context_request = False
new_requests_cur_rank = []
if new_requests != [] and new_requests[
0] != None and self.expected_num_active_requests > self.all_ranks_num_active_requests[
0] != None and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
# Balance context tokens across ranks
HeapVal = namedtuple(
@ -1249,8 +1246,7 @@ class PyExecutor:
)
all_ranks_new_requests_heap = [
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
for tp_rank, val in enumerate(
self.all_ranks_num_active_requests)
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]
new_requests_cur_rank = all_ranks_new_requests_heap[
self.dist.tp_rank].request_list
@ -1294,15 +1290,6 @@ class PyExecutor:
new_requests_cur_rank = new_requests
return new_requests_cur_rank
@nvtx_range("_gather_dp_requests_num")
def _gather_dp_requests_num(self):
if self.enable_attention_dp:
gather_active_requests = []
responses_list = self.dist.tp_allgather(len(self.active_requests))
for num_active_requests in responses_list:
gather_active_requests.append(num_active_requests)
self.all_ranks_num_active_requests = gather_active_requests
def _add_kv_cache_events(self):
kv_cache_manager = self.resource_manager.resource_managers.get(
"kv_cache_manager")