From 4a81991b65b6517d2ebc1a819184b9aac6522b34 Mon Sep 17 00:00:00 2001 From: QI JUN <22017000+QiJune@users.noreply.github.com> Date: Mon, 26 May 2025 11:14:54 +0800 Subject: [PATCH] Chore: refine shutdown signal of PyExecutor (#4614) * refine shutdown signal of PyExecutor Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * clean Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> --------- Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 131 ++++++++---------- 1 file changed, 59 insertions(+), 72 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index fd4c634f1d..636cb04933 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -48,13 +48,7 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC" # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" - -def _is_executor_request(req_queue_item) -> bool: - return isinstance(req_queue_item, tuple) - - -def _is_cancel_request(req_queue_item) -> bool: - return isinstance(req_queue_item, int) +SHUTDOWN_REQUEST_ID = -1 def _get_from_request_queue(request_queue, timeout: datetime.timedelta, @@ -71,8 +65,8 @@ def _get_from_request_queue(request_queue, timeout: datetime.timedelta, while req_count < max_req_count: queue_item = request_queue.get_nowait() items.append(queue_item) - if _is_executor_request(queue_item): - # if it is request, (Not finish signal or cancel signal) + if queue_item[0] != SHUTDOWN_REQUEST_ID: + # if it is request, not shutdown signal req_count += 1 except queue.Empty: pass @@ -350,7 +344,7 @@ class PyExecutor: """ try: self.enqueue_lock.acquire() - self.request_queue.put(None) + self.request_queue.put((SHUTDOWN_REQUEST_ID, )) self.active = False finally: self.enqueue_lock.release() @@ -1243,8 +1237,8 @@ class PyExecutor: self.has_context_request = False new_requests_cur_rank = [] - if new_requests != [] and new_requests[ - 0] != None and self.expected_num_active_requests > all_ranks_num_active_requests[ + if new_requests != [] and new_requests[0][ + 0] != SHUTDOWN_REQUEST_ID and self.expected_num_active_requests > all_ranks_num_active_requests[ self.dist.tp_rank]: # Balance context tokens across ranks HeapVal = namedtuple( @@ -1298,7 +1292,7 @@ class PyExecutor: self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len( new_requests_cur_rank) - if len(new_requests) == 1 and new_requests[0] is None: + if len(new_requests) == 1 and new_requests[0][0] == SHUTDOWN_REQUEST_ID: new_requests_cur_rank = new_requests return new_requests_cur_rank @@ -1313,15 +1307,12 @@ class PyExecutor: def _merge_tp_requests(self, new_requests: List[ExecutorRequest]): for request in new_requests: - if request is None: + if request[0] == SHUTDOWN_REQUEST_ID: return True for req_item in new_requests: - if _is_executor_request(req_item): - req_id, exe_req = req_item - req = executor_request_to_llm_request(req_id, exe_req) - self.active_requests.append(req) - elif _is_cancel_request(req_item): - self.canceled_req_ids.insert(req_item) + req_id, exe_req = req_item + req = executor_request_to_llm_request(req_id, exe_req) + self.active_requests.append(req) return False @@ -1348,7 +1339,7 @@ class PyExecutor: """ req_id_to_obj = {} for item in requests: - if item is None: + if item[0] == SHUTDOWN_REQUEST_ID: continue req_id, req = item[:2] obj = getattr(req, attribute_name, None) @@ -1362,7 +1353,7 @@ class PyExecutor: to each request. """ for item in requests: - if item is None: + if item[0] == SHUTDOWN_REQUEST_ID: continue req_id, req = item[:2] py_obj = py_request_objects.get(req_id) @@ -1410,61 +1401,57 @@ class PyExecutor: def _merge_star_attention_requests(self, new_requests: List[ExecutorRequest]): for request in new_requests: - if request is None: + if request[0] == SHUTDOWN_REQUEST_ID: return True for req_item in new_requests: - if _is_executor_request(req_item): - req_id, exe_req, query_token_ids = req_item - ctx_len0 = len(exe_req.input_token_ids) - ctx_blocks, position_blocks, last_block_padding_num = [ - exe_req.input_token_ids - ], [[i for i in range(ctx_len0)]], 0 - ctx_blocks, position_blocks, last_block_padding_num = self._partition_context( - exe_req.input_token_ids) - if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0: - ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num] - position_blocks[-1] = position_blocks[ - -1][:-last_block_padding_num] - #if has query - if query_token_ids: - ctx_blocks.append(query_token_ids) - position_blocks.append([ - i for i in range(ctx_len0, ctx_len0 + - len(query_token_ids)) - ]) + req_id, exe_req, query_token_ids = req_item + ctx_len0 = len(exe_req.input_token_ids) + ctx_blocks, position_blocks, last_block_padding_num = [ + exe_req.input_token_ids + ], [[i for i in range(ctx_len0)]], 0 + ctx_blocks, position_blocks, last_block_padding_num = self._partition_context( + exe_req.input_token_ids) + if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0: + ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num] + position_blocks[-1] = position_blocks[ + -1][:-last_block_padding_num] + #if has query + if query_token_ids: + ctx_blocks.append(query_token_ids) + position_blocks.append([ + i for i in range(ctx_len0, ctx_len0 + len(query_token_ids)) + ]) - # insert the dummy block to align the number of ctx iterations of each rank - block_size = self.dist.cp_config['block_size'] - total_blocks = (ctx_len0 + block_size - 1) // block_size - num_blocks_per_rank = ( - total_blocks + self.dist.cp_size - - 1) // self.dist.cp_size + 1 # 1 for query block - if len(ctx_blocks) == num_blocks_per_rank: - ctx_blocks.insert(1, []) - position_blocks.insert(1, []) - elif len(ctx_blocks) == num_blocks_per_rank + 1: - # anchor + ctx_blocks + qry_block - pass - else: - print( - f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}' - ) - assert False, f'invalid context partition' + # insert the dummy block to align the number of ctx iterations of each rank + block_size = self.dist.cp_config['block_size'] + total_blocks = (ctx_len0 + block_size - 1) // block_size + num_blocks_per_rank = ( + total_blocks + self.dist.cp_size - + 1) // self.dist.cp_size + 1 # 1 for query block + if len(ctx_blocks) == num_blocks_per_rank: + ctx_blocks.insert(1, []) + position_blocks.insert(1, []) + elif len(ctx_blocks) == num_blocks_per_rank + 1: + # anchor + ctx_blocks + qry_block + pass + else: + print( + f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}' + ) + assert False, f'invalid context partition' - # fake data for scheduler - ctx_blocks_list = [0] * (block_size + - self.dist.cp_config['cp_anchor_size']) + # fake data for scheduler + ctx_blocks_list = [0] * (block_size + + self.dist.cp_config['cp_anchor_size']) - req = executor_request_to_llm_request(req_id, exe_req, - ctx_blocks_list) - req.gen_iters = 0 - req.ctx_iters = 0 - req.ctx_blocks = ctx_blocks - req.ctx_position_blocks = position_blocks - req.query_id = query_token_ids - self.active_requests.append(req) - elif _is_cancel_request(req_item): - self.canceled_req_ids.insert(req_item) + req = executor_request_to_llm_request(req_id, exe_req, + ctx_blocks_list) + req.gen_iters = 0 + req.ctx_iters = 0 + req.ctx_blocks = ctx_blocks + req.ctx_position_blocks = position_blocks + req.query_id = query_token_ids + self.active_requests.append(req) return False