Chore: refine shutdown signal of PyExecutor (#4614)

* refine shutdown signal of PyExecutor

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>

* clean

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>

* fix ci

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>

* fix ci

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>

---------

Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-05-26 11:14:54 +08:00 committed by GitHub
parent 2fee408536
commit 4a81991b65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,13 +48,7 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
# Set to a path to save detailed tracing of PyTorch operations.
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
def _is_executor_request(req_queue_item) -> bool:
return isinstance(req_queue_item, tuple)
def _is_cancel_request(req_queue_item) -> bool:
return isinstance(req_queue_item, int)
SHUTDOWN_REQUEST_ID = -1
def _get_from_request_queue(request_queue, timeout: datetime.timedelta,
@ -71,8 +65,8 @@ def _get_from_request_queue(request_queue, timeout: datetime.timedelta,
while req_count < max_req_count:
queue_item = request_queue.get_nowait()
items.append(queue_item)
if _is_executor_request(queue_item):
# if it is request, (Not finish signal or cancel signal)
if queue_item[0] != SHUTDOWN_REQUEST_ID:
# if it is request, not shutdown signal
req_count += 1
except queue.Empty:
pass
@ -350,7 +344,7 @@ class PyExecutor:
"""
try:
self.enqueue_lock.acquire()
self.request_queue.put(None)
self.request_queue.put((SHUTDOWN_REQUEST_ID, ))
self.active = False
finally:
self.enqueue_lock.release()
@ -1243,8 +1237,8 @@ class PyExecutor:
self.has_context_request = False
new_requests_cur_rank = []
if new_requests != [] and new_requests[
0] != None and self.expected_num_active_requests > all_ranks_num_active_requests[
if new_requests != [] and new_requests[0][
0] != SHUTDOWN_REQUEST_ID and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
# Balance context tokens across ranks
HeapVal = namedtuple(
@ -1298,7 +1292,7 @@ class PyExecutor:
self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len(
new_requests_cur_rank)
if len(new_requests) == 1 and new_requests[0] is None:
if len(new_requests) == 1 and new_requests[0][0] == SHUTDOWN_REQUEST_ID:
new_requests_cur_rank = new_requests
return new_requests_cur_rank
@ -1313,15 +1307,12 @@ class PyExecutor:
def _merge_tp_requests(self, new_requests: List[ExecutorRequest]):
for request in new_requests:
if request is None:
if request[0] == SHUTDOWN_REQUEST_ID:
return True
for req_item in new_requests:
if _is_executor_request(req_item):
req_id, exe_req = req_item
req = executor_request_to_llm_request(req_id, exe_req)
self.active_requests.append(req)
elif _is_cancel_request(req_item):
self.canceled_req_ids.insert(req_item)
req_id, exe_req = req_item
req = executor_request_to_llm_request(req_id, exe_req)
self.active_requests.append(req)
return False
@ -1348,7 +1339,7 @@ class PyExecutor:
"""
req_id_to_obj = {}
for item in requests:
if item is None:
if item[0] == SHUTDOWN_REQUEST_ID:
continue
req_id, req = item[:2]
obj = getattr(req, attribute_name, None)
@ -1362,7 +1353,7 @@ class PyExecutor:
to each request.
"""
for item in requests:
if item is None:
if item[0] == SHUTDOWN_REQUEST_ID:
continue
req_id, req = item[:2]
py_obj = py_request_objects.get(req_id)
@ -1410,61 +1401,57 @@ class PyExecutor:
def _merge_star_attention_requests(self,
new_requests: List[ExecutorRequest]):
for request in new_requests:
if request is None:
if request[0] == SHUTDOWN_REQUEST_ID:
return True
for req_item in new_requests:
if _is_executor_request(req_item):
req_id, exe_req, query_token_ids = req_item
ctx_len0 = len(exe_req.input_token_ids)
ctx_blocks, position_blocks, last_block_padding_num = [
exe_req.input_token_ids
], [[i for i in range(ctx_len0)]], 0
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
exe_req.input_token_ids)
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
position_blocks[-1] = position_blocks[
-1][:-last_block_padding_num]
#if has query
if query_token_ids:
ctx_blocks.append(query_token_ids)
position_blocks.append([
i for i in range(ctx_len0, ctx_len0 +
len(query_token_ids))
])
req_id, exe_req, query_token_ids = req_item
ctx_len0 = len(exe_req.input_token_ids)
ctx_blocks, position_blocks, last_block_padding_num = [
exe_req.input_token_ids
], [[i for i in range(ctx_len0)]], 0
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
exe_req.input_token_ids)
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
position_blocks[-1] = position_blocks[
-1][:-last_block_padding_num]
#if has query
if query_token_ids:
ctx_blocks.append(query_token_ids)
position_blocks.append([
i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))
])
# insert the dummy block to align the number of ctx iterations of each rank
block_size = self.dist.cp_config['block_size']
total_blocks = (ctx_len0 + block_size - 1) // block_size
num_blocks_per_rank = (
total_blocks + self.dist.cp_size -
1) // self.dist.cp_size + 1 # 1 for query block
if len(ctx_blocks) == num_blocks_per_rank:
ctx_blocks.insert(1, [])
position_blocks.insert(1, [])
elif len(ctx_blocks) == num_blocks_per_rank + 1:
# anchor + ctx_blocks + qry_block
pass
else:
print(
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
)
assert False, f'invalid context partition'
# insert the dummy block to align the number of ctx iterations of each rank
block_size = self.dist.cp_config['block_size']
total_blocks = (ctx_len0 + block_size - 1) // block_size
num_blocks_per_rank = (
total_blocks + self.dist.cp_size -
1) // self.dist.cp_size + 1 # 1 for query block
if len(ctx_blocks) == num_blocks_per_rank:
ctx_blocks.insert(1, [])
position_blocks.insert(1, [])
elif len(ctx_blocks) == num_blocks_per_rank + 1:
# anchor + ctx_blocks + qry_block
pass
else:
print(
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
)
assert False, f'invalid context partition'
# fake data for scheduler
ctx_blocks_list = [0] * (block_size +
self.dist.cp_config['cp_anchor_size'])
# fake data for scheduler
ctx_blocks_list = [0] * (block_size +
self.dist.cp_config['cp_anchor_size'])
req = executor_request_to_llm_request(req_id, exe_req,
ctx_blocks_list)
req.gen_iters = 0
req.ctx_iters = 0
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
self.active_requests.append(req)
elif _is_cancel_request(req_item):
self.canceled_req_ids.insert(req_item)
req = executor_request_to_llm_request(req_id, exe_req,
ctx_blocks_list)
req.gen_iters = 0
req.ctx_iters = 0
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
self.active_requests.append(req)
return False